-
Notifications
You must be signed in to change notification settings - Fork 172
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix dev sgx: Add label protection (#1090)
* Update estimator.py: add marvell related * Update trainer_worker.py: add arguments * Add privacy via upload: add protects and attacks * Update __init__.py: add privacy into package * Update marvell.py:Add comments * Update marvell.py: add comments on zero division problems * Update norm_attack.py: add comments * Update emb_attack.py: add comments * Update discorloss.py: add paper * Update estimator.py: add comments * Update trainer_worker.py:add comments * Update discorloss.py:add comments * adjust structure move the privacy protection related codes to /privacy/splitnn * Delete fedlearner/privacy/discorloss.py * Delete fedlearner/privacy/emb_attack.py * Delete fedlearner/privacy/marvell.py * Delete fedlearner/privacy/norm_attack.py * Update estimator.py: label protection
- Loading branch information
1 parent
ab1a6bb
commit 51673e8
Showing
9 changed files
with
796 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
import tensorflow as tf | ||
import logging | ||
import time | ||
|
||
# DisCorLoss论文详见:https://arxiv.org/abs/2203.01451 | ||
|
||
class DisCorLoss(tf.keras.losses.Loss): | ||
def __init__(self, **kwargs): | ||
super(DisCorLoss, self).__init__(**kwargs) | ||
|
||
def _pairwise_dist(self, A, B): | ||
# squared norms of each row in A and B | ||
na = tf.reduce_sum(tf.square(A), 1) | ||
nb = tf.reduce_sum(tf.square(B), 1) | ||
|
||
# na as a row and nb as a column vectors | ||
na = tf.reshape(na, [-1, 1]) | ||
nb = tf.reshape(nb, [1, -1]) | ||
|
||
# return pairwise euclidead difference matrix | ||
D = tf.sqrt(tf.maximum(na - 2 * tf.matmul(A, B, False, True) + nb + 1e-20, | ||
0.0)) | ||
return D | ||
|
||
def tf_distance_cor(self, embeddings, labels, debug=False): | ||
start = time.time() | ||
|
||
embeddings = tf.debugging.check_numerics(embeddings, "embeddings contains nan/inf") | ||
labels = tf.debugging.check_numerics(labels, "labels contains nan/inf") | ||
labels = tf.expand_dims(labels, 1) | ||
|
||
n = tf.cast(tf.shape(embeddings)[0], tf.float32) | ||
a = self._pairwise_dist(embeddings, embeddings) | ||
b = self._pairwise_dist(labels, labels) | ||
|
||
# X = x - x的行均值 - x的列均值 + x的总均值 | ||
A = a - tf.reduce_mean(a, | ||
axis=1) - tf.expand_dims(tf.reduce_mean(a, | ||
axis=0), | ||
axis=1) + tf.reduce_mean(a) | ||
B = b - tf.reduce_mean(b, | ||
axis=1) - tf.expand_dims(tf.reduce_mean(b, | ||
axis=0), | ||
axis=1) + tf.reduce_mean(b) | ||
# 计算协方差 | ||
dCovXY = tf.sqrt(tf.abs(tf.reduce_sum(A * B) / (n ** 2))) | ||
# 计算方差 | ||
dVarXX = tf.sqrt(tf.abs(tf.reduce_sum(A * A) / (n ** 2))) | ||
dVarYY = tf.sqrt(tf.abs(tf.reduce_sum(B * B) / (n ** 2))) | ||
# 计算相关性 | ||
dCorXY = dCovXY / tf.sqrt(dVarXX * dVarYY) | ||
end = time.time() | ||
if debug: | ||
print(("tf distance cov: {} and cor: {}, dVarXX: {}, " | ||
"dVarYY:{} uses: {}").format( | ||
dCovXY, dCorXY, | ||
dVarXX, dVarYY, | ||
end - start)) | ||
return dCorXY | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
import tensorflow.compat.v1 as tf | ||
|
||
# Emb Attack见论文:https://arxiv.org/pdf/2203.01451.pdf | ||
|
||
def get_emb_pred(emb): | ||
mean_emb = tf.reduce_mean(emb, axis=0) | ||
# 规范化处理emb | ||
mean_reduced_emb = emb - mean_emb | ||
# 对规范化矩阵做奇异值分解 | ||
s, u, v = tf.linalg.svd(mean_reduced_emb) | ||
# 最大奇异值对应的右奇异向量与矩阵做内积 | ||
top_singular_vector = tf.transpose(v)[0] | ||
pred = tf.linalg.matvec(mean_reduced_emb, top_singular_vector) | ||
# 内积之后的结果可以分为两个簇 | ||
pred = tf.math.sigmoid(pred) | ||
return pred | ||
|
||
def emb_attack_auc(emb, y): | ||
emb_pred = get_emb_pred(emb) | ||
emb_pred = tf.reshape(emb_pred, y.shape) | ||
# 计算emb attack auc | ||
_, emb_auc = tf.metrics.auc(y, emb_pred) | ||
return emb_auc |
Oops, something went wrong.