Skip to content

Commit

Permalink
Add unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
melihyilmaz committed Nov 27, 2023
1 parent ade5a31 commit 8033fc7
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 4 deletions.
1 change: 0 additions & 1 deletion casanovo/denovo/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ def __init__(
self.trainer = None
self.model = None
self.loaders = None

self.writer = None

# Configure checkpoints.
Expand Down
19 changes: 16 additions & 3 deletions tests/unit_tests/test_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,15 +514,28 @@ def test_spectrum_id_mzml(mzml_small, tmp_path):

def test_train_val_step_functions():
"""Test train and validation step functions operating on batches."""
model = Spec2Pep(n_beams=1, residues="massivekb", min_peptide_len=4)
model = Spec2Pep(
n_beams=1,
residues="massivekb",
min_peptide_len=4,
train_label_smoothing=0.1,
)
spectra = torch.zeros(1, 5, 2)
precursors = torch.tensor([[469.25364, 2.0, 235.63410]])
peptides = ["PEPK"]
batch = (spectra, precursors, peptides)

train_step_loss = model.training_step(batch)
val_step_loss = model.validation_step(batch)

# Check if valid loss value returned
assert model.training_step(batch) > 0
assert model.validation_step(batch) > 0
assert train_step_loss > 0
assert val_step_loss > 0

# Check if smoothing is applied in training and not in validation
assert model.celoss.label_smoothing == 0.1
assert model.val_celoss.label_smoothing == 0
assert val_step_loss != train_step_loss


def test_run_map(mgf_small):
Expand Down

0 comments on commit 8033fc7

Please sign in to comment.