diff --git a/experiments/attack_defense_test.py b/experiments/attack_defense_test.py index 13f1c6e..5ee2d53 100644 --- a/experiments/attack_defense_test.py +++ b/experiments/attack_defense_test.py @@ -778,6 +778,200 @@ def test_adv_training(): Metric("Accuracy", mask='test')]) print(metric_loc) +def test_pgd(): + # ______________________ Attack on node ______________________ + my_device = device('cpu') + + # Load dataset + full_name = ("single-graph", "Planetoid", 'Cora') + dataset, data, results_dataset_path = DatasetManager.get_by_full_name( + full_name=full_name, + dataset_ver_ind=0 + ) + + gcn_gcn = model_configs_zoo(dataset=dataset, model_name='gcn_gcn') + + manager_config = ConfigPattern( + _config_class="ModelManagerConfig", + _config_kwargs={ + "mask_features": [], + "optimizer": { + "_class_name": "Adam", + "_config_kwargs": {}, + } + } + ) + + gnn_model_manager = FrameworkGNNModelManager( + gnn=gcn_gcn, + dataset_path=results_dataset_path, + manager_config=manager_config, + modification=ModelModificationConfig(model_ver_ind=0, epochs=0) + ) + + gnn_model_manager.gnn.to(my_device) + + num_steps = 200 + gnn_model_manager.train_model(gen_dataset=dataset, + steps=num_steps, + save_model_flag=False) + + acc_test = gnn_model_manager.evaluate_model(gen_dataset=dataset, + metrics=[Metric("Accuracy", mask='test')])['test']['Accuracy'] + print(f"Accuracy on test: {acc_test}") + + # Node for attack + node_idx = 650 + + # Model prediction on a node before PGD attack on it + gnn_model_manager.gnn.eval() + with torch.no_grad(): + probabilities = torch.exp(gnn_model_manager.gnn(dataset.data.x, dataset.data.edge_index)) + + predicted_class = probabilities[node_idx].argmax().item() + predicted_probability = probabilities[node_idx][predicted_class].item() + real_class = dataset.data.y[node_idx].item() + + info_before_pgd_attack_on_node = {"node_idx": node_idx, + "predicted_class": predicted_class, + "predicted_probability": predicted_probability, + "real_class": real_class} + + # Attack config + evasion_attack_config = ConfigPattern( + _class_name="PGD", + _import_path=EVASION_ATTACK_PARAMETERS_PATH, + _config_class="EvasionAttackConfig", + _config_kwargs={ + "is_feature_attack": True, + "element_idx": node_idx, + "epsilon": 0.1, + "learning_rate": 0.001, + "num_iterations": 500, + "num_rand_trials": 100 + } + ) + + gnn_model_manager.set_evasion_attacker(evasion_attack_config=evasion_attack_config) + + # Attack + _ = gnn_model_manager.evaluate_model(gen_dataset=dataset, + metrics=[Metric("Accuracy", mask='test')])['test']['Accuracy'] + + # Model prediction on a node after PGD attack on it + with torch.no_grad(): + probabilities = torch.exp(gnn_model_manager.gnn(gnn_model_manager.evasion_attacker.attack_diff.data.x, + gnn_model_manager.evasion_attacker.attack_diff.data.edge_index)) + + predicted_class = probabilities[node_idx].argmax().item() + predicted_probability = probabilities[node_idx][predicted_class].item() + real_class = dataset.data.y[node_idx].item() + + info_after_pgd_attack_on_node = {"node_idx": node_idx, + "predicted_class": predicted_class, + "predicted_probability": predicted_probability, + "real_class": real_class} + # ____________________________________________________________ + + # ______________________ Attack on graph _____________________ + # Load dataset + full_name = ("multiple-graphs", "TUDataset", 'MUTAG') + dataset, data, results_dataset_path = DatasetManager.get_by_full_name( + full_name=full_name, + dataset_ver_ind=0 + ) + + model = model_configs_zoo(dataset=dataset, model_name='gin_gin_gin_lin_lin_con') + + manager_config = ConfigPattern( + _config_class="ModelManagerConfig", + _config_kwargs={ + "mask_features": [], + "optimizer": { + "_class_name": "Adam", + "_config_kwargs": {}, + } + } + ) + + gnn_model_manager = FrameworkGNNModelManager( + gnn=model, + dataset_path=results_dataset_path, + manager_config=manager_config, + modification=ModelModificationConfig(model_ver_ind=0, epochs=0) + ) + + gnn_model_manager.gnn.to(my_device) + + num_steps = 200 + gnn_model_manager.train_model(gen_dataset=dataset, + steps=num_steps, + save_model_flag=False) + + acc_test = gnn_model_manager.evaluate_model(gen_dataset=dataset, + metrics=[Metric("Accuracy", mask='test')])['test']['Accuracy'] + print(f"Accuracy on test: {acc_test}") + + # Graph for attack + graph_idx = 0 + + # Model prediction on a graph before PGD attack on it + gnn_model_manager.gnn.eval() + with torch.no_grad(): + probabilities = torch.exp(gnn_model_manager.gnn(dataset.dataset[graph_idx].x, + dataset.dataset[graph_idx].edge_index)) + + predicted_class = probabilities.argmax().item() + predicted_probability = probabilities[0][predicted_class].item() + real_class = dataset.dataset[graph_idx].y.item() + + info_before_pgd_attack_on_graph = {"graph_idx": graph_idx, + "predicted_class": predicted_class, + "predicted_probability": predicted_probability, + "real_class": real_class} + + # Attack config + evasion_attack_config = ConfigPattern( + _class_name="PGD", + _import_path=EVASION_ATTACK_PARAMETERS_PATH, + _config_class="EvasionAttackConfig", + _config_kwargs={ + "is_feature_attack": True, + "element_idx": graph_idx, + "epsilon": 0.1, + "learning_rate": 0.001, + "num_iterations": 500, + "num_rand_trials": 100 + } + ) + + gnn_model_manager.set_evasion_attacker(evasion_attack_config=evasion_attack_config) + + # Attack + _ = gnn_model_manager.evaluate_model(gen_dataset=dataset, + metrics=[Metric("Accuracy", mask='test')])['test']['Accuracy'] + + # Model prediction on a graph after PGD attack on it + with torch.no_grad(): + probabilities = torch.exp(gnn_model_manager.gnn(gnn_model_manager.evasion_attacker.attack_diff.dataset[graph_idx].x, + gnn_model_manager.evasion_attacker.attack_diff.dataset[graph_idx].edge_index)) + + predicted_class = probabilities.argmax().item() + predicted_probability = probabilities[0][predicted_class].item() + real_class = dataset.dataset[graph_idx].y.item() + + info_after_pgd_attack_on_graph = {"graph_idx": graph_idx, + "predicted_class": predicted_class, + "predicted_probability": predicted_probability, + "real_class": real_class} + + # ____________________________________________________________ + print(f"Before PGD attack on node (Cora dataset): {info_before_pgd_attack_on_node}") + print(f"After PGD attack on node (Cora dataset): {info_after_pgd_attack_on_node}") + print(f"Before PGD attack on graph (MUTAG dataset): {info_before_pgd_attack_on_graph}") + print(f"After PGD attack on graph (MUTAG dataset): {info_after_pgd_attack_on_graph}") + + if __name__ == '__main__': import random random.seed(10) @@ -785,4 +979,5 @@ def test_adv_training(): # torch.manual_seed(5000) # test_gnnguard() # test_jaccard() - test_attack_defense() + # test_attack_defense() + test_pgd() diff --git a/experiments/interpretation_metrics_test.py b/experiments/interpretation_metrics_test.py new file mode 100644 index 0000000..8b01d9e --- /dev/null +++ b/experiments/interpretation_metrics_test.py @@ -0,0 +1,119 @@ +import random +import warnings + +import torch + +from aux.custom_decorators import timing_decorator +from aux.utils import EXPLAINERS_LOCAL_RUN_PARAMETERS_PATH, EXPLAINERS_INIT_PARAMETERS_PATH +from explainers.explainers_manager import FrameworkExplainersManager +from models_builder.gnn_models import FrameworkGNNModelManager, Metric +from src.aux.configs import ModelModificationConfig, ConfigPattern +from src.base.datasets_processing import DatasetManager +from src.models_builder.models_zoo import model_configs_zoo + + +@timing_decorator +def run_interpretation_test(): + full_name = ("single-graph", "Planetoid", 'Cora') + steps_epochs = 10 + save_model_flag = False + my_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + dataset, data, results_dataset_path = DatasetManager.get_by_full_name( + full_name=full_name, + dataset_ver_ind=0 + ) + gnn = model_configs_zoo(dataset=dataset, model_name='gcn_gcn') + manager_config = ConfigPattern( + _config_class="ModelManagerConfig", + _config_kwargs={ + "mask_features": [], + "optimizer": { + # "_config_class": "Config", + "_class_name": "Adam", + # "_import_path": OPTIMIZERS_PARAMETERS_PATH, + # "_class_import_info": ["torch.optim"], + "_config_kwargs": {}, + } + } + ) + gnn_model_manager = FrameworkGNNModelManager( + gnn=gnn, + dataset_path=results_dataset_path, + manager_config=manager_config, + modification=ModelModificationConfig(model_ver_ind=0, epochs=steps_epochs) + ) + gnn_model_manager.gnn.to(my_device) + data.x = data.x.float() + data = data.to(my_device) + + warnings.warn("Start training") + try: + raise FileNotFoundError() + except FileNotFoundError: + gnn_model_manager.epochs = gnn_model_manager.modification.epochs = 0 + train_test_split_path = gnn_model_manager.train_model(gen_dataset=dataset, steps=steps_epochs, + save_model_flag=save_model_flag, + metrics=[Metric("F1", mask='train', average=None)]) + + if train_test_split_path is not None: + dataset.save_train_test_mask(train_test_split_path) + train_mask, val_mask, test_mask, train_test_sizes = torch.load(train_test_split_path / 'train_test_split')[ + :] + dataset.train_mask, dataset.val_mask, dataset.test_mask = train_mask, val_mask, test_mask + data.percent_train_class, data.percent_test_class = train_test_sizes + warnings.warn("Training was successful") + + metric_loc = gnn_model_manager.evaluate_model( + gen_dataset=dataset, metrics=[Metric("F1", mask='test', average='macro')]) + print(metric_loc) + + explainer_init_config = ConfigPattern( + _class_name="GNNExplainer(torch-geom)", + _import_path=EXPLAINERS_INIT_PARAMETERS_PATH, + _config_class="ExplainerInitConfig", + _config_kwargs={ + "epochs": 10 + } + ) + explainer_metrics_run_config = ConfigPattern( + _config_class="ExplainerRunConfig", + _config_kwargs={ + "mode": "local", + "kwargs": { + "_class_name": "GNNExplainer(torch-geom)", + "_import_path": EXPLAINERS_LOCAL_RUN_PARAMETERS_PATH, + "_config_class": "Config", + "_config_kwargs": { + "stability_graph_perturbations_nums": 10, + "stability_feature_change_percent": 0.05, + "stability_node_removal_percent": 0.05, + "consistency_num_explanation_runs": 10 + }, + } + } + ) + + explainer_GNNExpl = FrameworkExplainersManager( + init_config=explainer_init_config, + dataset=dataset, gnn_manager=gnn_model_manager, + explainer_name='GNNExplainer(torch-geom)', + ) + + num_explaining_nodes = 10 + node_indices = random.sample(range(dataset.data.x.shape[0]), num_explaining_nodes) + + # explainer_GNNExpl.explainer.pbar = ProgressBar(socket, "er", desc=f'{explainer_GNNExpl.explainer.name} explaining') + # explanation_metric = NodesExplainerMetric( + # model=explainer_GNNExpl.gnn, + # graph=explainer_GNNExpl.gen_dataset.data, + # explainer=explainer_GNNExpl.explainer + # ) + # res = explanation_metric.evaluate(node_indices) + explanation_metrics = explainer_GNNExpl.evaluate_metrics(node_indices, explainer_metrics_run_config) + print(explanation_metrics) + + +if __name__ == '__main__': + random.seed(11) + run_interpretation_test() diff --git a/metainfo/evasion_attack_parameters.json b/metainfo/evasion_attack_parameters.json index 27b3c19..03ea7be 100644 --- a/metainfo/evasion_attack_parameters.json +++ b/metainfo/evasion_attack_parameters.json @@ -26,5 +26,12 @@ "generations" : ["Generations", "int", 50, {"min": 0, "step": 1}, "Number of generations for genetic algorithm"], "prob_cross": ["Probability for crossover", "float", 0.5, {"min": 0, "max": 1, "step": 0.01}, "Probability of crossover between two genes"], "prob_mutate": ["Probability for mutation", "float", 0.02, {"min": 0, "max": 1, "step": 0.01}, "Probability of gene mutation"] + }, + "PGD": { + "epsilon": ["Epsilon", "float", 0.1, {"min": 0, "max": 1, "step": 0.01}, "Epsilon"], + "learning_rate": ["Learning rate", "float", 0.01, {}, "Learning rate for adjacency matrix optimization"], + "num_iterations": ["Number of iterations", "int", 100, {"min": 1, "step": 1}, "Number of iterations of gradient descent"], + "num_rand_trials": ["Number of random trials", "int", 100, {"min": 1, "step": 1}, "number of random trials in Random Sampling Algorithm"] } -} \ No newline at end of file +} + diff --git a/src/attacks/evasion_attacks.py b/src/attacks/evasion_attacks.py index 43f2621..47a65a5 100644 --- a/src/attacks/evasion_attacks.py +++ b/src/attacks/evasion_attacks.py @@ -8,6 +8,13 @@ from src.attacks.nettack.nettack import Nettack from src.attacks.nettack.utils import preprocess_graph, largest_connected_components, data_to_csr_matrix, train_w1_w2 +# PGD imports +from attacks.evasion_attacks_collection.pgd.utils import Projection, RandomSampling +import torch.nn.functional as F +from torch_geometric.utils import to_dense_adj, dense_to_sparse, k_hop_subgraph +from tqdm import tqdm +from torch_geometric.nn import SGConv + class EvasionAttacker(Attacker): def __init__(self, **kwargs): @@ -42,6 +49,110 @@ def attack(self, model_manager, gen_dataset, mask_tensor): return gen_dataset +class PGDAttacker(EvasionAttacker): + name = "PGD" + + def __init__(self, + is_feature_attack=False, + element_idx=0, + epsilon=0.5, + learning_rate=0.001, + num_iterations=100, + num_rand_trials=100): + + super().__init__() + self.attack_diff = None + self.is_feature_attack = is_feature_attack # feature / structure + self.element_idx = element_idx + self.epsilon = epsilon + self.learning_rate = learning_rate + self.num_iterations = num_iterations + self.num_rand_trials = num_rand_trials + + def attack(self, model_manager, gen_dataset, mask_tensor): + if gen_dataset.is_multi(): + self._attack_on_graph(model_manager, gen_dataset) + else: + self._attack_on_node(model_manager, gen_dataset) + + def _attack_on_node(self, model_manager, gen_dataset): + node_idx = self.element_idx + + edge_index = gen_dataset.data.edge_index + y = gen_dataset.data.y + x = gen_dataset.data.x + + model = model_manager.gnn + num_hops = model.n_layers + + subset, edge_index_subset, inv, edge_mask = k_hop_subgraph(node_idx=node_idx, + num_hops=num_hops, + edge_index=edge_index, + relabel_nodes=True, + directed=False) + + if self.is_feature_attack: # feature attack + node_idx_remap = torch.where(subset == node_idx)[0].item() + y = y.clone() + y = y[subset] + x = x.clone() + x = x[subset] + orig_x = x.clone() + x.requires_grad = True + optimizer = torch.optim.Adam([x], lr=self.learning_rate, weight_decay=5e-4) + + for t in tqdm(range(self.num_iterations)): + out = model(x, edge_index_subset) + loss = -model_manager.loss_function(out[node_idx_remap], y[node_idx_remap]) + # print(loss) + model.zero_grad() + loss.backward() + x.grad.sign_() + optimizer.step() + with torch.no_grad(): + x.copy_(torch.max(torch.min(x, orig_x + self.epsilon), orig_x - self.epsilon)) + x.copy_(torch.clamp(x, -self.epsilon, self.epsilon)) + # return the modified lines back to the original tensor x + gen_dataset.data.x[subset] = x.detach() + self.attack_diff = gen_dataset + else: # structure attack + pass + + def _attack_on_graph(self, model_manager, gen_dataset): + graph_idx = self.element_idx + + edge_index = gen_dataset.dataset[graph_idx].edge_index + y = gen_dataset.dataset[graph_idx].y + x = gen_dataset.dataset[graph_idx].x + + model = model_manager.gnn + + if self.is_feature_attack: # feature attack + x = x.clone() + orig_x = x.clone() + x.requires_grad = True + optimizer = torch.optim.Adam([x], lr=self.learning_rate, weight_decay=5e-4) + + for t in tqdm(range(self.num_iterations)): + out = model(x, edge_index) + loss = -model_manager.loss_function(out, y) + # print(loss) + model.zero_grad() + loss.backward() + x.grad.sign_() + optimizer.step() + with torch.no_grad(): + x.copy_(torch.max(torch.min(x, orig_x + self.epsilon), orig_x - self.epsilon)) + x.copy_(torch.clamp(x, -self.epsilon, self.epsilon)) + gen_dataset.dataset[graph_idx].x.copy_(x.detach()) + self.attack_diff = gen_dataset + else: # structure attack + pass + + def attack_diff(self): + return self.attack_diff + + class NettackEvasionAttacker(EvasionAttacker): name = "NettackEvasionAttacker" diff --git a/src/attacks/evasion_attacks_collection/pgd/utils.py b/src/attacks/evasion_attacks_collection/pgd/utils.py new file mode 100644 index 0000000..bdcf953 --- /dev/null +++ b/src/attacks/evasion_attacks_collection/pgd/utils.py @@ -0,0 +1,96 @@ +import torch +from tqdm import tqdm + + +class Projection: + def __init__(self, eps): + self.eps = eps + + def __call__(self, a): + """ + """ + # a = a_matrix.flatten() + + projection = self.projection(a) + + # projection_matrix = projection.view(a_matrix.shape) + return projection + + def projection(self, a): + """ + Calculating the projection of 'a' onto a set 'S' + """ + if torch.isnan(a).any(): + # NaN found in vector a. Replace with zeros. + a = torch.nan_to_num(a, nan=0.0) + + # Projection onto [0, 1] + s = torch.clamp(a, min=0, max=1) + + # Check the sum, otherwise project onto simplex + if torch.sum(s) <= self.eps: + return s + + # Projection onto simplex + return self.projection_onto_simplex(s) + + def projection_onto_simplex(self, v): + """ + Projection of a vector 'v' onto a simplex with a restriction on the sum of elements + """ + if torch.sum(v) <= self.eps: + return v + + # Sort the elements of the vector 'v' in descending order + u, _ = torch.sort(v, descending=True) + cssv = torch.cumsum(u, dim=0) - self.eps + + # Find the index 'rho' where the projection begins + mask = u > (cssv / torch.arange(1, len(u) + 1, dtype=v.dtype)) + indices = torch.nonzero(mask).squeeze() + + rho = indices[-1] + + theta = cssv[rho] / (rho.item() + 1) + + # Final projection onto simplex + proj = torch.clamp(v - theta, min=0) + + # This simplex projection algorithm guarantees that torch.sum(v) <= self.eps. + # Let's write an assertion taking into account the rule for comparing floating-point numbers. + assert torch.allclose(torch.sum(proj), torch.tensor(self.eps, dtype=torch.float)) is True + + return proj + + +# TODO check name of variables in RandomSampling Algorithm +class RandomSampling: + def __init__(self, K, eps, attack_loss, model, edge_index_joint, data): + """Random sampling from probabilistic to binary topology perturbation""" + self.K = K + # TODO add condition (1^T, s) <= eps on result + self.eps = eps + self.edge_index_joint = edge_index_joint + self.attack_loss = attack_loss + self.model = model + self.data = data + + def __call__(self, mask): + u_list = [] + for k in tqdm(range(self.K), desc="Random sampling", leave=True): + random_matrix = torch.rand_like(mask) + comparison_matrix = random_matrix < mask + u = comparison_matrix.to(dtype=torch.float32) + u_list.append(u) + + print("Please wait...") + best_u = None + best_f_value = float('inf') + + # for u in u_list: + # preds = self.model(self.data.x, self.edge_index_joint, None, edge_weight_perturbed) + # f_value = self.attack_loss(preds, self.data.y) + # if f_value < best_f_value: + # best_f_value = f_value + # best_u = u + return best_u \ No newline at end of file diff --git a/src/attacks/poison_attacks.py b/src/attacks/poison_attacks.py index 0d85927..e02028d 100644 --- a/src/attacks/poison_attacks.py +++ b/src/attacks/poison_attacks.py @@ -1,5 +1,4 @@ import numpy as np -import importlib import torch from attacks.attack_base import Attacker diff --git a/src/explainers/GNNExplainer/torch_geom_our/out.py b/src/explainers/GNNExplainer/torch_geom_our/out.py index e0a10ce..5595d5a 100644 --- a/src/explainers/GNNExplainer/torch_geom_our/out.py +++ b/src/explainers/GNNExplainer/torch_geom_our/out.py @@ -102,6 +102,17 @@ def run(self, mode, kwargs, finalize=True): self.raw_explanation = self.explainer(self.x, self.edge_index, index=self.node_idx) self.pbar.close() + @finalize_decorator + def evaluate_tensor_graph(self, x, edge_index, node_idx, **kwargs): + self._run_mode = "local" + self.node_idx = node_idx + self.x = x + self.edge_index = edge_index + self.pbar.reset(total=self.epochs, mode=self._run_mode) + self.explainer.algorithm.pbar = self.pbar + self.raw_explanation = self.explainer(self.x, self.edge_index, index=self.node_idx, **kwargs) + self.pbar.close() + def _finalize(self): mode = self._run_mode assert mode == "local" @@ -111,13 +122,14 @@ def _finalize(self): self.explanation = AttributionExplanation( local=mode, - edges="continuous" if edge_mask is not None else False, - features="continuous" if node_mask is not None else False) + edges="continuous" if self.edge_mask_type=="object" else False, + nodes="continuous" if self.node_mask_type=="object" else False, + features="continuous" if self.node_mask_type=="common_attributes" else False) important_edges = {} important_nodes = {} + important_features = {} - # TODO What if edge_mask_type or node_mask_type is None, common_attributes, attributes? if self.edge_mask_type is not None and self.node_mask_type is not None: # Multi graphs check is not needed: the explanation format for @@ -125,32 +137,50 @@ def _finalize(self): eps = 0.001 # Edges - num_edges = edge_mask.size(0) - assert num_edges == self.edge_index.size(1) - edges = self.edge_index - - for i in range(num_edges): - imp = float(edge_mask[i]) - if not imp < eps: - edge = edges[0][i], edges[1][i] - important_edges[f"{edge[0]},{edge[1]}"] = format(imp, '.4f') + if self.edge_mask_type=="object": + num_edges = edge_mask.size(0) + assert num_edges == self.edge_index.size(1) + edges = self.edge_index + + for i in range(num_edges): + imp = float(edge_mask[i]) + if not imp < eps: + edge = edges[0][i], edges[1][i] + important_edges[f"{edge[0]},{edge[1]}"] = format(imp, '.4f') + else: # if "common_attributes" or "attributes" + raise NotImplementedError(f"Edge mask type '{self.edge_mask_type}' is not yet implemented.") # Nodes - num_nodes = node_mask.size(0) - assert num_nodes == self.x.size(0) - - for i in range(num_nodes): - imp = float(node_mask[i][0]) - if not imp < eps: - important_nodes[i] = format(imp, '.4f') + if self.node_mask_type=="object": + num_nodes = node_mask.size(0) + assert num_nodes == self.x.size(0) + + for i in range(num_nodes): + imp = float(node_mask[i][0]) + if not imp < eps: + important_nodes[i] = format(imp, '.4f') + # Features + elif self.node_mask_type=="common_attributes": + num_features = node_mask.size(1) + assert num_features == self.x.size(1) + + for i in range(num_features): + imp = float(node_mask[0][i]) + if not imp < eps: + important_features[i] = format(imp, '.4f') + else: # if "attributes" + # TODO add functional if node_mask_type=="attributes" + raise NotImplementedError(f"Node mask type '{self.node_mask_type}' is not yet implemented.") if self.gen_dataset.is_multi(): important_edges = {self.graph_idx: important_edges} important_nodes = {self.graph_idx: important_nodes} + important_features = {self.graph_idx: important_features} # TODO Write functions with output threshold self.explanation.add_edges(important_edges) self.explanation.add_nodes(important_nodes) + self.explanation.add_features(important_features) # print(important_edges) # print(important_nodes) diff --git a/src/explainers/explainer_metrics.py b/src/explainers/explainer_metrics.py new file mode 100644 index 0000000..32e0eb5 --- /dev/null +++ b/src/explainers/explainer_metrics.py @@ -0,0 +1,203 @@ +import numpy as np +import torch +from torch_geometric.utils import subgraph + + +class NodesExplainerMetric: + def __init__(self, model, graph, explainer, kwargs_dict): + self.model = model + self.explainer = explainer + self.graph = graph + self.x = self.graph.x + self.edge_index = self.graph.edge_index + self.kwargs_dict = { + "stability_graph_perturbations_nums": 10, + "stability_feature_change_percent": 0.05, + "stability_node_removal_percent": 0.05, + "consistency_num_explanation_runs": 10 + } + self.kwargs_dict.update(kwargs_dict) + self.nodes_explanations = {} # explanations cache. node_ind -> explanation + self.dictionary = { + } + + def evaluate(self, target_nodes_indices): + num_targets = len(target_nodes_indices) + sparsity = 0 + stability = 0 + consistency = 0 + for node_ind in target_nodes_indices: + self.get_explanation(node_ind) + sparsity += self.calculate_sparsity(node_ind) + stability += self.calculate_stability( + node_ind, + graph_perturbations_nums=self.kwargs_dict["stability_graph_perturbations_nums"], + feature_change_percent=self.kwargs_dict["stability_feature_change_percent"], + node_removal_percent=self.kwargs_dict["stability_node_removal_percent"] + ) + consistency += self.calculate_consistency( + node_ind, + num_explanation_runs=self.kwargs_dict["consistency_num_explanation_runs"] + ) + fidelity = self.calculate_fidelity(target_nodes_indices) + self.dictionary["sparsity"] = sparsity / num_targets + self.dictionary["stability"] = stability / num_targets + self.dictionary["consistency"] = consistency / num_targets + self.dictionary["fidelity"] = fidelity + return self.dictionary + + def calculate_fidelity(self, target_nodes_indices): + original_answer = self.model.get_answer(self.x, self.edge_index) + same_answers_count = 0 + for node_ind in target_nodes_indices: + node_explanation = self.get_explanation(node_ind) + new_x, new_edge_index, new_target_node = self.filter_graph_by_explanation( + self.x, self.edge_index, node_explanation, node_ind + ) + filtered_answer = self.model.get_answer(new_x, new_edge_index) + matched = filtered_answer[new_target_node] == original_answer[node_ind] + print(f"Processed fidelity calculation for node id {node_ind}. Matched: {matched}") + if matched: + same_answers_count += 1 + fidelity = same_answers_count / len(target_nodes_indices) + return fidelity + + def calculate_sparsity(self, node_ind): + explanation = self.get_explanation(node_ind) + sparsity = 1 - (len(explanation["data"]["nodes"]) + len(explanation["data"]["edges"])) / ( + len(self.x) + len(self.edge_index)) + return sparsity + + def calculate_stability( + self, + node_ind, + graph_perturbations_nums=10, + feature_change_percent=0.05, + node_removal_percent=0.05 + ): + base_explanation = self.get_explanation(node_ind) + stability = 0 + for _ in range(graph_perturbations_nums): + new_x, new_edge_index = self.perturb_graph( + self.x, self.edge_index, node_ind, feature_change_percent, node_removal_percent + ) + perturbed_explanation = self.calculate_explanation(new_x, new_edge_index, node_ind) + base_explanation_vector, perturbed_explanation_vector = \ + NodesExplainerMetric.calculate_explanation_vectors(base_explanation, perturbed_explanation) + + stability += euclidean_distance(base_explanation_vector, perturbed_explanation_vector) + + stability = stability / graph_perturbations_nums + return stability + + def calculate_consistency(self, node_ind, num_explanation_runs=10): + explanation = self.get_explanation(node_ind) + consistency = 0 + for _ in range(num_explanation_runs): + perturbed_explanation = self.calculate_explanation(self.x, self.edge_index, node_ind) + base_explanation_vector, perturbed_explanation_vector = \ + NodesExplainerMetric.calculate_explanation_vectors(explanation, perturbed_explanation) + consistency += cosine_similarity(base_explanation_vector, perturbed_explanation_vector) + explanation = perturbed_explanation + + consistency = consistency / num_explanation_runs + return consistency + + def calculate_explanation(self, x, edge_index, node_idx, **kwargs): + print(f"Processing explanation calculation for node id {node_idx}.") + self.explainer.evaluate_tensor_graph(x, edge_index, node_idx, **kwargs) + print(f"Explanation calculation for node id {node_idx} completed.") + return self.explainer.explanation.dictionary + + def get_explanation(self, node_ind): + if node_ind in self.nodes_explanations: + node_explanation = self.nodes_explanations[node_ind] + else: + node_explanation = self.calculate_explanation(self.x, self.edge_index, node_ind) + self.nodes_explanations[node_ind] = node_explanation + return node_explanation + + @staticmethod + def parse_explanation(explanation): + important_nodes = { + int(node): float(weight) for node, weight in explanation["data"]["nodes"].items() + } + important_edges = { + tuple(map(int, edge.split(','))): float(weight) + for edge, weight in explanation["data"]["edges"].items() + } + return important_nodes, important_edges + + @staticmethod + def filter_graph_by_explanation(x, edge_index, explanation, target_node): + important_nodes, important_edges = NodesExplainerMetric.parse_explanation(explanation) + all_important_nodes = set(important_nodes.keys()) + all_important_nodes.add(target_node) + for u, v in important_edges.keys(): + all_important_nodes.add(u) + all_important_nodes.add(v) + + important_node_indices = list(all_important_nodes) + node_mask = torch.zeros(x.size(0), dtype=torch.bool) + node_mask[important_node_indices] = True + + new_edge_index, new_edge_weight = subgraph(node_mask, edge_index, relabel_nodes=True) + new_x = x[node_mask] + new_target_node = important_node_indices.index(target_node) + return new_x, new_edge_index, new_target_node + + @staticmethod + def calculate_explanation_vectors(base_explanation, perturbed_explanation): + base_important_nodes, base_important_edges = NodesExplainerMetric.parse_explanation( + base_explanation + ) + perturbed_important_nodes, perturbed_important_edges = NodesExplainerMetric.parse_explanation( + perturbed_explanation + ) + union_nodes = set(base_important_nodes.keys()) | set(perturbed_important_nodes.keys()) + union_edges = set(base_important_edges.keys()) | set(perturbed_important_edges.keys()) + explain_vector_len = len(union_nodes) + len(union_edges) + base_explanation_vector = np.zeros(explain_vector_len) + perturbed_explanation_vector = np.zeros(explain_vector_len) + i = 0 + for expl_node_ind in union_nodes: + base_explanation_vector[i] = base_important_nodes.get(expl_node_ind, 0) + perturbed_explanation_vector[i] = perturbed_important_nodes.get(expl_node_ind, 0) + i += 1 + for expl_edge in union_edges: + base_explanation_vector[i] = base_important_edges.get(expl_edge, 0) + perturbed_explanation_vector[i] = perturbed_important_edges.get(expl_edge, 0) + i += 1 + return base_explanation_vector, perturbed_explanation_vector + + @staticmethod + def perturb_graph(x, edge_index, node_ind, feature_change_percent, node_removal_percent): + new_x = x.clone() + num_nodes = x.shape[0] + num_features = x.shape[1] + num_features_to_change = int(feature_change_percent * num_nodes * num_features) + indices = torch.randint(0, num_nodes * num_features, (num_features_to_change,), device=x.device) + new_x.view(-1)[indices] = 1.0 - new_x.view(-1)[indices] + + neighbors = edge_index[1][edge_index[0] == node_ind].unique() + num_nodes_to_remove = int(node_removal_percent * neighbors.shape[0]) + + if num_nodes_to_remove > 0: + nodes_to_remove = neighbors[ + torch.randperm(neighbors.size(0), device=edge_index.device)[:num_nodes_to_remove] + ] + mask = ~((edge_index[0] == node_ind).unsqueeze(1) & (edge_index[1].unsqueeze(0) == nodes_to_remove).any( + dim=0)) + new_edge_index = edge_index[:, mask] + else: + new_edge_index = edge_index + + return new_x, new_edge_index + + +def cosine_similarity(a, b): + return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)) + + +def euclidean_distance(a, b): + return np.linalg.norm(a - b) diff --git a/src/explainers/explainers_manager.py b/src/explainers/explainers_manager.py index 933771d..cf93aae 100644 --- a/src/explainers/explainers_manager.py +++ b/src/explainers/explainers_manager.py @@ -1,10 +1,10 @@ import json -from aux.configs import ExplainerInitConfig, ExplainerModificationConfig, ExplainerRunConfig, \ - CONFIG_CLASS_NAME, CONFIG_OBJ, ConfigPattern +from aux.configs import ExplainerInitConfig, ExplainerModificationConfig, CONFIG_OBJ, ConfigPattern from aux.declaration import Declare from aux.utils import EXPLAINERS_INIT_PARAMETERS_PATH from explainers.explainer import Explainer, ProgressBar +from explainers.explainer_metrics import NodesExplainerMetric # TODO misha can we do it not manually? # Need to import all modules with subclasses of Explainer, otherwise python can't see them @@ -176,6 +176,50 @@ def conduct_experiment(self, run_config, socket=None): return result + def evaluate_metrics(self, target_nodes_indices, run_config=None, socket=None): + """ + Evaluates explanation metrics between given node indices + """ + # TODO: Refactor this method for framework design + if run_config: + params = getattr(getattr(run_config, CONFIG_OBJ).kwargs, CONFIG_OBJ).to_dict() + else: + params = {} + self.explainer.pbar = ProgressBar( + socket, "er", desc=f'{self.explainer.name} explaining metrics calculation' + ) # progress bar + try: + print("Evaluating explanation metrics...") + if self.gen_dataset.is_multi(): + raise NotImplementedError("Explanation metrics for graph classification") + else: + explanation_metrics_calculator = NodesExplainerMetric( + model=self.gnn, + graph=self.gen_dataset.data, + explainer=self.explainer, + kwargs_dict=params + ) + result = explanation_metrics_calculator.evaluate(target_nodes_indices) + print("Explanation metrics are ready") + + if socket: + # TODO: Handle this on frontend + socket.send("er", { + "status": "OK", + "explanation_metrics": result + }) + + # TODO what if save_explanation_flag=False? + if self.save_explanation_flag: + # self.save_explanation_metrics(run_config) + self.model_manager.save_model_executor() + except Exception as e: + if socket: + socket.send("er", {"status": "FAILED"}) + raise e + + return result + @staticmethod def available_explainers(gen_dataset, model_manager): """ Get a list of explainers applicable for current model and dataset. diff --git a/src/models_builder/models_zoo.py b/src/models_builder/models_zoo.py index 2c639c0..72003bd 100644 --- a/src/models_builder/models_zoo.py +++ b/src/models_builder/models_zoo.py @@ -311,6 +311,46 @@ def model_configs_zoo(dataset, model_name): ) ) + gcn_gcn_no_self_loops = FrameworkGNNConstructor( + model_config=ModelConfig( + structure=ModelStructureConfig( + [ + { + 'label': 'n', + 'layer': { + 'layer_name': 'GCNConv', + 'layer_kwargs': { + 'in_channels': dataset.num_node_features, + 'out_channels': 16, + 'add_self_loops': False + }, + }, + 'activation': { + 'activation_name': 'ReLU', + 'activation_kwargs': None, + }, + }, + + { + 'label': 'n', + 'layer': { + 'layer_name': 'GCNConv', + 'layer_kwargs': { + 'in_channels': 16, + 'out_channels': dataset.num_classes, + 'add_self_loops': False + }, + }, + 'activation': { + 'activation_name': 'LogSoftmax', + 'activation_kwargs': None, + }, + }, + ] + ) + ) + ) + gcn_gcn_linearized = FrameworkGNNConstructor( model_config=ModelConfig( structure=ModelStructureConfig(