diff --git a/experiments/attack_defense_metric_test.py b/experiments/attack_defense_metric_test.py index 7bdc547..e2bd79e 100644 --- a/experiments/attack_defense_metric_test.py +++ b/experiments/attack_defense_metric_test.py @@ -1,3 +1,4 @@ +import copy import warnings import torch @@ -140,7 +141,7 @@ def attack_defense_metrics(): # print(metric_loc) adm = FrameworkAttackDefenseManager( - gen_dataset=dataset, + gen_dataset=copy.deepcopy(dataset), gnn_manager=gnn_model_manager, ) # adm.evasion_attack_pipeline( diff --git a/src/models_builder/attack_defense_manager.py b/src/models_builder/attack_defense_manager.py index 3b70097..4e0f98d 100644 --- a/src/models_builder/attack_defense_manager.py +++ b/src/models_builder/attack_defense_manager.py @@ -1,3 +1,4 @@ +import copy import json import os import warnings @@ -89,14 +90,15 @@ def evasion_attack_pipeline( self.gnn_manager.modification.epochs = 0 self.gnn_manager.gnn.reset_parameters() from models_builder.gnn_models import Metric + local_gen_dataset_copy = copy.deepcopy(self.gen_dataset) self.gnn_manager.train_model( - gen_dataset=self.gen_dataset, + gen_dataset=local_gen_dataset_copy, steps=steps, - save_model_flag=save_model_flag, + save_model_flag=False, metrics=[Metric("F1", mask='train', average=None)] ) y_predict_clean = self.gnn_manager.run_model( - gen_dataset=self.gen_dataset, + gen_dataset=local_gen_dataset_copy, mask=mask, out='logits', ) @@ -105,17 +107,17 @@ def evasion_attack_pipeline( self.gnn_manager.modification.epochs = 0 self.gnn_manager.gnn.reset_parameters() self.gnn_manager.train_model( - gen_dataset=self.gen_dataset, + gen_dataset=local_gen_dataset_copy, steps=steps, save_model_flag=save_model_flag, metrics=[Metric("F1", mask='train', average=None)] ) self.gnn_manager.call_evasion_attack( - gen_dataset=self.gen_dataset, + gen_dataset=local_gen_dataset_copy, mask=mask, ) y_predict_attack = self.gnn_manager.run_model( - gen_dataset=self.gen_dataset, + gen_dataset=local_gen_dataset_copy, mask=mask, out='logits', ) @@ -152,14 +154,15 @@ def poison_attack_pipeline( self.gnn_manager.modification.epochs = 0 self.gnn_manager.gnn.reset_parameters() from models_builder.gnn_models import Metric + local_gen_dataset_copy = copy.deepcopy(self.gen_dataset) self.gnn_manager.train_model( - gen_dataset=self.gen_dataset, + gen_dataset=local_gen_dataset_copy, steps=steps, save_model_flag=False, metrics=[Metric("F1", mask='train', average=None)] ) y_predict_clean = self.gnn_manager.run_model( - gen_dataset=self.gen_dataset, + gen_dataset=local_gen_dataset_copy, mask=mask, out='logits', ) @@ -168,13 +171,13 @@ def poison_attack_pipeline( self.gnn_manager.modification.epochs = 0 self.gnn_manager.gnn.reset_parameters() self.gnn_manager.train_model( - gen_dataset=self.gen_dataset, + gen_dataset=local_gen_dataset_copy, steps=steps, save_model_flag=save_model_flag, metrics=[Metric("F1", mask='train', average=None)] ) y_predict_attack = self.gnn_manager.run_model( - gen_dataset=self.gen_dataset, + gen_dataset=local_gen_dataset_copy, mask=mask, out='logits', )