Skip to content

Commit

Permalink
Tests for factories closes #25
Browse files Browse the repository at this point in the history
  • Loading branch information
pab1s committed Apr 26, 2024
1 parent cbcb9d4 commit 1c339ac
Show file tree
Hide file tree
Showing 8 changed files with 91 additions and 4 deletions.
22 changes: 22 additions & 0 deletions tests/test_callback_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import pytest
from callbacks import CSVLogging, EarlyStopping
from factories.callback_factory import CallbackFactory

def test_callback_factory_creation():
factory = CallbackFactory()

# Test CSVLogging creation
params = {"csv_path": "./logs"}
csv_logger = factory.create("CSVLogging", **params)
assert isinstance(csv_logger, CSVLogging), "Failed to create CSVLogging"

# Test EarlyStopping creation
params = {"patience": 5}
early_stopper = factory.create("EarlyStopping", **params)
assert isinstance(early_stopper, EarlyStopping), "Failed to create EarlyStopping"

def test_callback_factory_unknown():
factory = CallbackFactory()
with pytest.raises(ValueError) as e:
factory.create("UnknownCallback")
assert "Unknown callback" in str(e.value)
2 changes: 1 addition & 1 deletion tests/test_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
def test_checkpoint():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

transforms = get_transforms(CONFIG_TEST)
transforms = get_transforms(CONFIG_TEST['data']['transforms'])
data = get_dataset(
name=CONFIG_TEST['data']['name'],
root_dir=CONFIG_TEST['data']['dataset_path'],
Expand Down
2 changes: 1 addition & 1 deletion tests/test_early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
def test_early_stopping():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

transforms = get_transforms(CONFIG_TEST)
transforms = get_transforms(CONFIG_TEST['data']['transforms'])
data = get_dataset(
name=CONFIG_TEST['data']['name'],
root_dir=CONFIG_TEST['data']['dataset_path'],
Expand Down
2 changes: 1 addition & 1 deletion tests/test_fine_tuning_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def test_fine_tuning_loop():

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

transforms = get_transforms(CONFIG_TEST)
transforms = get_transforms(CONFIG_TEST['data']['transforms'])

data = get_dataset(
name=CONFIG_TEST['data']['name'],
Expand Down
20 changes: 20 additions & 0 deletions tests/test_loss_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import pytest
from torch.nn import CrossEntropyLoss, MSELoss
from factories.loss_factory import LossFactory

def test_loss_factory_creation():
factory = LossFactory()

# Test CrossEntropyLoss creation
loss = factory.create("CrossEntropyLoss")
assert isinstance(loss, CrossEntropyLoss), "Failed to create CrossEntropyLoss"

# Test MSELoss creation
loss = factory.create("MSELoss")
assert isinstance(loss, MSELoss), "Failed to create MSELoss"

def test_loss_factory_unknown():
factory = LossFactory()
with pytest.raises(ValueError) as e:
factory.create("UnknownLoss")
assert "Unknown configuration" in str(e.value)
20 changes: 20 additions & 0 deletions tests/test_model_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import pytest
from torchvision.models import efficientnet_b0, resnet18
from factories.model_factory import ModelFactory

def test_model_factory_creation():
factory = ModelFactory()

# Test EfficientNet B0 creation
model = factory.create("efficientnet_b0", num_classes=10, pretrained=False)
assert isinstance(model, efficientnet_b0().__class__), "Failed to create EfficientNet B0 with correct class"

# Test ResNet18 creation
model = factory.create("resnet18", num_classes=10, pretrained=False)
assert isinstance(model, resnet18().__class__), "Failed to create ResNet18 with correct class"

def test_model_factory_unknown_model():
factory = ModelFactory()
with pytest.raises(ValueError) as e:
factory.create("unknown_model", num_classes=10, pretrained=True)
assert "Unknown configuration" in str(e.value)
25 changes: 25 additions & 0 deletions tests/test_optimizer_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import pytest
import torch
from torch.optim import Adam, SGD
from factories.optimizer_factory import OptimizerFactory

def test_optimizer_factory_creation():
factory = OptimizerFactory()

# Simulate model parameters
params = [torch.tensor([1.0, 2.0], requires_grad=True)]

# Test Adam creation with a learning rate
optimizer = factory.create("Adam")
assert isinstance(optimizer(lr=0.01, params=params), Adam), "Failed to create Adam optimizer"

# Test SGD creation with a learning rate
optimizer = factory.create("SGD")
assert isinstance(optimizer(lr=0.01, params=params), SGD), "Failed to create SGD optimizer"

def test_optimizer_factory_unknown():
factory = OptimizerFactory()
with pytest.raises(ValueError) as excinfo:
# Even though this will fail, you should still simulate correct usage
factory.create("UnknownOptimizer")
assert "Unknown configuration" in str(excinfo.value)
2 changes: 1 addition & 1 deletion tests/test_training_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def test_training_loop():

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

transforms = get_transforms(CONFIG_TEST)
transforms = get_transforms(CONFIG_TEST['data']['transforms'])

data = get_dataset(
name=CONFIG_TEST['data']['name'],
Expand Down

0 comments on commit 1c339ac

Please sign in to comment.