From d570dc8ddad27803d0868a609c0c5e4dec1c97c9 Mon Sep 17 00:00:00 2001 From: Pablo Olivares Date: Tue, 19 Mar 2024 00:49:23 +0100 Subject: [PATCH] Training time displayed advances #7 --- trainers/base_trainer.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/trainers/base_trainer.py b/trainers/base_trainer.py index 0ad5410..206e56b 100644 --- a/trainers/base_trainer.py +++ b/trainers/base_trainer.py @@ -1,6 +1,7 @@ import torch from utils.plotting import plot_loss from utils.logging import log_to_csv +import time class BaseTrainer: def __init__(self, model, device): @@ -25,6 +26,7 @@ def train(self, train_loader, num_epochs, log_path, plot_path, verbose=True): """ Train the model for a given number of epochs. """ epoch_losses = [] + start_time = time.time() for epoch in range(num_epochs): epoch_loss = self._train_epoch(train_loader, epoch, num_epochs, verbose) @@ -34,7 +36,9 @@ def train(self, train_loader, num_epochs, log_path, plot_path, verbose=True): print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}") log_to_csv(epoch_losses, log_path) + elapsed_time = time.time() - start_time if verbose: + print(f"Training completed in: {elapsed_time:.2f} seconds") plot_loss(epoch_losses, plot_path) def evaluate(self, test_loader, verbose=True):