Skip to content

Commit

Permalink
Streamline Taxonomic input
Browse files Browse the repository at this point in the history
More consistently differentiate between unrefined and refined taxonomy.
* Taxometer requires an unrefined, and errors on a refined one
* Recluster DBScan can take either, but will warn if passed a refined one, and
  not `--no_predictor`. If refinement is needed and the requisite comp and ab
  are not passed, error
* TaxVamb can take either, but warns like DBScan. Does not do addtional check
  for comp and ab, since this is always required for TaxVamb.
  • Loading branch information
jakobnissen committed Dec 3, 2024
1 parent 90e195e commit 104dfba
Show file tree
Hide file tree
Showing 2 changed files with 155 additions and 85 deletions.
207 changes: 136 additions & 71 deletions vamb/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,30 +373,61 @@ def __init__(
self.refcheck = refcheck


class TaxonomyPath:
@classmethod
def from_args(cls, args: argparse.Namespace):
if args.taxonomy is None:
raise argparse.ArgumentTypeError(
"Cannot load taxonomy without specifying --taxonomy"
)
return cls(typeasserted(args.taxonomy, Path))
class TaxonomyBase:
__slots__ = ["path"]

def __init__(self, path: Path):
self.path = check_existing_file(path)
self.path = path

def get_tax_path(self) -> Path:
return self.path

class RefinedTaxonomy(TaxonomyBase):
pass


class UnrefinedTaxonomy(TaxonomyBase):
pass


def get_taxonomy(args: argparse.Namespace) -> Union[RefinedTaxonomy, UnrefinedTaxonomy]:
path = args.taxonomy
if path is None:
raise ValueError(
"Cannot load taxonomy for Taxometer without specifying --taxonomy"
)
with open(check_existing_file(path)) as file:
try:
header = next(file)
except StopIteration:
header = None

if header is None:
raise ValueError(f'Empty taxonomy path at "{path}"')
elif header == vamb.taxonomy.TAXONOMY_HEADER:
return UnrefinedTaxonomy(path)
elif header == vamb.taxonomy.PREDICTED_TAXONOMY_HEADER:
return RefinedTaxonomy(path)
else:
raise ValueError(
f'ERROR: When reading taxonomy file at "{path}", '
f"the first line was not either {repr(vamb.taxonomy.TAXONOMY_HEADER)} "
f"or {repr(vamb.taxonomy.TAXONOMY_HEADER)}'"
)


class TaxometerOptions:
@classmethod
def from_args(cls, args: argparse.Namespace):
tax = get_taxonomy(args.taxonomy)
if isinstance(tax, RefinedTaxonomy):
raise ValueError(
f'Attempted to run Taxometer to refine taxonomy at "{args.taxonomy}", '
"but this file appears to already be an output of Taxometer"
)
return cls(
GeneralOptions.from_args(args),
CompositionOptions.from_args(args),
AbundanceOptions.from_args(args),
TaxonomyPath.from_args(args),
tax,
BasicTrainingOptions.from_args_taxometer(args),
typeasserted(args.pred_softmax_threshold, float),
args.ploss,
Expand All @@ -407,7 +438,7 @@ def __init__(
general: GeneralOptions,
composition: CompositionOptions,
abundance: AbundanceOptions,
taxonomy: TaxonomyPath,
taxonomy: UnrefinedTaxonomy,
basic: BasicTrainingOptions,
softmax_threshold: float,
ploss: Union[
Expand All @@ -432,38 +463,6 @@ def __init__(
self.basic = basic


class RefinableTaxonomyOptions:
__slots__ = ["path_or_tax_options"]

@classmethod
def from_args(cls, args: argparse.Namespace):
predict = not typeasserted(args.no_predictor, bool)

# TaxometerOptions have more options, but only the composition and the abundance
# can be omitted, so we only check those here.
if predict:
if not CompositionOptions.are_args_present(
args
) or not AbundanceOptions.are_args_present(args):
raise ValueError(
"If `--no_predictor` is not passed, Taxometer is run to refine taxonomy, "
"and this requires composition input and abundance input to be passed in"
)
return cls(TaxometerOptions.from_args(args))
else:
return cls(TaxonomyPath.from_args(args))

def __init__(self, path_or_tax_options: Union[TaxonomyPath, TaxometerOptions]):
self.path_or_tax_options = path_or_tax_options

def get_tax_path(self) -> Path:
p = self.path_or_tax_options
if isinstance(p, TaxonomyPath):
return p.get_tax_path()
else:
return p.taxonomy.get_tax_path()


class MarkerPath:
def __init__(self, path: Path):
self.path = check_existing_file(path)
Expand Down Expand Up @@ -504,8 +503,50 @@ def __init__(self, clusters: Path):


class DBScanOptions:
def __init__(self, taxonomy_options: RefinableTaxonomyOptions, n_processes: int):
self.taxonomy_options = taxonomy_options
@classmethod
def from_args(cls, args: argparse.Namespace, n_threads: int):
tax = get_taxonomy(args)
predict = not typeasserted(args.no_predictor, bool)
if predict:
if isinstance(tax, RefinedTaxonomy):
logger.warning(
"Flag --no_predictor not set, but the taxonomy passed in "
"on --taxonomy is already refined. Skipped refinement."
)
return cls(tax, n_threads)
else:
if not (
CompositionOptions.are_args_present(args)
and AbundanceOptions.are_args_present(args)
):
raise ValueError(
"Flag --no_predictor is not set, but abundance "
"or composition has not been passed in, so there is no information "
"to refine taxonomy with"
)
tax_options = TaxometerOptions(
GeneralOptions.from_args(args),
CompositionOptions.from_args(args),
AbundanceOptions.from_args(args),
tax,
BasicTrainingOptions.from_args_taxometer(args),
typeasserted(args.pred_softmax_threshold, float),
args.ploss,
)
return cls(tax_options, n_threads)
else:
return cls(tax, n_threads)

def __init__(
self,
ops: Union[
UnrefinedTaxonomy,
RefinedTaxonomy,
TaxometerOptions,
],
n_processes: int,
):
self.taxonomy = ops
self.n_processes = n_processes


Expand Down Expand Up @@ -714,15 +755,35 @@ def from_args(cls, args: argparse.Namespace):
common = BinnerCommonOptions.from_args(args)
basic = BasicTrainingOptions.from_args_vae(args)
vae = VAEOptions.from_args(basic, args)
taxonomy = RefinableTaxonomyOptions.from_args(args)
return cls(common, vae, taxonomy)
taxonomy = get_taxonomy(args)
predict = not typeasserted(args.no_predictor, bool)
if predict:
if isinstance(taxonomy, RefinedTaxonomy):
logger.warning(
"Flag --no_predictor not set, but the taxonomy passed in "
"on --taxonomy is already refined. Skipped refinement."
)
tax = taxonomy
else:
tax = TaxometerOptions(
common.general,
common.comp,
common.abundance,
taxonomy,
basic,
typeasserted(args.pred_softmax_threshold, float),
args.ploss,
)
else:
tax = taxonomy
return cls(common, vae, tax)

# The VAEVAE models share the same settings as the VAE model, so we just use VAEOptions
def __init__(
self,
common: BinnerCommonOptions,
vae: VAEOptions,
taxonomy: RefinableTaxonomyOptions,
taxonomy: Union[RefinedTaxonomy, TaxometerOptions, UnrefinedTaxonomy],
):
self.common = common
self.vae = vae
Expand Down Expand Up @@ -768,15 +829,15 @@ def from_args(cls, args: argparse.Namespace):
)
algorithm = KmeansOptions(clusters)
elif args.algorithm == "dbscan":
tax = RefinableTaxonomyOptions.from_args(args)
algorithm = DBScanOptions(tax, general.n_threads)
algorithm = DBScanOptions.from_args(args, general.n_threads)
else:
assert False # no more algorithms

# Avoid loading composition again if already loaded in DBScanOptions
if isinstance(algorithm, DBScanOptions) and isinstance(
algorithm.taxonomy_options.path_or_tax_options, TaxometerOptions
algorithm.taxonomy, tuple
):
comp = algorithm.taxonomy_options.path_or_tax_options.composition
comp = algorithm.taxonomy[1]
else:
comp = CompositionOptions.from_args(args)

Expand Down Expand Up @@ -1456,24 +1517,26 @@ def run_vaevae(opt: BinTaxVambOptions):
composition.metadata.lengths,
composition.metadata.identifiers,
)
if isinstance(opt.taxonomy.path_or_tax_options, TaxometerOptions):
logger.info("Predicting missing values from taxonomy")
if isinstance(opt.taxonomy, TaxometerOptions):
predicted_contig_taxonomies = predict_taxonomy(
comp_metadata=composition.metadata,
abundance=abundance,
tnfs=tnfs,
lengths=lengths,
out_dir=opt.common.general.out_dir,
taxonomy_options=opt.taxonomy.path_or_tax_options,
taxonomy_options=opt.taxonomy,
cuda=opt.common.general.cuda,
)
contig_taxonomies = predicted_contig_taxonomies.to_taxonomy()
elif isinstance(opt.taxonomy, RefinedTaxonomy):
logger.info("Loading already-refined taxonomy from file")
contig_taxonomies = vamb.taxonomy.Taxonomy.from_refined_file(
opt.taxonomy.path, composition.metadata, False
)
else:
logger.info("Not predicting the taxonomy")
logger.info("Loading unrefined taxonomy from file")
contig_taxonomies = vamb.taxonomy.Taxonomy.from_file(
opt.taxonomy.path_or_tax_options.path,
composition.metadata,
False,
opt.taxonomy.path, composition.metadata, False
)

nodes, ind_nodes, table_parent = vamb.taxvamb_encode.make_graph(
Expand Down Expand Up @@ -1569,9 +1632,6 @@ def run_vaevae(opt: BinTaxVambOptions):
)


# TODO: The whole data flow around predict_taxonomy needs to change.
# Ideally, we should have a "get taxonomy" function that loads, possibly refines,
# and then returns the taxonomy object.
def run_reclustering(opt: ReclusteringOptions):
composition = calc_tnf(
opt.composition,
Expand All @@ -1587,8 +1647,8 @@ def run_reclustering(opt: ReclusteringOptions):

if isinstance(alg, DBScanOptions):
# If we should refine or not.
if isinstance(alg.taxonomy_options.path_or_tax_options, TaxometerOptions):
taxopt = alg.taxonomy_options.path_or_tax_options
if isinstance(alg.taxonomy, TaxometerOptions):
taxopt = alg.taxonomy
abundance = calc_abundance(
taxopt.abundance,
taxopt.general.out_dir,
Expand All @@ -1607,11 +1667,16 @@ def run_reclustering(opt: ReclusteringOptions):
)
taxonomy = predicted_tax.to_taxonomy()
else:
tax_path = alg.taxonomy_options.path_or_tax_options.path
logger.info(f'Loading taxonomy from file "{tax_path}"')
taxonomy = vamb.taxonomy.Taxonomy.from_file(
tax_path, composition.metadata, True
)
logger.info(f'Loading taxonomy from file "{alg.taxonomy.path}"')
if isinstance(alg.taxonomy, UnrefinedTaxonomy):
taxonomy = vamb.taxonomy.Taxonomy.from_file(
alg.taxonomy.path, composition.metadata, True
)
else:
taxonomy = vamb.taxonomy.Taxonomy.from_refined_file(
alg.taxonomy.path, composition.metadata, True
)

instantiated_alg = vamb.reclustering.DBScanAlgorithm(
composition.metadata, taxonomy, opt.general.n_threads
)
Expand Down
33 changes: 19 additions & 14 deletions vamb/taxonomy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
from vamb.parsecontigs import CompositionMetaData
import numpy as np

TAXONOMY_HEADER = "contigs\tpredictions"
PREDICTED_TAXONOMY_HEADER = "contigs\tpredictions\tscores"


class ContigTaxonomy:
"""
Expand Down Expand Up @@ -56,6 +59,14 @@ def from_file(
observed = cls.parse_tax_file(tax_file, is_canonical)
return cls.from_observed(observed, metadata, is_canonical)

@classmethod
def from_refined_file(
cls, tax_file: Path, metadata: CompositionMetaData, is_canonical: bool
):
observed = PredictedTaxonomy.parse_tax_file(tax_file, is_canonical)
observed = [(name, tax.contig_taxonomy) for (name, tax) in observed]
return cls.from_observed(observed, metadata, is_canonical)

@classmethod
def from_observed(
cls,
Expand Down Expand Up @@ -101,9 +112,9 @@ def parse_tax_file(
with open(path) as file:
result: list[tuple[str, ContigTaxonomy]] = []
header = next(file, None)
if header is None or not header.startswith("contigs\tpredictions"):
if header is None or not header.startswith(TAXONOMY_HEADER):
raise ValueError(
'In taxonomy file, expected header to begin with "contigs\\tpredictions"'
f"In taxonomy file, expected header to begin with {repr(TAXONOMY_HEADER)}"
)
for line in file:
(contigname, taxonomy, *_) = line.split("\t")
Expand Down Expand Up @@ -160,31 +171,25 @@ def nseqs(self) -> int:

@staticmethod
def parse_tax_file(
path: Path, minlen: int, force_canonical: bool
) -> list[tuple[str, int, PredictedContigTaxonomy]]:
path: Path, force_canonical: bool
) -> list[tuple[str, PredictedContigTaxonomy]]:
with open(path) as file:
result: list[tuple[str, int, PredictedContigTaxonomy]] = []
result: list[tuple[str, PredictedContigTaxonomy]] = []
lines = filter(None, map(str.rstrip, file))
header = next(lines, None)
if header is None or not header.startswith(
"contigs\tpredictions\tlengths\tscores"
):
if header is None or not header.startswith(PREDICTED_TAXONOMY_HEADER):
raise ValueError(
'In predicted taxonomy file, expected header to begin with "contigs\\tpredictions\\tlengths\\tscores"'
f"In predicted taxonomy file, expected header to begin with {repr(PREDICTED_TAXONOMY_HEADER)}"
)
for line in lines:
(contigname, taxonomy, lengthstr, scores, *_) = line.split("\t")
length = int(lengthstr)
if length < minlen:
continue
(contigname, taxonomy, scores, *_) = line.split("\t")
contig_taxonomy = ContigTaxonomy.from_semicolon_sep(
taxonomy, force_canonical
)
probs = np.array([float(i) for i in scores.split(";")], dtype=float)
result.append(
(
contigname,
length,
PredictedContigTaxonomy(contig_taxonomy, probs),
)
)
Expand Down

0 comments on commit 104dfba

Please sign in to comment.