Skip to content

Commit

Permalink
Update base_trainer.py
Browse files Browse the repository at this point in the history
  • Loading branch information
pab1s committed Apr 1, 2024
1 parent c737377 commit a3dc3ab
Showing 1 changed file with 14 additions and 3 deletions.
17 changes: 14 additions & 3 deletions trainers/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,20 @@ def unfreeze_all_layers(self) -> None:

@abstractmethod
def _train_epoch(self, train_loader, epoch, num_epochs, verbose=True) -> float:
""" Train the model for one epoch. """
raise NotImplementedError(
"The train_epoch method must be implemented by the subclass.")
"""
Trains the model for one epoch using the provided train_loader.
Args:
train_loader (DataLoader): The data loader for training data.
epoch (int): The current epoch number.
num_epochs (int): The total number of epochs.
verbose (bool, optional): Whether to print training progress. Defaults to True.
Returns:
float: The loss value for the epoch.
"""
raise NotImplementedError(
"The train_epoch method must be implemented by the subclass.")

def train(self, train_loader, num_epochs, valid_loader=None, log_path=None, plot_path=None, verbose=True) -> None:
"""
Expand Down

0 comments on commit a3dc3ab

Please sign in to comment.