diff --git a/vamb/__main__.py b/vamb/__main__.py index 51b9b53c..759dc438 100755 --- a/vamb/__main__.py +++ b/vamb/__main__.py @@ -1324,6 +1324,7 @@ def predict_taxonomy( cuda: bool, ) -> vamb.taxonomy.PredictedTaxonomy: begintime = time.time() + logger.info("Predicting taxonomy with Taxometer") taxonomies = vamb.taxonomy.Taxonomy.from_file( taxonomy_options.taxonomy.path, comp_metadata, False @@ -1331,7 +1332,7 @@ def predict_taxonomy( nodes, ind_nodes, table_parent = vamb.taxvamb_encode.make_graph( taxonomies.contig_taxonomies ) - logger.info(f"{len(nodes)} nodes in the graph") + logger.info(f"\t{len(nodes)} nodes in the graph") classes_order: list[str] = [] for i in taxonomies.contig_taxonomies: if i is None or len(i.ranks) == 0: @@ -1606,12 +1607,15 @@ def run_reclustering(opt: ReclusteringOptions): taxonomy = predicted_tax.to_taxonomy() else: tax_path = alg.taxonomy_options.path_or_tax_options.path + logger.info(f'Loading taxonomy from file "{tax_path}"') taxonomy = vamb.taxonomy.Taxonomy.from_file( tax_path, composition.metadata, True ) instantiated_alg = vamb.reclustering.DBScanAlgorithm( composition.metadata, taxonomy, opt.general.n_threads ) + logger.info("Reclustering") + logger.info("\tAlgorithm: DBSCAN") reclustered_contigs = vamb.reclustering.recluster_bins( markers, latent, instantiated_alg ) @@ -1634,12 +1638,17 @@ def run_reclustering(opt: ReclusteringOptions): s.add(vamb.reclustering.ContigId(i)) clusters_as_ids.append(s) instantiated_alg = vamb.reclustering.KmeansAlgorithm( - clusters_as_ids, opt.general.seed, composition.metadata.identifiers + clusters_as_ids, + abs(opt.general.seed) % 4294967295, # Note: sklearn seeds must be uint32 + composition.metadata.lengths, ) + logger.info("Reclustering") + logger.info("\tAlgorithm: KMeans") reclustered_contigs = vamb.reclustering.recluster_bins( markers, latent, instantiated_alg ) + logger.info("\tReclustering complete") identifiers = composition.metadata.identifiers clusters_dict: dict[str, set[str]] = dict() for i, cluster in enumerate(reclustered_contigs): diff --git a/vamb/reclustering.py b/vamb/reclustering.py index 63f9c94d..45a3bdb4 100644 --- a/vamb/reclustering.py +++ b/vamb/reclustering.py @@ -23,15 +23,12 @@ EPS_VALUES = np.arange(0.01, 0.35, 0.02) -# TODO: This might be slightly algorithmically inefficient. The big problem is that it re-heapifies -# the heap whenever a bin is emitted, which makes it O(N^2). -# To solve this, I need some datastructure which is like a heap, but which allows me to update -# arbitrary elements in the heap. -# This can be solved with a MutableBinaryMaxHeap in Julia's DataStructures.jl, for inspiration. +# This is a bottleneck, speed wise. Hence, we do a bunch of tricks to remove the +# obviously bad bins to reduce computational cost. def deduplicate( - scoring: Callable[[set[ContigId]], float], + scoring: Callable[[Iterable[ContigId]], tuple[float, float]], bins: dict[BinId, set[ContigId]], -) -> list[tuple[float, set[ContigId]]]: +) -> list[set[ContigId]]: """ deduplicate(f, bins) @@ -41,14 +38,17 @@ def deduplicate( previously been returned from any other bin. Returns `list[tuple[float, set[ContigId]]]` with scored, disjoint bins. """ - # We continuously mutate `bins`, so we need to make a copy here, so the caller - # doesn't have `bins` mutated inadvertently - bins = {b: s.copy() for (b, s) in bins.items()} + contig_sets = remove_duplicate_bins(bins.values()) + scored_contig_sets = remove_badly_contaminated(scoring, contig_sets) + (to_deduplicate, result) = remove_unambigous_bins(scored_contig_sets) + bins = {BinId(i): b for (i, (b, _)) in enumerate(to_deduplicate)} - result: list[tuple[float, set[ContigId]]] = [] # Use a heap, such that `heappop` will return the best bin # (Heaps in Python are min-heaps, so we use the negative score.) - heap: list[tuple[float, BinId]] = [(-scoring(c), b) for (b, c) in bins.items()] + heap = [ + (-score_from_comp_cont(s), BinId(i)) + for (i, (_, s)) in enumerate(to_deduplicate) + ] heapq.heapify(heap) # When removing the best bin, the contigs in that bin must be removed from all @@ -62,9 +62,9 @@ def deduplicate( bins_of_contig[ci].add(b) while len(heap) > 0: - (neg_score, best_bin) = heapq.heappop(heap) + (_, best_bin) = heapq.heappop(heap) contigs = bins[best_bin] - result.append((-neg_score, contigs)) + result.append(contigs) to_recompute.clear() for contig in contigs: @@ -81,11 +81,13 @@ def deduplicate( del bins[bin] to_recompute.add(bin) - # Remove the bin we picked as the best - del bins[best_bin] + # Remove the bin we picked as the best + del bins[best_bin] # Check this here to skip recomputing scores and re-heapifying, since that # takes time. + # TODO: Could possibly skip this sometimes, if we know that any of the recomputed + # has a score lower than the next from the heap (and the heap top is not to be recomputed) if len(to_recompute) > 0: heap = [(s, b) for (s, b) in heap if b not in to_recompute] for bin in to_recompute: @@ -94,18 +96,74 @@ def deduplicate( # These empty bins should be discarded c = bins.get(bin, None) if c is not None: - heap.append((-scoring(c), bin)) + heap.append((-score_from_comp_cont(scoring(c)), bin)) heapq.heapify(heap) return result +def remove_duplicate_bins(sets: Iterable[set[ContigId]]) -> list[set[ContigId]]: + seen_sets: set[frozenset[ContigId]] = set() + seen_singletons: set[ContigId] = set() + for contig_set in sets: + if len(contig_set) == 1: + seen_singletons.add(next(iter(contig_set))) + else: + fz = frozenset(contig_set) + if fz not in seen_sets: + seen_sets.add(fz) + + result: list[set[ContigId]] = [{s} for s in seen_singletons] + for st in seen_sets: + result.append(set(st)) + return result + + +def remove_badly_contaminated( + scorer: Callable[[Iterable[ContigId]], tuple[float, float]], + sets: Iterable[set[ContigId]], +) -> list[tuple[set[ContigId], tuple[float, float]]]: + result: list[tuple[set[ContigId], tuple[float, float]]] = [] + max_contamination = 1.0 + for contig_set in sets: + (completeness, contamination) = scorer(contig_set) + if contamination <= max_contamination: + result.append((contig_set, (completeness, contamination))) + return result + + +def remove_unambigous_bins( + sets: list[tuple[set[ContigId], tuple[float, float]]], +) -> tuple[list[tuple[set[ContigId], tuple[float, float]]], list[set[ContigId]]]: + """Remove all bins from d for which all the contigs are only present in that one bin, + and put them in the returned list. + These contigs have a trivial, unambiguous assignment. + """ + in_single_bin: dict[ContigId, bool] = dict() + for contig_set, _ in sets: + for contig in contig_set: + existing = in_single_bin.get(contig) + if existing is None: + in_single_bin[contig] = True + elif existing is True: + in_single_bin[contig] = False + to_deduplicate: list[tuple[set[ContigId], tuple[float, float]]] = [] + unambiguous: list[set[ContigId]] = [] + for contig_set, scores in sets: + if all(in_single_bin[c] for c in contig_set): + unambiguous.append(contig_set) + else: + to_deduplicate.append((contig_set, scores)) + return (to_deduplicate, unambiguous) + + class KmeansAlgorithm: "Arguments needed specifically when using the KMeans algorithm" def __init__( self, clusters: list[set[ContigId]], random_seed: int, contiglengths: np.ndarray ): + assert np.issubdtype(contiglengths.dtype, np.integer) self.contiglengths = contiglengths self.clusters = clusters self.random_seed = random_seed @@ -138,6 +196,9 @@ def recluster_bins( latent: np.ndarray, algorithm: Union[KmeansAlgorithm, DBScanAlgorithm], ) -> list[set[ContigId]]: + assert np.issubdtype(algorithm.contiglengths.dtype, np.integer) + assert np.issubdtype(latent.dtype, np.floating) + if not (len(algorithm.contiglengths) == markers.n_seqs == len(latent)): raise ValueError( "Number of elements in contiglengths, markers and latent must match" @@ -171,6 +232,9 @@ def recluster_kmeans( random_seed: int, ) -> list[set[ContigId]]: assert len(latent) == len(contiglengths) == markers.n_seqs + assert np.issubdtype(contiglengths.dtype, np.integer) + assert np.issubdtype(latent.dtype, np.floating) + assert latent.ndim == 2 result: list[set[ContigId]] = [] indices_by_medoid: dict[int, set[ContigId]] = defaultdict(set) @@ -203,7 +267,7 @@ def recluster_kmeans( median_counts, ) - cluster_indices = np.ndarray(list(cluster)) + cluster_indices = np.array(list(cluster)) cluter_latent = latent[cluster_indices] cluster_lengths = contiglengths[cluster_indices] seed_latent = latent[seeds] @@ -271,15 +335,18 @@ def get_kmeans_seeds( return result -# An arbitrary score of a bin, where higher numbers is better. -# completeness - 5 * contamination is used by the CheckM group as a heuristic. -def score_bin(bin: set[ContigId], markers: Markers) -> float: - counts = count_markers(bin, markers) +def get_completeness_contamination(counts: np.ndarray) -> tuple[float, float]: n_total = counts.sum() n_unique = (counts > 0).sum() completeness = n_unique / len(counts) contamination = (n_total - n_unique) / len(counts) - return completeness - 5 * contamination + return (completeness, contamination) + + +# An arbitrary score of a bin, where higher numbers is better. +# completeness - 5 * contamination is used by the CheckM group as a heuristic. +def score_from_comp_cont(comp_cont: tuple[float, float]) -> float: + return comp_cont[0] - 5 * comp_cont[1] def recluster_dbscan( @@ -342,8 +409,9 @@ def dbscan_genus( # present in one output bin, using the scoring function to greedily # output the best of the redundant bins. bin_dict = {BinId(i): c for (i, c) in enumerate(redundant_bins)} - scored_deduplicated = deduplicate(lambda x: score_bin(x, markers), bin_dict) - return [b for (_, b) in scored_deduplicated] + return deduplicate( + lambda x: get_completeness_contamination(count_markers(x, markers)), bin_dict + ) def group_indices_by_genus( @@ -358,4 +426,4 @@ def group_indices_by_genus( # Currently, we just skip it here if genus is not None: by_genus[genus].append(ContigId(i)) - return {g: np.ndarray(i, dtype=np.int32) for (g, i) in by_genus.items()} + return {g: np.array(i, dtype=np.int32) for (g, i) in by_genus.items()}