From a3dc3ab052f002871e9191159715370cfeebd30d Mon Sep 17 00:00:00 2001 From: Pablo Olivares Date: Mon, 1 Apr 2024 15:31:00 +0200 Subject: [PATCH] Update base_trainer.py --- trainers/base_trainer.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/trainers/base_trainer.py b/trainers/base_trainer.py index f08244f..2dea095 100644 --- a/trainers/base_trainer.py +++ b/trainers/base_trainer.py @@ -60,9 +60,20 @@ def unfreeze_all_layers(self) -> None: @abstractmethod def _train_epoch(self, train_loader, epoch, num_epochs, verbose=True) -> float: - """ Train the model for one epoch. """ - raise NotImplementedError( - "The train_epoch method must be implemented by the subclass.") + """ + Trains the model for one epoch using the provided train_loader. + + Args: + train_loader (DataLoader): The data loader for training data. + epoch (int): The current epoch number. + num_epochs (int): The total number of epochs. + verbose (bool, optional): Whether to print training progress. Defaults to True. + + Returns: + float: The loss value for the epoch. + """ + raise NotImplementedError( + "The train_epoch method must be implemented by the subclass.") def train(self, train_loader, num_epochs, valid_loader=None, log_path=None, plot_path=None, verbose=True) -> None: """