Skip to content

Commit

Permalink
plotter
Browse files Browse the repository at this point in the history
  • Loading branch information
Jaap-Meerhof committed Oct 31, 2023
1 parent d6ee391 commit e468590
Showing 1 changed file with 13 additions and 13 deletions.
26 changes: 13 additions & 13 deletions src/SFXGBoost/view/plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def plot_acc_one(data, labels, destination='plot.png', name='Sample Text', subte
plt.legend(loc='best')
plt.grid(True)
# plt.show()
plt.savefig(destination, format="jpeg", dpi=1200, bbox_inches='tight')
plt.savefig(destination, format="pdf", dpi=1200, bbox_inches='tight')

def plot_experiment(all_data:dict, experiment_number:int):
"""Plots accuracy and precision of the attack in one plot with two subplots
Expand All @@ -38,7 +38,7 @@ def plot_experiment(all_data:dict, experiment_number:int):
All_data = {name_network: {dataset:{precision,...,metric}}}
Args:
all_data (_type_): _description_
destination (str, optional): _description_. Defaults to f"Plots/experiment{experiment_number}.jpeg".
destination (str, optional): _description_. Defaults to f"Plots/experiment{experiment_number}.pdf".
"""
data_acc = {}
data_prec = {}
Expand All @@ -61,14 +61,14 @@ def plot_experiment(all_data:dict, experiment_number:int):
data_overfitting[name].append(overfitting)

datasets = list(all_data[list(all_data.keys())[0]].keys())
plot_histogram(datasets, data_acc, title="accuracy attack", y_label="accuracy", destination=f"Plots/experiment{experiment_number}_acc_attack.jpeg")
plot_histogram(datasets, data_prec, title="precision attack", y_label="precision", destination=f"Plots/experiment{experiment_number}_precision.jpeg")
plot_histogram(datasets, data_acc_test, title="accuracy test target", y_label="accuracy", destination=f"Plots/experiment{experiment_number}_acc_test.jpeg")
plot_histogram(datasets, data_overfitting, title="overfitting target", y_label="overfitting", destination=f"Plots/experiment{experiment_number}_overfitting.jpeg")
plot_histogram(datasets, data_acc, title="accuracy attack", y_label="accuracy", destination=f"Plots/experiment{experiment_number}_acc_attack.pdf")
plot_histogram(datasets, data_prec, title="precision attack", y_label="precision", destination=f"Plots/experiment{experiment_number}_precision.pdf")
plot_histogram(datasets, data_acc_test, title="accuracy test target", y_label="accuracy", destination=f"Plots/experiment{experiment_number}_acc_test.pdf")
plot_histogram(datasets, data_overfitting, title="overfitting target", y_label="overfitting", destination=f"Plots/experiment{experiment_number}_overfitting.pdf")



def plot_histogram(datasets, data, title="Sample text", y_label="y_label", destination="Plots/experiment2.jpeg"):
def plot_histogram(datasets, data, title="Sample text", y_label="y_label", destination="Plots/experiment2.pdf"):
"""_summary_
example
datasets = ("Healthcare", "MNIST")
Expand All @@ -79,13 +79,13 @@ def plot_histogram(datasets, data, title="Sample text", y_label="y_label", desti
Args:
datasets (_type_): _description_
data (_type_): _description_
destination (str, optional): _description_. Defaults to "Plots/experiment2.jpeg".
destination (str, optional): _description_. Defaults to "Plots/experiment2.pdf".
"""
from datetime import date
import time
day = date.today().strftime("%b-%d-%Y")
curTime = time.strftime("%H:%M", time.localtime())
destination = destination.replace(".jpeg", f"{day},{curTime}.jpeg")
destination = destination.replace(".pdf", f"{day},{curTime}.pdf")
width = 0.25
multiplier = 0
one_value = False
Expand All @@ -107,10 +107,10 @@ def plot_histogram(datasets, data, title="Sample text", y_label="y_label", desti
ax.set_ylim(0, 1.1)
else:
ax.set_ylim(0, 1)
plt.savefig(destination, dpi=1200, format='jpeg')
plt.savefig(destination, dpi=1200, format='pdf')
# plt.show()

def plot_auc(y_true, y_pred, destination="Plots/experiment2_AUC_attack2.jpeg"):
def plot_auc(y_true, y_pred, destination="Plots/experiment2_AUC_attack2.pdf"):
"""plots a AUC curve that will also display the best threshold in the legend.
thanks chatGPT for creating this
Expand All @@ -122,7 +122,7 @@ def plot_auc(y_true, y_pred, destination="Plots/experiment2_AUC_attack2.jpeg"):
import time
day = date.today().strftime("%b-%d-%Y")
curTime = time.strftime("%H:%M", time.localtime())
destination = destination.replace(".jpeg", f"{day},{curTime}.jpeg")
destination = destination.replace(".pdf", f"{day},{curTime}.pdf")

fpr, tpr, thresholds = roc_curve(y_true, y_pred)

Expand All @@ -141,4 +141,4 @@ def plot_auc(y_true, y_pred, destination="Plots/experiment2_AUC_attack2.jpeg"):
plt.ylabel('True Positive Rate')
plt.title('Receiver Operating Characteristic (ROC) Curve')
plt.legend(loc='lower right')
plt.savefig(destination, dpi=1200, format='jpeg')
plt.savefig(destination, dpi=1200, format='pdf')

0 comments on commit e468590

Please sign in to comment.