Skip to content

Commit

Permalink
Documented files closes #14
Browse files Browse the repository at this point in the history
  • Loading branch information
pab1s committed Apr 1, 2024
1 parent da2fce5 commit 99491c1
Show file tree
Hide file tree
Showing 16 changed files with 377 additions and 10 deletions.
33 changes: 32 additions & 1 deletion datasets/car_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,26 @@
class CarDataset(Dataset):
def __init__(self, data_dir, train=None, transform=None):
"""
CarDataset class represents a dataset of car images.
Args:
data_dir (string): Path to the dataset directory.
train (bool, optional): Whether the dataset is for training or not. Default is None.
transform (callable, optional): Optional transform to be applied on a sample.
"""
self.data_dir = data_dir
self.transform = transform
self.images, self.labels, self.idx_to_class = self._load_dataset()

def _load_dataset(self):
"""
Load the dataset from the data directory.
Returns:
images (list): List of image paths.
labels (list): List of corresponding labels.
idx_to_class (dict): Mapping of label index to class name.
"""
images = []
labels = []
label_to_idx = {}
Expand All @@ -37,9 +48,24 @@ def _load_dataset(self):
return images, labels, idx_to_class

def __len__(self):
"""
Get the length of the dataset.
Returns:
int: Length of the dataset.
"""
return len(self.images)

def __getitem__(self, idx):
"""
Get a sample from the dataset.
Args:
idx (int): Index of the sample.
Returns:
tuple: A tuple containing the image and its corresponding label.
"""
image_path = self.images[idx]
image = Image.open(image_path)
label = self.labels[idx]
Expand All @@ -50,5 +76,10 @@ def __getitem__(self, idx):
return image, label

def get_classes(self):
"""Return the list of class names."""
"""
Return the list of class names.
Returns:
list: List of class names.
"""
return [self.idx_to_class[idx] for idx in sorted(self.idx_to_class)]
15 changes: 15 additions & 0 deletions datasets/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,21 @@
from datasets.car_dataset import CarDataset

def get_dataset(name, root_dir, train=None, transform=None):
"""
Returns a dataset based on the given name.
Args:
name (str): The name of the dataset.
root_dir (str): The root directory where the dataset is stored.
train (bool, optional): If True, returns the training set. If False, returns the test set. Defaults to None.
transform (callable, optional): A function/transform that takes in an image and returns a transformed version. Defaults to None.
Returns:
torch.utils.data.Dataset: The requested dataset.
Raises:
ValueError: If the dataset name is not supported.
"""
if name == 'CIFAR10':
return CIFAR10(root=root_dir, train=train, download=True, transform=transform)
elif name == 'CIFAR100':
Expand Down
24 changes: 24 additions & 0 deletions datasets/transformations.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,30 @@
from torchvision import transforms

def get_transforms(config):
"""
Returns a composed transformation object based on the provided configuration.
Args:
config (dict): A dictionary containing the configuration for the transformations.
Returns:
torchvision.transforms.Compose: A composed transformation object.
Raises:
ValueError: If the specified transform is not recognized.
Example:
>>> config = {
>>> 'data': {
>>> 'transforms': [
>>> {'name': 'RandomHorizontalFlip'},
>>> {'name': 'RandomVerticalFlip'},
>>> {'name': 'ToTensor'}
>>> ]
>>> }
>>> }
>>> transforms = get_transforms(config)
"""
transform_list = []
for transform_config in config['data']['transforms']:
transform_name = transform_config['name']
Expand Down
14 changes: 14 additions & 0 deletions models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,20 @@
from models.resnet import get_resnet

def get_model(model_name, num_classes, pretrained=True):
"""
Returns a pre-trained model based on the given model_name.
Args:
model_name (str): The name of the model to be used.
num_classes (int): The number of output classes.
pretrained (bool, optional): Whether to load pre-trained weights. Defaults to True.
Returns:
torch.nn.Module: The pre-trained model.
Raises:
ValueError: If the model_name is not supported.
"""
if 'efficientnet' in model_name:
return get_efficientnet(model_name, num_classes, pretrained)
elif 'resnet' in model_name:
Expand Down
15 changes: 14 additions & 1 deletion models/efficientnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,20 @@
import torch.nn as nn

