Skip to content

Commit

Permalink
Make taxonomy more consistently used (#374)
Browse files Browse the repository at this point in the history
More consistently differentiate between unrefined and refined taxonomy.
* Taxometer requires an unrefined, and errors on a refined one
* Recluster DBScan can take either, but will warn if passed a refined one, and
  not `--no_predictor`. If refinement is needed and the requisite comp and ab
  are not passed, error
* TaxVamb can take either, but warns like DBScan. Does not do addtional check
  for comp and ab, since this is always required for TaxVamb.

Also minor fixes in logging
Also some changes to the DBScan algorithm - the old algo has not yet been
validated, so this is not a high-impact change.
  • Loading branch information
jakobnissen authored Dec 3, 2024
1 parent 6c97ece commit 589194e
Show file tree
Hide file tree
Showing 4 changed files with 224 additions and 131 deletions.
5 changes: 1 addition & 4 deletions .github/workflows/cli_vamb.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ jobs:
cache-dependency-path: '**/pyproject.toml'
- name: Download fixtures
run: |
wget https://www.dropbox.com/scl/fi/xzc0tro7oe6tfm3igygpj/ci_data.zip\?rlkey\=xuv6b5eoynfryp4fba1kfp5jm\&st\=rjb1xccw\&dl\=0 -O ci_data.zip
wget https://www.dropbox.com/scl/fi/10tdf0w0kf70pf46hy8ks/ci_data.zip\?rlkey\=smlcinkesuwiw557zulgbb59l\&st\=hhokiqma\&dl\=0 -O ci_data.zip
unzip -o ci_data.zip
- name: Install dependencies
run: |
Expand Down Expand Up @@ -56,9 +56,6 @@ jobs:
vamb taxometer --outdir outdir_taxometer --fasta catalogue_mock.fna.gz --abundance abundance_mock.npz --taxonomy taxonomy_mock.tsv -pe 10 -pt 10
ls -la outdir_taxometer
cat outdir_taxometer/log.txt
vamb taxometer --outdir outdir_taxometer_pred --fasta catalogue_mock.fna.gz --abundance abundance_mock.npz --taxonomy outdir_taxometer/results_taxometer.tsv -pe 10 -pt 10
ls -la outdir_taxometer_pred
cat outdir_taxometer/log.txt
- name: Run k-means reclustering
run: |
vamb recluster --outdir outdir_recluster --fasta catalogue_mock.fna.gz --abundance abundance_mock.npz --latent_path outdir_taxvamb/vaevae_latent.npz --clusters_path outdir_taxvamb/vaevae_clusters_split.tsv --markers markers_mock.npz --algorithm kmeans --minfasta 200000
Expand Down
207 changes: 136 additions & 71 deletions vamb/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,30 +373,61 @@ def __init__(
self.refcheck = refcheck


class TaxonomyPath:
@classmethod
def from_args(cls, args: argparse.Namespace):
if args.taxonomy is None:
raise argparse.ArgumentTypeError(
"Cannot load taxonomy without specifying --taxonomy"
)
return cls(typeasserted(args.taxonomy, Path))
class TaxonomyBase:
__slots__ = ["path"]

def __init__(self, path: Path):
self.path = check_existing_file(path)
self.path = path

def get_tax_path(self) -> Path:
return self.path

class RefinedTaxonomy(TaxonomyBase):
pass


class UnrefinedTaxonomy(TaxonomyBase):
pass


def get_taxonomy(args: argparse.Namespace) -> Union[RefinedTaxonomy, UnrefinedTaxonomy]:
path = args.taxonomy
if path is None:
raise ValueError(
"Cannot load taxonomy for Taxometer without specifying --taxonomy"
)
with open(check_existing_file(path)) as file:
try:
header = next(file).rstrip("\r\n")
except StopIteration:
header = None

if header is None:
raise ValueError(f'Empty taxonomy path at "{path}"')
elif header == vamb.taxonomy.TAXONOMY_HEADER:
return UnrefinedTaxonomy(path)
elif header == vamb.taxonomy.PREDICTED_TAXONOMY_HEADER:
return RefinedTaxonomy(path)
else:
raise ValueError(
f'ERROR: When reading taxonomy file at "{path}", '
f"the first line was not either {repr(vamb.taxonomy.TAXONOMY_HEADER)} "
f"or {repr(vamb.taxonomy.PREDICTED_TAXONOMY_HEADER)}'"
)


class TaxometerOptions:
@classmethod
def from_args(cls, args: argparse.Namespace):
tax = get_taxonomy(args)
if isinstance(tax, RefinedTaxonomy):
raise ValueError(
f'Attempted to run Taxometer to refine taxonomy at "{args.taxonomy}", '
"but this file appears to already be an output of Taxometer"
)
return cls(
GeneralOptions.from_args(args),
CompositionOptions.from_args(args),
AbundanceOptions.from_args(args),
TaxonomyPath.from_args(args),
tax,
BasicTrainingOptions.from_args_taxometer(args),
typeasserted(args.pred_softmax_threshold, float),
args.ploss,
Expand All @@ -407,7 +438,7 @@ def __init__(
general: GeneralOptions,
composition: CompositionOptions,
abundance: AbundanceOptions,
taxonomy: TaxonomyPath,
taxonomy: UnrefinedTaxonomy,
basic: BasicTrainingOptions,
softmax_threshold: float,
ploss: Union[
Expand All @@ -432,38 +463,6 @@ def __init__(
self.basic = basic


class RefinableTaxonomyOptions:
__slots__ = ["path_or_tax_options"]

@classmethod
def from_args(cls, args: argparse.Namespace):
predict = not typeasserted(args.no_predictor, bool)

# TaxometerOptions have more options, but only the composition and the abundance
# can be omitted, so we only check those here.
if predict:
if not CompositionOptions.are_args_present(
args
) or not AbundanceOptions.are_args_present(args):
raise ValueError(
"If `--no_predictor` is not passed, Taxometer is run to refine taxonomy, "
"and this requires composition input and abundance input to be passed in"
)
return cls(TaxometerOptions.from_args(args))
else:
return cls(TaxonomyPath.from_args(args))

def __init__(self, path_or_tax_options: Union[TaxonomyPath, TaxometerOptions]):
self.path_or_tax_options = path_or_tax_options

def get_tax_path(self) -> Path:
p = self.path_or_tax_options
if isinstance(p, TaxonomyPath):
return p.get_tax_path()
else:
return p.taxonomy.get_tax_path()


class MarkerPath:
def __init__(self, path: Path):
self.path = check_existing_file(path)
Expand Down Expand Up @@ -504,8 +503,50 @@ def __init__(self, clusters: Path):


class DBScanOptions:
def __init__(self, taxonomy_options: RefinableTaxonomyOptions, n_processes: int):
self.taxonomy_options = taxonomy_options
@classmethod
def from_args(cls, args: argparse.Namespace, n_threads: int):
tax = get_taxonomy(args)
predict = not typeasserted(args.no_predictor, bool)
if predict:
if isinstance(tax, RefinedTaxonomy):
logger.warning(
"Flag --no_predictor not set, but the taxonomy passed in "
"on --taxonomy is already refined. Skipped refinement."
)
return cls(tax, n_threads)
else:
if not (
CompositionOptions.are_args_present(args)
and AbundanceOptions.are_args_present(args)
):
raise ValueError(
"Flag --no_predictor is not set, but abundance "
"or composition has not been passed in, so there is no information "
"to refine taxonomy with"
)
tax_options = TaxometerOptions(
GeneralOptions.from_args(args),
CompositionOptions.from_args(args),
AbundanceOptions.from_args(args),
tax,
BasicTrainingOptions.from_args_taxometer(args),
typeasserted(args.pred_softmax_threshold, float),
args.ploss,
)
return cls(tax_options, n_threads)
else:
return cls(tax, n_threads)

def __init__(
self,
ops: Union[
UnrefinedTaxonomy,
RefinedTaxonomy,
TaxometerOptions,
],
n_processes: int,
):
self.taxonomy = ops
self.n_processes = n_processes


Expand Down Expand Up @@ -714,15 +755,35 @@ def from_args(cls, args: argparse.Namespace):
common = BinnerCommonOptions.from_args(args)
basic = BasicTrainingOptions.from_args_vae(args)
vae = VAEOptions.from_args(basic, args)
taxonomy = RefinableTaxonomyOptions.from_args(args)
return cls(common, vae, taxonomy)
taxonomy = get_taxonomy(args)
predict = not typeasserted(args.no_predictor, bool)
if predict:
if isinstance(taxonomy, RefinedTaxonomy):
logger.warning(
"Flag --no_predictor not set, but the taxonomy passed in "
"on --taxonomy is already refined. Skipped refinement."
)
tax = taxonomy
else:
tax = TaxometerOptions(
common.general,
common.comp,
common.abundance,
taxonomy,
basic,
typeasserted(args.pred_softmax_threshold, float),
args.ploss,
)
else:
tax = taxonomy
return cls(common, vae, tax)

# The VAEVAE models share the same settings as the VAE model, so we just use VAEOptions
def __init__(
self,
common: BinnerCommonOptions,
vae: VAEOptions,
taxonomy: RefinableTaxonomyOptions,
taxonomy: Union[RefinedTaxonomy, TaxometerOptions, UnrefinedTaxonomy],
):
self.common = common
self.vae = vae
Expand Down Expand Up @@ -768,15 +829,15 @@ def from_args(cls, args: argparse.Namespace):
)
algorithm = KmeansOptions(clusters)
elif args.algorithm == "dbscan":
tax = RefinableTaxonomyOptions.from_args(args)
algorithm = DBScanOptions(tax, general.n_threads)
algorithm = DBScanOptions.from_args(args, general.n_threads)
else:
assert False # no more algorithms

# Avoid loading composition again if already loaded in DBScanOptions
if isinstance(algorithm, DBScanOptions) and isinstance(
algorithm.taxonomy_options.path_or_tax_options, TaxometerOptions
algorithm.taxonomy, tuple
):
comp = algorithm.taxonomy_options.path_or_tax_options.composition
comp = algorithm.taxonomy[1]
else:
comp = CompositionOptions.from_args(args)

Expand Down Expand Up @@ -1456,24 +1517,26 @@ def run_vaevae(opt: BinTaxVambOptions):
composition.metadata.lengths,
composition.metadata.identifiers,
)
if isinstance(opt.taxonomy.path_or_tax_options, TaxometerOptions):
logger.info("Predicting missing values from taxonomy")
if isinstance(opt.taxonomy, TaxometerOptions):
predicted_contig_taxonomies = predict_taxonomy(
comp_metadata=composition.metadata,
abundance=abundance,
tnfs=tnfs,
lengths=lengths,
out_dir=opt.common.general.out_dir,
taxonomy_options=opt.taxonomy.path_or_tax_options,
taxonomy_options=opt.taxonomy,
cuda=opt.common.general.cuda,
)
contig_taxonomies = predicted_contig_taxonomies.to_taxonomy()
elif isinstance(opt.taxonomy, RefinedTaxonomy):
logger.info("Loading already-refined taxonomy from file")
contig_taxonomies = vamb.taxonomy.Taxonomy.from_refined_file(
opt.taxonomy.path, composition.metadata, False
)
else:
logger.info("Not predicting the taxonomy")
logger.info("Loading unrefined taxonomy from file")
contig_taxonomies = vamb.taxonomy.Taxonomy.from_file(
opt.taxonomy.path_or_tax_options.path,
composition.metadata,
False,
opt.taxonomy.path, composition.metadata, False
)

nodes, ind_nodes, table_parent = vamb.taxvamb_encode.make_graph(
Expand Down Expand Up @@ -1569,9 +1632,6 @@ def run_vaevae(opt: BinTaxVambOptions):
)


# TODO: The whole data flow around predict_taxonomy needs to change.
# Ideally, we should have a "get taxonomy" function that loads, possibly refines,
# and then returns the taxonomy object.
def run_reclustering(opt: ReclusteringOptions):
composition = calc_tnf(
opt.composition,
Expand All @@ -1587,8 +1647,8 @@ def run_reclustering(opt: ReclusteringOptions):

if isinstance(alg, DBScanOptions):
# If we should refine or not.
if isinstance(alg.taxonomy_options.path_or_tax_options, TaxometerOptions):
taxopt = alg.taxonomy_options.path_or_tax_options
if isinstance(alg.taxonomy, TaxometerOptions):
taxopt = alg.taxonomy
abundance = calc_abundance(
taxopt.abundance,
taxopt.general.out_dir,
Expand All @@ -1607,11 +1667,16 @@ def run_reclustering(opt: ReclusteringOptions):
)
taxonomy = predicted_tax.to_taxonomy()
else:
tax_path = alg.taxonomy_options.path_or_tax_options.path
logger.info(f'Loading taxonomy from file "{tax_path}"')
taxonomy = vamb.taxonomy.Taxonomy.from_file(
tax_path, composition.metadata, True
)
logger.info(f'Loading taxonomy from file "{alg.taxonomy.path}"')
if isinstance(alg.taxonomy, UnrefinedTaxonomy):
taxonomy = vamb.taxonomy.Taxonomy.from_file(
alg.taxonomy.path, composition.metadata, True
)
else:
taxonomy = vamb.taxonomy.Taxonomy.from_refined_file(
alg.taxonomy.path, composition.metadata, True
)

instantiated_alg = vamb.reclustering.DBScanAlgorithm(
composition.metadata, taxonomy, opt.general.n_threads
)
Expand Down
Loading

0 comments on commit 589194e

Please sign in to comment.