From 12c86203e09423a4ad31aaa061e6addf0d9eb85b Mon Sep 17 00:00:00 2001 From: Natalia Segal Date: Mon, 20 Jun 2022 17:46:19 +0200 Subject: [PATCH 01/12] First implementation of sliding window local sparse attention. --- opennmt/layers/transformer.py | 186 +++++++++++++++++++++++++++++- opennmt/tests/transformer_test.py | 64 ++++++++++ 2 files changed, 247 insertions(+), 3 deletions(-) diff --git a/opennmt/layers/transformer.py b/opennmt/layers/transformer.py index 94fe0d638..0f235beb7 100644 --- a/opennmt/layers/transformer.py +++ b/opennmt/layers/transformer.py @@ -117,6 +117,122 @@ def matmul_with_relative_representations(a, b, transpose_b=False): return c +def split_chunks(a, chunk_length, concat_3_chunks=True): + """Splits a tensor into chunks along the timesteps axis. + + Args: + a: A ``tf.Tensor`` of shape :math:`[B, H, T, D]`. + chunk_length: The length of a chunk :math:`C`. + + Returns: + A ``tf.Tensor`` of shape :math:`[B * N, H, C (* 3), D]`, where :math:`N` is the number of chunks. + """ + + batch, num_heads, timesteps, units_per_head = misc.shape_list(a) + + # Pad to a factor of chunk_length. + rank = a.shape.rank + timestep_axis = rank - 2 + pad_len = -timesteps % chunk_length + paddings = pad_len * tf.one_hot([-1, timestep_axis], rank, axis=0, dtype=tf.int32) + # batch, num_heads, timesteps padded, units_per_head + a_padded = tf.pad(tensor=a, paddings=paddings) + padded_len = misc.shape_list(a_padded)[timestep_axis] + + # Chunk along timesteps axis. + num_chunks = padded_len // chunk_length + chunked_shape = [batch, num_heads, num_chunks, chunk_length, units_per_head] + # batch, num_heads, num_chunks, chunk_length, units_per_head + a_chunked = tf.reshape(a_padded, chunked_shape) + + # Concatenate previous and next chunk to each chunk, for overlapping. + if concat_3_chunks: + paddings = tf.one_hot([2, 2], rank + 1, axis=0, dtype=tf.int32) + # batch, num_heads, 1 + num_chunks + 1, chunk_length, units_per_head + a_chunked_padded = tf.pad(a_chunked, paddings) + # batch, num_heads, num_chunks, chunk_length*3, units_per_head + a_chunked = tf.concat( + [a_chunked_padded[:, :, i : (i + num_chunks), ...] for i in range(3)], 3 + ) + + # Transpose and flatten first dimension (batch * num_chunks). + # batch, num_chunks, num_heads, chunk_length (*3), units_per_head + a_transposed = tf.transpose(a_chunked, perm=[0, 2, 1, 3, 4]) + input_shape = misc.shape_list(a_transposed) + output_shape = tf.concat([[batch * num_chunks], input_shape[2:]], 0) + # batch x num_chunks, num_heads, chunk_length (*3), units_per_head + return tf.reshape(a_transposed, output_shape), num_chunks + + +def chunk_att_mask(mask, chunk_length): + """Transforms an attention mask into a chunked representation, masking everything but a sliding diagonal with a radius of chunk length. + + Args: + mask: A ``tf.Tensor`` of shape :math:`[B, T]` or :math:`[B, T, T]`. + chunk_length: The length of a chunk :math:`C`. + + Returns: + A ``tf.Tensor`` of shape :math:`[B * N, C, C * 3]`, where :math:`N` is the number of chunks. + """ + + mask_shape = misc.shape_list(mask) + batch = mask_shape[0] + timesteps = mask_shape[-1] + rank = len(mask_shape) + + if rank == 2: + # Broadcast on queries time dimension. + mask = tf.expand_dims(mask, 1) + mask = tf.broadcast_to(mask, [batch, timesteps, timesteps]) + rank = 3 + + # Pad to a factor of chunk_length. + pad_len = -timesteps % chunk_length + mask = tf.pad(tensor=mask, paddings=[[0, 0], [0, pad_len], [0, pad_len]]) + padded_timesteps = misc.shape_list(mask)[-1] + + # Append chunk_length padding to timestep axis, before and after. + paddings = chunk_length * tf.one_hot( + [rank - 1, rank - 1], rank, axis=0, dtype=tf.int32 + ) + mask_padded = tf.pad(tensor=mask, paddings=paddings) + padded_len = misc.shape_list(mask_padded)[-1] + mask_flattened = tf.reshape(mask_padded, shape=[batch, -1]) + + # Skew to the left by one and keep 2*chunk_length + 1 relevant locations (chunk_length radius around diagonal). + skewed_len = padded_len + 1 + skewed_padding_len = ( + padded_timesteps * skewed_len - misc.shape_list(mask_flattened)[-1] + ) + skewed_paddings = skewed_padding_len * tf.one_hot( + [-1, rank - 2], rank - 1, axis=0, dtype=tf.int32 + ) + mask_padded = tf.pad(mask_flattened, paddings=skewed_paddings) + skewed_shape = [batch, -1, skewed_len] + mask_skewed = tf.reshape(mask_padded, shape=skewed_shape) + mask_skewed = mask_skewed[:, :, : chunk_length * 2 + 1] + + chunk_num = padded_timesteps // chunk_length + mask_skewed_chunked = tf.reshape(mask_skewed, [batch, chunk_num, chunk_length, -1]) + + # Unskew each chunk to be compatible with chunked attention shape. + unskewed_len = chunk_length * 3 + unskewed_paddings = chunk_length * tf.one_hot( + [-1, rank], rank + 1, axis=0, dtype=tf.int32 + ) + mask_skewed_padded = tf.pad(mask_skewed_chunked, paddings=unskewed_paddings) + mask_skewed_flattened = tf.reshape(mask_skewed_padded, shape=[batch, chunk_num, -1]) + mask_skewed_flattened = mask_skewed_flattened[:, :, : (chunk_length * unskewed_len)] + mask_unskewed = tf.reshape( + mask_skewed_flattened, shape=[batch, chunk_num, chunk_length, chunk_length * 3] + ) + + # Flatten the first dimension to batch * chunk_num. + return tf.reshape( + mask_unskewed, shape=[batch * chunk_num, chunk_length, chunk_length * 3] + ) + + class FeedForwardNetwork(tf.keras.layers.Layer): """Implements the Transformer's "Feed Forward" layer. @@ -214,6 +330,8 @@ def __init__( dropout=0.1, return_attention=False, maximum_relative_position=None, + max_length_full_attention=None, + local_attention_radius=None, **kwargs ): """Initializes this layer. @@ -225,6 +343,8 @@ def __init__( return_attention: If ``True``, also return the attention weights. maximum_relative_position: Maximum relative position representation (from https://arxiv.org/abs/1803.02155). + max_length_full_attention: Maximum sequence length for full attention. If this parameter is not None, sparse attention is calculated for longer sequences. + local_attention_radius: Attention radius around each token for local sliding window sparse attention. kwargs: Additional layer arguments. """ super().__init__(**kwargs) @@ -242,6 +362,8 @@ def __init__( self.dropout = dropout self.return_attention = return_attention self.maximum_relative_position = maximum_relative_position + self.max_length_full_attention = max_length_full_attention + self.local_attention_radius = local_attention_radius def map_v1_weights(self, weights): # V1 used conv1d layers that have a leading dimensions. @@ -354,15 +476,38 @@ def _compute_kv(x): cache = (keys, values) + queries_length = misc.shape_list(queries)[2] + + use_sparse_att = False + if self.max_length_full_attention is not None: + if memory is not None: + raise ValueError("Sparse attention only supports self-attention.") + if self.maximum_relative_position is not None: + raise ValueError("Sparse attention doesn't support relative positions.") + use_sparse_att = queries_length > self.max_length_full_attention + + chunk_length = self.local_attention_radius # Dot product attention. - dot = tf.matmul(queries, keys, transpose_b=True) + if use_sparse_att: + # batch x num_chunks, num_heads, chunk_length, units_per_head + queries_chunked, _ = split_chunks( + queries, chunk_length, concat_3_chunks=False + ) + # batch x num_chunks, num_heads, chunk_length*3, units_per_head + keys_chunked, _ = split_chunks(keys, chunk_length) + # batch x num_chunks, num_heads, chunk_length, chunk_length*3 + dot = tf.matmul(queries_chunked, keys_chunked, transpose_b=True) + else: + dot = tf.matmul(queries, keys, transpose_b=True) if relative_repr_keys is not None: dot += matmul_with_relative_representations( queries, relative_repr_keys, transpose_b=True ) if mask is not None: mask = tf.cast(mask, tf.float32) - if mask.shape.rank == 2: + if use_sparse_att: + mask = chunk_att_mask(mask, chunk_length) + elif mask.shape.rank == 2: mask = tf.expand_dims(mask, 1) # Broadcast on time dimension. mask = tf.expand_dims(mask, 1) # Broadcast on head dimension. dot = tf.cast( @@ -371,7 +516,42 @@ def _compute_kv(x): ) attn = tf.cast(tf.nn.softmax(tf.cast(dot, tf.float32)), dot.dtype) drop_attn = common.dropout(attn, self.dropout, training=training) - heads = tf.matmul(drop_attn, values) + if use_sparse_att: + # batch x num_chunks, num_heads, chunk_length*3, units_per_head + values_chunked, num_chunks = split_chunks(values, chunk_length) + # batch x num_chunks, num_heads, chunk_length, units_per_head + heads = tf.matmul(drop_attn, values_chunked) + + # Unchunk + heads_shape = misc.shape_list(heads) + # batch, num_chunks, num_heads, chunk_length, self.num_units_per_head + heads = tf.reshape( + heads, + [ + heads_shape[0] // num_chunks, + num_chunks, + heads_shape[1], + heads_shape[2], + heads_shape[3], + ], + ) + # batch, num_heads, num_chunks, chunk_length, self.num_units_per_head + heads = tf.transpose(heads, perm=[0, 2, 1, 3, 4]) + heads_shape = misc.shape_list(heads) + heads = tf.reshape( + heads, + [ + heads_shape[0], + heads_shape[1], + heads_shape[2] * heads_shape[3], + heads_shape[4], + ], + ) + + # Remove padding used for chunking. + heads = heads[:, :, :queries_length, :] + else: + heads = tf.matmul(drop_attn, values) if relative_repr_values is not None: heads += matmul_with_relative_representations( drop_attn, relative_repr_values diff --git a/opennmt/tests/transformer_test.py b/opennmt/tests/transformer_test.py index afc5d57c1..034b1000c 100644 --- a/opennmt/tests/transformer_test.py +++ b/opennmt/tests/transformer_test.py @@ -116,6 +116,62 @@ def testRelativePositions(self): [[2, 3, 4, 4], [1, 2, 3, 4], [0, 1, 2, 3], [0, 0, 1, 2]], ) + @parameterized.expand([[2, True], [2, False], [3, True], [3, False]]) + def testSplitChunks(self, chunk_length, concat_3_chunks): + batch = 3 + length = [5, 3, 7] + num_heads = 4 + depth = 10 + + inputs = tf.random.normal( + [batch, num_heads, max(length), depth], dtype=tf.float32 + ) + split, num_chunks = transformer.split_chunks( + inputs, chunk_length=chunk_length, concat_3_chunks=concat_3_chunks + ) + split_shape = split.shape + self.assertEqual(num_chunks, split_shape[0] / batch) + self.assertEqual(num_heads, split_shape[1]) + chunk_length_eval = chunk_length * 3 if concat_3_chunks else chunk_length + self.assertEqual(chunk_length_eval, split_shape[2]) + self.assertEqual(depth, split_shape[3]) + + @parameterized.expand( + [[tf.bool, 2], [tf.float32, 2], [tf.bool, 3], [tf.float32, 3]] + ) + def testChunkAttentionMask(self, dtype, chunk_length): + length = [2, 4, 3] + batch = len(length) + maximum_length = 5 + mask = tf.sequence_mask(lengths=length, maxlen=maximum_length, dtype=dtype) + mask_chunked = transformer.chunk_att_mask(mask, chunk_length=chunk_length) + output_shape = mask_chunked.shape + num_chunks = abs(-maximum_length // chunk_length) + self.assertEqual(num_chunks, output_shape[0] / batch) + self.assertEqual(chunk_length, output_shape[1]) + self.assertEqual(chunk_length * 3, output_shape[2]) + + self.assertIs(mask_chunked.dtype, dtype) + + expected = np.zeros(output_shape, dtype=dtype.as_numpy_dtype) + + token_radius = chunk_length * 2 + 1 + for b in range(batch): + seq_length = length[b] + for ch in range(num_chunks): + end = chunk_length + seq_length - chunk_length * ch + if end > 0: + chunk_idx = b * num_chunks + ch + for l in range(chunk_length): + seq_length_idx = ch * chunk_length + l + if seq_length_idx < maximum_length: + start_idx = l if ch != 0 else chunk_length + end_idx = min(end, token_radius + l) + expected[chunk_idx][l][start_idx:end_idx] = 1 + + mask_chunked = self.evaluate(mask_chunked) + self.assertAllEqual(mask_chunked, expected) + def testFeedForwardNetwork(self): ffn = transformer.FeedForwardNetwork(20, 10) x = tf.random.uniform([4, 5, 10]) @@ -159,6 +215,14 @@ def testMultiHeadSelfAttentionRelativePositionsWithCache(self): cache = (tf.zeros([4, 4, 0, 5]), tf.zeros([4, 4, 0, 5])) _, cache = attention(x, cache=cache) + def testMultiHeadSelfAttentionSparse(self): + attention = transformer.MultiHeadAttention( + 4, 20, local_attention_radius=2, max_length_full_attention=3 + ) + x = tf.random.uniform([2, 9, 10]) + mask = tf.sequence_mask([9, 7]) + attention(x, mask=mask) + def testMultiHeadSelfAttentionRelativeGradients(self): attention = transformer.MultiHeadAttention(4, 20, maximum_relative_position=6) From d0f9adf96d8875d486236a8a83032bc20a6a1d89 Mon Sep 17 00:00:00 2001 From: Natalia Segal Date: Mon, 20 Jun 2022 18:19:41 +0200 Subject: [PATCH 02/12] Apply flake8 --- opennmt/layers/transformer.py | 14 +++++++++----- opennmt/tests/transformer_test.py | 10 +++++----- 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/opennmt/layers/transformer.py b/opennmt/layers/transformer.py index 0f235beb7..5713ea3cf 100644 --- a/opennmt/layers/transformer.py +++ b/opennmt/layers/transformer.py @@ -125,7 +125,7 @@ def split_chunks(a, chunk_length, concat_3_chunks=True): chunk_length: The length of a chunk :math:`C`. Returns: - A ``tf.Tensor`` of shape :math:`[B * N, H, C (* 3), D]`, where :math:`N` is the number of chunks. + A ``tf.Tensor`` of shape :math:`[B * N, H, C (* 3), D]`, where :math:`N` is the chunk number. """ batch, num_heads, timesteps, units_per_head = misc.shape_list(a) @@ -165,7 +165,9 @@ def split_chunks(a, chunk_length, concat_3_chunks=True): def chunk_att_mask(mask, chunk_length): - """Transforms an attention mask into a chunked representation, masking everything but a sliding diagonal with a radius of chunk length. + """Transforms an attention mask into a chunked representation. + + Chunked mask masks everything but a sliding diagonal with a radius of ``chunk_length``. Args: mask: A ``tf.Tensor`` of shape :math:`[B, T]` or :math:`[B, T, T]`. @@ -199,7 +201,8 @@ def chunk_att_mask(mask, chunk_length): padded_len = misc.shape_list(mask_padded)[-1] mask_flattened = tf.reshape(mask_padded, shape=[batch, -1]) - # Skew to the left by one and keep 2*chunk_length + 1 relevant locations (chunk_length radius around diagonal). + # Skew to the left by one and keep 2*chunk_length + 1 relevant locations. + # This corresponds to chunk_length radius around the diagonal. skewed_len = padded_len + 1 skewed_padding_len = ( padded_timesteps * skewed_len - misc.shape_list(mask_flattened)[-1] @@ -343,8 +346,9 @@ def __init__( return_attention: If ``True``, also return the attention weights. maximum_relative_position: Maximum relative position representation (from https://arxiv.org/abs/1803.02155). - max_length_full_attention: Maximum sequence length for full attention. If this parameter is not None, sparse attention is calculated for longer sequences. - local_attention_radius: Attention radius around each token for local sliding window sparse attention. + max_length_full_attention: Maximum sequence length for full attention. + If ``None``, use sparse attention for longer sequences. + local_attention_radius: Attention radius around each token for local sliding attention. kwargs: Additional layer arguments. """ super().__init__(**kwargs) diff --git a/opennmt/tests/transformer_test.py b/opennmt/tests/transformer_test.py index 034b1000c..2ef86bbb5 100644 --- a/opennmt/tests/transformer_test.py +++ b/opennmt/tests/transformer_test.py @@ -162,12 +162,12 @@ def testChunkAttentionMask(self, dtype, chunk_length): end = chunk_length + seq_length - chunk_length * ch if end > 0: chunk_idx = b * num_chunks + ch - for l in range(chunk_length): - seq_length_idx = ch * chunk_length + l + for ch_l in range(chunk_length): + seq_length_idx = ch * chunk_length + ch_l if seq_length_idx < maximum_length: - start_idx = l if ch != 0 else chunk_length - end_idx = min(end, token_radius + l) - expected[chunk_idx][l][start_idx:end_idx] = 1 + start_idx = ch_l if ch != 0 else chunk_length + end_idx = min(end, token_radius + ch_l) + expected[chunk_idx][ch_l][start_idx:end_idx] = 1 mask_chunked = self.evaluate(mask_chunked) self.assertAllEqual(mask_chunked, expected) From 54b722c55502decd0c31fef7d71e487c2b8e25ea Mon Sep 17 00:00:00 2001 From: Natalia Segal Date: Tue, 21 Jun 2022 18:18:48 +0200 Subject: [PATCH 03/12] Use explicit paddings instead of tf.one_hot --- opennmt/layers/transformer.py | 28 ++++++++++------------------ 1 file changed, 10 insertions(+), 18 deletions(-) diff --git a/opennmt/layers/transformer.py b/opennmt/layers/transformer.py index 5713ea3cf..6285b0e58 100644 --- a/opennmt/layers/transformer.py +++ b/opennmt/layers/transformer.py @@ -131,13 +131,10 @@ def split_chunks(a, chunk_length, concat_3_chunks=True): batch, num_heads, timesteps, units_per_head = misc.shape_list(a) # Pad to a factor of chunk_length. - rank = a.shape.rank - timestep_axis = rank - 2 pad_len = -timesteps % chunk_length - paddings = pad_len * tf.one_hot([-1, timestep_axis], rank, axis=0, dtype=tf.int32) # batch, num_heads, timesteps padded, units_per_head - a_padded = tf.pad(tensor=a, paddings=paddings) - padded_len = misc.shape_list(a_padded)[timestep_axis] + a_padded = tf.pad(tensor=a, paddings=[[0, 0], [0, 0], [0, pad_len], [0, 0]]) + padded_len = misc.shape_list(a_padded)[2] # Chunk along timesteps axis. num_chunks = padded_len // chunk_length @@ -147,9 +144,10 @@ def split_chunks(a, chunk_length, concat_3_chunks=True): # Concatenate previous and next chunk to each chunk, for overlapping. if concat_3_chunks: - paddings = tf.one_hot([2, 2], rank + 1, axis=0, dtype=tf.int32) # batch, num_heads, 1 + num_chunks + 1, chunk_length, units_per_head - a_chunked_padded = tf.pad(a_chunked, paddings) + a_chunked_padded = tf.pad( + a_chunked, paddings=[[0, 0], [0, 0], [1, 1], [0, 0], [0, 0]] + ) # batch, num_heads, num_chunks, chunk_length*3, units_per_head a_chunked = tf.concat( [a_chunked_padded[:, :, i : (i + num_chunks), ...] for i in range(3)], 3 @@ -186,7 +184,6 @@ def chunk_att_mask(mask, chunk_length): # Broadcast on queries time dimension. mask = tf.expand_dims(mask, 1) mask = tf.broadcast_to(mask, [batch, timesteps, timesteps]) - rank = 3 # Pad to a factor of chunk_length. pad_len = -timesteps % chunk_length @@ -194,10 +191,9 @@ def chunk_att_mask(mask, chunk_length): padded_timesteps = misc.shape_list(mask)[-1] # Append chunk_length padding to timestep axis, before and after. - paddings = chunk_length * tf.one_hot( - [rank - 1, rank - 1], rank, axis=0, dtype=tf.int32 + mask_padded = tf.pad( + tensor=mask, paddings=[[0, 0], [0, 0], [chunk_length, chunk_length]] ) - mask_padded = tf.pad(tensor=mask, paddings=paddings) padded_len = misc.shape_list(mask_padded)[-1] mask_flattened = tf.reshape(mask_padded, shape=[batch, -1]) @@ -207,10 +203,7 @@ def chunk_att_mask(mask, chunk_length): skewed_padding_len = ( padded_timesteps * skewed_len - misc.shape_list(mask_flattened)[-1] ) - skewed_paddings = skewed_padding_len * tf.one_hot( - [-1, rank - 2], rank - 1, axis=0, dtype=tf.int32 - ) - mask_padded = tf.pad(mask_flattened, paddings=skewed_paddings) + mask_padded = tf.pad(mask_flattened, paddings=[[0, 0], [0, skewed_padding_len]]) skewed_shape = [batch, -1, skewed_len] mask_skewed = tf.reshape(mask_padded, shape=skewed_shape) mask_skewed = mask_skewed[:, :, : chunk_length * 2 + 1] @@ -220,10 +213,9 @@ def chunk_att_mask(mask, chunk_length): # Unskew each chunk to be compatible with chunked attention shape. unskewed_len = chunk_length * 3 - unskewed_paddings = chunk_length * tf.one_hot( - [-1, rank], rank + 1, axis=0, dtype=tf.int32 + mask_skewed_padded = tf.pad( + mask_skewed_chunked, paddings=[[0, 0], [0, 0], [0, 0], [0, chunk_length]] ) - mask_skewed_padded = tf.pad(mask_skewed_chunked, paddings=unskewed_paddings) mask_skewed_flattened = tf.reshape(mask_skewed_padded, shape=[batch, chunk_num, -1]) mask_skewed_flattened = mask_skewed_flattened[:, :, : (chunk_length * unskewed_len)] mask_unskewed = tf.reshape( From 1d8f323d17d87b0d57b51923ba7ae2b32bb9c256 Mon Sep 17 00:00:00 2001 From: Natalia Segal Date: Tue, 21 Jun 2022 18:22:54 +0200 Subject: [PATCH 04/12] Fix docstrings and add ETC paper reference --- opennmt/layers/transformer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/opennmt/layers/transformer.py b/opennmt/layers/transformer.py index 6285b0e58..51dd7b908 100644 --- a/opennmt/layers/transformer.py +++ b/opennmt/layers/transformer.py @@ -123,6 +123,7 @@ def split_chunks(a, chunk_length, concat_3_chunks=True): Args: a: A ``tf.Tensor`` of shape :math:`[B, H, T, D]`. chunk_length: The length of a chunk :math:`C`. + concat_3_chunks: Optional, if ``True``, append previous and following chunks to each chunk. Returns: A ``tf.Tensor`` of shape :math:`[B * N, H, C (* 3), D]`, where :math:`N` is the chunk number. @@ -339,7 +340,8 @@ def __init__( maximum_relative_position: Maximum relative position representation (from https://arxiv.org/abs/1803.02155). max_length_full_attention: Maximum sequence length for full attention. - If ``None``, use sparse attention for longer sequences. + If not ``None``, use sparse attention for longer sequences + (from https://arxiv.org/abs/2004.08483). local_attention_radius: Attention radius around each token for local sliding attention. kwargs: Additional layer arguments. """ From f513e20de0accbb789c177886b2778412fc37562 Mon Sep 17 00:00:00 2001 From: Natalia Segal Date: Tue, 21 Jun 2022 18:52:06 +0200 Subject: [PATCH 05/12] Combine chunks in a function, override queries, keys, and values. --- opennmt/layers/transformer.py | 81 ++++++++++++++++------------------- 1 file changed, 38 insertions(+), 43 deletions(-) diff --git a/opennmt/layers/transformer.py b/opennmt/layers/transformer.py index 51dd7b908..09fcc39aa 100644 --- a/opennmt/layers/transformer.py +++ b/opennmt/layers/transformer.py @@ -163,6 +163,37 @@ def split_chunks(a, chunk_length, concat_3_chunks=True): return tf.reshape(a_transposed, output_shape), num_chunks +def combine_chunks(a, num_chunks, unchunked_length): + # Unchunk + a_shape = misc.shape_list(a) + # batch, num_chunks, num_heads, chunk_length, self.num_units_per_head + a = tf.reshape( + a, + [ + a_shape[0] // num_chunks, + num_chunks, + a_shape[1], + a_shape[2], + a_shape[3], + ], + ) + # batch, num_heads, num_chunks, chunk_length, self.num_units_per_head + a = tf.transpose(a, perm=[0, 2, 1, 3, 4]) + a_shape = misc.shape_list(a) + a = tf.reshape( + a, + [ + a_shape[0], + a_shape[1], + a_shape[2] * a_shape[3], + a_shape[4], + ], + ) + + # Remove padding used for chunking. + return a[:, :, :unchunked_length, :] + + def chunk_att_mask(mask, chunk_length): """Transforms an attention mask into a chunked representation. @@ -488,15 +519,12 @@ def _compute_kv(x): # Dot product attention. if use_sparse_att: # batch x num_chunks, num_heads, chunk_length, units_per_head - queries_chunked, _ = split_chunks( - queries, chunk_length, concat_3_chunks=False - ) + queries, _ = split_chunks(queries, chunk_length, concat_3_chunks=False) # batch x num_chunks, num_heads, chunk_length*3, units_per_head - keys_chunked, _ = split_chunks(keys, chunk_length) - # batch x num_chunks, num_heads, chunk_length, chunk_length*3 - dot = tf.matmul(queries_chunked, keys_chunked, transpose_b=True) - else: - dot = tf.matmul(queries, keys, transpose_b=True) + keys, _ = split_chunks(keys, chunk_length) + # batch x num_chunks, num_heads, chunk_length*3, units_per_head + values, num_chunks = split_chunks(values, chunk_length) + dot = tf.matmul(queries, keys, transpose_b=True) if relative_repr_keys is not None: dot += matmul_with_relative_representations( queries, relative_repr_keys, transpose_b=True @@ -514,42 +542,9 @@ def _compute_kv(x): ) attn = tf.cast(tf.nn.softmax(tf.cast(dot, tf.float32)), dot.dtype) drop_attn = common.dropout(attn, self.dropout, training=training) + heads = tf.matmul(drop_attn, values) if use_sparse_att: - # batch x num_chunks, num_heads, chunk_length*3, units_per_head - values_chunked, num_chunks = split_chunks(values, chunk_length) - # batch x num_chunks, num_heads, chunk_length, units_per_head - heads = tf.matmul(drop_attn, values_chunked) - - # Unchunk - heads_shape = misc.shape_list(heads) - # batch, num_chunks, num_heads, chunk_length, self.num_units_per_head - heads = tf.reshape( - heads, - [ - heads_shape[0] // num_chunks, - num_chunks, - heads_shape[1], - heads_shape[2], - heads_shape[3], - ], - ) - # batch, num_heads, num_chunks, chunk_length, self.num_units_per_head - heads = tf.transpose(heads, perm=[0, 2, 1, 3, 4]) - heads_shape = misc.shape_list(heads) - heads = tf.reshape( - heads, - [ - heads_shape[0], - heads_shape[1], - heads_shape[2] * heads_shape[3], - heads_shape[4], - ], - ) - - # Remove padding used for chunking. - heads = heads[:, :, :queries_length, :] - else: - heads = tf.matmul(drop_attn, values) + heads = combine_chunks(heads, num_chunks, queries_length) if relative_repr_values is not None: heads += matmul_with_relative_representations( drop_attn, relative_repr_values From 6cd7516901850e0486ed721db0cfe30a7a564724 Mon Sep 17 00:00:00 2001 From: Natalia Segal Date: Fri, 26 Aug 2022 13:01:40 +0200 Subject: [PATCH 06/12] Add global attention --- opennmt/layers/transformer.py | 90 +++++++++++++++++++++++++++---- opennmt/tests/transformer_test.py | 82 +++++++++++++++++++++++----- 2 files changed, 149 insertions(+), 23 deletions(-) diff --git a/opennmt/layers/transformer.py b/opennmt/layers/transformer.py index 09fcc39aa..5da80921a 100644 --- a/opennmt/layers/transformer.py +++ b/opennmt/layers/transformer.py @@ -117,7 +117,7 @@ def matmul_with_relative_representations(a, b, transpose_b=False): return c -def split_chunks(a, chunk_length, concat_3_chunks=True): +def split_chunks(a, chunk_length, concat_3_chunks=True, global_length=0): """Splits a tensor into chunks along the timesteps axis. Args: @@ -129,6 +129,10 @@ def split_chunks(a, chunk_length, concat_3_chunks=True): A ``tf.Tensor`` of shape :math:`[B * N, H, C (* 3), D]`, where :math:`N` is the chunk number. """ + if global_length: + global_a = a[:, :, :global_length, :] + a = a[:, :, global_length:, :] + batch, num_heads, timesteps, units_per_head = misc.shape_list(a) # Pad to a factor of chunk_length. @@ -157,9 +161,17 @@ def split_chunks(a, chunk_length, concat_3_chunks=True): # Transpose and flatten first dimension (batch * num_chunks). # batch, num_chunks, num_heads, chunk_length (*3), units_per_head a_transposed = tf.transpose(a_chunked, perm=[0, 2, 1, 3, 4]) + + if global_length: + # batch, num_chunks, num_heads, global timesteps, units_per_head + expanded_global_a = tf.tile( + tf.expand_dims(global_a, 1), [1, num_chunks, 1, 1, 1] + ) + a_transposed = tf.concat([a_transposed, expanded_global_a], axis=3) + input_shape = misc.shape_list(a_transposed) output_shape = tf.concat([[batch * num_chunks], input_shape[2:]], 0) - # batch x num_chunks, num_heads, chunk_length (*3), units_per_head + # batch x num_chunks, num_heads, chunk_length (*3) + global_length, units_per_head return tf.reshape(a_transposed, output_shape), num_chunks @@ -194,7 +206,7 @@ def combine_chunks(a, num_chunks, unchunked_length): return a[:, :, :unchunked_length, :] -def chunk_att_mask(mask, chunk_length): +def chunk_att_mask(mask, chunk_length, global_length=0): """Transforms an attention mask into a chunked representation. Chunked mask masks everything but a sliding diagonal with a radius of ``chunk_length``. @@ -207,6 +219,10 @@ def chunk_att_mask(mask, chunk_length): A ``tf.Tensor`` of shape :math:`[B * N, C, C * 3]`, where :math:`N` is the number of chunks. """ + if global_length: + global_mask = mask[:, :global_length] + mask = mask[:, global_length:] + mask_shape = misc.shape_list(mask) batch = mask_shape[0] timesteps = mask_shape[-1] @@ -254,9 +270,17 @@ def chunk_att_mask(mask, chunk_length): mask_skewed_flattened, shape=[batch, chunk_num, chunk_length, chunk_length * 3] ) + if global_length: + # batch, num_chunks, chunk_length, global_length + expanded_global_mask = tf.tile( + global_mask[:, tf.newaxis, tf.newaxis, :], [1, chunk_num, chunk_length, 1] + ) + mask_unskewed = tf.concat([mask_unskewed, expanded_global_mask], axis=3) + # Flatten the first dimension to batch * chunk_num. return tf.reshape( - mask_unskewed, shape=[batch * chunk_num, chunk_length, chunk_length * 3] + mask_unskewed, + shape=[batch * chunk_num, chunk_length, chunk_length * 3 + global_length], ) @@ -359,6 +383,7 @@ def __init__( maximum_relative_position=None, max_length_full_attention=None, local_attention_radius=None, + global_attention_length=0, **kwargs ): """Initializes this layer. @@ -374,6 +399,7 @@ def __init__( If not ``None``, use sparse attention for longer sequences (from https://arxiv.org/abs/2004.08483). local_attention_radius: Attention radius around each token for local sliding attention. + global_attention_length: Number of tokens used for global attention with sparse attention. kwargs: Additional layer arguments. """ super().__init__(**kwargs) @@ -393,6 +419,7 @@ def __init__( self.maximum_relative_position = maximum_relative_position self.max_length_full_attention = max_length_full_attention self.local_attention_radius = local_attention_radius + self.global_attention_length = global_attention_length def map_v1_weights(self, weights): # V1 used conv1d layers that have a leading dimensions. @@ -513,17 +540,33 @@ def _compute_kv(x): raise ValueError("Sparse attention only supports self-attention.") if self.maximum_relative_position is not None: raise ValueError("Sparse attention doesn't support relative positions.") + if self.return_attention: + raise ValueError( + "Cannot return attention weights when using sparse attention." + ) + use_sparse_att = queries_length > self.max_length_full_attention chunk_length = self.local_attention_radius # Dot product attention. if use_sparse_att: + if self.global_attention_length: + global_queries = queries[:, :, : self.global_attention_length, :] + queries = queries[:, :, self.global_attention_length :, :] + global_keys = keys + global_values = values + global_dot = tf.matmul(global_queries, global_keys, transpose_b=True) + # batch x num_chunks, num_heads, chunk_length, units_per_head queries, _ = split_chunks(queries, chunk_length, concat_3_chunks=False) - # batch x num_chunks, num_heads, chunk_length*3, units_per_head - keys, _ = split_chunks(keys, chunk_length) - # batch x num_chunks, num_heads, chunk_length*3, units_per_head - values, num_chunks = split_chunks(values, chunk_length) + # batch x num_chunks, num_heads, chunk_length*3 + global_length, units_per_head + keys, _ = split_chunks( + keys, chunk_length, global_length=self.global_attention_length + ) + # batch x num_chunks, num_heads, chunk_length*3 + global_length, units_per_head + values, num_chunks = split_chunks( + values, chunk_length, global_length=self.global_attention_length + ) dot = tf.matmul(queries, keys, transpose_b=True) if relative_repr_keys is not None: dot += matmul_with_relative_representations( @@ -532,7 +575,8 @@ def _compute_kv(x): if mask is not None: mask = tf.cast(mask, tf.float32) if use_sparse_att: - mask = chunk_att_mask(mask, chunk_length) + global_mask = mask[:, tf.newaxis, tf.newaxis, :] + mask = chunk_att_mask(mask, chunk_length, self.global_attention_length) elif mask.shape.rank == 2: mask = tf.expand_dims(mask, 1) # Broadcast on time dimension. mask = tf.expand_dims(mask, 1) # Broadcast on head dimension. @@ -540,11 +584,30 @@ def _compute_kv(x): tf.cast(dot, tf.float32) * mask + ((1.0 - mask) * tf.float32.min), dot.dtype, ) + if self.global_attention_length: + global_dot = tf.cast( + tf.cast(global_dot, tf.float32) * global_mask + + ((1.0 - global_mask) * tf.float32.min), + global_dot.dtype, + ) + attn = tf.cast(tf.nn.softmax(tf.cast(dot, tf.float32)), dot.dtype) drop_attn = common.dropout(attn, self.dropout, training=training) heads = tf.matmul(drop_attn, values) + + if self.global_attention_length: + global_attn = tf.cast( + tf.nn.softmax(tf.cast(global_dot, tf.float32)), global_dot.dtype + ) + global_drop_attn = common.dropout( + global_attn, self.dropout, training=training + ) + global_heads = tf.matmul(global_drop_attn, global_values) + if use_sparse_att: - heads = combine_chunks(heads, num_chunks, queries_length) + heads = combine_chunks( + heads, num_chunks, queries_length - self.global_attention_length + ) if relative_repr_values is not None: heads += matmul_with_relative_representations( drop_attn, relative_repr_values @@ -553,6 +616,13 @@ def _compute_kv(x): # Concatenate all heads output. combined = combine_heads(heads) outputs = self.linear_output(combined) + if self.global_attention_length: + global_combined = combine_heads(global_heads) + global_outputs = self.linear_output( + global_combined + ) # TODO : a separate global linear input and output layers ? + outputs = tf.concat((global_outputs, outputs), axis=1) + if self.return_attention: return outputs, cache, attn return outputs, cache diff --git a/opennmt/tests/transformer_test.py b/opennmt/tests/transformer_test.py index 2ef86bbb5..d3403cc1d 100644 --- a/opennmt/tests/transformer_test.py +++ b/opennmt/tests/transformer_test.py @@ -116,8 +116,23 @@ def testRelativePositions(self): [[2, 3, 4, 4], [1, 2, 3, 4], [0, 1, 2, 3], [0, 0, 1, 2]], ) - @parameterized.expand([[2, True], [2, False], [3, True], [3, False]]) - def testSplitChunks(self, chunk_length, concat_3_chunks): + @parameterized.expand( + [ + [2, True], + [2, False], + [3, True], + [3, False], + [2, True, 1], + [2, False, 1], + [3, True, 1], + [3, False, 1], + [2, True, 2], + [2, False, 2], + [3, True, 2], + [3, False, 2], + ] + ) + def testSplitChunks(self, chunk_length, concat_3_chunks, global_length=0): batch = 3 length = [5, 3, 7] num_heads = 4 @@ -127,33 +142,62 @@ def testSplitChunks(self, chunk_length, concat_3_chunks): [batch, num_heads, max(length), depth], dtype=tf.float32 ) split, num_chunks = transformer.split_chunks( - inputs, chunk_length=chunk_length, concat_3_chunks=concat_3_chunks + inputs, + chunk_length=chunk_length, + concat_3_chunks=concat_3_chunks, + global_length=global_length, ) split_shape = split.shape self.assertEqual(num_chunks, split_shape[0] / batch) self.assertEqual(num_heads, split_shape[1]) chunk_length_eval = chunk_length * 3 if concat_3_chunks else chunk_length + chunk_length_eval += global_length self.assertEqual(chunk_length_eval, split_shape[2]) self.assertEqual(depth, split_shape[3]) @parameterized.expand( - [[tf.bool, 2], [tf.float32, 2], [tf.bool, 3], [tf.float32, 3]] + [ + [tf.bool, 2], + [tf.float32, 2], + [tf.bool, 3], + [tf.float32, 3], + [tf.bool, 2, 1], + [tf.float32, 2, 1], + [tf.bool, 3, 1], + [tf.float32, 3, 1], + [tf.bool, 2, 2], + [tf.float32, 2, 2], + [tf.bool, 3, 2], + [tf.float32, 3, 2], + ] ) - def testChunkAttentionMask(self, dtype, chunk_length): + def testChunkAttentionMask(self, dtype, chunk_length, global_length=0): length = [2, 4, 3] batch = len(length) maximum_length = 5 mask = tf.sequence_mask(lengths=length, maxlen=maximum_length, dtype=dtype) - mask_chunked = transformer.chunk_att_mask(mask, chunk_length=chunk_length) - output_shape = mask_chunked.shape - num_chunks = abs(-maximum_length // chunk_length) - self.assertEqual(num_chunks, output_shape[0] / batch) - self.assertEqual(chunk_length, output_shape[1]) - self.assertEqual(chunk_length * 3, output_shape[2]) + mask_chunked = transformer.chunk_att_mask( + mask, chunk_length=chunk_length, global_length=global_length + ) + ( + output_batch_times_chunks, + output_chunk_length, + output_expanded_chunk_length, + ) = mask_chunked.shape + if global_length: + maximum_length = maximum_length - global_length + length = [el - global_length for el in length] + num_chunks = abs(-(maximum_length) // chunk_length) + self.assertEqual(num_chunks * batch, output_batch_times_chunks) + self.assertEqual(chunk_length, output_chunk_length) + self.assertEqual(chunk_length * 3 + global_length, output_expanded_chunk_length) self.assertIs(mask_chunked.dtype, dtype) - expected = np.zeros(output_shape, dtype=dtype.as_numpy_dtype) + expected = np.zeros( + (output_batch_times_chunks, output_chunk_length, chunk_length * 3), + dtype=dtype.as_numpy_dtype, + ) token_radius = chunk_length * 2 + 1 for b in range(batch): @@ -170,6 +214,14 @@ def testChunkAttentionMask(self, dtype, chunk_length): expected[chunk_idx][ch_l][start_idx:end_idx] = 1 mask_chunked = self.evaluate(mask_chunked) + if global_length: + expanded_mask = np.repeat(mask, num_chunks, axis=0) + expanded_mask = np.repeat( + expanded_mask[:, np.newaxis, :], chunk_length, axis=1 + ) + expected = tf.concat( + (expected, expanded_mask[:, :, :global_length]), axis=2 + ) self.assertAllEqual(mask_chunked, expected) def testFeedForwardNetwork(self): @@ -217,7 +269,11 @@ def testMultiHeadSelfAttentionRelativePositionsWithCache(self): def testMultiHeadSelfAttentionSparse(self): attention = transformer.MultiHeadAttention( - 4, 20, local_attention_radius=2, max_length_full_attention=3 + 4, + 20, + local_attention_radius=2, + max_length_full_attention=3, + global_attention_length=2, ) x = tf.random.uniform([2, 9, 10]) mask = tf.sequence_mask([9, 7]) From 7170c207d95972b60efa683fba518546a0ea6354 Mon Sep 17 00:00:00 2001 From: Natalia Segal Date: Wed, 31 Aug 2022 17:34:44 +0200 Subject: [PATCH 07/12] Use itertools.product for test parameter combinations --- opennmt/tests/transformer_test.py | 36 ++++--------------------------- 1 file changed, 4 insertions(+), 32 deletions(-) diff --git a/opennmt/tests/transformer_test.py b/opennmt/tests/transformer_test.py index d3403cc1d..f13cc710d 100644 --- a/opennmt/tests/transformer_test.py +++ b/opennmt/tests/transformer_test.py @@ -1,3 +1,5 @@ +import itertools + import numpy as np import tensorflow as tf @@ -116,22 +118,7 @@ def testRelativePositions(self): [[2, 3, 4, 4], [1, 2, 3, 4], [0, 1, 2, 3], [0, 0, 1, 2]], ) - @parameterized.expand( - [ - [2, True], - [2, False], - [3, True], - [3, False], - [2, True, 1], - [2, False, 1], - [3, True, 1], - [3, False, 1], - [2, True, 2], - [2, False, 2], - [3, True, 2], - [3, False, 2], - ] - ) + @parameterized.expand(itertools.product([2, 3], [True, False], [0, 1, 2])) def testSplitChunks(self, chunk_length, concat_3_chunks, global_length=0): batch = 3 length = [5, 3, 7] @@ -155,22 +142,7 @@ def testSplitChunks(self, chunk_length, concat_3_chunks, global_length=0): self.assertEqual(chunk_length_eval, split_shape[2]) self.assertEqual(depth, split_shape[3]) - @parameterized.expand( - [ - [tf.bool, 2], - [tf.float32, 2], - [tf.bool, 3], - [tf.float32, 3], - [tf.bool, 2, 1], - [tf.float32, 2, 1], - [tf.bool, 3, 1], - [tf.float32, 3, 1], - [tf.bool, 2, 2], - [tf.float32, 2, 2], - [tf.bool, 3, 2], - [tf.float32, 3, 2], - ] - ) + @parameterized.expand(itertools.product([tf.bool, tf.float32], [2, 3], [0, 1, 2])) def testChunkAttentionMask(self, dtype, chunk_length, global_length=0): length = [2, 4, 3] batch = len(length) From 1f9a2de6a893118d6421ae2c35cf258d0761b854 Mon Sep 17 00:00:00 2001 From: Natalia Segal Date: Mon, 12 Sep 2022 11:37:30 +0200 Subject: [PATCH 08/12] Fix conditions for using global attention --- opennmt/layers/transformer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/opennmt/layers/transformer.py b/opennmt/layers/transformer.py index 5da80921a..0f69b9bba 100644 --- a/opennmt/layers/transformer.py +++ b/opennmt/layers/transformer.py @@ -584,7 +584,7 @@ def _compute_kv(x): tf.cast(dot, tf.float32) * mask + ((1.0 - mask) * tf.float32.min), dot.dtype, ) - if self.global_attention_length: + if use_sparse_att and self.global_attention_length: global_dot = tf.cast( tf.cast(global_dot, tf.float32) * global_mask + ((1.0 - global_mask) * tf.float32.min), @@ -595,7 +595,7 @@ def _compute_kv(x): drop_attn = common.dropout(attn, self.dropout, training=training) heads = tf.matmul(drop_attn, values) - if self.global_attention_length: + if use_sparse_att and self.global_attention_length: global_attn = tf.cast( tf.nn.softmax(tf.cast(global_dot, tf.float32)), global_dot.dtype ) @@ -616,7 +616,7 @@ def _compute_kv(x): # Concatenate all heads output. combined = combine_heads(heads) outputs = self.linear_output(combined) - if self.global_attention_length: + if use_sparse_att and self.global_attention_length: global_combined = combine_heads(global_heads) global_outputs = self.linear_output( global_combined From e8ffc95a6bd7804e49fd8b2f744a8c59d09eca33 Mon Sep 17 00:00:00 2001 From: Natalia Segal Date: Tue, 13 Sep 2022 14:18:07 +0200 Subject: [PATCH 09/12] Fix global attention mask --- opennmt/layers/transformer.py | 28 +++++++++++++++++----------- opennmt/tests/transformer_test.py | 9 +++++---- 2 files changed, 22 insertions(+), 15 deletions(-) diff --git a/opennmt/layers/transformer.py b/opennmt/layers/transformer.py index 0f69b9bba..9f00cc03f 100644 --- a/opennmt/layers/transformer.py +++ b/opennmt/layers/transformer.py @@ -219,10 +219,6 @@ def chunk_att_mask(mask, chunk_length, global_length=0): A ``tf.Tensor`` of shape :math:`[B * N, C, C * 3]`, where :math:`N` is the number of chunks. """ - if global_length: - global_mask = mask[:, :global_length] - mask = mask[:, global_length:] - mask_shape = misc.shape_list(mask) batch = mask_shape[0] timesteps = mask_shape[-1] @@ -233,9 +229,16 @@ def chunk_att_mask(mask, chunk_length, global_length=0): mask = tf.expand_dims(mask, 1) mask = tf.broadcast_to(mask, [batch, timesteps, timesteps]) + if global_length: + global_mask = mask[:, global_length:, :global_length] + mask = mask[:, global_length:, global_length:] + timesteps = timesteps - global_length + # Pad to a factor of chunk_length. pad_len = -timesteps % chunk_length mask = tf.pad(tensor=mask, paddings=[[0, 0], [0, pad_len], [0, pad_len]]) + if global_length: + global_mask = tf.pad(tensor=global_mask, paddings=[[0, 0], [0, pad_len], [0, 0]]) padded_timesteps = misc.shape_list(mask)[-1] # Append chunk_length padding to timestep axis, before and after. @@ -272,9 +275,7 @@ def chunk_att_mask(mask, chunk_length, global_length=0): if global_length: # batch, num_chunks, chunk_length, global_length - expanded_global_mask = tf.tile( - global_mask[:, tf.newaxis, tf.newaxis, :], [1, chunk_num, chunk_length, 1] - ) + expanded_global_mask = tf.reshape(global_mask, shape=[batch, chunk_num, chunk_length, global_length]) mask_unskewed = tf.concat([mask_unskewed, expanded_global_mask], axis=3) # Flatten the first dimension to batch * chunk_num. @@ -575,7 +576,12 @@ def _compute_kv(x): if mask is not None: mask = tf.cast(mask, tf.float32) if use_sparse_att: - global_mask = mask[:, tf.newaxis, tf.newaxis, :] + if self.global_attention_length: + if mask.shape.rank == 2: + global_mask = mask[:, tf.newaxis, :] + else: + global_mask = mask[:, :self.global_attention_length, :] + global_mask = global_mask[:, tf.newaxis, :, :] mask = chunk_att_mask(mask, chunk_length, self.global_attention_length) elif mask.shape.rank == 2: mask = tf.expand_dims(mask, 1) # Broadcast on time dimension. @@ -584,7 +590,7 @@ def _compute_kv(x): tf.cast(dot, tf.float32) * mask + ((1.0 - mask) * tf.float32.min), dot.dtype, ) - if use_sparse_att and self.global_attention_length: + if use_sparse_att and self.global_attention_length > 0: global_dot = tf.cast( tf.cast(global_dot, tf.float32) * global_mask + ((1.0 - global_mask) * tf.float32.min), @@ -595,7 +601,7 @@ def _compute_kv(x): drop_attn = common.dropout(attn, self.dropout, training=training) heads = tf.matmul(drop_attn, values) - if use_sparse_att and self.global_attention_length: + if use_sparse_att and self.global_attention_length > 0: global_attn = tf.cast( tf.nn.softmax(tf.cast(global_dot, tf.float32)), global_dot.dtype ) @@ -616,7 +622,7 @@ def _compute_kv(x): # Concatenate all heads output. combined = combine_heads(heads) outputs = self.linear_output(combined) - if use_sparse_att and self.global_attention_length: + if use_sparse_att and self.global_attention_length > 0: global_combined = combine_heads(global_heads) global_outputs = self.linear_output( global_combined diff --git a/opennmt/tests/transformer_test.py b/opennmt/tests/transformer_test.py index f13cc710d..97bced4d1 100644 --- a/opennmt/tests/transformer_test.py +++ b/opennmt/tests/transformer_test.py @@ -187,10 +187,11 @@ def testChunkAttentionMask(self, dtype, chunk_length, global_length=0): mask_chunked = self.evaluate(mask_chunked) if global_length: - expanded_mask = np.repeat(mask, num_chunks, axis=0) - expanded_mask = np.repeat( - expanded_mask[:, np.newaxis, :], chunk_length, axis=1 - ) + expanded_mask = tf.expand_dims(mask, 1) + expanded_mask = tf.broadcast_to(expanded_mask, [batch, maximum_length, maximum_length + global_length]) + pad = chunk_length*num_chunks - maximum_length + expanded_mask = tf.pad(tensor=expanded_mask, paddings=[[0, 0], [0, pad], [0, 0]]) + expanded_mask = tf.reshape(expanded_mask, [batch * num_chunks, chunk_length, maximum_length + global_length]) expected = tf.concat( (expected, expanded_mask[:, :, :global_length]), axis=2 ) From 927670ec7e29a5f0a871bf669b74628a6bcaa903 Mon Sep 17 00:00:00 2001 From: Natalia Segal Date: Tue, 20 Sep 2022 12:28:59 +0200 Subject: [PATCH 10/12] Apply black --- opennmt/layers/transformer.py | 10 +++++++--- opennmt/tests/transformer_test.py | 15 +++++++++++---- 2 files changed, 18 insertions(+), 7 deletions(-) diff --git a/opennmt/layers/transformer.py b/opennmt/layers/transformer.py index 9f00cc03f..79ca3d4d0 100644 --- a/opennmt/layers/transformer.py +++ b/opennmt/layers/transformer.py @@ -238,7 +238,9 @@ def chunk_att_mask(mask, chunk_length, global_length=0): pad_len = -timesteps % chunk_length mask = tf.pad(tensor=mask, paddings=[[0, 0], [0, pad_len], [0, pad_len]]) if global_length: - global_mask = tf.pad(tensor=global_mask, paddings=[[0, 0], [0, pad_len], [0, 0]]) + global_mask = tf.pad( + tensor=global_mask, paddings=[[0, 0], [0, pad_len], [0, 0]] + ) padded_timesteps = misc.shape_list(mask)[-1] # Append chunk_length padding to timestep axis, before and after. @@ -275,7 +277,9 @@ def chunk_att_mask(mask, chunk_length, global_length=0): if global_length: # batch, num_chunks, chunk_length, global_length - expanded_global_mask = tf.reshape(global_mask, shape=[batch, chunk_num, chunk_length, global_length]) + expanded_global_mask = tf.reshape( + global_mask, shape=[batch, chunk_num, chunk_length, global_length] + ) mask_unskewed = tf.concat([mask_unskewed, expanded_global_mask], axis=3) # Flatten the first dimension to batch * chunk_num. @@ -580,7 +584,7 @@ def _compute_kv(x): if mask.shape.rank == 2: global_mask = mask[:, tf.newaxis, :] else: - global_mask = mask[:, :self.global_attention_length, :] + global_mask = mask[:, : self.global_attention_length, :] global_mask = global_mask[:, tf.newaxis, :, :] mask = chunk_att_mask(mask, chunk_length, self.global_attention_length) elif mask.shape.rank == 2: diff --git a/opennmt/tests/transformer_test.py b/opennmt/tests/transformer_test.py index 97bced4d1..f81a5f35d 100644 --- a/opennmt/tests/transformer_test.py +++ b/opennmt/tests/transformer_test.py @@ -188,10 +188,17 @@ def testChunkAttentionMask(self, dtype, chunk_length, global_length=0): mask_chunked = self.evaluate(mask_chunked) if global_length: expanded_mask = tf.expand_dims(mask, 1) - expanded_mask = tf.broadcast_to(expanded_mask, [batch, maximum_length, maximum_length + global_length]) - pad = chunk_length*num_chunks - maximum_length - expanded_mask = tf.pad(tensor=expanded_mask, paddings=[[0, 0], [0, pad], [0, 0]]) - expanded_mask = tf.reshape(expanded_mask, [batch * num_chunks, chunk_length, maximum_length + global_length]) + expanded_mask = tf.broadcast_to( + expanded_mask, [batch, maximum_length, maximum_length + global_length] + ) + pad = chunk_length * num_chunks - maximum_length + expanded_mask = tf.pad( + tensor=expanded_mask, paddings=[[0, 0], [0, pad], [0, 0]] + ) + expanded_mask = tf.reshape( + expanded_mask, + [batch * num_chunks, chunk_length, maximum_length + global_length], + ) expected = tf.concat( (expected, expanded_mask[:, :, :global_length]), axis=2 ) From 9e3a0a13512644c019cc16205063472d8b230a86 Mon Sep 17 00:00:00 2001 From: Natalia Segal Date: Mon, 15 May 2023 17:33:16 +0200 Subject: [PATCH 11/12] Fix graph mode execution for sparse attention --- opennmt/layers/transformer.py | 156 +++++++++++++++++++++------------- 1 file changed, 99 insertions(+), 57 deletions(-) diff --git a/opennmt/layers/transformer.py b/opennmt/layers/transformer.py index 79ca3d4d0..3f8f2120d 100644 --- a/opennmt/layers/transformer.py +++ b/opennmt/layers/transformer.py @@ -175,6 +175,20 @@ def split_chunks(a, chunk_length, concat_3_chunks=True, global_length=0): return tf.reshape(a_transposed, output_shape), num_chunks +def split_qkv(queries, keys, values, chunk_length, global_attention_length): + + # batch x num_chunks, num_heads, chunk_length, units_per_head + queries, _ = split_chunks(queries, chunk_length, concat_3_chunks=False) + # batch x num_chunks, num_heads, chunk_length*3 + global_length, units_per_head + keys, _ = split_chunks(keys, chunk_length, global_length=global_attention_length) + # batch x num_chunks, num_heads, chunk_length*3 + global_length, units_per_head + values, num_chunks = split_chunks( + values, chunk_length, global_length=global_attention_length + ) + + return queries, keys, values, num_chunks + + def combine_chunks(a, num_chunks, unchunked_length): # Unchunk a_shape = misc.shape_list(a) @@ -289,6 +303,13 @@ def chunk_att_mask(mask, chunk_length, global_length=0): ) +def calculate_attn(dot, values, dropout, training): + attn = tf.cast(tf.nn.softmax(tf.cast(dot, tf.float32)), dot.dtype) + drop_attn = common.dropout(attn, dropout, training=training) + heads = tf.matmul(drop_attn, values) + return heads, attn, drop_attn + + class FeedForwardNetwork(tf.keras.layers.Layer): """Implements the Transformer's "Feed Forward" layer. @@ -387,7 +408,7 @@ def __init__( return_attention=False, maximum_relative_position=None, max_length_full_attention=None, - local_attention_radius=None, + local_attention_radius=0, global_attention_length=0, **kwargs ): @@ -539,7 +560,6 @@ def _compute_kv(x): queries_length = misc.shape_list(queries)[2] - use_sparse_att = False if self.max_length_full_attention is not None: if memory is not None: raise ValueError("Sparse attention only supports self-attention.") @@ -550,73 +570,88 @@ def _compute_kv(x): "Cannot return attention weights when using sparse attention." ) - use_sparse_att = queries_length > self.max_length_full_attention - - chunk_length = self.local_attention_radius - # Dot product attention. - if use_sparse_att: - if self.global_attention_length: - global_queries = queries[:, :, : self.global_attention_length, :] - queries = queries[:, :, self.global_attention_length :, :] + if self.max_length_full_attention is not None: + use_sparse_att = tf.less(self.max_length_full_attention, queries_length) + if self.global_attention_length > 0: global_keys = keys global_values = values - global_dot = tf.matmul(global_queries, global_keys, transpose_b=True) - - # batch x num_chunks, num_heads, chunk_length, units_per_head - queries, _ = split_chunks(queries, chunk_length, concat_3_chunks=False) - # batch x num_chunks, num_heads, chunk_length*3 + global_length, units_per_head - keys, _ = split_chunks( - keys, chunk_length, global_length=self.global_attention_length - ) - # batch x num_chunks, num_heads, chunk_length*3 + global_length, units_per_head - values, num_chunks = split_chunks( - values, chunk_length, global_length=self.global_attention_length + global_queries = queries + if use_sparse_att: + queries = queries[:, :, self.global_attention_length :, :] + global_queries = queries[:, :, : self.global_attention_length, :] + queries, keys, values, num_chunks = tf.cond( + use_sparse_att, + lambda: split_qkv( + queries, + keys, + values, + self.local_attention_radius, + self.global_attention_length, + ), + lambda: (queries, keys, values, 0), ) + # Dot product attention. dot = tf.matmul(queries, keys, transpose_b=True) if relative_repr_keys is not None: dot += matmul_with_relative_representations( queries, relative_repr_keys, transpose_b=True ) + if ( + self.max_length_full_attention is not None + and self.global_attention_length > 0 + ): + global_dot = global_queries + if use_sparse_att: + global_dot = tf.matmul(global_queries, global_keys, transpose_b=True) + if mask is not None: mask = tf.cast(mask, tf.float32) - if use_sparse_att: - if self.global_attention_length: - if mask.shape.rank == 2: - global_mask = mask[:, tf.newaxis, :] - else: - global_mask = mask[:, : self.global_attention_length, :] - global_mask = global_mask[:, tf.newaxis, :, :] - mask = chunk_att_mask(mask, chunk_length, self.global_attention_length) + if self.max_length_full_attention is not None: + if self.global_attention_length > 0: + global_mask = mask + if use_sparse_att: + if mask.shape.rank == 2: + global_mask = mask[:, tf.newaxis, :] + else: + global_mask = mask[:, : self.global_attention_length, :] + global_mask = global_mask[:, tf.newaxis, :, :] + global_dot = tf.cast( + tf.cast(global_dot, tf.float32) * global_mask + + ((1.0 - global_mask) * tf.float32.min), + global_dot.dtype, + ) + mask = tf.cond( + use_sparse_att, + lambda: chunk_att_mask( + mask, self.local_attention_radius, self.global_attention_length + ), + lambda: tf.expand_dims(mask, 1) if mask.shape.rank == 2 else mask, + ) elif mask.shape.rank == 2: - mask = tf.expand_dims(mask, 1) # Broadcast on time dimension. + mask = tf.expand_dims(mask, 1) mask = tf.expand_dims(mask, 1) # Broadcast on head dimension. dot = tf.cast( tf.cast(dot, tf.float32) * mask + ((1.0 - mask) * tf.float32.min), dot.dtype, ) - if use_sparse_att and self.global_attention_length > 0: - global_dot = tf.cast( - tf.cast(global_dot, tf.float32) * global_mask - + ((1.0 - global_mask) * tf.float32.min), - global_dot.dtype, - ) - attn = tf.cast(tf.nn.softmax(tf.cast(dot, tf.float32)), dot.dtype) - drop_attn = common.dropout(attn, self.dropout, training=training) - heads = tf.matmul(drop_attn, values) + heads, attn, drop_attn = calculate_attn(dot, values, self.dropout, training) - if use_sparse_att and self.global_attention_length > 0: - global_attn = tf.cast( - tf.nn.softmax(tf.cast(global_dot, tf.float32)), global_dot.dtype - ) - global_drop_attn = common.dropout( - global_attn, self.dropout, training=training - ) - global_heads = tf.matmul(global_drop_attn, global_values) - - if use_sparse_att: - heads = combine_chunks( - heads, num_chunks, queries_length - self.global_attention_length + if self.max_length_full_attention is not None: + if self.global_attention_length > 0: + global_heads = heads + global_attn = attn + if use_sparse_att: + global_heads, global_attn, _ = calculate_attn( + global_dot, global_values, self.dropout, training + ) + + heads = tf.cond( + use_sparse_att, + lambda: combine_chunks( + heads, num_chunks, queries_length - self.global_attention_length + ), + lambda: heads, ) if relative_repr_values is not None: heads += matmul_with_relative_representations( @@ -626,12 +661,19 @@ def _compute_kv(x): # Concatenate all heads output. combined = combine_heads(heads) outputs = self.linear_output(combined) - if use_sparse_att and self.global_attention_length > 0: - global_combined = combine_heads(global_heads) - global_outputs = self.linear_output( - global_combined - ) # TODO : a separate global linear input and output layers ? - outputs = tf.concat((global_outputs, outputs), axis=1) + + if ( + self.max_length_full_attention is not None + and self.global_attention_length > 0 + ): + global_combined = combined + global_outputs = outputs + if use_sparse_att: + global_combined = combine_heads(global_heads) + global_outputs = self.linear_output( + global_combined + ) # TODO : a separate global linear input and output layers ? + outputs = tf.concat((global_outputs, outputs), axis=1) if self.return_attention: return outputs, cache, attn From 6c1e090cb580c3354dcaf8dc8cdf54b30c78eb58 Mon Sep 17 00:00:00 2001 From: Natalia Segal Date: Fri, 26 May 2023 19:38:40 +0200 Subject: [PATCH 12/12] Move chunk functions and add conditions for short input lengths. --- opennmt/layers/transformer.py | 386 +++++++++++++++++----------------- 1 file changed, 195 insertions(+), 191 deletions(-) diff --git a/opennmt/layers/transformer.py b/opennmt/layers/transformer.py index e34bdf689..0f4d06107 100644 --- a/opennmt/layers/transformer.py +++ b/opennmt/layers/transformer.py @@ -117,192 +117,6 @@ def matmul_with_relative_representations(a, b, transpose_b=False): return c -def split_chunks(a, chunk_length, concat_3_chunks=True, global_length=0): - """Splits a tensor into chunks along the timesteps axis. - - Args: - a: A ``tf.Tensor`` of shape :math:`[B, H, T, D]`. - chunk_length: The length of a chunk :math:`C`. - concat_3_chunks: Optional, if ``True``, append previous and following chunks to each chunk. - - Returns: - A ``tf.Tensor`` of shape :math:`[B * N, H, C (* 3), D]`, where :math:`N` is the chunk number. - """ - - if global_length: - global_a = a[:, :, :global_length, :] - a = a[:, :, global_length:, :] - - batch, num_heads, timesteps, units_per_head = misc.shape_list(a) - - # Pad to a factor of chunk_length. - pad_len = -timesteps % chunk_length - # batch, num_heads, timesteps padded, units_per_head - a_padded = tf.pad(tensor=a, paddings=[[0, 0], [0, 0], [0, pad_len], [0, 0]]) - padded_len = misc.shape_list(a_padded)[2] - - # Chunk along timesteps axis. - num_chunks = padded_len // chunk_length - chunked_shape = [batch, num_heads, num_chunks, chunk_length, units_per_head] - # batch, num_heads, num_chunks, chunk_length, units_per_head - a_chunked = tf.reshape(a_padded, chunked_shape) - - # Concatenate previous and next chunk to each chunk, for overlapping. - if concat_3_chunks: - # batch, num_heads, 1 + num_chunks + 1, chunk_length, units_per_head - a_chunked_padded = tf.pad( - a_chunked, paddings=[[0, 0], [0, 0], [1, 1], [0, 0], [0, 0]] - ) - # batch, num_heads, num_chunks, chunk_length*3, units_per_head - a_chunked = tf.concat( - [a_chunked_padded[:, :, i : (i + num_chunks), ...] for i in range(3)], 3 - ) - - # Transpose and flatten first dimension (batch * num_chunks). - # batch, num_chunks, num_heads, chunk_length (*3), units_per_head - a_transposed = tf.transpose(a_chunked, perm=[0, 2, 1, 3, 4]) - - if global_length: - # batch, num_chunks, num_heads, global timesteps, units_per_head - expanded_global_a = tf.tile( - tf.expand_dims(global_a, 1), [1, num_chunks, 1, 1, 1] - ) - a_transposed = tf.concat([a_transposed, expanded_global_a], axis=3) - - input_shape = misc.shape_list(a_transposed) - output_shape = tf.concat([[batch * num_chunks], input_shape[2:]], 0) - # batch x num_chunks, num_heads, chunk_length (*3) + global_length, units_per_head - return tf.reshape(a_transposed, output_shape), num_chunks - - -def split_qkv(queries, keys, values, chunk_length, global_attention_length): - - # batch x num_chunks, num_heads, chunk_length, units_per_head - queries, _ = split_chunks(queries, chunk_length, concat_3_chunks=False) - # batch x num_chunks, num_heads, chunk_length*3 + global_length, units_per_head - keys, _ = split_chunks(keys, chunk_length, global_length=global_attention_length) - # batch x num_chunks, num_heads, chunk_length*3 + global_length, units_per_head - values, num_chunks = split_chunks( - values, chunk_length, global_length=global_attention_length - ) - - return queries, keys, values, num_chunks - - -def combine_chunks(a, num_chunks, unchunked_length): - # Unchunk - a_shape = misc.shape_list(a) - # batch, num_chunks, num_heads, chunk_length, self.num_units_per_head - a = tf.reshape( - a, - [ - a_shape[0] // num_chunks, - num_chunks, - a_shape[1], - a_shape[2], - a_shape[3], - ], - ) - # batch, num_heads, num_chunks, chunk_length, self.num_units_per_head - a = tf.transpose(a, perm=[0, 2, 1, 3, 4]) - a_shape = misc.shape_list(a) - a = tf.reshape( - a, - [ - a_shape[0], - a_shape[1], - a_shape[2] * a_shape[3], - a_shape[4], - ], - ) - - # Remove padding used for chunking. - return a[:, :, :unchunked_length, :] - - -def chunk_att_mask(mask, chunk_length, global_length=0): - """Transforms an attention mask into a chunked representation. - - Chunked mask masks everything but a sliding diagonal with a radius of ``chunk_length``. - - Args: - mask: A ``tf.Tensor`` of shape :math:`[B, T]` or :math:`[B, T, T]`. - chunk_length: The length of a chunk :math:`C`. - - Returns: - A ``tf.Tensor`` of shape :math:`[B * N, C, C * 3]`, where :math:`N` is the number of chunks. - """ - - mask_shape = misc.shape_list(mask) - batch = mask_shape[0] - timesteps = mask_shape[-1] - rank = len(mask_shape) - - if rank == 2: - # Broadcast on queries time dimension. - mask = tf.expand_dims(mask, 1) - mask = tf.broadcast_to(mask, [batch, timesteps, timesteps]) - - if global_length: - global_mask = mask[:, global_length:, :global_length] - mask = mask[:, global_length:, global_length:] - timesteps = timesteps - global_length - - # Pad to a factor of chunk_length. - pad_len = -timesteps % chunk_length - mask = tf.pad(tensor=mask, paddings=[[0, 0], [0, pad_len], [0, pad_len]]) - if global_length: - global_mask = tf.pad( - tensor=global_mask, paddings=[[0, 0], [0, pad_len], [0, 0]] - ) - padded_timesteps = misc.shape_list(mask)[-1] - - # Append chunk_length padding to timestep axis, before and after. - mask_padded = tf.pad( - tensor=mask, paddings=[[0, 0], [0, 0], [chunk_length, chunk_length]] - ) - padded_len = misc.shape_list(mask_padded)[-1] - mask_flattened = tf.reshape(mask_padded, shape=[batch, -1]) - - # Skew to the left by one and keep 2*chunk_length + 1 relevant locations. - # This corresponds to chunk_length radius around the diagonal. - skewed_len = padded_len + 1 - skewed_padding_len = ( - padded_timesteps * skewed_len - misc.shape_list(mask_flattened)[-1] - ) - mask_padded = tf.pad(mask_flattened, paddings=[[0, 0], [0, skewed_padding_len]]) - skewed_shape = [batch, -1, skewed_len] - mask_skewed = tf.reshape(mask_padded, shape=skewed_shape) - mask_skewed = mask_skewed[:, :, : chunk_length * 2 + 1] - - chunk_num = padded_timesteps // chunk_length - mask_skewed_chunked = tf.reshape(mask_skewed, [batch, chunk_num, chunk_length, -1]) - - # Unskew each chunk to be compatible with chunked attention shape. - unskewed_len = chunk_length * 3 - mask_skewed_padded = tf.pad( - mask_skewed_chunked, paddings=[[0, 0], [0, 0], [0, 0], [0, chunk_length]] - ) - mask_skewed_flattened = tf.reshape(mask_skewed_padded, shape=[batch, chunk_num, -1]) - mask_skewed_flattened = mask_skewed_flattened[:, :, : (chunk_length * unskewed_len)] - mask_unskewed = tf.reshape( - mask_skewed_flattened, shape=[batch, chunk_num, chunk_length, chunk_length * 3] - ) - - if global_length: - # batch, num_chunks, chunk_length, global_length - expanded_global_mask = tf.reshape( - global_mask, shape=[batch, chunk_num, chunk_length, global_length] - ) - mask_unskewed = tf.concat([mask_unskewed, expanded_global_mask], axis=3) - - # Flatten the first dimension to batch * chunk_num. - return tf.reshape( - mask_unskewed, - shape=[batch * chunk_num, chunk_length, chunk_length * 3 + global_length], - ) - - def calculate_attn(dot, values, dropout, training): attn = tf.cast(tf.nn.softmax(tf.cast(dot, tf.float32)), dot.dtype) drop_attn = common.dropout(attn, dropout, training=training) @@ -573,7 +387,11 @@ def _compute_kv(x): ) if self.max_length_full_attention is not None: - use_sparse_att = tf.less(self.max_length_full_attention, queries_length) + use_sparse_att = ( + tf.less(self.max_length_full_attention, queries_length) + & tf.less(0, queries_length) + & tf.less(self.global_attention_length, queries_length) + ) if self.global_attention_length > 0: global_keys = keys global_values = values @@ -618,7 +436,9 @@ def _compute_kv(x): else: global_mask = mask[:, : self.global_attention_length, :] global_mask = global_mask[:, tf.newaxis, :, :] - global_dot = (global_dot * global_mask) + (1.0 - global_mask) * global_dot.dtype.min + global_dot = (global_dot * global_mask) + ( + 1.0 - global_mask + ) * global_dot.dtype.min mask = tf.cond( use_sparse_att, lambda: chunk_att_mask( @@ -667,9 +487,7 @@ def _compute_kv(x): global_outputs = outputs if use_sparse_att: global_combined = combine_heads(global_heads) - global_outputs = self.linear_output( - global_combined - ) # TODO : a separate global linear input and output layers ? + global_outputs = self.linear_output(global_combined) outputs = tf.concat((global_outputs, outputs), axis=1) if self.return_attention: @@ -900,3 +718,189 @@ def call( outputs = self.ffn(outputs, training=training) cache = dict(self_kv=self_kv, memory_kv=memory_kv) return outputs, cache, attention + + +def split_chunks(a, chunk_length, concat_3_chunks=True, global_length=0): + """Splits a tensor into chunks along the timesteps axis. + + Args: + a: A ``tf.Tensor`` of shape :math:`[B, H, T, D]`. + chunk_length: The length of a chunk :math:`C`. + concat_3_chunks: Optional, if ``True``, append previous and following chunks to each chunk. + + Returns: + A ``tf.Tensor`` of shape :math:`[B * N, H, C (* 3), D]`, where :math:`N` is the chunk number. + """ + + if global_length: + global_a = a[:, :, :global_length, :] + a = a[:, :, global_length:, :] + + batch, num_heads, timesteps, units_per_head = misc.shape_list(a) + + # Pad to a factor of chunk_length. + pad_len = -timesteps % chunk_length + # batch, num_heads, timesteps padded, units_per_head + a_padded = tf.pad(tensor=a, paddings=[[0, 0], [0, 0], [0, pad_len], [0, 0]]) + padded_len = misc.shape_list(a_padded)[2] + + # Chunk along timesteps axis. + num_chunks = padded_len // chunk_length + chunked_shape = [batch, num_heads, num_chunks, chunk_length, units_per_head] + # batch, num_heads, num_chunks, chunk_length, units_per_head + a_chunked = tf.reshape(a_padded, chunked_shape) + + # Concatenate previous and next chunk to each chunk, for overlapping. + if concat_3_chunks: + # batch, num_heads, 1 + num_chunks + 1, chunk_length, units_per_head + a_chunked_padded = tf.pad( + a_chunked, paddings=[[0, 0], [0, 0], [1, 1], [0, 0], [0, 0]] + ) + # batch, num_heads, num_chunks, chunk_length*3, units_per_head + a_chunked = tf.concat( + [a_chunked_padded[:, :, i : (i + num_chunks), ...] for i in range(3)], 3 + ) + + # Transpose and flatten first dimension (batch * num_chunks). + # batch, num_chunks, num_heads, chunk_length (*3), units_per_head + a_transposed = tf.transpose(a_chunked, perm=[0, 2, 1, 3, 4]) + + if global_length: + # batch, num_chunks, num_heads, global timesteps, units_per_head + expanded_global_a = tf.tile( + tf.expand_dims(global_a, 1), [1, num_chunks, 1, 1, 1] + ) + a_transposed = tf.concat([a_transposed, expanded_global_a], axis=3) + + input_shape = misc.shape_list(a_transposed) + output_shape = tf.concat([[batch * num_chunks], input_shape[2:]], 0) + # batch x num_chunks, num_heads, chunk_length (*3) + global_length, units_per_head + return tf.reshape(a_transposed, output_shape), num_chunks + + +def split_qkv(queries, keys, values, chunk_length, global_attention_length): + + # batch x num_chunks, num_heads, chunk_length, units_per_head + queries, _ = split_chunks(queries, chunk_length, concat_3_chunks=False) + # batch x num_chunks, num_heads, chunk_length*3 + global_length, units_per_head + keys, _ = split_chunks(keys, chunk_length, global_length=global_attention_length) + # batch x num_chunks, num_heads, chunk_length*3 + global_length, units_per_head + values, num_chunks = split_chunks( + values, chunk_length, global_length=global_attention_length + ) + + return queries, keys, values, num_chunks + + +def combine_chunks(a, num_chunks, unchunked_length): + # Unchunk + a_shape = misc.shape_list(a) + # batch, num_chunks, num_heads, chunk_length, self.num_units_per_head + a = tf.reshape( + a, + [ + a_shape[0] // num_chunks, + num_chunks, + a_shape[1], + a_shape[2], + a_shape[3], + ], + ) + # batch, num_heads, num_chunks, chunk_length, self.num_units_per_head + a = tf.transpose(a, perm=[0, 2, 1, 3, 4]) + a_shape = misc.shape_list(a) + a = tf.reshape( + a, + [ + a_shape[0], + a_shape[1], + a_shape[2] * a_shape[3], + a_shape[4], + ], + ) + + # Remove padding used for chunking. + return a[:, :, :unchunked_length, :] + + +def chunk_att_mask(mask, chunk_length, global_length=0): + """Transforms an attention mask into a chunked representation. + + Chunked mask masks everything but a sliding diagonal with a radius of ``chunk_length``. + + Args: + mask: A ``tf.Tensor`` of shape :math:`[B, T]` or :math:`[B, T, T]`. + chunk_length: The length of a chunk :math:`C`. + + Returns: + A ``tf.Tensor`` of shape :math:`[B * N, C, C * 3]`, where :math:`N` is the number of chunks. + """ + + mask_shape = misc.shape_list(mask) + batch = mask_shape[0] + timesteps = mask_shape[-1] + rank = len(mask_shape) + + if rank == 2: + # Broadcast on queries time dimension. + mask = tf.expand_dims(mask, 1) + mask = tf.broadcast_to(mask, [batch, timesteps, timesteps]) + + if global_length: + global_mask = mask[:, global_length:, :global_length] + mask = mask[:, global_length:, global_length:] + timesteps = timesteps - global_length + + # Pad to a factor of chunk_length. + pad_len = -timesteps % chunk_length + mask = tf.pad(tensor=mask, paddings=[[0, 0], [0, pad_len], [0, pad_len]]) + if global_length: + global_mask = tf.pad( + tensor=global_mask, paddings=[[0, 0], [0, pad_len], [0, 0]] + ) + padded_timesteps = misc.shape_list(mask)[-1] + + # Append chunk_length padding to timestep axis, before and after. + mask_padded = tf.pad( + tensor=mask, paddings=[[0, 0], [0, 0], [chunk_length, chunk_length]] + ) + padded_len = misc.shape_list(mask_padded)[-1] + mask_flattened = tf.reshape(mask_padded, shape=[batch, -1]) + + # Skew to the left by one and keep 2*chunk_length + 1 relevant locations. + # This corresponds to chunk_length radius around the diagonal. + skewed_len = padded_len + 1 + skewed_padding_len = ( + padded_timesteps * skewed_len - misc.shape_list(mask_flattened)[-1] + ) + mask_padded = tf.pad(mask_flattened, paddings=[[0, 0], [0, skewed_padding_len]]) + skewed_shape = [batch, -1, skewed_len] + mask_skewed = tf.reshape(mask_padded, shape=skewed_shape) + mask_skewed = mask_skewed[:, :, : chunk_length * 2 + 1] + + chunk_num = padded_timesteps // chunk_length + mask_skewed_chunked = tf.reshape(mask_skewed, [batch, chunk_num, chunk_length, -1]) + + # Unskew each chunk to be compatible with chunked attention shape. + unskewed_len = chunk_length * 3 + mask_skewed_padded = tf.pad( + mask_skewed_chunked, paddings=[[0, 0], [0, 0], [0, 0], [0, chunk_length]] + ) + mask_skewed_flattened = tf.reshape(mask_skewed_padded, shape=[batch, chunk_num, -1]) + mask_skewed_flattened = mask_skewed_flattened[:, :, : (chunk_length * unskewed_len)] + mask_unskewed = tf.reshape( + mask_skewed_flattened, shape=[batch, chunk_num, chunk_length, chunk_length * 3] + ) + + if global_length: + # batch, num_chunks, chunk_length, global_length + expanded_global_mask = tf.reshape( + global_mask, shape=[batch, chunk_num, chunk_length, global_length] + ) + mask_unskewed = tf.concat([mask_unskewed, expanded_global_mask], axis=3) + + # Flatten the first dimension to batch * chunk_num. + return tf.reshape( + mask_unskewed, + shape=[batch * chunk_num, chunk_length, chunk_length * 3 + global_length], + )