-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #24 from pab1s/feature/callbacks
Feature/callbacks Completed
- Loading branch information
Showing
16 changed files
with
438 additions
and
149 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from callbacks.csv_logging import CSVLogging | ||
from callbacks.epoch_results_logging import EpochResultsLogging | ||
from callbacks.early_stopping import EarlyStopping | ||
from callbacks.checkpoint import Checkpoint |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
class Callback: | ||
""" | ||
A base class for defining callbacks in a training process. | ||
Callbacks are functions that can be executed at various stages during training. | ||
They can be used to perform additional actions or modify the behavior of the training process. | ||
Methods: | ||
should_continue(logs=None) -> bool: | ||
Determines whether the training process should continue or stop. | ||
on_epoch_begin(epoch, logs=None) -> None: | ||
Executed at the beginning of each epoch. | ||
on_epoch_end(epoch, logs=None) -> None: | ||
Executed at the end of each epoch. | ||
on_train_begin(logs=None) -> None: | ||
Executed at the beginning of the training process. | ||
on_train_end(logs=None) -> None: | ||
Executed at the end of the training process. | ||
""" | ||
|
||
def should_continue(self, logs=None) -> bool: | ||
""" | ||
Determines whether the training process should continue or stop. | ||
Args: | ||
logs (dict): Optional dictionary containing training logs. | ||
Returns: | ||
bool: True if the training process should continue, False otherwise. | ||
""" | ||
return True | ||
|
||
def on_epoch_begin(self, epoch, logs=None) -> None: | ||
""" | ||
Executed at the beginning of each epoch. | ||
Args: | ||
epoch (int): The current epoch number. | ||
logs (dict): Optional dictionary containing training logs. | ||
""" | ||
pass | ||
|
||
def on_epoch_end(self, epoch, logs=None) -> None: | ||
""" | ||
Executed at the end of each epoch. | ||
Args: | ||
epoch (int): The current epoch number. | ||
logs (dict): Optional dictionary containing training logs. | ||
""" | ||
pass | ||
|
||
def on_train_begin(self, logs=None) -> None: | ||
""" | ||
Executed at the beginning of the training process. | ||
Args: | ||
logs (dict): Optional dictionary containing training logs. | ||
""" | ||
pass | ||
|
||
def on_train_end(self, logs=None) -> None: | ||
""" | ||
Executed at the end of the training process. | ||
Args: | ||
logs (dict): Optional dictionary containing training logs. | ||
""" | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
from callbacks.callback import Callback | ||
import os | ||
import torch | ||
|
||
class Checkpoint(Callback): | ||
""" | ||
Callback class for saving model checkpoints during training. | ||
Args: | ||
checkpoint_dir (str): Directory to save the checkpoints. | ||
model (torch.nn.Module): The model to be saved. | ||
optimizer (torch.optim.Optimizer): The optimizer to be saved. | ||
scheduler (torch.optim.lr_scheduler._LRScheduler, optional): The scheduler to be saved. Default is None. | ||
save_freq (int, optional): Frequency of saving checkpoints. Default is 1. | ||
verbose (bool, optional): Whether to print the checkpoint save path. Default is False. | ||
""" | ||
|
||
def __init__(self, checkpoint_dir, model, optimizer, scheduler=None, save_freq=1, verbose=False): | ||
self.checkpoint_dir = checkpoint_dir | ||
self.model = model | ||
self.optimizer = optimizer | ||
self.scheduler = scheduler | ||
self.save_freq = save_freq | ||
self.verbose = verbose | ||
if not os.path.exists(checkpoint_dir): | ||
os.makedirs(checkpoint_dir, exist_ok=True) | ||
|
||
def on_epoch_end(self, epoch, logs=None): | ||
""" | ||
Callback function called at the end of each epoch. | ||
Args: | ||
epoch (int): The current epoch number. | ||
logs (dict, optional): Dictionary containing training and validation losses. Default is None. | ||
""" | ||
if (epoch + 1) % self.save_freq == 0: | ||
self.save_checkpoint(epoch, logs) | ||
|
||
def save_checkpoint(self, epoch, logs=None): | ||
""" | ||
Save the model checkpoint. | ||
Args: | ||
epoch (int): The current epoch number. | ||
logs (dict, optional): Dictionary containing training and validation losses. Default is None. | ||
""" | ||
state = { | ||
'epoch': epoch, | ||
'model_state_dict': self.model.state_dict(), | ||
'optimizer_state_dict': self.optimizer.state_dict(), | ||
'train_losses': logs.get('training_losses', []), | ||
'val_losses': logs.get('validation_losses', []), | ||
} | ||
if self.scheduler: | ||
state['scheduler_state_dict'] = self.scheduler.state_dict() | ||
|
||
save_path = os.path.join(self.checkpoint_dir, f'checkpoint_epoch_{epoch+1}.pth') | ||
torch.save(state, save_path) | ||
if self.verbose: | ||
print(f"Checkpoint saved at {save_path}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
from callbacks.callback import Callback | ||
import csv | ||
|
||
class CSVLogging(Callback): | ||
""" | ||
Callback for logging training and validation metrics to a CSV file. | ||
Args: | ||
csv_path (str): The path to the CSV file. | ||
Attributes: | ||
csv_path (str): The path to the CSV file. | ||
headers_written (bool): Flag indicating whether the headers have been written to the CSV file. | ||
""" | ||
|
||
def __init__(self, csv_path): | ||
self.csv_path = csv_path | ||
self.headers_written = False | ||
|
||
def on_epoch_end(self, epoch, logs=None): | ||
""" | ||
Method called at the end of each epoch during training. | ||
Args: | ||
epoch (int): The current epoch number. | ||
logs (dict): Dictionary containing the training and validation metrics. | ||
Returns: | ||
None | ||
""" | ||
if logs is None: | ||
return | ||
|
||
epoch_data = logs.get('epoch') | ||
train_loss = logs.get('train_loss') | ||
val_loss = logs.get('val_loss') | ||
train_metrics = logs.get('train_metrics', {}) | ||
val_metrics = logs.get('val_metrics', {}) | ||
|
||
metrics = {'train_loss': train_loss, 'val_loss': val_loss} | ||
metrics.update({f'train_{key}': value for key, value in train_metrics.items()}) | ||
metrics.update({f'val_{key}': value for key, value in val_metrics.items()}) | ||
|
||
if not self.headers_written: | ||
headers = ['epoch'] + list(metrics.keys()) | ||
with open(self.csv_path, 'w', newline='') as file: | ||
writer = csv.writer(file) | ||
writer.writerow(headers) | ||
self.headers_written = True | ||
|
||
values = [epoch_data] + [metrics[key] for key in headers[1:]] # Ensure the order matches headers | ||
with open(self.csv_path, 'a', newline='') as file: | ||
writer = csv.writer(file) | ||
writer.writerow(values) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
from callbacks.callback import Callback | ||
|
||
class EarlyStopping(Callback): | ||
def __init__(self, monitor='val_loss', patience=5, verbose=False, delta=0): | ||
self.monitor = monitor | ||
self.patience = patience | ||
self.verbose = verbose | ||
self.delta = delta | ||
self.best_score = None | ||
self.epochs_no_improve = 0 | ||
self.stopped_epoch = 0 | ||
self.early_stop = False | ||
|
||
def on_epoch_end(self, epoch, logs=None): | ||
current = logs.get(self.monitor) | ||
if current is None: | ||
return | ||
|
||
score = -current if 'loss' in self.monitor else current | ||
if self.best_score is None: | ||
self.best_score = score | ||
elif score < self.best_score + self.delta: | ||
self.epochs_no_improve += 1 | ||
if self.epochs_no_improve >= self.patience: | ||
self.early_stop = True | ||
self.stopped_epoch = epoch + 1 | ||
if self.verbose: | ||
print(f"Early stopping triggered at epoch {self.stopped_epoch}") | ||
else: | ||
self.best_score = score | ||
self.epochs_no_improve = 0 | ||
|
||
def should_continue(self, logs=None): | ||
return not self.early_stop |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
from callbacks.callback import Callback | ||
|
||
class EpochResultsLogging(Callback): | ||
""" | ||
Callback for logging epoch results during training. | ||
This callback prints the training loss, training metrics, and validation metrics at the end of each epoch. | ||
Args: | ||
Callback: The base class for Keras callbacks. | ||
Methods: | ||
on_epoch_end: Called at the end of each epoch to print the epoch results. | ||
""" | ||
|
||
def on_epoch_end(self, epoch, logs=None): | ||
""" | ||
Prints the epoch results at the end of each epoch. | ||
Args: | ||
epoch (int): The current epoch number. | ||
logs (dict): A dictionary containing the training and validation metrics. | ||
Returns: | ||
None | ||
""" | ||
epoch_loss_train = logs.get('train_loss') | ||
epoch_metrics_train = logs.get('train_metrics', {}) | ||
epoch_metrics_valid = logs.get('val_metrics', {}) | ||
num_epochs = logs.get('num_epochs', 0) | ||
|
||
print(f"Epoch {epoch+1}/{num_epochs}, Training loss: {epoch_loss_train:.4f}") | ||
|
||
for metric_name, value in epoch_metrics_train.items(): | ||
print(f"Training {metric_name}: {value:.4f}") | ||
|
||
if epoch_metrics_valid: | ||
for metric_name, value in epoch_metrics_valid.items(): | ||
print(f"Validation {metric_name}: {value if value is not None else 'N/A'}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.