Skip to content

Commit

Permalink
Resolve bin name clashes between Z and Y space in AAE (#382)
Browse files Browse the repository at this point in the history
The AVAMB ensemble model produces bins from the Z space, the Y space and the
ordinary VAE model. When the FASTA bins are written with the --minfasta flag,
the bins derived from the three models would erroneously  overwrite each other.

Now, the written bins may be given an optional prefix to avoid name clashes.
This prefix is only used when writing bins from the AAE Z and Y spaces - the
VAE and VAEVAE models use no prefix.

So, now the bins created using the --minfasta flag with the AAE model will be
called e.g. `z_S5C9` or `y_S12C51402`.
  • Loading branch information
jakobnissen authored Jan 8, 2025
1 parent 77e1cba commit 905115c
Showing 1 changed file with 22 additions and 2 deletions.
24 changes: 22 additions & 2 deletions vamb/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1167,6 +1167,7 @@ def cluster_and_write_files(
cuda: bool,
base_clusters_name: str, # e.g. /foo/bar/vae -> /foo/bar/vae_unsplit.tsv
fasta_output: Optional[FastaOutput],
bin_prefix: Optional[str], # see write_clusters_and_bins
):
begintime = time.time()
# Create cluser iterator
Expand Down Expand Up @@ -1220,6 +1221,7 @@ def cluster_and_write_files(

write_clusters_and_bins(
fasta_output,
bin_prefix,
binsplitter,
base_clusters_name,
cluster_dict,
Expand All @@ -1231,6 +1233,11 @@ def cluster_and_write_files(

def write_clusters_and_bins(
fasta_output: Optional[FastaOutput],
# If `x` and not None, clusters will be renamed `x` + old_name.
# This is necessary since for the AAE, we may need to write bins
# from three latent spaces into the same directory, and the names
# must not clash.
bin_prefix: Optional[str],
binsplitter: vamb.vambtools.BinSplitter,
base_clusters_name: str, # e.g. /foo/bar/vae -> /foo/bar/vae_unsplit.tsv
clusters: dict[str, set[str]],
Expand Down Expand Up @@ -1269,7 +1276,8 @@ def write_clusters_and_bins(
sizeof = dict(zip(sequence_names, sequence_lens))
for binname, contigs in clusters.items():
if sum(sizeof[c] for c in contigs) >= fasta_output.min_fasta_size:
filtered_clusters[binname] = contigs
new_name = binname if bin_prefix is None else bin_prefix + binname
filtered_clusters[new_name] = contigs

with vamb.vambtools.Reader(fasta_output.existing_fasta_path.path) as file:
vamb.vambtools.write_bins(
Expand Down Expand Up @@ -1321,6 +1329,7 @@ def run_bin_default(opt: BinDefaultOptions):
opt.common.general.cuda,
str(opt.common.general.out_dir.joinpath("vae_clusters")),
FastaOutput.try_from_common(opt.common),
None,
)
del latent

Expand Down Expand Up @@ -1351,6 +1360,12 @@ def run_bin_aae(opt: BinAvambOptions):
comp_metadata = composition.metadata
del composition, abundance
assert comp_metadata.nseqs == len(latent_z)
# Cluster and output the Z clusters
# This function calls write_clusters_and_bins,
# but also does the actual clustering and writes cluster metadata.
# This does not apply to the aae_y clusters, since their cluster label
# can be extracted directly from the latent space without clustering,
# and hence below, `write_clusters_and_bins` is called directly instead.
cluster_and_write_files(
opt.common.clustering,
opt.common.output.binsplitter,
Expand All @@ -1361,13 +1376,16 @@ def run_bin_aae(opt: BinAvambOptions):
opt.common.general.cuda,
str(opt.common.general.out_dir.joinpath("aae_z_clusters")),
FastaOutput.try_from_common(opt.common),
"z_",
)
del latent_z

# We enforce this in the VAEAAEOptions constructor, see comment there
# Cluster and output the Y clusters
assert opt.common.clustering.max_clusters is None
write_clusters_and_bins(
FastaOutput.try_from_common(opt.common),
"y_",
binsplitter=opt.common.output.binsplitter,
base_clusters_name=str(opt.common.general.out_dir.joinpath("aae_y_clusters")),
clusters=clusters_y_dict,
Expand Down Expand Up @@ -1629,6 +1647,7 @@ def run_vaevae(opt: BinTaxVambOptions):
opt.common.general.cuda,
str(opt.common.general.out_dir.joinpath("vaevae_clusters")),
FastaOutput.try_from_common(opt.common),
None,
)


Expand Down Expand Up @@ -1732,6 +1751,7 @@ def run_reclustering(opt: ReclusteringOptions):

write_clusters_and_bins(
fasta_output,
None,
opt.output.binsplitter,
str(opt.general.out_dir.joinpath("clusters_reclustered")),
clusters_dict,
Expand Down Expand Up @@ -1982,7 +2002,7 @@ def add_vae_arguments(subparser: argparse.ArgumentParser):
"-r",
dest="lrate",
metavar="",
type=Optional[float],
type=float,
default=None,
help=argparse.SUPPRESS,
)
Expand Down

0 comments on commit 905115c

Please sign in to comment.