Skip to content

Commit

Permalink
Add simple tests for the AAE (#227)
Browse files Browse the repository at this point in the history
  • Loading branch information
jakobnissen authored Oct 13, 2023
1 parent 62f53a4 commit 33fbca1
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 6 deletions.
82 changes: 82 additions & 0 deletions test/test_aamb_encode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import unittest
import numpy as np
import io
import random

import vamb


class TestAAE(unittest.TestCase):
tnfs = np.random.random((111, 103)).astype(np.float32)
rpkm = np.random.random((111, 14)).astype(np.float32)
lens = np.random.randint(2000, 5000, size=111)
contignames = ["".join(random.choices("abcdefghijklmnopqrstu", k=10)) for _ in lens]
nlatent_l = 32
default_args = (14, 256, nlatent_l, 25, 0.5, 0.5, 0.15, False, 0)
default_temperature = 0.16
default_lr = 0.001

# Construction
def test_bad_args(self):
default_args = self.default_args

# Test the default args work
aae = vamb.aamb_encode.AAE(*default_args)
self.assertIsInstance(aae, vamb.aamb_encode.AAE)

with self.assertRaises(ValueError):
vamb.aamb_encode.AAE(0, *default_args[1:])

with self.assertRaises(ValueError):
vamb.aamb_encode.AAE(*default_args[:1], 0, *default_args[2:])

with self.assertRaises(ValueError):
vamb.aamb_encode.AAE(*default_args[:2], 0, *default_args[3:])

with self.assertRaises(ValueError):
vamb.aamb_encode.AAE(*default_args[:3], 0, *default_args[4:])

with self.assertRaises(ValueError):
vamb.aamb_encode.AAE(*default_args[:5], float("nan"), *default_args[6:])

with self.assertRaises(ValueError):
vamb.aamb_encode.AAE(*default_args[:5], -0.0001, *default_args[6:])

with self.assertRaises(ValueError):
vamb.aamb_encode.AAE(*default_args[:6], float("nan"), *default_args[7:])

def test_loss_falls(self):
aae = vamb.aamb_encode.AAE(*self.default_args)
rpkm_copy = self.rpkm.copy()
tnfs_copy = self.tnfs.copy()
dl = vamb.encode.make_dataloader(
rpkm_copy, tnfs_copy, self.lens, batchsize=16, destroy=True
)
(di, ti, ai, we) = next(iter(dl))
mu, do, to, _, _, _, _ = aae(di, ti)
start_loss = aae.calc_loss(di, do, ti, to)[0].data.item()
iobuffer = io.StringIO()

# Loss drops with training
aae.trainmodel(
dl,
nepochs=3,
batchsteps=[1, 2],
T=self.default_temperature,
lr=self.default_lr,
logfile=iobuffer,
modelfile=None,
)
mu, do, to, _, _, _, _ = aae(di, ti)
end_loss = aae.calc_loss(di, do, ti, to)[0].data.item()
self.assertLess(end_loss, start_loss)

def test_encode(self):
aae = vamb.aamb_encode.AAE(*self.default_args)
dl = vamb.encode.make_dataloader(
self.rpkm.copy(), self.tnfs.copy(), self.lens, batchsize=16, destroy=True
)
(_, encoding) = aae.get_latents(self.contignames, dl)
self.assertIsInstance(encoding, np.ndarray)
self.assertEqual(encoding.dtype, np.float32)
self.assertEqual(encoding.shape, (len(self.rpkm), self.nlatent_l))
25 changes: 19 additions & 6 deletions vamb/aamb_encode.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@


import numpy as np
from math import log
from math import log, isfinite
import time
from torch.autograd import Variable
from torch.distributions.relaxed_categorical import RelaxedOneHotCategorical
Expand Down Expand Up @@ -30,10 +30,23 @@ def __init__(
_cuda: bool,
seed: int,
):
if nsamples is None:
raise ValueError(
f"Number of samples should be provided to define the encoder input layer as well as the categorical latent dimension, not {nsamples}"
)
for variable, name in [
(nsamples, "nsamples"),
(nhiddens, "nhiddens"),
(nlatent_l, "nlatents_l"),
(nlatent_y, "nlatents_y"),
]:
if variable < 1:
raise ValueError(f"{name} must be at least 1, not {variable}")

real_variables = [(sl, "sl"), (slr, "slr")]
if alpha is not None:
real_variables.append((alpha, "alpha"))
for variable, name in real_variables:
if not isfinite(variable) or not (0.0 <= variable <= 1.0):
raise ValueError(
f"{name} must be in the interval [0.0, 1.0], not {variable}"
)

torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
Expand Down Expand Up @@ -443,7 +456,7 @@ def get_latents(
self.eval()

new_data_loader = set_batchsize(data_loader, 256, encode=True)
depths_array, _, _ = data_loader.dataset.tensors
depths_array, _, _, _ = data_loader.dataset.tensors

length = len(depths_array)
latent = np.empty((length, self.ld), dtype=np.float32)
Expand Down

0 comments on commit 33fbca1

Please sign in to comment.