def get_efficientnet(model_name, num_classes, pretrained=True):
"""
Get an EfficientNet model with a custom classifier.
Args:
model_name (str): Name of the EfficientNet model to use. Supported options are "efficientnet_b0", "efficientnet_b1", and "efficientnet_b2".
num_classes (int): Number of output classes for the custom classifier.
pretrained (bool, optional): Whether to load pretrained weights for the model. Defaults to True.
Returns:
torch.nn.Module: EfficientNet model with a custom classifier.
Raises:
ValueError: If an unsupported EfficientNet version is specified.
"""
weights = "DEFAULT" if pretrained else None

if model_name == "efficientnet_b0":
Expand All @@ -13,7 +27,6 @@ def get_efficientnet(model_name, num_classes, pretrained=True):
else:
raise ValueError("Unsupported EfficientNet version")

# Change the classifier head to match the number of classes
num_ftrs = model.classifier[1].in_features
model.classifier[1] = nn.Linear(num_ftrs, num_classes)
return model
15 changes: 14 additions & 1 deletion models/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,20 @@
import torch.nn as nn

def get_resnet(model_name, num_classes, pretrained=True):
"""
Get a ResNet model with a specified architecture.
Args:
model_name (str): The name of the ResNet architecture. Supported options are "resnet18", "resnet34", and "resnet50".
num_classes (int): The number of output classes.
pretrained (bool, optional): Whether to load pretrained weights. Defaults to True.
Returns:
torch.nn.Module: The ResNet model with the specified architecture and number of output classes.
Raises:
ValueError: If an unsupported ResNet version is specified.
"""
weights = "DEFAULT" if pretrained else None

if model_name == "resnet18":
Expand All @@ -13,7 +27,6 @@ def get_resnet(model_name, num_classes, pretrained=True):
else:
raise ValueError("Unsupported ResNet version")

# Change the classifier head to match the number of classes
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, num_classes)
return model
15 changes: 15 additions & 0 deletions tests/test_datasets_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,21 @@ def basic_transform():
return transforms.Compose([transforms.ToTensor()])

def test_cifar10_download_and_load(basic_transform):
"""
Test function to download and load CIFAR10 dataset.
Args:
basic_transform: A transformation to apply to the dataset.
Returns:
None
Raises:
AssertionError: If the train dataset is not an instance of datasets.CIFAR10.
AssertionError: If the test dataset is not an instance of datasets.CIFAR10.
AssertionError: If the CIFAR10 train dataset does not contain 50,000 images.
AssertionError: If the CIFAR10 test dataset does not contain 10,000 images.
"""
root_dir = './data'
train_dataset = get_dataset('CIFAR10', root_dir=root_dir, train=True, transform=basic_transform)
test_dataset = get_dataset('CIFAR10', root_dir=root_dir, train=False, transform=basic_transform)
Expand Down
22 changes: 21 additions & 1 deletion tests/test_fine_tuning_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,27 @@
CONFIG_TEST = yaml.safe_load(file)

def test_fine_tuning_loop():
"""Test a short fine-tuning loop to ensure pipeline works with BasicTrainer."""
"""
Test the fine-tuning loop of the model training pipeline.
This function performs the following steps:
1. Sets the device to CUDA if available, otherwise to CPU.
2. Retrieves the data transforms using the CONFIG_TEST dictionary.
3. Gets the dataset using the specified name and root directory from CONFIG_TEST.
4. Splits the dataset into train, validation, and test sets.
5. Creates data loaders for the train, validation, and test sets.
6. Retrieves the model using the specified name, number of classes, and pretrained flag from CONFIG_TEST.
7. Defines the criterion, optimizer, and metrics for training.
8. Initializes the trainer object using the CONFIG_TEST dictionary.
9. Builds the trainer for fine-tuning, optionally freezing layers until a specified layer.
10. Trains the model using the train and validation loaders for a specified number of epochs.
11. Unfreezes all layers of the model.
12. Builds the trainer for fine-tuning without freezing any layers.
13. Trains the model again using the train and validation loaders for a specified number of epochs.
14. Evaluates the model on the test set and retrieves the metrics results.
15. Asserts that the length of the metrics results is equal to the number of metrics.
16. Asserts that all metric values are greater than or equal to 0.
"""

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Expand Down
1 change: 0 additions & 1 deletion tests/test_model_initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ def test_model_initialization_and_forward_pass(model_func, model_name):
model = model_func(model_name, num_classes=10, pretrained=False)
assert model is not None, f"{model_name} should be initialized"

