-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
4e4c72d
commit a2a9903
Showing
1 changed file
with
191 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,191 @@ | ||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
from tensorflow.python.framework import dtypes | ||
from tensorflow.python.framework import ops | ||
from tensorflow.python.framework import sparse_tensor | ||
from tensorflow.python.framework import tensor_shape | ||
from tensorflow.python.ops import array_ops | ||
from tensorflow.python.ops import control_flow_ops | ||
from tensorflow.python.ops import logging_ops | ||
from tensorflow.python.ops import math_ops | ||
from tensorflow.python.ops import nn | ||
from tensorflow.python.ops import script_ops | ||
from tensorflow.python.ops import sparse_ops | ||
from tensorflow.python.summary import summary | ||
from sklearn import metrics | ||
|
||
|
||
def pairwise_distance(feature, squared=False): | ||
"""Computes the pairwise distance matrix with numerical stability. | ||
output[i, j] = || feature[i, :] - feature[j, :] ||_2 | ||
Args: | ||
feature: 2-D Tensor of size [number of data, feature dimension]. | ||
squared: Boolean, whether or not to square the pairwise distances. | ||
Returns: | ||
pairwise_distances: 2-D Tensor of size [number of data, number of data]. | ||
""" | ||
pairwise_distances_squared = math_ops.add( | ||
math_ops.reduce_sum(math_ops.square(feature), axis=[1], keepdims=True), | ||
math_ops.reduce_sum( | ||
math_ops.square(array_ops.transpose(feature)), | ||
axis=[0], | ||
keepdims=True)) - 2.0 * math_ops.matmul(feature, | ||
array_ops.transpose(feature)) | ||
|
||
# Deal with numerical inaccuracies. Set small negatives to zero. | ||
pairwise_distances_squared = math_ops.maximum(pairwise_distances_squared, 0.0) | ||
# Get the mask where the zero distances are at. | ||
error_mask = math_ops.less_equal(pairwise_distances_squared, 0.0) | ||
|
||
# Optionally take the sqrt. | ||
if squared: | ||
pairwise_distances = pairwise_distances_squared | ||
else: | ||
pairwise_distances = math_ops.sqrt( | ||
pairwise_distances_squared + math_ops.to_float(error_mask) * 1e-16) | ||
|
||
# Undo conditionally adding 1e-16. | ||
pairwise_distances = math_ops.multiply( | ||
pairwise_distances, math_ops.to_float(math_ops.logical_not(error_mask))) | ||
|
||
num_data = array_ops.shape(feature)[0] | ||
# Explicitly set diagonals to zero. | ||
mask_offdiagonals = array_ops.ones_like(pairwise_distances) - array_ops.diag( | ||
array_ops.ones([num_data])) | ||
pairwise_distances = math_ops.multiply(pairwise_distances, mask_offdiagonals) | ||
return pairwise_distances | ||
|
||
def pairwise_cosine_distance(feature): | ||
# normalize each row | ||
normalized = nn.l2_normalize(feature, axis = 1) | ||
|
||
# multiply row i with row j using transpose | ||
# element wise product | ||
prod = math_ops.matmul(normalized, normalized, | ||
adjoint_b = True # transpose second matrix | ||
) | ||
|
||
dist = 1 - prod | ||
return dist | ||
|
||
def _build_multilabel_adjacency(labels): | ||
""" | ||
Since that we assume labels share at least one concepts are similar and don't | ||
share any concepts are dissimilar, so we can compute c @ c.T, and zero elements | ||
are dissimilar pairs, otherwise similar. | ||
:param labels: labels of [batch_size, class_num] | ||
:return: a [batch_size, batch_size] adjacency matrix | ||
""" | ||
adj = labels @ array_ops.transpose(labels) | ||
return math_ops.greater(adj, 0) | ||
|
||
def masked_maximum(data, mask, dim=1): | ||
"""Computes the axis wise maximum over chosen elements. | ||
Args: | ||
data: 2-D float `Tensor` of size [n, m]. | ||
mask: 2-D Boolean `Tensor` of size [n, m]. | ||
dim: The dimension over which to compute the maximum. | ||
Returns: | ||
masked_maximums: N-D `Tensor`. | ||
The maximized dimension is of size 1 after the operation. | ||
""" | ||
axis_minimums = math_ops.reduce_min(data, dim, keepdims=True) | ||
masked_maximums = math_ops.reduce_max( | ||
math_ops.multiply(data - axis_minimums, mask), dim, | ||
keepdims=True) + axis_minimums | ||
return masked_maximums | ||
|
||
|
||
def masked_minimum(data, mask, dim=1): | ||
"""Computes the axis wise minimum over chosen elements. | ||
Args: | ||
data: 2-D float `Tensor` of size [n, m]. | ||
mask: 2-D Boolean `Tensor` of size [n, m]. | ||
dim: The dimension over which to compute the minimum. | ||
Returns: | ||
masked_minimums: N-D `Tensor`. | ||
The minimized dimension is of size 1 after the operation. | ||
""" | ||
axis_maximums = math_ops.reduce_max(data, dim, keepdims=True) | ||
masked_minimums = math_ops.reduce_min( | ||
math_ops.multiply(data - axis_maximums, mask), dim, | ||
keepdims=True) + axis_maximums | ||
return masked_minimums | ||
|
||
|
||
def triplet_semihard_loss_multilabel(labels, embeddings, use_cos=False, margin=1.0): | ||
"""Computes the triplet loss with semi-hard negative mining. | ||
The loss encourages the positive distances (between a pair of embeddings with | ||
the same labels) to be smaller than the minimum negative distance among | ||
which are at least greater than the positive distance plus the margin constant | ||
(called semi-hard negative) in the mini-batch. If no such negative exists, | ||
uses the largest negative distance instead. | ||
See: https://arxiv.org/abs/1503.03832. | ||
Args: | ||
labels: tensor of shape [batch_size, class_num] for multi-label samples | ||
embeddings: 2-D float `Tensor` of embedding vectors. Embeddings should | ||
be l2 normalized. | ||
use_cos: metric of embedding, cosine similarity or l2 distance | ||
margin: Float, margin term in the loss definition. | ||
Returns: | ||
triplet_loss: tf.float32 scalar. | ||
""" | ||
# Reshape [batch_size] label tensor to a [batch_size, 1] label tensor. | ||
|
||
# Build pairwise squared distance matrix. | ||
pdist_matrix = pairwise_cosine_distance(embeddings) if use_cos else pairwise_distance(embeddings, squared=True) | ||
# Build pairwise binary adjacency matrix. | ||
adjacency = _build_multilabel_adjacency(labels) | ||
# Invert so we can select negatives only. | ||
adjacency_not = math_ops.logical_not(adjacency) | ||
|
||
batch_size = labels.get_shape().as_list()[0] | ||
|
||
# Compute the mask. | ||
pdist_matrix_tile = array_ops.tile(pdist_matrix, [batch_size, 1]) | ||
mask = math_ops.logical_and( | ||
array_ops.tile(adjacency_not, [batch_size, 1]), | ||
math_ops.greater( | ||
pdist_matrix_tile, array_ops.reshape( | ||
array_ops.transpose(pdist_matrix), [-1, 1]))) | ||
mask_final = array_ops.reshape( | ||
math_ops.greater( | ||
math_ops.reduce_sum( | ||
math_ops.cast(mask, dtype=dtypes.float32), 1, keepdims=True), | ||
0.0), [batch_size, batch_size]) | ||
mask_final = array_ops.transpose(mask_final) | ||
|
||
adjacency_not = math_ops.cast(adjacency_not, dtype=dtypes.float32) | ||
mask = math_ops.cast(mask, dtype=dtypes.float32) | ||
|
||
# negatives_outside: smallest D_an where D_an > D_ap. | ||
negatives_outside = array_ops.reshape( | ||
masked_minimum(pdist_matrix_tile, mask), [batch_size, batch_size]) | ||
negatives_outside = array_ops.transpose(negatives_outside) | ||
|
||
# negatives_inside: largest D_an. | ||
negatives_inside = array_ops.tile( | ||
masked_maximum(pdist_matrix, adjacency_not), [1, batch_size]) | ||
semi_hard_negatives = array_ops.where( | ||
mask_final, negatives_outside, negatives_inside) | ||
|
||
loss_mat = math_ops.add(margin, pdist_matrix - semi_hard_negatives) | ||
|
||
mask_positives = math_ops.cast( | ||
adjacency, dtype=dtypes.float32) - array_ops.diag( | ||
array_ops.ones([batch_size])) | ||
|
||
# In lifted-struct, the authors multiply 0.5 for upper triangular | ||
# in semihard, they take all positive pairs except the diagonal. | ||
num_positives = math_ops.reduce_sum(mask_positives) | ||
|
||
triplet_loss = math_ops.truediv( | ||
math_ops.reduce_sum( | ||
math_ops.maximum( | ||
math_ops.multiply(loss_mat, mask_positives), 0.0)), | ||
num_positives, | ||
name='triplet_semihard_loss') | ||
|
||
return triplet_loss |