Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/callbacks Completed #24

Merged
merged 11 commits into from
Apr 22, 2024
4 changes: 4 additions & 0 deletions callbacks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from callbacks.csv_logging import CSVLogging
from callbacks.epoch_results_logging import EpochResultsLogging
from callbacks.early_stopping import EarlyStopping
from callbacks.checkpoint import Checkpoint
73 changes: 73 additions & 0 deletions callbacks/callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
class Callback:
"""
A base class for defining callbacks in a training process.

Callbacks are functions that can be executed at various stages during training.
They can be used to perform additional actions or modify the behavior of the training process.

Methods:
should_continue(logs=None) -> bool:
Determines whether the training process should continue or stop.

on_epoch_begin(epoch, logs=None) -> None:
Executed at the beginning of each epoch.

on_epoch_end(epoch, logs=None) -> None:
Executed at the end of each epoch.

on_train_begin(logs=None) -> None:
Executed at the beginning of the training process.

on_train_end(logs=None) -> None:
Executed at the end of the training process.
"""

def should_continue(self, logs=None) -> bool:
"""
Determines whether the training process should continue or stop.

Args:
logs (dict): Optional dictionary containing training logs.

Returns:
bool: True if the training process should continue, False otherwise.
"""
return True

def on_epoch_begin(self, epoch, logs=None) -> None:
"""
Executed at the beginning of each epoch.

Args:
epoch (int): The current epoch number.
logs (dict): Optional dictionary containing training logs.
"""
pass

def on_epoch_end(self, epoch, logs=None) -> None:
"""
Executed at the end of each epoch.

Args:
epoch (int): The current epoch number.
logs (dict): Optional dictionary containing training logs.
"""
pass

def on_train_begin(self, logs=None) -> None:
"""
Executed at the beginning of the training process.

Args:
logs (dict): Optional dictionary containing training logs.
"""
pass

def on_train_end(self, logs=None) -> None:
"""
Executed at the end of the training process.

Args:
logs (dict): Optional dictionary containing training logs.
"""
pass
60 changes: 60 additions & 0 deletions callbacks/checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from callbacks.callback import Callback
import os
import torch

class Checkpoint(Callback):
"""
Callback class for saving model checkpoints during training.

Args:
checkpoint_dir (str): Directory to save the checkpoints.
model (torch.nn.Module): The model to be saved.
optimizer (torch.optim.Optimizer): The optimizer to be saved.
scheduler (torch.optim.lr_scheduler._LRScheduler, optional): The scheduler to be saved. Default is None.
save_freq (int, optional): Frequency of saving checkpoints. Default is 1.
verbose (bool, optional): Whether to print the checkpoint save path. Default is False.
"""

def __init__(self, checkpoint_dir, model, optimizer, scheduler=None, save_freq=1, verbose=False):
self.checkpoint_dir = checkpoint_dir
self.model = model
self.optimizer = optimizer
self.scheduler = scheduler
self.save_freq = save_freq
self.verbose = verbose
if not os.path.exists(checkpoint_dir):
os.makedirs(checkpoint_dir, exist_ok=True)

def on_epoch_end(self, epoch, logs=None):
"""
Callback function called at the end of each epoch.

Args:
epoch (int): The current epoch number.
logs (dict, optional): Dictionary containing training and validation losses. Default is None.
"""
if (epoch + 1) % self.save_freq == 0:
self.save_checkpoint(epoch, logs)

def save_checkpoint(self, epoch, logs=None):
"""
Save the model checkpoint.

Args:
epoch (int): The current epoch number.
logs (dict, optional): Dictionary containing training and validation losses. Default is None.
"""
state = {
'epoch': epoch,
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'train_losses': logs.get('training_losses', []),
'val_losses': logs.get('validation_losses', []),
}
if self.scheduler:
state['scheduler_state_dict'] = self.scheduler.state_dict()

save_path = os.path.join(self.checkpoint_dir, f'checkpoint_epoch_{epoch+1}.pth')
torch.save(state, save_path)
if self.verbose:
print(f"Checkpoint saved at {save_path}")
54 changes: 54 additions & 0 deletions callbacks/csv_logging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from callbacks.callback import Callback
import csv

class CSVLogging(Callback):
"""
Callback for logging training and validation metrics to a CSV file.

Args:
csv_path (str): The path to the CSV file.

Attributes:
csv_path (str): The path to the CSV file.
headers_written (bool): Flag indicating whether the headers have been written to the CSV file.
"""

def __init__(self, csv_path):
self.csv_path = csv_path
self.headers_written = False

def on_epoch_end(self, epoch, logs=None):
"""
Method called at the end of each epoch during training.

Args:
epoch (int): The current epoch number.
logs (dict): Dictionary containing the training and validation metrics.

Returns:
None
"""
if logs is None:
return

