diff --git a/trainers/base_trainer.py b/trainers/base_trainer.py index f08244f..2dea095 100644 --- a/trainers/base_trainer.py +++ b/trainers/base_trainer.py @@ -60,9 +60,20 @@ def unfreeze_all_layers(self) -> None: @abstractmethod def _train_epoch(self, train_loader, epoch, num_epochs, verbose=True) -> float: - """ Train the model for one epoch. """ - raise NotImplementedError( - "The train_epoch method must be implemented by the subclass.") + """ + Trains the model for one epoch using the provided train_loader. + + Args: + train_loader (DataLoader): The data loader for training data. + epoch (int): The current epoch number. + num_epochs (int): The total number of epochs. + verbose (bool, optional): Whether to print training progress. Defaults to True. + + Returns: + float: The loss value for the epoch. + """ + raise NotImplementedError( + "The train_epoch method must be implemented by the subclass.") def train(self, train_loader, num_epochs, valid_loader=None, log_path=None, plot_path=None, verbose=True) -> None: """