Skip to content

Commit

Permalink
Make reclustering work
Browse files Browse the repository at this point in the history
  • Loading branch information
jakobnissen committed Sep 4, 2024
1 parent 9749070 commit dbeb097
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 28 deletions.
13 changes: 11 additions & 2 deletions vamb/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1324,14 +1324,15 @@ def predict_taxonomy(
cuda: bool,
) -> vamb.taxonomy.PredictedTaxonomy:
begintime = time.time()
logger.info("Predicting taxonomy with Taxometer")

taxonomies = vamb.taxonomy.Taxonomy.from_file(
taxonomy_options.taxonomy.path, comp_metadata, False
)
nodes, ind_nodes, table_parent = vamb.taxvamb_encode.make_graph(
taxonomies.contig_taxonomies
)
logger.info(f"{len(nodes)} nodes in the graph")
logger.info(f"\t{len(nodes)} nodes in the graph")
classes_order: list[str] = []
for i in taxonomies.contig_taxonomies:
if i is None or len(i.ranks) == 0:
Expand Down Expand Up @@ -1606,12 +1607,15 @@ def run_reclustering(opt: ReclusteringOptions):
taxonomy = predicted_tax.to_taxonomy()
else:
tax_path = alg.taxonomy_options.path_or_tax_options.path
logger.info(f'Loading taxonomy from file "{tax_path}"')
taxonomy = vamb.taxonomy.Taxonomy.from_file(
tax_path, composition.metadata, True
)
instantiated_alg = vamb.reclustering.DBScanAlgorithm(
composition.metadata, taxonomy, opt.general.n_threads
)
logger.info("Reclustering")
logger.info("\tAlgorithm: DBSCAN")
reclustered_contigs = vamb.reclustering.recluster_bins(
markers, latent, instantiated_alg
)
Expand All @@ -1634,12 +1638,17 @@ def run_reclustering(opt: ReclusteringOptions):
s.add(vamb.reclustering.ContigId(i))
clusters_as_ids.append(s)
instantiated_alg = vamb.reclustering.KmeansAlgorithm(
clusters_as_ids, opt.general.seed, composition.metadata.identifiers
clusters_as_ids,
abs(opt.general.seed) % 4294967295, # Note: sklearn seeds must be uint32
composition.metadata.lengths,
)
logger.info("Reclustering")
logger.info("\tAlgorithm: KMeans")
reclustered_contigs = vamb.reclustering.recluster_bins(
markers, latent, instantiated_alg
)

logger.info("\tReclustering complete")
identifiers = composition.metadata.identifiers
clusters_dict: dict[str, set[str]] = dict()
for i, cluster in enumerate(reclustered_contigs):
Expand Down
120 changes: 94 additions & 26 deletions vamb/reclustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,12 @@
EPS_VALUES = np.arange(0.01, 0.35, 0.02)


# TODO: This might be slightly algorithmically inefficient. The big problem is that it re-heapifies
# the heap whenever a bin is emitted, which makes it O(N^2).
# To solve this, I need some datastructure which is like a heap, but which allows me to update
# arbitrary elements in the heap.
# This can be solved with a MutableBinaryMaxHeap in Julia's DataStructures.jl, for inspiration.
# This is a bottleneck, speed wise. Hence, we do a bunch of tricks to remove the
# obviously bad bins to reduce computational cost.
def deduplicate(
scoring: Callable[[set[ContigId]], float],
scoring: Callable[[Iterable[ContigId]], tuple[float, float]],
bins: dict[BinId, set[ContigId]],
) -> list[tuple[float, set[ContigId]]]:
) -> list[set[ContigId]]:
"""
deduplicate(f, bins)
Expand All @@ -41,14 +38,17 @@ def deduplicate(
previously been returned from any other bin.
Returns `list[tuple[float, set[ContigId]]]` with scored, disjoint bins.
"""
# We continuously mutate `bins`, so we need to make a copy here, so the caller
# doesn't have `bins` mutated inadvertently
bins = {b: s.copy() for (b, s) in bins.items()}
contig_sets = remove_duplicate_bins(bins.values())
scored_contig_sets = remove_badly_contaminated(scoring, contig_sets)
(to_deduplicate, result) = remove_unambigous_bins(scored_contig_sets)
bins = {BinId(i): b for (i, (b, _)) in enumerate(to_deduplicate)}

