diff --git a/tests/test_base.py b/tests/test_base.py index a47e5ab..8a31728 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -9,6 +9,7 @@ from scdataloader.utils import populate_my_ontology from scprint.tasks import Denoiser from scprint import scPrint +import pytest def test_base(): @@ -54,8 +55,11 @@ def test_base(): predict_depth_mult=3, dtype=torch.float32, ) - metrics, random_indices, genes, expr_pred = dn( - model=model, - adata=adata, - ) + try: + metrics, random_indices, genes, expr_pred = dn( + model=model, + adata=adata, + ) + except Exception as e: + pytest.fail(f"An exception occurred: {str(e)}") assert metrics["reco2full"] - metrics["noisy2full"] > 0, "Model is not denoising"