Skip to content

Commit

Permalink
Tests for data loading and model initialization advances #2
Browse files Browse the repository at this point in the history
These are simple tests to start with. They will be enhanced once I had started to implement my own dataset and try different models.
  • Loading branch information
pab1s committed Mar 19, 2024
1 parent 2a1e936 commit 63a94ef
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 0 deletions.
9 changes: 9 additions & 0 deletions environment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ dependencies:
- _libgcc_mutex=0.1
- _openmp_mutex=5.1
- blas=1.0
- brotli-python=1.0.9
- bzip2=1.0.8
- ca-certificates=2024.3.11
- certifi=2024.2.2
Expand All @@ -25,6 +26,7 @@ dependencies:
- gmpy2=2.1.2
- gnutls=3.6.15
- idna=3.4
- iniconfig=1.1.1
- intel-openmp=2023.1.0
- jinja2=3.1.3
- jpeg=9e
Expand Down Expand Up @@ -73,16 +75,23 @@ dependencies:
- openh264=2.1.1
- openjpeg=2.4.0
- openssl=3.0.13
- packaging=23.2
- pillow=10.2.0
- pip=23.3.1
- pluggy=1.0.0
- pyparsing=3.0.9
- pysocks=1.7.1
- pytest=7.4.0
- python=3.11.8
- python-dateutil=2.8.2
- pytorch=2.2.1
- pytorch-cuda=12.1
- pytorch-mutex=1.0
- pyyaml=6.0.1
- readline=8.2
- requests=2.31.0
- setuptools=68.2.2
- six=1.16.0
- sqlite=3.41.2
- sympy=1.12
- tbb=2021.8.0
Expand Down
Empty file added tests/__init__.py
Empty file.
18 changes: 18 additions & 0 deletions tests/test_datasets_loading.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import pytest
from utils.data_utils import get_dataloaders

def test_load_cifar10():
config = {
'data': {
'name': 'CIFAR10',
'dataset_path': './data',
},
'training': {
'batch_size': 4,
}
}
train_loader, test_loader = get_dataloaders(config)
# Check that loaders are not empty
assert len(train_loader) > 0, "CIFAR10 training loader should not be empty"
assert len(test_loader) > 0, "CIFAR10 test loader should not be empty"

17 changes: 17 additions & 0 deletions tests/test_model_initialization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import pytest
import torch
from models.efficientnet import get_efficientnet
from models.resnet import get_resnet

@pytest.mark.parametrize("model_func, model_name", [
(get_efficientnet, 'efficientnet_b0'),
(get_resnet, 'resnet18'),
])
def test_model_initialization_and_forward_pass(model_func, model_name):
model = model_func(model_name, num_classes=10, pretrained=False)
assert model is not None, f"{model_name} should be initialized"

# Forward pass test
dummy_input = torch.randn(2, 3, 224, 224)
output = model(dummy_input)
assert output.shape == (2, 10), f"Output shape of {model_name} should be (2, 10) for batch size of 2 and 10 classes"

0 comments on commit 63a94ef

Please sign in to comment.