From c4f7e5e9f1cd943eba6cf151316f87f6ba0d9d4f Mon Sep 17 00:00:00 2001 From: Pablo Olivares Date: Sun, 31 Mar 2024 03:19:40 +0200 Subject: [PATCH] `BaseTrainer` turned abstract class closes #10 --- trainers/base_trainer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/trainers/base_trainer.py b/trainers/base_trainer.py index dd19048..bc48e6a 100644 --- a/trainers/base_trainer.py +++ b/trainers/base_trainer.py @@ -1,10 +1,11 @@ import torch from utils.plotting import plot_loss from utils.logging import log_to_csv +from abc import ABC, abstractmethod import time -class BaseTrainer(): +class BaseTrainer(ABC): def __init__(self, model, device): self.model = model self.device = device @@ -37,6 +38,7 @@ def unfreeze_all_layers(self) -> None: for param in self.model.parameters(): param.requires_grad = True + @abstractmethod def _train_epoch(self, train_loader, epoch, num_epochs, verbose=True) -> float: """ Train the model for one epoch. """ raise NotImplementedError(