Skip to content

Commit

Permalink
BaseTrainer turned abstract class closes #10
Browse files Browse the repository at this point in the history
  • Loading branch information
pab1s committed Mar 31, 2024
1 parent 06bcecf commit c4f7e5e
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion trainers/base_trainer.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import torch
from utils.plotting import plot_loss
from utils.logging import log_to_csv
from abc import ABC, abstractmethod
import time


class BaseTrainer():
class BaseTrainer(ABC):
def __init__(self, model, device):
self.model = model
self.device = device
Expand Down Expand Up @@ -37,6 +38,7 @@ def unfreeze_all_layers(self) -> None:
for param in self.model.parameters():
param.requires_grad = True

@abstractmethod
def _train_epoch(self, train_loader, epoch, num_epochs, verbose=True) -> float:
""" Train the model for one epoch. """
raise NotImplementedError(
Expand Down

0 comments on commit c4f7e5e

Please sign in to comment.