Skip to content

Commit

Permalink
Training time displayed advances #7
Browse files Browse the repository at this point in the history
  • Loading branch information
pab1s committed Mar 18, 2024
1 parent 4010459 commit d570dc8
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions trainers/base_trainer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch
from utils.plotting import plot_loss
from utils.logging import log_to_csv
import time

class BaseTrainer:
def __init__(self, model, device):
Expand All @@ -25,6 +26,7 @@ def train(self, train_loader, num_epochs, log_path, plot_path, verbose=True):
""" Train the model for a given number of epochs. """

epoch_losses = []
start_time = time.time()

for epoch in range(num_epochs):
epoch_loss = self._train_epoch(train_loader, epoch, num_epochs, verbose)
Expand All @@ -34,7 +36,9 @@ def train(self, train_loader, num_epochs, log_path, plot_path, verbose=True):
print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}")
log_to_csv(epoch_losses, log_path)

elapsed_time = time.time() - start_time
if verbose:
print(f"Training completed in: {elapsed_time:.2f} seconds")
plot_loss(epoch_losses, plot_path)

def evaluate(self, test_loader, verbose=True):
Expand Down

0 comments on commit d570dc8

Please sign in to comment.