From 3f066f6639e123652fab4d7961c95e00b234669b Mon Sep 17 00:00:00 2001 From: "lukyanov_kirya@bk.ru" Date: Wed, 20 Nov 2024 16:52:37 +0300 Subject: [PATCH] add models_utils.py and fix come bugs with tensors and device in gnn_models.py --- experiments/attack_defense_test.py | 38 ++++++++++++++------------- src/models_builder/gnn_models.py | 4 ++- src/models_builder/models_utils.py | 42 ++++++++++++++++++++++++++++++ 3 files changed, 65 insertions(+), 19 deletions(-) create mode 100644 src/models_builder/models_utils.py diff --git a/experiments/attack_defense_test.py b/experiments/attack_defense_test.py index 5ee2d53..6d2b6b6 100644 --- a/experiments/attack_defense_test.py +++ b/experiments/attack_defense_test.py @@ -2,9 +2,9 @@ import warnings - from torch import device +from models_builder.models_utils import apply_decorator_to_graph_layers from src.aux.utils import POISON_ATTACK_PARAMETERS_PATH, POISON_DEFENSE_PARAMETERS_PATH, EVASION_ATTACK_PARAMETERS_PATH, \ EVASION_DEFENSE_PARAMETERS_PATH from src.models_builder.gnn_models import FrameworkGNNModelManager, Metric @@ -18,15 +18,14 @@ def test_attack_defense(): - my_device = device('cuda' if torch.cuda.is_available() else 'cpu') full_name = None # full_name = ("multiple-graphs", "TUDataset", 'MUTAG') # full_name = ("single-graph", "custom", 'karate') - # full_name = ("single-graph", "Planetoid", 'Cora') - full_name = ("single-graph", "Amazon", 'Photo') + full_name = ("single-graph", "Planetoid", 'Cora') + # full_name = ("single-graph", "Amazon", 'Photo') # full_name = ("single-graph", "Planetoid", 'CiteSeer') # full_name = ("multiple-graphs", "TUDataset", 'PROTEINS') @@ -183,7 +182,7 @@ def test_attack_defense(): _import_path=EVASION_ATTACK_PARAMETERS_PATH, _config_class="EvasionAttackConfig", _config_kwargs={ - "node_idx": 0, # Node for attack + "node_idx": 0, # Node for attack "n_perturbations": 20, "perturb_features": True, "perturb_structure": True, @@ -192,12 +191,12 @@ def test_attack_defense(): } ) - netattackgroup_evasion_attack_config = ConfigPattern( + netattackgroup_evasion_attack_config = ConfigPattern( _class_name="NettackGroupEvasionAttacker", _import_path=EVASION_ATTACK_PARAMETERS_PATH, _config_class="EvasionAttackConfig", _config_kwargs={ - "node_idxs": [random.randint(0, 500) for _ in range(20)], # Nodes for attack + "node_idxs": [random.randint(0, 500) for _ in range(20)], # Nodes for attack "n_perturbations": 50, "perturb_features": True, "perturb_structure": True, @@ -215,7 +214,6 @@ def test_attack_defense(): } ) - fgsm_evasion_attack_config0 = ConfigPattern( _class_name="FGSM", _import_path=EVASION_ATTACK_PARAMETERS_PATH, @@ -230,14 +228,14 @@ def test_attack_defense(): _config_class="EvasionDefenseConfig", _config_kwargs={ "attack_name": None, - "attack_config": fgsm_evasion_attack_config0 # evasion_attack_config + "attack_config": fgsm_evasion_attack_config0 } ) # gnn_model_manager.set_poison_attacker(poison_attack_config=random_poison_attack_config) # gnn_model_manager.set_poison_defender(poison_defense_config=gnnguard_poison_defense_config) # gnn_model_manager.set_evasion_attacker(evasion_attack_config=netattackgroup_evasion_attack_config) - gnn_model_manager.set_evasion_defender(evasion_defense_config=at_evasion_defense_config) + # gnn_model_manager.set_evasion_defender(evasion_defense_config=at_evasion_defense_config) warnings.warn("Start training") dataset.train_test_split() @@ -265,6 +263,7 @@ def test_attack_defense(): Metric("Accuracy", mask='test')]) print(metric_loc) + def test_meta(): from attacks.metattack import meta_gradient_attack # my_device = device('cpu') @@ -336,6 +335,7 @@ def test_meta(): Metric("Accuracy", mask='test')]) print(metric_loc) + def test_nettack_evasion(): my_device = device('cpu') @@ -444,6 +444,7 @@ def test_nettack_evasion(): metrics=[Metric("Accuracy", mask=mask_loc)])[mask_loc]['Accuracy'] print(f"Accuracy on test loc: {acc_test_loc}") + def test_qattack(): from attacks.QAttack import qattack my_device = device('cpu') @@ -499,7 +500,6 @@ def test_qattack(): # acc_train = gnn_model_manager.evaluate_model(gen_dataset=dataset, # metrics=[Metric("Accuracy", mask='train')])['train']['Accuracy'] - acc_test = gnn_model_manager.evaluate_model(gen_dataset=dataset, metrics=[Metric("Accuracy", mask='test')])['test']['Accuracy'] # print(f"Accuracy on train: {acc_train}. Accuracy on test: {acc_test}") @@ -524,8 +524,7 @@ def test_qattack(): # Attack config - - #dataset = gnn_model_manager.evasion_attacker.attack(gnn_model_manager, dataset, None) + # dataset = gnn_model_manager.evasion_attacker.attack(gnn_model_manager, dataset, None) # Attack # gnn_model_manager.evaluate_model(gen_dataset=dataset, metrics=[Metric("F1", mask='test', average='macro')]) @@ -551,6 +550,7 @@ def test_qattack(): # print(f"info_before_evasion_attack: {info_before_evasion_attack}") # print(f"info_after_evasion_attack: {info_after_evasion_attack}") + def test_jaccard(): from defense.JaccardDefense import jaccard_def # my_device = device('cuda' if is_available() else 'cpu') @@ -778,6 +778,7 @@ def test_adv_training(): Metric("Accuracy", mask='test')]) print(metric_loc) + def test_pgd(): # ______________________ Attack on node ______________________ my_device = device('cpu') @@ -953,8 +954,9 @@ def test_pgd(): # Model prediction on a graph after PGD attack on it with torch.no_grad(): - probabilities = torch.exp(gnn_model_manager.gnn(gnn_model_manager.evasion_attacker.attack_diff.dataset[graph_idx].x, - gnn_model_manager.evasion_attacker.attack_diff.dataset[graph_idx].edge_index)) + probabilities = torch.exp( + gnn_model_manager.gnn(gnn_model_manager.evasion_attacker.attack_diff.dataset[graph_idx].x, + gnn_model_manager.evasion_attacker.attack_diff.dataset[graph_idx].edge_index)) predicted_class = probabilities.argmax().item() predicted_probability = probabilities[0][predicted_class].item() @@ -974,10 +976,10 @@ def test_pgd(): if __name__ == '__main__': import random + random.seed(10) - #test_attack_defense() + test_attack_defense() # torch.manual_seed(5000) # test_gnnguard() # test_jaccard() - # test_attack_defense() - test_pgd() + # test_pgd() diff --git a/src/models_builder/gnn_models.py b/src/models_builder/gnn_models.py index 1fad619..c906a32 100644 --- a/src/models_builder/gnn_models.py +++ b/src/models_builder/gnn_models.py @@ -72,6 +72,8 @@ def __init__(self, name, mask, **kwargs): def compute(self, y_true, y_pred): if self.name in Metric.available_metrics: + if y_true.device != "cpu": + y_true = y_true.cpu() return Metric.available_metrics[self.name](y_true, y_pred, **self.kwargs) raise NotImplementedError() @@ -1034,7 +1036,7 @@ def run_model(self, gen_dataset, mask='test', out='answers'): number_of_batches = ceil(mask_size / self.batch) # data_x_elem_len = data.x.size()[1] - full_out = torch.Tensor() + full_out = torch.empty(0, device=data.x.device) # features_mask_tensor = torch.full(size=data.x.size(), fill_value=True) for batch_ind in range(number_of_batches): diff --git a/src/models_builder/models_utils.py b/src/models_builder/models_utils.py new file mode 100644 index 0000000..d060e45 --- /dev/null +++ b/src/models_builder/models_utils.py @@ -0,0 +1,42 @@ +import torch +from torch_geometric.nn import MessagePassing + + +def apply_message_gradient_capture(layer, name): + """ + # Example how get Tensors + # for name, layer in self.gnn.named_children(): + # if isinstance(layer, MessagePassing): + # print(f"{name}: {layer.get_message_gradients()}") + """ + original_message = layer.message + layer.message_gradients = {} + + def capture_message_gradients(x_j, *args, **kwargs): + x_j = x_j.requires_grad_() + if not layer.training: + return original_message(x_j=x_j, *args, **kwargs) + + def save_message_grad(grad): + layer.message_gradients[name] = grad.detach() + x_j.register_hook(save_message_grad) + return original_message(x_j=x_j, *args, **kwargs) + layer.message = capture_message_gradients + + def get_message_gradients(): + return layer.message_gradients + layer.get_message_gradients = get_message_gradients + + +def apply_decorator_to_graph_layers(model): + # TODO Kirill add more options + """ + Example how use this def + apply_decorator_to_graph_layers(gnn) + """ + for name, layer in model.named_children(): + if isinstance(layer, MessagePassing): + apply_message_gradient_capture(layer, name) + elif isinstance(layer, torch.nn.Module): + apply_decorator_to_graph_layers(layer) +