# Forward pass test
dummy_input = torch.randn(2, 3, 224, 224)
output = model(dummy_input)
assert output.shape == (2, 10), f"Output shape of {model_name} should be (2, 10) for batch size of 2 and 10 classes"
21 changes: 20 additions & 1 deletion tests/test_training_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,26 @@
CONFIG_TEST = yaml.safe_load(file)

def test_training_loop():
"""Test a short training loop to ensure pipeline works with BasicTrainer."""
"""
Test a short training loop to ensure pipeline works with BasicTrainer.
This function performs a short training loop using the BasicTrainer class to ensure that the training pipeline is functioning correctly.
The function performs the following steps:
1. Sets the device to CUDA if available, otherwise sets it to CPU.
2. Retrieves the necessary transforms using the CONFIG_TEST dictionary.
3. Retrieves the dataset using the CONFIG_TEST dictionary.
4. Splits the dataset into train, validation, and test sets.
5. Creates data loaders for the train, validation, and test sets.
6. Retrieves the model using the CONFIG_TEST dictionary.
7. Defines the criterion, optimizer, and metrics for training.
8. Initializes the trainer using the CONFIG_TEST dictionary.
9. Builds the trainer with the specified criterion, optimizer, and metrics.
10. Trains the model using the train and validation data loaders.
11. Evaluates the model using the test data loader.
12. Asserts that the number of metrics results is equal to the number of metrics.
13. Asserts that all metric values are greater than or equal to 0.
"""

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Expand Down
13 changes: 13 additions & 0 deletions trainers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,19 @@
from trainers.basic_trainer import BasicTrainer

def get_trainer(trainer_name, **kwargs):
"""
Returns an instance of the specified trainer.
Parameters:
- trainer_name (str): The name of the trainer.
- **kwargs: Additional keyword arguments to be passed to the trainer constructor.
Returns:
- Trainer: An instance of the specified trainer.
Raises:
- ValueError: If the trainer name is not recognized.
"""
if trainer_name == "BasicTrainer":
return BasicTrainer(**kwargs)
else:
Expand Down
35 changes: 35 additions & 0 deletions trainers/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,25 @@


class BaseTrainer(ABC):
"""
Base class for trainers in the tda-nn-separability project.
Attributes:
model (nn.Module): The model to be trained.
device (torch.device): The device to be used for training.
criterion: The loss function used for training.
optimizer: The optimizer used for training.
scheduler: The learning rate scheduler.
metrics (list): List of metrics used for evaluation during training.
Methods:
build: Build the model, criterion, optimizer, and scheduler.
freeze_layers: Freeze layers up to a specified layer.
unfreeze_all_layers: Unfreeze all layers of the model.
train: Train the model for a given number of epochs.
evaluate: Evaluate the model on a given dataset.
"""

def __init__(self, model, device):
self.model = model
self.device = device
Expand Down Expand Up @@ -49,6 +68,14 @@ def train(self, train_loader, num_epochs, valid_loader=None, log_path=None, plot
"""
Train the model for a given number of epochs, calculating metrics at the end of each epoch
for both training and validation sets.
Args:
train_loader: The data loader for the training set.
num_epochs (int): The number of epochs to train the model.
valid_loader: The data loader for the validation set (optional).
log_path: The path to save the training log (optional).
plot_path: The path to save the training plot (optional).
verbose (bool): Whether to print training progress (default: True).
"""
training_epoch_losses = []
validation_epoch_losses = []
Expand Down Expand Up @@ -88,6 +115,14 @@ def train(self, train_loader, num_epochs, valid_loader=None, log_path=None, plot
def evaluate(self, data_loader, metrics=None, verbose=True) -> Tuple[float, dict]:
"""
Evaluate the model on a given dataset.
Args:
data_loader: The data loader for the dataset.
metrics (list): List of metrics to evaluate (optional).
verbose (bool): Whether to print evaluation results (default: True).
Returns:
Tuple[float, dict]: The average loss and metric results.
"""
if metrics is None:
metrics = self.metrics
Expand Down
Loading

0 comments on commit 99491c1

Please sign in to comment.