Skip to content

Commit

Permalink
add models_utils.py and fix come bugs with tensors and device in gnn_…
Browse files Browse the repository at this point in the history
…models.py
  • Loading branch information
LukyanovKirillML committed Nov 20, 2024
1 parent c2a7a42 commit 3f066f6
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 19 deletions.
38 changes: 20 additions & 18 deletions experiments/attack_defense_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

import warnings


from torch import device

from models_builder.models_utils import apply_decorator_to_graph_layers
from src.aux.utils import POISON_ATTACK_PARAMETERS_PATH, POISON_DEFENSE_PARAMETERS_PATH, EVASION_ATTACK_PARAMETERS_PATH, \
EVASION_DEFENSE_PARAMETERS_PATH
from src.models_builder.gnn_models import FrameworkGNNModelManager, Metric
Expand All @@ -18,15 +18,14 @@


def test_attack_defense():

my_device = device('cuda' if torch.cuda.is_available() else 'cpu')

full_name = None

# full_name = ("multiple-graphs", "TUDataset", 'MUTAG')
# full_name = ("single-graph", "custom", 'karate')
# full_name = ("single-graph", "Planetoid", 'Cora')
full_name = ("single-graph", "Amazon", 'Photo')
full_name = ("single-graph", "Planetoid", 'Cora')
# full_name = ("single-graph", "Amazon", 'Photo')
# full_name = ("single-graph", "Planetoid", 'CiteSeer')
# full_name = ("multiple-graphs", "TUDataset", 'PROTEINS')

Expand Down Expand Up @@ -183,7 +182,7 @@ def test_attack_defense():
_import_path=EVASION_ATTACK_PARAMETERS_PATH,
_config_class="EvasionAttackConfig",
_config_kwargs={
"node_idx": 0, # Node for attack
"node_idx": 0, # Node for attack
"n_perturbations": 20,
"perturb_features": True,
"perturb_structure": True,
Expand All @@ -192,12 +191,12 @@ def test_attack_defense():
}
)

netattackgroup_evasion_attack_config = ConfigPattern(
netattackgroup_evasion_attack_config = ConfigPattern(
_class_name="NettackGroupEvasionAttacker",
_import_path=EVASION_ATTACK_PARAMETERS_PATH,
_config_class="EvasionAttackConfig",
_config_kwargs={
"node_idxs": [random.randint(0, 500) for _ in range(20)], # Nodes for attack
"node_idxs": [random.randint(0, 500) for _ in range(20)], # Nodes for attack
"n_perturbations": 50,
"perturb_features": True,
"perturb_structure": True,
Expand All @@ -215,7 +214,6 @@ def test_attack_defense():
}
)


fgsm_evasion_attack_config0 = ConfigPattern(
_class_name="FGSM",
_import_path=EVASION_ATTACK_PARAMETERS_PATH,
Expand All @@ -230,14 +228,14 @@ def test_attack_defense():
_config_class="EvasionDefenseConfig",
_config_kwargs={
"attack_name": None,
"attack_config": fgsm_evasion_attack_config0 # evasion_attack_config
"attack_config": fgsm_evasion_attack_config0
}
)

# gnn_model_manager.set_poison_attacker(poison_attack_config=random_poison_attack_config)
# gnn_model_manager.set_poison_defender(poison_defense_config=gnnguard_poison_defense_config)
# gnn_model_manager.set_evasion_attacker(evasion_attack_config=netattackgroup_evasion_attack_config)
gnn_model_manager.set_evasion_defender(evasion_defense_config=at_evasion_defense_config)
# gnn_model_manager.set_evasion_defender(evasion_defense_config=at_evasion_defense_config)

warnings.warn("Start training")
dataset.train_test_split()
Expand Down Expand Up @@ -265,6 +263,7 @@ def test_attack_defense():
Metric("Accuracy", mask='test')])
print(metric_loc)


def test_meta():
from attacks.metattack import meta_gradient_attack
# my_device = device('cpu')
Expand Down Expand Up @@ -336,6 +335,7 @@ def test_meta():
Metric("Accuracy", mask='test')])
print(metric_loc)


def test_nettack_evasion():
my_device = device('cpu')

Expand Down Expand Up @@ -444,6 +444,7 @@ def test_nettack_evasion():
metrics=[Metric("Accuracy", mask=mask_loc)])[mask_loc]['Accuracy']
print(f"Accuracy on test loc: {acc_test_loc}")


def test_qattack():
from attacks.QAttack import qattack
my_device = device('cpu')
Expand Down Expand Up @@ -499,7 +500,6 @@ def test_qattack():
# acc_train = gnn_model_manager.evaluate_model(gen_dataset=dataset,
# metrics=[Metric("Accuracy", mask='train')])['train']['Accuracy']


