diff --git a/test/test_encode.py b/test/test_encode.py index 12afdc07..b579657e 100644 --- a/test/test_encode.py +++ b/test/test_encode.py @@ -174,7 +174,7 @@ def test_loss_falls(self): vae = vamb.encode.VAE(self.rpkm.shape[1]) rpkm_copy = self.rpkm.copy() tnfs_copy = self.tnfs.copy() - dl, mask = vamb.encode.make_dataloader( + dl, _ = vamb.encode.make_dataloader( rpkm_copy, tnfs_copy, self.lens, batchsize=16, destroy=True ) di = torch.Tensor(rpkm_copy) @@ -202,10 +202,20 @@ def test_loss_falls(self): after_encoding = vae_2.encode(dl) self.assertTrue(np.all(np.abs(before_encoding - after_encoding) < 1e-6)) + def test_warn_too_many_batch_steps(self): + vae = vamb.encode.VAE(self.rpkm.shape[1]) + rpkm_copy = self.rpkm.copy() + tnfs_copy = self.tnfs.copy() + dl, _ = vamb.encode.make_dataloader( + rpkm_copy, tnfs_copy, self.lens, batchsize=16, destroy=True + ) + with self.assertWarns(Warning): + vae.trainmodel(dl, nepochs=4, batchsteps=[1, 2, 3]) + def test_encoding(self): nlatent = 15 vae = vamb.encode.VAE(self.rpkm.shape[1], nlatent=nlatent) - dl, mask = vamb.encode.make_dataloader( + dl, _ = vamb.encode.make_dataloader( self.rpkm, self.tnfs, self.lens, batchsize=32 ) encoding = vae.encode(dl) diff --git a/vamb/encode.py b/vamb/encode.py index 5e380b34..6a081d59 100644 --- a/vamb/encode.py +++ b/vamb/encode.py @@ -8,7 +8,12 @@ from torch import Tensor from torch import nn as _nn from math import log as _log +<<<<<<< HEAD from time import time +||||||| parent of 5536570 (Handle too high batch steps more graciously) +======= +import warnings +>>>>>>> 5536570 (Handle too high batch steps more graciously) __doc__ = """Encode a depths matrix and a tnf matrix to latent representation. @@ -379,7 +384,7 @@ def trainepoch( epoch_celoss = 0.0 if epoch in batchsteps: - data_loader = set_batchsize(data_loader, data_loader.batch_size * 2) + data_loader = set_batchsize(data_loader, data_loader.batch_size * 2) # type: ignore for depths_in, tnf_in, weights in data_loader: depths_in.requires_grad = True @@ -551,28 +556,41 @@ def trainmodel( if nepochs < 1: raise ValueError("Minimum 1 epoch, not {nepochs}") - if batchsteps is None: - batchsteps_set: set[int] = set() + if batchsteps is None or len(batchsteps) == 0: + sorted_batch_steps: list[int] = [] else: # First collect to list in order to allow all element types, then check that # they are integers - batchsteps = list(batchsteps) if not all(isinstance(i, int) for i in batchsteps): raise ValueError("All elements of batchsteps must be integers") - if max(batchsteps, default=0) >= nepochs: + sorted_batch_steps = sorted(set(batchsteps)) + if sorted_batch_steps[0] < 1: + raise ValueError( + f"Minimum of batchsteps must be 1, not {sorted_batch_steps[0]}" + ) + if sorted_batch_steps[-1] >= nepochs: raise ValueError("Max batchsteps must not equal or exceed nepochs") - last_batchsize = dataloader.batch_size * 2 ** len(batchsteps) - if len(dataloader.dataset) < last_batchsize: # type: ignore + + n_contigs = len(dataloader.dataset) # type: ignore + starting_batch_size: int = dataloader.batch_size # type: ignore + if n_contigs < starting_batch_size: raise ValueError( - f"Last batch size of {last_batchsize} exceeds dataset length " - f"of {len(dataloader.dataset)}. " # type: ignore + f"Starting batch size of {starting_batch_size} exceeds dataset length " + f"of {n_contigs}. " "This means you have too few contigs left after filtering to train. " "It is not adviced to run Vamb with fewer than 10,000 sequences " "after filtering. " "Please check the Vamb log file to see where the sequences were " "filtered away, and verify BAM files has sensible content." ) - batchsteps_set = set(batchsteps) + maximum_batch_steps = (n_contigs // starting_batch_size).bit_length() - 1 + if maximum_batch_steps < len(sorted_batch_steps): + warnings.warn( + f"Requested {len(sorted_batch_steps)} batch steps, but with a starting " + f"batch size of {starting_batch_size} and {n_contigs} contigs, " + f"only the first {maximum_batch_steps} batch steps can be used." + ) + sorted_batch_steps = sorted_batch_steps[:maximum_batch_steps] # Get number of features # Following line is un-inferrable due to typing problems with DataLoader @@ -591,8 +609,8 @@ def trainmodel( print("\tN epochs:", nepochs, file=logfile) print("\tStarting batch size:", dataloader.batch_size, file=logfile) batchsteps_string = ( - ", ".join(map(str, sorted(batchsteps_set))) - if batchsteps_set + ", ".join(map(str, sorted_batch_steps)) + if len(sorted_batch_steps) > 0 else "None" ) print("\tBatchsteps:", batchsteps_string, file=logfile) @@ -603,7 +621,7 @@ def trainmodel( # Train for epoch in range(nepochs): dataloader = self.trainepoch( - dataloader, epoch, optimizer, sorted(batchsteps_set), time(), logfile + dataloader, epoch, optimizer, sorted_batch_steps, time(), logfile ) # Save weights - Lord forgive me, for I have sinned when catching all exceptions