diff --git a/opennmt/layers/transformer.py b/opennmt/layers/transformer.py index 9e7021e3e..0f4d06107 100644 --- a/opennmt/layers/transformer.py +++ b/opennmt/layers/transformer.py @@ -117,6 +117,13 @@ def matmul_with_relative_representations(a, b, transpose_b=False): return c +def calculate_attn(dot, values, dropout, training): + attn = tf.cast(tf.nn.softmax(tf.cast(dot, tf.float32)), dot.dtype) + drop_attn = common.dropout(attn, dropout, training=training) + heads = tf.matmul(drop_attn, values) + return heads, attn, drop_attn + + class FeedForwardNetwork(tf.keras.layers.Layer): """Implements the Transformer's "Feed Forward" layer. @@ -215,6 +222,9 @@ def __init__( dropout=0.1, return_attention=False, maximum_relative_position=None, + max_length_full_attention=None, + local_attention_radius=0, + global_attention_length=0, **kwargs ): """Initializes this layer. @@ -227,6 +237,11 @@ def __init__( return_attention: If ``True``, also return the attention weights. maximum_relative_position: Maximum relative position representation (from https://arxiv.org/abs/1803.02155). + max_length_full_attention: Maximum sequence length for full attention. + If not ``None``, use sparse attention for longer sequences + (from https://arxiv.org/abs/2004.08483). + local_attention_radius: Attention radius around each token for local sliding attention. + global_attention_length: Number of tokens used for global attention with sparse attention. kwargs: Additional layer arguments. """ super().__init__(**kwargs) @@ -244,6 +259,9 @@ def __init__( self.dropout = dropout self.return_attention = return_attention self.maximum_relative_position = maximum_relative_position + self.max_length_full_attention = max_length_full_attention + self.local_attention_radius = local_attention_radius + self.global_attention_length = global_attention_length def map_v1_weights(self, weights): # V1 used conv1d layers that have a leading dimensions. @@ -356,6 +374,42 @@ def _compute_kv(x): cache = (keys, values) + queries_length = misc.shape_list(queries)[2] + + if self.max_length_full_attention is not None: + if memory is not None: + raise ValueError("Sparse attention only supports self-attention.") + if self.maximum_relative_position is not None: + raise ValueError("Sparse attention doesn't support relative positions.") + if self.return_attention: + raise ValueError( + "Cannot return attention weights when using sparse attention." + ) + + if self.max_length_full_attention is not None: + use_sparse_att = ( + tf.less(self.max_length_full_attention, queries_length) + & tf.less(0, queries_length) + & tf.less(self.global_attention_length, queries_length) + ) + if self.global_attention_length > 0: + global_keys = keys + global_values = values + global_queries = queries + if use_sparse_att: + queries = queries[:, :, self.global_attention_length :, :] + global_queries = queries[:, :, : self.global_attention_length, :] + queries, keys, values, num_chunks = tf.cond( + use_sparse_att, + lambda: split_qkv( + queries, + keys, + values, + self.local_attention_radius, + self.global_attention_length, + ), + lambda: (queries, keys, values, 0), + ) # Dot product attention. dot = tf.matmul(queries, keys, transpose_b=True) if relative_repr_keys is not None: @@ -363,17 +417,59 @@ def _compute_kv(x): queries, relative_repr_keys, transpose_b=True ) + if ( + self.max_length_full_attention is not None + and self.global_attention_length > 0 + ): + global_dot = global_queries + if use_sparse_att: + global_dot = tf.matmul(global_queries, global_keys, transpose_b=True) + if mask is not None: mask = tf.cast(mask, dot.dtype) - if mask.shape.rank == 2: - mask = tf.expand_dims(mask, 1) # Broadcast on time dimension. + if self.max_length_full_attention is not None: + if self.global_attention_length > 0: + global_mask = mask + if use_sparse_att: + if mask.shape.rank == 2: + global_mask = mask[:, tf.newaxis, :] + else: + global_mask = mask[:, : self.global_attention_length, :] + global_mask = global_mask[:, tf.newaxis, :, :] + global_dot = (global_dot * global_mask) + ( + 1.0 - global_mask + ) * global_dot.dtype.min + mask = tf.cond( + use_sparse_att, + lambda: chunk_att_mask( + mask, self.local_attention_radius, self.global_attention_length + ), + lambda: tf.expand_dims(mask, 1) if mask.shape.rank == 2 else mask, + ) + elif mask.shape.rank == 2: + mask = tf.expand_dims(mask, 1) mask = tf.expand_dims(mask, 1) # Broadcast on head dimension. dot = (dot * mask) + (1.0 - mask) * dot.dtype.min - attn = tf.nn.softmax(dot) - drop_attn = common.dropout(attn, self.dropout, training=training) + heads, attn, drop_attn = calculate_attn(dot, values, self.dropout, training) + + if self.max_length_full_attention is not None: + if self.global_attention_length > 0: + global_heads = heads + global_attn = attn + if use_sparse_att: + global_heads, global_attn, _ = calculate_attn( + global_dot, global_values, self.dropout, training + ) + + heads = tf.cond( + use_sparse_att, + lambda: combine_chunks( + heads, num_chunks, queries_length - self.global_attention_length + ), + lambda: heads, + ) - heads = tf.matmul(drop_attn, values) if relative_repr_values is not None: heads += matmul_with_relative_representations( drop_attn, relative_repr_values @@ -382,6 +478,18 @@ def _compute_kv(x): # Concatenate all heads output. combined = combine_heads(heads) outputs = self.linear_output(combined) + + if ( + self.max_length_full_attention is not None + and self.global_attention_length > 0 + ): + global_combined = combined + global_outputs = outputs + if use_sparse_att: + global_combined = combine_heads(global_heads) + global_outputs = self.linear_output(global_combined) + outputs = tf.concat((global_outputs, outputs), axis=1) + if self.return_attention: return outputs, cache, attn return outputs, cache @@ -610,3 +718,189 @@ def call( outputs = self.ffn(outputs, training=training) cache = dict(self_kv=self_kv, memory_kv=memory_kv) return outputs, cache, attention + + +def split_chunks(a, chunk_length, concat_3_chunks=True, global_length=0): + """Splits a tensor into chunks along the timesteps axis. + + Args: + a: A ``tf.Tensor`` of shape :math:`[B, H, T, D]`. + chunk_length: The length of a chunk :math:`C`. + concat_3_chunks: Optional, if ``True``, append previous and following chunks to each chunk. + + Returns: + A ``tf.Tensor`` of shape :math:`[B * N, H, C (* 3), D]`, where :math:`N` is the chunk number. + """ + + if global_length: + global_a = a[:, :, :global_length, :] + a = a[:, :, global_length:, :] + + batch, num_heads, timesteps, units_per_head = misc.shape_list(a) + + # Pad to a factor of chunk_length. + pad_len = -timesteps % chunk_length + # batch, num_heads, timesteps padded, units_per_head + a_padded = tf.pad(tensor=a, paddings=[[0, 0], [0, 0], [0, pad_len], [0, 0]]) + padded_len = misc.shape_list(a_padded)[2] + + # Chunk along timesteps axis. + num_chunks = padded_len // chunk_length + chunked_shape = [batch, num_heads, num_chunks, chunk_length, units_per_head] + # batch, num_heads, num_chunks, chunk_length, units_per_head + a_chunked = tf.reshape(a_padded, chunked_shape) + + # Concatenate previous and next chunk to each chunk, for overlapping. + if concat_3_chunks: + # batch, num_heads, 1 + num_chunks + 1, chunk_length, units_per_head + a_chunked_padded = tf.pad( + a_chunked, paddings=[[0, 0], [0, 0], [1, 1], [0, 0], [0, 0]] + ) + # batch, num_heads, num_chunks, chunk_length*3, units_per_head + a_chunked = tf.concat( + [a_chunked_padded[:, :, i : (i + num_chunks), ...] for i in range(3)], 3 + ) + + # Transpose and flatten first dimension (batch * num_chunks). + # batch, num_chunks, num_heads, chunk_length (*3), units_per_head + a_transposed = tf.transpose(a_chunked, perm=[0, 2, 1, 3, 4]) + + if global_length: + # batch, num_chunks, num_heads, global timesteps, units_per_head + expanded_global_a = tf.tile( + tf.expand_dims(global_a, 1), [1, num_chunks, 1, 1, 1] + ) + a_transposed = tf.concat([a_transposed, expanded_global_a], axis=3) + + input_shape = misc.shape_list(a_transposed) + output_shape = tf.concat([[batch * num_chunks], input_shape[2:]], 0) + # batch x num_chunks, num_heads, chunk_length (*3) + global_length, units_per_head + return tf.reshape(a_transposed, output_shape), num_chunks + + +def split_qkv(queries, keys, values, chunk_length, global_attention_length): + + # batch x num_chunks, num_heads, chunk_length, units_per_head + queries, _ = split_chunks(queries, chunk_length, concat_3_chunks=False) + # batch x num_chunks, num_heads, chunk_length*3 + global_length, units_per_head + keys, _ = split_chunks(keys, chunk_length, global_length=global_attention_length) + # batch x num_chunks, num_heads, chunk_length*3 + global_length, units_per_head + values, num_chunks = split_chunks( + values, chunk_length, global_length=global_attention_length + ) + + return queries, keys, values, num_chunks + + +def combine_chunks(a, num_chunks, unchunked_length): + # Unchunk + a_shape = misc.shape_list(a) + # batch, num_chunks, num_heads, chunk_length, self.num_units_per_head + a = tf.reshape( + a, + [ + a_shape[0] // num_chunks, + num_chunks, + a_shape[1], + a_shape[2], + a_shape[3], + ], + ) + # batch, num_heads, num_chunks, chunk_length, self.num_units_per_head + a = tf.transpose(a, perm=[0, 2, 1, 3, 4]) + a_shape = misc.shape_list(a) + a = tf.reshape( + a, + [ + a_shape[0], + a_shape[1], + a_shape[2] * a_shape[3], + a_shape[4], + ], + ) + + # Remove padding used for chunking. + return a[:, :, :unchunked_length, :] + + +def chunk_att_mask(mask, chunk_length, global_length=0): + """Transforms an attention mask into a chunked representation. + + Chunked mask masks everything but a sliding diagonal with a radius of ``chunk_length``. + + Args: + mask: A ``tf.Tensor`` of shape :math:`[B, T]` or :math:`[B, T, T]`. + chunk_length: The length of a chunk :math:`C`. + + Returns: + A ``tf.Tensor`` of shape :math:`[B * N, C, C * 3]`, where :math:`N` is the number of chunks. + """ + + mask_shape = misc.shape_list(mask) + batch = mask_shape[0] + timesteps = mask_shape[-1] + rank = len(mask_shape) + + if rank == 2: + # Broadcast on queries time dimension. + mask = tf.expand_dims(mask, 1) + mask = tf.broadcast_to(mask, [batch, timesteps, timesteps]) + + if global_length: + global_mask = mask[:, global_length:, :global_length] + mask = mask[:, global_length:, global_length:] + timesteps = timesteps - global_length + + # Pad to a factor of chunk_length. + pad_len = -timesteps % chunk_length + mask = tf.pad(tensor=mask, paddings=[[0, 0], [0, pad_len], [0, pad_len]]) + if global_length: + global_mask = tf.pad( + tensor=global_mask, paddings=[[0, 0], [0, pad_len], [0, 0]] + ) + padded_timesteps = misc.shape_list(mask)[-1] + + # Append chunk_length padding to timestep axis, before and after. + mask_padded = tf.pad( + tensor=mask, paddings=[[0, 0], [0, 0], [chunk_length, chunk_length]] + ) + padded_len = misc.shape_list(mask_padded)[-1] + mask_flattened = tf.reshape(mask_padded, shape=[batch, -1]) + + # Skew to the left by one and keep 2*chunk_length + 1 relevant locations. + # This corresponds to chunk_length radius around the diagonal. + skewed_len = padded_len + 1 + skewed_padding_len = ( + padded_timesteps * skewed_len - misc.shape_list(mask_flattened)[-1] + ) + mask_padded = tf.pad(mask_flattened, paddings=[[0, 0], [0, skewed_padding_len]]) + skewed_shape = [batch, -1, skewed_len] + mask_skewed = tf.reshape(mask_padded, shape=skewed_shape) + mask_skewed = mask_skewed[:, :, : chunk_length * 2 + 1] + + chunk_num = padded_timesteps // chunk_length + mask_skewed_chunked = tf.reshape(mask_skewed, [batch, chunk_num, chunk_length, -1]) + + # Unskew each chunk to be compatible with chunked attention shape. + unskewed_len = chunk_length * 3 + mask_skewed_padded = tf.pad( + mask_skewed_chunked, paddings=[[0, 0], [0, 0], [0, 0], [0, chunk_length]] + ) + mask_skewed_flattened = tf.reshape(mask_skewed_padded, shape=[batch, chunk_num, -1]) + mask_skewed_flattened = mask_skewed_flattened[:, :, : (chunk_length * unskewed_len)] + mask_unskewed = tf.reshape( + mask_skewed_flattened, shape=[batch, chunk_num, chunk_length, chunk_length * 3] + ) + + if global_length: + # batch, num_chunks, chunk_length, global_length + expanded_global_mask = tf.reshape( + global_mask, shape=[batch, chunk_num, chunk_length, global_length] + ) + mask_unskewed = tf.concat([mask_unskewed, expanded_global_mask], axis=3) + + # Flatten the first dimension to batch * chunk_num. + return tf.reshape( + mask_unskewed, + shape=[batch * chunk_num, chunk_length, chunk_length * 3 + global_length], + ) diff --git a/opennmt/tests/transformer_test.py b/opennmt/tests/transformer_test.py index afc5d57c1..f81a5f35d 100644 --- a/opennmt/tests/transformer_test.py +++ b/opennmt/tests/transformer_test.py @@ -1,3 +1,5 @@ +import itertools + import numpy as np import tensorflow as tf @@ -116,6 +118,92 @@ def testRelativePositions(self): [[2, 3, 4, 4], [1, 2, 3, 4], [0, 1, 2, 3], [0, 0, 1, 2]], ) + @parameterized.expand(itertools.product([2, 3], [True, False], [0, 1, 2])) + def testSplitChunks(self, chunk_length, concat_3_chunks, global_length=0): + batch = 3 + length = [5, 3, 7] + num_heads = 4 + depth = 10 + + inputs = tf.random.normal( + [batch, num_heads, max(length), depth], dtype=tf.float32 + ) + split, num_chunks = transformer.split_chunks( + inputs, + chunk_length=chunk_length, + concat_3_chunks=concat_3_chunks, + global_length=global_length, + ) + split_shape = split.shape + self.assertEqual(num_chunks, split_shape[0] / batch) + self.assertEqual(num_heads, split_shape[1]) + chunk_length_eval = chunk_length * 3 if concat_3_chunks else chunk_length + chunk_length_eval += global_length + self.assertEqual(chunk_length_eval, split_shape[2]) + self.assertEqual(depth, split_shape[3]) + + @parameterized.expand(itertools.product([tf.bool, tf.float32], [2, 3], [0, 1, 2])) + def testChunkAttentionMask(self, dtype, chunk_length, global_length=0): + length = [2, 4, 3] + batch = len(length) + maximum_length = 5 + mask = tf.sequence_mask(lengths=length, maxlen=maximum_length, dtype=dtype) + mask_chunked = transformer.chunk_att_mask( + mask, chunk_length=chunk_length, global_length=global_length + ) + ( + output_batch_times_chunks, + output_chunk_length, + output_expanded_chunk_length, + ) = mask_chunked.shape + if global_length: + maximum_length = maximum_length - global_length + length = [el - global_length for el in length] + num_chunks = abs(-(maximum_length) // chunk_length) + self.assertEqual(num_chunks * batch, output_batch_times_chunks) + self.assertEqual(chunk_length, output_chunk_length) + self.assertEqual(chunk_length * 3 + global_length, output_expanded_chunk_length) + + self.assertIs(mask_chunked.dtype, dtype) + + expected = np.zeros( + (output_batch_times_chunks, output_chunk_length, chunk_length * 3), + dtype=dtype.as_numpy_dtype, + ) + + token_radius = chunk_length * 2 + 1 + for b in range(batch): + seq_length = length[b] + for ch in range(num_chunks): + end = chunk_length + seq_length - chunk_length * ch + if end > 0: + chunk_idx = b * num_chunks + ch + for ch_l in range(chunk_length): + seq_length_idx = ch * chunk_length + ch_l + if seq_length_idx < maximum_length: + start_idx = ch_l if ch != 0 else chunk_length + end_idx = min(end, token_radius + ch_l) + expected[chunk_idx][ch_l][start_idx:end_idx] = 1 + + mask_chunked = self.evaluate(mask_chunked) + if global_length: + expanded_mask = tf.expand_dims(mask, 1) + expanded_mask = tf.broadcast_to( + expanded_mask, [batch, maximum_length, maximum_length + global_length] + ) + pad = chunk_length * num_chunks - maximum_length + expanded_mask = tf.pad( + tensor=expanded_mask, paddings=[[0, 0], [0, pad], [0, 0]] + ) + expanded_mask = tf.reshape( + expanded_mask, + [batch * num_chunks, chunk_length, maximum_length + global_length], + ) + expected = tf.concat( + (expected, expanded_mask[:, :, :global_length]), axis=2 + ) + self.assertAllEqual(mask_chunked, expected) + def testFeedForwardNetwork(self): ffn = transformer.FeedForwardNetwork(20, 10) x = tf.random.uniform([4, 5, 10]) @@ -159,6 +247,18 @@ def testMultiHeadSelfAttentionRelativePositionsWithCache(self): cache = (tf.zeros([4, 4, 0, 5]), tf.zeros([4, 4, 0, 5])) _, cache = attention(x, cache=cache) + def testMultiHeadSelfAttentionSparse(self): + attention = transformer.MultiHeadAttention( + 4, + 20, + local_attention_radius=2, + max_length_full_attention=3, + global_attention_length=2, + ) + x = tf.random.uniform([2, 9, 10]) + mask = tf.sequence_mask([9, 7]) + attention(x, mask=mask) + def testMultiHeadSelfAttentionRelativeGradients(self): attention = transformer.MultiHeadAttention(4, 20, maximum_relative_position=6)