result: list[tuple[float, set[ContigId]]] = []
# Use a heap, such that `heappop` will return the best bin
# (Heaps in Python are min-heaps, so we use the negative score.)
heap: list[tuple[float, BinId]] = [(-scoring(c), b) for (b, c) in bins.items()]
heap = [
(-score_from_comp_cont(s), BinId(i))
for (i, (_, s)) in enumerate(to_deduplicate)
]
heapq.heapify(heap)

# When removing the best bin, the contigs in that bin must be removed from all
Expand All @@ -62,9 +62,9 @@ def deduplicate(
bins_of_contig[ci].add(b)

while len(heap) > 0:
(neg_score, best_bin) = heapq.heappop(heap)
(_, best_bin) = heapq.heappop(heap)
contigs = bins[best_bin]
result.append((-neg_score, contigs))
result.append(contigs)

to_recompute.clear()
for contig in contigs:
Expand All @@ -81,11 +81,13 @@ def deduplicate(
del bins[bin]
to_recompute.add(bin)

# Remove the bin we picked as the best
del bins[best_bin]
# Remove the bin we picked as the best
del bins[best_bin]

# Check this here to skip recomputing scores and re-heapifying, since that
# takes time.
# TODO: Could possibly skip this sometimes, if we know that any of the recomputed
# has a score lower than the next from the heap (and the heap top is not to be recomputed)
if len(to_recompute) > 0:
heap = [(s, b) for (s, b) in heap if b not in to_recompute]
for bin in to_recompute:
Expand All @@ -94,18 +96,74 @@ def deduplicate(
# These empty bins should be discarded
c = bins.get(bin, None)
if c is not None:
heap.append((-scoring(c), bin))
heap.append((-score_from_comp_cont(scoring(c)), bin))
heapq.heapify(heap)

return result


def remove_duplicate_bins(sets: Iterable[set[ContigId]]) -> list[set[ContigId]]:
seen_sets: set[frozenset[ContigId]] = set()
seen_singletons: set[ContigId] = set()
for contig_set in sets:
if len(contig_set) == 1:
seen_singletons.add(next(iter(contig_set)))
else:
fz = frozenset(contig_set)
if fz not in seen_sets:
seen_sets.add(fz)

result: list[set[ContigId]] = [{s} for s in seen_singletons]
for st in seen_sets:
result.append(set(st))
return result


def remove_badly_contaminated(
scorer: Callable[[Iterable[ContigId]], tuple[float, float]],
sets: Iterable[set[ContigId]],
) -> list[tuple[set[ContigId], tuple[float, float]]]:
result: list[tuple[set[ContigId], tuple[float, float]]] = []
max_contamination = 1.0
for contig_set in sets:
(completeness, contamination) = scorer(contig_set)
if contamination <= max_contamination:
result.append((contig_set, (completeness, contamination)))
return result


def remove_unambigous_bins(
sets: list[tuple[set[ContigId], tuple[float, float]]],
) -> tuple[list[tuple[set[ContigId], tuple[float, float]]], list[set[ContigId]]]:
"""Remove all bins from d for which all the contigs are only present in that one bin,
and put them in the returned list.
These contigs have a trivial, unambiguous assignment.
"""
in_single_bin: dict[ContigId, bool] = dict()
for contig_set, _ in sets:
for contig in contig_set:
existing = in_single_bin.get(contig)
if existing is None:
in_single_bin[contig] = True
elif existing is True:
in_single_bin[contig] = False
to_deduplicate: list[tuple[set[ContigId], tuple[float, float]]] = []
unambiguous: list[set[ContigId]] = []
for contig_set, scores in sets:
if all(in_single_bin[c] for c in contig_set):
unambiguous.append(contig_set)
else:
to_deduplicate.append((contig_set, scores))
return (to_deduplicate, unambiguous)


class KmeansAlgorithm:
"Arguments needed specifically when using the KMeans algorithm"

def __init__(
self, clusters: list[set[ContigId]], random_seed: int, contiglengths: np.ndarray
):
assert np.issubdtype(contiglengths.dtype, np.integer)
self.contiglengths = contiglengths
self.clusters = clusters
self.random_seed = random_seed
Expand Down Expand Up @@ -138,6 +196,9 @@ def recluster_bins(
latent: np.ndarray,
algorithm: Union[KmeansAlgorithm, DBScanAlgorithm],
) -> list[set[ContigId]]:
assert np.issubdtype(algorithm.contiglengths.dtype, np.integer)
assert np.issubdtype(latent.dtype, np.floating)

if not (len(algorithm.contiglengths) == markers.n_seqs == len(latent)):
raise ValueError(
"Number of elements in contiglengths, markers and latent must match"
Expand Down Expand Up @@ -171,6 +232,9 @@ def recluster_kmeans(
random_seed: int,
) -> list[set[ContigId]]:
assert len(latent) == len(contiglengths) == markers.n_seqs
assert np.issubdtype(contiglengths.dtype, np.integer)
assert np.issubdtype(latent.dtype, np.floating)
assert latent.ndim == 2

result: list[set[ContigId]] = []
indices_by_medoid: dict[int, set[ContigId]] = defaultdict(set)
Expand Down Expand Up @@ -203,7 +267,7 @@ def recluster_kmeans(
median_counts,
)

cluster_indices = np.ndarray(list(cluster))
cluster_indices = np.array(list(cluster))
cluter_latent = latent[cluster_indices]
cluster_lengths = contiglengths[cluster_indices]
seed_latent = latent[seeds]
Expand Down Expand Up @@ -271,15 +335,18 @@ def get_kmeans_seeds(
return result


# An arbitrary score of a bin, where higher numbers is better.
# completeness - 5 * contamination is used by the CheckM group as a heuristic.
def score_bin(bin: set[ContigId], markers: Markers) -> float:
counts = count_markers(bin, markers)
def get_completeness_contamination(counts: np.ndarray) -> tuple[float, float]:
n_total = counts.sum()
n_unique = (counts > 0).sum()
completeness = n_unique / len(counts)
contamination = (n_total - n_unique) / len(counts)
return completeness - 5 * contamination
return (completeness, contamination)


# An arbitrary score of a bin, where higher numbers is better.
# completeness - 5 * contamination is used by the CheckM group as a heuristic.
def score_from_comp_cont(comp_cont: tuple[float, float]) -> float:
return comp_cont[0] - 5 * comp_cont[1]


def recluster_dbscan(
Expand Down Expand Up @@ -342,8 +409,9 @@ def dbscan_genus(
# present in one output bin, using the scoring function to greedily
# output the best of the redundant bins.
bin_dict = {BinId(i): c for (i, c) in enumerate(redundant_bins)}
scored_deduplicated = deduplicate(lambda x: score_bin(x, markers), bin_dict)
return [b for (_, b) in scored_deduplicated]
return deduplicate(
lambda x: get_completeness_contamination(count_markers(x, markers)), bin_dict
)


def group_indices_by_genus(
Expand All @@ -358,4 +426,4 @@ def group_indices_by_genus(
# Currently, we just skip it here
if genus is not None:
by_genus[genus].append(ContigId(i))
return {g: np.ndarray(i, dtype=np.int32) for (g, i) in by_genus.items()}
return {g: np.array(i, dtype=np.int32) for (g, i) in by_genus.items()}

0 comments on commit dbeb097

Please sign in to comment.