diff --git a/trainers/base_trainer.py b/trainers/base_trainer.py index dd19048..bc48e6a 100644 --- a/trainers/base_trainer.py +++ b/trainers/base_trainer.py @@ -1,10 +1,11 @@ import torch from utils.plotting import plot_loss from utils.logging import log_to_csv +from abc import ABC, abstractmethod import time -class BaseTrainer(): +class BaseTrainer(ABC): def __init__(self, model, device): self.model = model self.device = device @@ -37,6 +38,7 @@ def unfreeze_all_layers(self) -> None: for param in self.model.parameters(): param.requires_grad = True + @abstractmethod def _train_epoch(self, train_loader, epoch, num_epochs, verbose=True) -> float: """ Train the model for one epoch. """ raise NotImplementedError(