diff --git a/environment.yaml b/environment.yaml index f98cc53..7840e43 100644 --- a/environment.yaml +++ b/environment.yaml @@ -7,6 +7,7 @@ dependencies: - _libgcc_mutex=0.1 - _openmp_mutex=5.1 - blas=1.0 + - brotli-python=1.0.9 - bzip2=1.0.8 - ca-certificates=2024.3.11 - certifi=2024.2.2 @@ -25,6 +26,7 @@ dependencies: - gmpy2=2.1.2 - gnutls=3.6.15 - idna=3.4 + - iniconfig=1.1.1 - intel-openmp=2023.1.0 - jinja2=3.1.3 - jpeg=9e @@ -73,9 +75,15 @@ dependencies: - openh264=2.1.1 - openjpeg=2.4.0 - openssl=3.0.13 + - packaging=23.2 - pillow=10.2.0 - pip=23.3.1 + - pluggy=1.0.0 + - pyparsing=3.0.9 + - pysocks=1.7.1 + - pytest=7.4.0 - python=3.11.8 + - python-dateutil=2.8.2 - pytorch=2.2.1 - pytorch-cuda=12.1 - pytorch-mutex=1.0 @@ -83,6 +91,7 @@ dependencies: - readline=8.2 - requests=2.31.0 - setuptools=68.2.2 + - six=1.16.0 - sqlite=3.41.2 - sympy=1.12 - tbb=2021.8.0 diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_datasets_loading.py b/tests/test_datasets_loading.py new file mode 100644 index 0000000..b1f7dee --- /dev/null +++ b/tests/test_datasets_loading.py @@ -0,0 +1,18 @@ +import pytest +from utils.data_utils import get_dataloaders + +def test_load_cifar10(): + config = { + 'data': { + 'name': 'CIFAR10', + 'dataset_path': './data', + }, + 'training': { + 'batch_size': 4, + } + } + train_loader, test_loader = get_dataloaders(config) + # Check that loaders are not empty + assert len(train_loader) > 0, "CIFAR10 training loader should not be empty" + assert len(test_loader) > 0, "CIFAR10 test loader should not be empty" + diff --git a/tests/test_model_initialization.py b/tests/test_model_initialization.py new file mode 100644 index 0000000..550034c --- /dev/null +++ b/tests/test_model_initialization.py @@ -0,0 +1,17 @@ +import pytest +import torch +from models.efficientnet import get_efficientnet +from models.resnet import get_resnet + +@pytest.mark.parametrize("model_func, model_name", [ + (get_efficientnet, 'efficientnet_b0'), + (get_resnet, 'resnet18'), +]) +def test_model_initialization_and_forward_pass(model_func, model_name): + model = model_func(model_name, num_classes=10, pretrained=False) + assert model is not None, f"{model_name} should be initialized" + + # Forward pass test + dummy_input = torch.randn(2, 3, 224, 224) + output = model(dummy_input) + assert output.shape == (2, 10), f"Output shape of {model_name} should be (2, 10) for batch size of 2 and 10 classes"