From 1386025f66baf31f353406a0d5e7ee14f45a3b95 Mon Sep 17 00:00:00 2001 From: Pablo Olivares Date: Mon, 22 Apr 2024 12:59:05 +0200 Subject: [PATCH] Logging times supported closes #22 --- utils/logging.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/utils/logging.py b/utils/logging.py index 726d6f2..84adb9b 100644 --- a/utils/logging.py +++ b/utils/logging.py @@ -3,7 +3,7 @@ import csv -def log_to_csv(training_losses, validation_losses, metric_values, csv_path) -> None: +def log_to_csv(training_losses, validation_losses, metric_values, times, csv_path) -> None: """ Logs the training and validation losses, along with metric values, to a CSV file. @@ -14,6 +14,7 @@ def log_to_csv(training_losses, validation_losses, metric_values, csv_path) -> N The keys are the names of the metrics, and the values are dictionaries with 'train' and 'valid' keys, representing the metric values for training and validation sets, respectively. + times (list): A list of times taken for each epoch. csv_path (str): The path to the CSV file where the log will be saved. """ headers = ['epoch', 'train_loss', 'val_loss'] @@ -21,6 +22,7 @@ def log_to_csv(training_losses, validation_losses, metric_values, csv_path) -> N for name in metric_names: headers.append(f'train_{name}') headers.append(f'val_{name}') + headers.append('time') with open(csv_path, 'w', newline='') as csv_file: writer = csv.writer(csv_file) @@ -39,6 +41,8 @@ def log_to_csv(training_losses, validation_losses, metric_values, csv_path) -> N row.append(train_metrics[epoch] if epoch < len(train_metrics) else 'N/A') row.append(valid_metrics[epoch] if epoch < len(valid_metrics) else 'N/A') + + row.append(times[epoch]) writer.writerow(row)