acc_test = gnn_model_manager.evaluate_model(gen_dataset=dataset,
metrics=[Metric("Accuracy", mask='test')])['test']['Accuracy']
# print(f"Accuracy on train: {acc_train}. Accuracy on test: {acc_test}")
Expand All @@ -524,8 +524,7 @@ def test_qattack():

# Attack config


#dataset = gnn_model_manager.evasion_attacker.attack(gnn_model_manager, dataset, None)
# dataset = gnn_model_manager.evasion_attacker.attack(gnn_model_manager, dataset, None)

# Attack
# gnn_model_manager.evaluate_model(gen_dataset=dataset, metrics=[Metric("F1", mask='test', average='macro')])
Expand All @@ -551,6 +550,7 @@ def test_qattack():
# print(f"info_before_evasion_attack: {info_before_evasion_attack}")
# print(f"info_after_evasion_attack: {info_after_evasion_attack}")


def test_jaccard():
from defense.JaccardDefense import jaccard_def
# my_device = device('cuda' if is_available() else 'cpu')
Expand Down Expand Up @@ -778,6 +778,7 @@ def test_adv_training():
Metric("Accuracy", mask='test')])
print(metric_loc)


def test_pgd():
# ______________________ Attack on node ______________________
my_device = device('cpu')
Expand Down Expand Up @@ -953,8 +954,9 @@ def test_pgd():

# Model prediction on a graph after PGD attack on it
with torch.no_grad():
probabilities = torch.exp(gnn_model_manager.gnn(gnn_model_manager.evasion_attacker.attack_diff.dataset[graph_idx].x,
gnn_model_manager.evasion_attacker.attack_diff.dataset[graph_idx].edge_index))
probabilities = torch.exp(
gnn_model_manager.gnn(gnn_model_manager.evasion_attacker.attack_diff.dataset[graph_idx].x,
gnn_model_manager.evasion_attacker.attack_diff.dataset[graph_idx].edge_index))

predicted_class = probabilities.argmax().item()
predicted_probability = probabilities[0][predicted_class].item()
Expand All @@ -974,10 +976,10 @@ def test_pgd():

if __name__ == '__main__':
import random

random.seed(10)
#test_attack_defense()
test_attack_defense()
# torch.manual_seed(5000)
# test_gnnguard()
# test_jaccard()
# test_attack_defense()
test_pgd()
# test_pgd()
4 changes: 3 additions & 1 deletion src/models_builder/gnn_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ def __init__(self, name, mask, **kwargs):

def compute(self, y_true, y_pred):
if self.name in Metric.available_metrics:
if y_true.device != "cpu":
y_true = y_true.cpu()
return Metric.available_metrics[self.name](y_true, y_pred, **self.kwargs)

raise NotImplementedError()
Expand Down Expand Up @@ -1034,7 +1036,7 @@ def run_model(self, gen_dataset, mask='test', out='answers'):

number_of_batches = ceil(mask_size / self.batch)
# data_x_elem_len = data.x.size()[1]
full_out = torch.Tensor()
full_out = torch.empty(0, device=data.x.device)
# features_mask_tensor = torch.full(size=data.x.size(), fill_value=True)

for batch_ind in range(number_of_batches):
Expand Down
42 changes: 42 additions & 0 deletions src/models_builder/models_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import torch
from torch_geometric.nn import MessagePassing


def apply_message_gradient_capture(layer, name):
"""
# Example how get Tensors
# for name, layer in self.gnn.named_children():
# if isinstance(layer, MessagePassing):
# print(f"{name}: {layer.get_message_gradients()}")
"""
original_message = layer.message
layer.message_gradients = {}

def capture_message_gradients(x_j, *args, **kwargs):
x_j = x_j.requires_grad_()
if not layer.training:
return original_message(x_j=x_j, *args, **kwargs)

def save_message_grad(grad):
layer.message_gradients[name] = grad.detach()
x_j.register_hook(save_message_grad)
return original_message(x_j=x_j, *args, **kwargs)
layer.message = capture_message_gradients

def get_message_gradients():
return layer.message_gradients
layer.get_message_gradients = get_message_gradients


def apply_decorator_to_graph_layers(model):
# TODO Kirill add more options
"""
Example how use this def
apply_decorator_to_graph_layers(gnn)
"""
for name, layer in model.named_children():
if isinstance(layer, MessagePassing):
apply_message_gradient_capture(layer, name)
elif isinstance(layer, torch.nn.Module):
apply_decorator_to_graph_layers(layer)

0 comments on commit 3f066f6

Please sign in to comment.