Skip to content

Commit

Permalink
Logging times supported closes #22
Browse files Browse the repository at this point in the history
  • Loading branch information
pab1s committed Apr 22, 2024
1 parent 5a1633a commit 1386025
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion utils/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import csv

def log_to_csv(training_losses, validation_losses, metric_values, csv_path) -> None:
def log_to_csv(training_losses, validation_losses, metric_values, times, csv_path) -> None:
"""
Logs the training and validation losses, along with metric values, to a CSV file.
Expand All @@ -14,13 +14,15 @@ def log_to_csv(training_losses, validation_losses, metric_values, csv_path) -> N
The keys are the names of the metrics, and the values are dictionaries
with 'train' and 'valid' keys, representing the metric values for training
and validation sets, respectively.
times (list): A list of times taken for each epoch.
csv_path (str): The path to the CSV file where the log will be saved.
"""
headers = ['epoch', 'train_loss', 'val_loss']
metric_names = list(metric_values.keys())
for name in metric_names:
headers.append(f'train_{name}')
headers.append(f'val_{name}')
headers.append('time')

with open(csv_path, 'w', newline='') as csv_file:
writer = csv.writer(csv_file)
Expand All @@ -39,6 +41,8 @@ def log_to_csv(training_losses, validation_losses, metric_values, csv_path) -> N

row.append(train_metrics[epoch] if epoch < len(train_metrics) else 'N/A')
row.append(valid_metrics[epoch] if epoch < len(valid_metrics) else 'N/A')

row.append(times[epoch])

writer.writerow(row)

Expand Down

0 comments on commit 1386025

Please sign in to comment.