-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
91 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
import pytest | ||
from callbacks import CSVLogging, EarlyStopping | ||
from factories.callback_factory import CallbackFactory | ||
|
||
def test_callback_factory_creation(): | ||
factory = CallbackFactory() | ||
|
||
# Test CSVLogging creation | ||
params = {"csv_path": "./logs"} | ||
csv_logger = factory.create("CSVLogging", **params) | ||
assert isinstance(csv_logger, CSVLogging), "Failed to create CSVLogging" | ||
|
||
# Test EarlyStopping creation | ||
params = {"patience": 5} | ||
early_stopper = factory.create("EarlyStopping", **params) | ||
assert isinstance(early_stopper, EarlyStopping), "Failed to create EarlyStopping" | ||
|
||
def test_callback_factory_unknown(): | ||
factory = CallbackFactory() | ||
with pytest.raises(ValueError) as e: | ||
factory.create("UnknownCallback") | ||
assert "Unknown callback" in str(e.value) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
import pytest | ||
from torch.nn import CrossEntropyLoss, MSELoss | ||
from factories.loss_factory import LossFactory | ||
|
||
def test_loss_factory_creation(): | ||
factory = LossFactory() | ||
|
||
# Test CrossEntropyLoss creation | ||
loss = factory.create("CrossEntropyLoss") | ||
assert isinstance(loss, CrossEntropyLoss), "Failed to create CrossEntropyLoss" | ||
|
||
# Test MSELoss creation | ||
loss = factory.create("MSELoss") | ||
assert isinstance(loss, MSELoss), "Failed to create MSELoss" | ||
|
||
def test_loss_factory_unknown(): | ||
factory = LossFactory() | ||
with pytest.raises(ValueError) as e: | ||
factory.create("UnknownLoss") | ||
assert "Unknown configuration" in str(e.value) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
import pytest | ||
from torchvision.models import efficientnet_b0, resnet18 | ||
from factories.model_factory import ModelFactory | ||
|
||
def test_model_factory_creation(): | ||
factory = ModelFactory() | ||
|
||
# Test EfficientNet B0 creation | ||
model = factory.create("efficientnet_b0", num_classes=10, pretrained=False) | ||
assert isinstance(model, efficientnet_b0().__class__), "Failed to create EfficientNet B0 with correct class" | ||
|
||
# Test ResNet18 creation | ||
model = factory.create("resnet18", num_classes=10, pretrained=False) | ||
assert isinstance(model, resnet18().__class__), "Failed to create ResNet18 with correct class" | ||
|
||
def test_model_factory_unknown_model(): | ||
factory = ModelFactory() | ||
with pytest.raises(ValueError) as e: | ||
factory.create("unknown_model", num_classes=10, pretrained=True) | ||
assert "Unknown configuration" in str(e.value) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
import pytest | ||
import torch | ||
from torch.optim import Adam, SGD | ||
from factories.optimizer_factory import OptimizerFactory | ||
|
||
def test_optimizer_factory_creation(): | ||
factory = OptimizerFactory() | ||
|
||
# Simulate model parameters | ||
params = [torch.tensor([1.0, 2.0], requires_grad=True)] | ||
|
||
# Test Adam creation with a learning rate | ||
optimizer = factory.create("Adam") | ||
assert isinstance(optimizer(lr=0.01, params=params), Adam), "Failed to create Adam optimizer" | ||
|
||
# Test SGD creation with a learning rate | ||
optimizer = factory.create("SGD") | ||
assert isinstance(optimizer(lr=0.01, params=params), SGD), "Failed to create SGD optimizer" | ||
|
||
def test_optimizer_factory_unknown(): | ||
factory = OptimizerFactory() | ||
with pytest.raises(ValueError) as excinfo: | ||
# Even though this will fail, you should still simulate correct usage | ||
factory.create("UnknownOptimizer") | ||
assert "Unknown configuration" in str(excinfo.value) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters