-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathvisualizer.py
31 lines (27 loc) · 877 Bytes
/
visualizer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from matplotlib import pyplot as plt
def visualize_metrics(skill_net, epoch_list, acc_list, auc_list, rmse_list):
plt.figure(figsize=(10, 6))
# Plot AUC
plt.subplot(3, 1, 1)
plt.plot(epoch_list, auc_list, label='AUC')
plt.xlabel('Epoch')
plt.ylabel('AUC')
plt.title('AUC Over Epochs')
plt.grid(True)
# Plot ACC
plt.subplot(3, 1, 2)
plt.plot(epoch_list, acc_list, label='ACC')
plt.xlabel('Epoch')
plt.ylabel('ACC')
plt.title('ACC Over Epochs')
plt.grid(True)
# Plot RMSE
plt.subplot(3, 1, 3)
plt.plot(epoch_list, rmse_list, label='RMSE')
plt.xlabel('Epoch')
plt.ylabel('RMSE')
plt.title('RMSE Over Epochs')
plt.grid(True)
plt.tight_layout() # Adjust spacing between subplots
# Save the plot
plt.savefig(f"plots/model_performance_{skill_net.tag}.png") # Customize filename