Skip to content

Commit

Permalink
Fix undefined gradient shape in relative MultiHeadAttention (#696)
Browse files Browse the repository at this point in the history
  • Loading branch information
guillaumekln authored Jun 23, 2020
1 parent 9079cc4 commit 5f720e5
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 2 deletions.
4 changes: 2 additions & 2 deletions opennmt/layers/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,8 +272,8 @@ def _compute_kv(x):
keys_length,
self.maximum_relative_position,
with_cache=bool(cache))
relative_repr_keys = tf.gather(self.relative_position_keys, relative_pos)
relative_repr_values = tf.gather(self.relative_position_values, relative_pos)
relative_repr_keys = tf.nn.embedding_lookup(self.relative_position_keys, relative_pos)
relative_repr_values = tf.nn.embedding_lookup(self.relative_position_values, relative_pos)
else:
relative_repr_keys = None
relative_repr_values = None
Expand Down
14 changes: 14 additions & 0 deletions opennmt/tests/transformer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,20 @@ def testMultiHeadSelfAttentionRelativePositionsWithCache(self):
cache = (tf.zeros([4, 4, 0, 5]), tf.zeros([4, 4, 0, 5]))
_, cache = attention(x, cache=cache)

def testMultiHeadSelfAttentionRelativeGradients(self):
attention = transformer.MultiHeadAttention(4, 20, maximum_relative_position=6)

@tf.function
def _compute_gradients_in_function(x):
with tf.GradientTape() as tape:
y, _ = attention(x)
loss = tf.math.reduce_sum(y)
gradients = tape.gradient(loss, attention.weights)
for gradient in gradients:
self.assertTrue(gradient.shape.is_fully_defined())

_compute_gradients_in_function(tf.random.uniform([4, 1, 10]))

def testMultiHeadAttention(self):
attention = transformer.MultiHeadAttention(4, 20)
queries = tf.random.uniform([4, 5, 10])
Expand Down

0 comments on commit 5f720e5

Please sign in to comment.