From cce84fed6688b9415506309c5f5402a62fdbc7ce Mon Sep 17 00:00:00 2001 From: Jaap-Meerhof Date: Fri, 18 Aug 2023 18:26:08 +0200 Subject: [PATCH] pfff, how am i gonna do the federated attack --- tests/main.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/tests/main.py b/tests/main.py index c4b7d3f..e92a8b7 100644 --- a/tests/main.py +++ b/tests/main.py @@ -138,7 +138,7 @@ def train_all_federated(target_model, shadow_models, attack_model, config:Config data = retrieve_data(target_model, shadow_model, attack_model, X_train, y_train, X_test, y_test, z_test, labels_test) return data -def retrieve_data(target_model, shadow_model, attack_model, X_train, y_train, X_test, y_test, z_test, labels_test): # todo put this in SFXGBoost +def retrieve_data(target_model, shadow_model, attack_model, X_train, y_train, X_test, y_test, z_test, labels_test): # TODO put this in SFXGBoost data = {} from sklearn.metrics import accuracy_score, precision_score @@ -175,7 +175,13 @@ def retrieve_data(target_model, shadow_model, attack_model, X_train, y_train, X_ data["precision test attack"] = prec_test_attack return data -def create_attack_model_federated(config:Config): +def create_attack_model_federated(config:Config, G, H): + + nFeatures = len(G[0]) + nTrees = len(G) + max_depth = config.max_depth + max_tree = config.max_tree + import torch.nn as nn class CNN(nn.Module): def __init__(self, *args, **kwargs) -> None: