Skip to content

Commit

Permalink
Remove overfitting warning
Browse files Browse the repository at this point in the history
In experiments, we've been unable to show that Vamb overfits, even with datasets
10% the size of our smallest test dataset. We're not sure why, but it's probably
because Vamb is extremely regularized already, and the regularization kicks
in only with a low number of contigs.

Vamb might still overfit on high quality assemblies like HiFi data, if there are
only a few thousand or few hundred contigs, but we cannot assess this until we
get good benchmark datasets of this quality.
  • Loading branch information
jakobnissen committed Feb 15, 2024
1 parent f37b5fd commit 491590c
Show file tree
Hide file tree
Showing 5 changed files with 12 additions and 32 deletions.
29 changes: 6 additions & 23 deletions vamb/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,19 +64,17 @@ class CompositionPath(type(Path())):


class CompositionOptions:
__slots__ = ["path", "min_contig_length", "warn_on_few_seqs"]
__slots__ = ["path", "min_contig_length"]

def __init__(
self,
fastapath: Optional[Path],
npzpath: Optional[Path],
min_contig_length: int,
warn_on_few_seqs: bool,
):
assert isinstance(fastapath, (Path, type(None)))
assert isinstance(npzpath, (Path, type(None)))
assert isinstance(min_contig_length, int)
assert isinstance(warn_on_few_seqs, bool)

if min_contig_length < 250:
raise argparse.ArgumentTypeError(
Expand All @@ -98,7 +96,6 @@ def __init__(
assert npzpath is not None
self.path = CompositionPath(npzpath)
self.min_contig_length = min_contig_length
self.warn_on_few_seqs = warn_on_few_seqs


class AbundancePath(type(Path())):
Expand Down Expand Up @@ -546,17 +543,6 @@ def calc_tnf(

binsplitter.initialize(composition.metadata.identifiers)

if options.warn_on_few_seqs and composition.nseqs < 20_000:
message = (
f"Kept only {composition.nseqs} sequences from FASTA file. "
"We normally expect 20,000 sequences or more to prevent overfitting. "
"As a deep learning model, VAEs are prone to overfitting with too few sequences. "
"You may want to bin more samples as a time, lower the beta parameter, "
"or use a different binner altogether.\n"
)
logger.opt(raw=True).info("\n")
logger.warning(message)

# Warn the user if any contigs have been observed, which is smaller
# than the threshold.
if not np.all(composition.metadata.mask):
Expand Down Expand Up @@ -771,9 +757,11 @@ def cluster_and_write_files(
print(
str(i + 1),
None if cluster.radius is None else round(cluster.radius, 3),
None
if cluster.observed_pvr is None
else round(cluster.observed_pvr, 2),
(
None
if cluster.observed_pvr is None
else round(cluster.observed_pvr, 2)
),
cluster.kind_str,
sum(sequence_lens[i] for i in cluster.members),
len(cluster.members),
Expand Down Expand Up @@ -1384,7 +1372,6 @@ def __init__(self, args):
self.args.fasta,
self.args.composition,
self.args.minlength,
args.warn_on_few_seqs,
)
self.abundance_options = AbundanceOptions(
self.args.bampaths,
Expand Down Expand Up @@ -2180,10 +2167,8 @@ def main():
args = parser.parse_args()

if args.subcommand == TAXOMETER:
args.warn_on_few_seqs = True
runner = TaxometerArguments(args)
elif args.subcommand == BIN:
args.warn_on_few_seqs = True
if args.model_subcommand is None:
vaevae_parserbin_parser.print_help()
sys.exit(1)
Expand All @@ -2194,8 +2179,6 @@ def main():
}
runner = classes_map[args.model_subcommand](args)
elif args.subcommand == RECLUSTER:
# Uniquely, the reclustering cannot overfit, so we don't need this warning
args.warn_on_few_seqs = False
runner = ReclusteringArguments(args)
else:
# There are no more subcommands
Expand Down
1 change: 0 additions & 1 deletion vamb/aamb_encode.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Adversarial autoencoders (AAE) for metagenomics binning, this files contains the implementation of the AAE"""


import numpy as np
from math import log, isfinite
import time
Expand Down
1 change: 0 additions & 1 deletion vamb/semisupervised_encode.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Semisupervised multimodal VAEs for metagenomics binning, this files contains the implementation of the VAEVAE for MMSEQ predictions"""


__cmd_doc__ = """Encode depths and TNF using a VAE to latent representation"""

import numpy as _np
Expand Down
1 change: 0 additions & 1 deletion vamb/taxvamb_encode.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Hierarchical loss for the labels suggested in https://arxiv.org/abs/2210.10929"""


__cmd_doc__ = """Hierarchical loss for the labels"""


Expand Down
12 changes: 6 additions & 6 deletions workflow_avamb/src/rip_bins.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,9 +183,9 @@ def remove_meaningless_edges_from_pairs(
contig_length,
)
print("Cluster ripped because of a meaningless edge ", cluster_updated)
clusters_changed_but_not_intersecting_contigs[
cluster_updated
] = cluster_contigs[cluster_updated]
clusters_changed_but_not_intersecting_contigs[cluster_updated] = (
cluster_contigs[cluster_updated]
)

components: list[set[str]] = list()
for component in nx.connected_components(graph_clusters):
Expand Down Expand Up @@ -295,9 +295,9 @@ def make_all_components_pair(
contig_length,
)
print("Cluster ripped because of a pairing component ", cluster_updated)
clusters_changed_but_not_intersecting_contigs[
cluster_updated
] = cluster_contigs[cluster_updated]
clusters_changed_but_not_intersecting_contigs[cluster_updated] = (
cluster_contigs[cluster_updated]
)
component_len = max(
[
len(nx.node_connected_component(graph_clusters, node_i))
Expand Down

0 comments on commit 491590c

Please sign in to comment.