From 593a21d83286966d3d94e32c49d75dbd63373be9 Mon Sep 17 00:00:00 2001 From: Pablo Olivares Date: Mon, 22 Apr 2024 19:50:53 +0200 Subject: [PATCH] Created `NaiveTrainer` for test purposes advances #2 --- trainers/__init__.py | 3 +++ trainers/naive_trainer.py | 34 ++++++++++++++++++++++++++++++++++ 2 files changed, 37 insertions(+) create mode 100644 trainers/naive_trainer.py diff --git a/trainers/__init__.py b/trainers/__init__.py index e5d9ec2..87299e4 100644 --- a/trainers/__init__.py +++ b/trainers/__init__.py @@ -1,4 +1,5 @@ from trainers.basic_trainer import BasicTrainer +from trainers.naive_trainer import NaiveTrainer def get_trainer(trainer_name, **kwargs): """ @@ -16,5 +17,7 @@ def get_trainer(trainer_name, **kwargs): """ if trainer_name == "BasicTrainer": return BasicTrainer(**kwargs) + if trainer_name == "NaiveTrainer": + return NaiveTrainer(**kwargs) else: raise ValueError(f"Trainer {trainer_name} not recognized.") diff --git a/trainers/naive_trainer.py b/trainers/naive_trainer.py new file mode 100644 index 0000000..0f9e574 --- /dev/null +++ b/trainers/naive_trainer.py @@ -0,0 +1,34 @@ +from trainers.base_trainer import BaseTrainer + + +class NaiveTrainer(BaseTrainer): + """ + A complete naive trainer class for training a model. It main purpose is to + be used as a boilerplate to test some functionalities of the trainer. + + Args: + model (nn.Module): The model to be trained. + device (torch.device): The device to be used for training. + + Attributes: + model (nn.Module): The model to be trained. + device (torch.device): The device to be used for training. + """ + + def __init__(self, model, device): + super().__init__(model, device) + + def _train_epoch(self, train_loader, epoch, num_epochs, verbose=True) -> float: + """ + Simulates training the model for one epoch. It actually does nothing. + + Args: + train_loader (DataLoader): The data loader for training data. + epoch (int): The current epoch number. + num_epochs (int): The total number of epochs. + verbose (bool, optional): Whether to display progress bar. Defaults to True. + + Returns: + float: A mock loss value for the epoch of 0.0. + """ + return 0.0 \ No newline at end of file