epoch_data = logs.get('epoch')
train_loss = logs.get('train_loss')
val_loss = logs.get('val_loss')
train_metrics = logs.get('train_metrics', {})
val_metrics = logs.get('val_metrics', {})

metrics = {'train_loss': train_loss, 'val_loss': val_loss}
metrics.update({f'train_{key}': value for key, value in train_metrics.items()})
metrics.update({f'val_{key}': value for key, value in val_metrics.items()})

if not self.headers_written:
headers = ['epoch'] + list(metrics.keys())
with open(self.csv_path, 'w', newline='') as file:
writer = csv.writer(file)
writer.writerow(headers)
self.headers_written = True

values = [epoch_data] + [metrics[key] for key in headers[1:]] # Ensure the order matches headers
with open(self.csv_path, 'a', newline='') as file:
writer = csv.writer(file)
writer.writerow(values)
34 changes: 34 additions & 0 deletions callbacks/early_stopping.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from callbacks.callback import Callback

class EarlyStopping(Callback):
def __init__(self, monitor='val_loss', patience=5, verbose=False, delta=0):
self.monitor = monitor
self.patience = patience
self.verbose = verbose
self.delta = delta
self.best_score = None
self.epochs_no_improve = 0
self.stopped_epoch = 0
self.early_stop = False

def on_epoch_end(self, epoch, logs=None):
current = logs.get(self.monitor)
if current is None:
return

score = -current if 'loss' in self.monitor else current
if self.best_score is None:
self.best_score = score
elif score < self.best_score + self.delta:
self.epochs_no_improve += 1
if self.epochs_no_improve >= self.patience:
self.early_stop = True
self.stopped_epoch = epoch + 1
if self.verbose:
print(f"Early stopping triggered at epoch {self.stopped_epoch}")
else:
self.best_score = score
self.epochs_no_improve = 0

def should_continue(self, logs=None):
return not self.early_stop
39 changes: 39 additions & 0 deletions callbacks/epoch_results_logging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from callbacks.callback import Callback

class EpochResultsLogging(Callback):
"""
Callback for logging epoch results during training.

This callback prints the training loss, training metrics, and validation metrics at the end of each epoch.

Args:
Callback: The base class for Keras callbacks.

Methods:
on_epoch_end: Called at the end of each epoch to print the epoch results.
"""

def on_epoch_end(self, epoch, logs=None):
"""
Prints the epoch results at the end of each epoch.

Args:
epoch (int): The current epoch number.
logs (dict): A dictionary containing the training and validation metrics.

Returns:
None
"""
epoch_loss_train = logs.get('train_loss')
epoch_metrics_train = logs.get('train_metrics', {})
epoch_metrics_valid = logs.get('val_metrics', {})
num_epochs = logs.get('num_epochs', 0)

print(f"Epoch {epoch+1}/{num_epochs}, Training loss: {epoch_loss_train:.4f}")

for metric_name, value in epoch_metrics_train.items():
print(f"Training {metric_name}: {value:.4f}")

if epoch_metrics_valid:
for metric_name, value in epoch_metrics_valid.items():
print(f"Validation {metric_name}: {value if value is not None else 'N/A'}")
4 changes: 1 addition & 3 deletions fine_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from torch.utils.data import DataLoader, random_split
from datasets.dataset import get_dataset
from datasets.transformations import get_transforms
from utils.metrics import Accuracy, Precision, Recall, F1Score
from utils.metrics import Accuracy, Precision
from models import get_model
from trainers import get_trainer
from os import path
Expand Down Expand Up @@ -62,7 +62,6 @@ def main(config_path):
train_loader=train_loader,
valid_loader=valid_loader,
num_epochs=config['training']['initial_epochs'],
log_path=log_filename,
plot_path=plot_filename
)

Expand All @@ -83,7 +82,6 @@ def main(config_path):
train_loader=train_loader,
valid_loader=valid_loader,
num_epochs=config['training']['fine_tuning_epochs'],
log_path=log_filename,
plot_path=plot_filename
)

Expand Down
15 changes: 13 additions & 2 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from datetime import datetime
from torch.utils.data import DataLoader, random_split
from datasets.dataset import get_dataset
from callbacks import EarlyStopping, CSVLogging, EpochResultsLogging
from datasets.transformations import get_transforms
from utils.metrics import Accuracy, Precision, Recall, F1Score
from models import get_model
Expand Down Expand Up @@ -58,12 +59,22 @@ def main(config_path):
metrics=metrics
)

callbacks = [
CSVLogging(log_filename),
EpochResultsLogging(),
EarlyStopping(
monitor='val_loss',
patience=2,
delta=0.1,
verbose=True
)
]

trainer.train(
train_loader=train_loader,
valid_loader=valid_loader,
num_epochs=config['training']['num_epochs'],
log_path=log_filename,
plot_path=plot_filename
callbacks=callbacks,
)

trainer.evaluate(data_loader=test_loader)
Expand Down
Loading