Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP : First implementation of sliding window local sparse attention. #951

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
190 changes: 187 additions & 3 deletions opennmt/layers/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,125 @@ def matmul_with_relative_representations(a, b, transpose_b=False):
return c


def split_chunks(a, chunk_length, concat_3_chunks=True):
guillaumekln marked this conversation as resolved.
Show resolved Hide resolved
"""Splits a tensor into chunks along the timesteps axis.

Args:
a: A ``tf.Tensor`` of shape :math:`[B, H, T, D]`.
chunk_length: The length of a chunk :math:`C`.

Returns:
A ``tf.Tensor`` of shape :math:`[B * N, H, C (* 3), D]`, where :math:`N` is the chunk number.
"""

batch, num_heads, timesteps, units_per_head = misc.shape_list(a)

# Pad to a factor of chunk_length.
rank = a.shape.rank
timestep_axis = rank - 2
pad_len = -timesteps % chunk_length
paddings = pad_len * tf.one_hot([-1, timestep_axis], rank, axis=0, dtype=tf.int32)
# batch, num_heads, timesteps padded, units_per_head
a_padded = tf.pad(tensor=a, paddings=paddings)
guillaumekln marked this conversation as resolved.
Show resolved Hide resolved
padded_len = misc.shape_list(a_padded)[timestep_axis]

# Chunk along timesteps axis.
num_chunks = padded_len // chunk_length
chunked_shape = [batch, num_heads, num_chunks, chunk_length, units_per_head]
# batch, num_heads, num_chunks, chunk_length, units_per_head
a_chunked = tf.reshape(a_padded, chunked_shape)

# Concatenate previous and next chunk to each chunk, for overlapping.
if concat_3_chunks:
paddings = tf.one_hot([2, 2], rank + 1, axis=0, dtype=tf.int32)
# batch, num_heads, 1 + num_chunks + 1, chunk_length, units_per_head
a_chunked_padded = tf.pad(a_chunked, paddings)
# batch, num_heads, num_chunks, chunk_length*3, units_per_head
a_chunked = tf.concat(
[a_chunked_padded[:, :, i : (i + num_chunks), ...] for i in range(3)], 3
)

# Transpose and flatten first dimension (batch * num_chunks).
# batch, num_chunks, num_heads, chunk_length (*3), units_per_head
a_transposed = tf.transpose(a_chunked, perm=[0, 2, 1, 3, 4])
input_shape = misc.shape_list(a_transposed)
output_shape = tf.concat([[batch * num_chunks], input_shape[2:]], 0)
# batch x num_chunks, num_heads, chunk_length (*3), units_per_head
return tf.reshape(a_transposed, output_shape), num_chunks


def chunk_att_mask(mask, chunk_length):
"""Transforms an attention mask into a chunked representation.

Chunked mask masks everything but a sliding diagonal with a radius of ``chunk_length``.

Args:
mask: A ``tf.Tensor`` of shape :math:`[B, T]` or :math:`[B, T, T]`.
chunk_length: The length of a chunk :math:`C`.

Returns:
A ``tf.Tensor`` of shape :math:`[B * N, C, C * 3]`, where :math:`N` is the number of chunks.
"""

mask_shape = misc.shape_list(mask)
batch = mask_shape[0]
timesteps = mask_shape[-1]
rank = len(mask_shape)

if rank == 2:
# Broadcast on queries time dimension.
mask = tf.expand_dims(mask, 1)
mask = tf.broadcast_to(mask, [batch, timesteps, timesteps])
rank = 3

# Pad to a factor of chunk_length.
pad_len = -timesteps % chunk_length
mask = tf.pad(tensor=mask, paddings=[[0, 0], [0, pad_len], [0, pad_len]])
padded_timesteps = misc.shape_list(mask)[-1]

# Append chunk_length padding to timestep axis, before and after.
paddings = chunk_length * tf.one_hot(
[rank - 1, rank - 1], rank, axis=0, dtype=tf.int32
)
mask_padded = tf.pad(tensor=mask, paddings=paddings)
padded_len = misc.shape_list(mask_padded)[-1]
mask_flattened = tf.reshape(mask_padded, shape=[batch, -1])

# Skew to the left by one and keep 2*chunk_length + 1 relevant locations.
# This corresponds to chunk_length radius around the diagonal.
skewed_len = padded_len + 1
skewed_padding_len = (
padded_timesteps * skewed_len - misc.shape_list(mask_flattened)[-1]
)
skewed_paddings = skewed_padding_len * tf.one_hot(
[-1, rank - 2], rank - 1, axis=0, dtype=tf.int32
)
mask_padded = tf.pad(mask_flattened, paddings=skewed_paddings)
skewed_shape = [batch, -1, skewed_len]
mask_skewed = tf.reshape(mask_padded, shape=skewed_shape)
mask_skewed = mask_skewed[:, :, : chunk_length * 2 + 1]

chunk_num = padded_timesteps // chunk_length
mask_skewed_chunked = tf.reshape(mask_skewed, [batch, chunk_num, chunk_length, -1])

# Unskew each chunk to be compatible with chunked attention shape.
unskewed_len = chunk_length * 3
unskewed_paddings = chunk_length * tf.one_hot(
[-1, rank], rank + 1, axis=0, dtype=tf.int32
)
mask_skewed_padded = tf.pad(mask_skewed_chunked, paddings=unskewed_paddings)
mask_skewed_flattened = tf.reshape(mask_skewed_padded, shape=[batch, chunk_num, -1])
mask_skewed_flattened = mask_skewed_flattened[:, :, : (chunk_length * unskewed_len)]
mask_unskewed = tf.reshape(
mask_skewed_flattened, shape=[batch, chunk_num, chunk_length, chunk_length * 3]
)

# Flatten the first dimension to batch * chunk_num.
return tf.reshape(
mask_unskewed, shape=[batch * chunk_num, chunk_length, chunk_length * 3]
)


class FeedForwardNetwork(tf.keras.layers.Layer):
"""Implements the Transformer's "Feed Forward" layer.

Expand Down Expand Up @@ -214,6 +333,8 @@ def __init__(
dropout=0.1,
return_attention=False,
maximum_relative_position=None,
max_length_full_attention=None,
local_attention_radius=None,
**kwargs
):
"""Initializes this layer.
Expand All @@ -225,6 +346,9 @@ def __init__(
return_attention: If ``True``, also return the attention weights.
maximum_relative_position: Maximum relative position representation
(from https://arxiv.org/abs/1803.02155).
max_length_full_attention: Maximum sequence length for full attention.
If ``None``, use sparse attention for longer sequences.
guillaumekln marked this conversation as resolved.
Show resolved Hide resolved
local_attention_radius: Attention radius around each token for local sliding attention.
kwargs: Additional layer arguments.
"""
super().__init__(**kwargs)
Expand All @@ -242,6 +366,8 @@ def __init__(
self.dropout = dropout
self.return_attention = return_attention
self.maximum_relative_position = maximum_relative_position
self.max_length_full_attention = max_length_full_attention
self.local_attention_radius = local_attention_radius

def map_v1_weights(self, weights):
# V1 used conv1d layers that have a leading dimensions.
Expand Down Expand Up @@ -354,15 +480,38 @@ def _compute_kv(x):

cache = (keys, values)

queries_length = misc.shape_list(queries)[2]

use_sparse_att = False
if self.max_length_full_attention is not None:
if memory is not None:
raise ValueError("Sparse attention only supports self-attention.")
if self.maximum_relative_position is not None:
raise ValueError("Sparse attention doesn't support relative positions.")
use_sparse_att = queries_length > self.max_length_full_attention

chunk_length = self.local_attention_radius
# Dot product attention.
dot = tf.matmul(queries, keys, transpose_b=True)
if use_sparse_att:
guillaumekln marked this conversation as resolved.
Show resolved Hide resolved
# batch x num_chunks, num_heads, chunk_length, units_per_head
queries_chunked, _ = split_chunks(
queries, chunk_length, concat_3_chunks=False
)
# batch x num_chunks, num_heads, chunk_length*3, units_per_head
keys_chunked, _ = split_chunks(keys, chunk_length)
# batch x num_chunks, num_heads, chunk_length, chunk_length*3
dot = tf.matmul(queries_chunked, keys_chunked, transpose_b=True)
else:
dot = tf.matmul(queries, keys, transpose_b=True)
if relative_repr_keys is not None:
dot += matmul_with_relative_representations(
queries, relative_repr_keys, transpose_b=True
)
if mask is not None:
mask = tf.cast(mask, tf.float32)
if mask.shape.rank == 2:
if use_sparse_att:
mask = chunk_att_mask(mask, chunk_length)
elif mask.shape.rank == 2:
mask = tf.expand_dims(mask, 1) # Broadcast on time dimension.
mask = tf.expand_dims(mask, 1) # Broadcast on head dimension.
dot = tf.cast(
Expand All @@ -371,7 +520,42 @@ def _compute_kv(x):
)
attn = tf.cast(tf.nn.softmax(tf.cast(dot, tf.float32)), dot.dtype)
drop_attn = common.dropout(attn, self.dropout, training=training)
heads = tf.matmul(drop_attn, values)
if use_sparse_att:
# batch x num_chunks, num_heads, chunk_length*3, units_per_head
values_chunked, num_chunks = split_chunks(values, chunk_length)
# batch x num_chunks, num_heads, chunk_length, units_per_head
heads = tf.matmul(drop_attn, values_chunked)

# Unchunk
heads_shape = misc.shape_list(heads)
# batch, num_chunks, num_heads, chunk_length, self.num_units_per_head
heads = tf.reshape(
heads,
[
heads_shape[0] // num_chunks,
num_chunks,
heads_shape[1],
heads_shape[2],
heads_shape[3],
],
)
# batch, num_heads, num_chunks, chunk_length, self.num_units_per_head
heads = tf.transpose(heads, perm=[0, 2, 1, 3, 4])
heads_shape = misc.shape_list(heads)
heads = tf.reshape(
heads,
[
heads_shape[0],
heads_shape[1],
heads_shape[2] * heads_shape[3],
heads_shape[4],
],
)
guillaumekln marked this conversation as resolved.
Show resolved Hide resolved

# Remove padding used for chunking.
heads = heads[:, :, :queries_length, :]
else:
heads = tf.matmul(drop_attn, values)
if relative_repr_values is not None:
heads += matmul_with_relative_representations(
drop_attn, relative_repr_values
Expand Down
64 changes: 64 additions & 0 deletions opennmt/tests/transformer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,62 @@ def testRelativePositions(self):
[[2, 3, 4, 4], [1, 2, 3, 4], [0, 1, 2, 3], [0, 0, 1, 2]],
)

@parameterized.expand([[2, True], [2, False], [3, True], [3, False]])
def testSplitChunks(self, chunk_length, concat_3_chunks):
batch = 3
length = [5, 3, 7]
num_heads = 4
depth = 10

inputs = tf.random.normal(
[batch, num_heads, max(length), depth], dtype=tf.float32
)
split, num_chunks = transformer.split_chunks(
inputs, chunk_length=chunk_length, concat_3_chunks=concat_3_chunks
)
split_shape = split.shape
self.assertEqual(num_chunks, split_shape[0] / batch)
self.assertEqual(num_heads, split_shape[1])
chunk_length_eval = chunk_length * 3 if concat_3_chunks else chunk_length
self.assertEqual(chunk_length_eval, split_shape[2])
self.assertEqual(depth, split_shape[3])

@parameterized.expand(
[[tf.bool, 2], [tf.float32, 2], [tf.bool, 3], [tf.float32, 3]]
)
def testChunkAttentionMask(self, dtype, chunk_length):
length = [2, 4, 3]
batch = len(length)
maximum_length = 5
mask = tf.sequence_mask(lengths=length, maxlen=maximum_length, dtype=dtype)
mask_chunked = transformer.chunk_att_mask(mask, chunk_length=chunk_length)
output_shape = mask_chunked.shape
num_chunks = abs(-maximum_length // chunk_length)
self.assertEqual(num_chunks, output_shape[0] / batch)
self.assertEqual(chunk_length, output_shape[1])
self.assertEqual(chunk_length * 3, output_shape[2])

self.assertIs(mask_chunked.dtype, dtype)

expected = np.zeros(output_shape, dtype=dtype.as_numpy_dtype)

token_radius = chunk_length * 2 + 1
for b in range(batch):
seq_length = length[b]
for ch in range(num_chunks):
end = chunk_length + seq_length - chunk_length * ch
if end > 0:
chunk_idx = b * num_chunks + ch
for ch_l in range(chunk_length):
seq_length_idx = ch * chunk_length + ch_l
if seq_length_idx < maximum_length:
start_idx = ch_l if ch != 0 else chunk_length
end_idx = min(end, token_radius + ch_l)
expected[chunk_idx][ch_l][start_idx:end_idx] = 1

mask_chunked = self.evaluate(mask_chunked)
self.assertAllEqual(mask_chunked, expected)

def testFeedForwardNetwork(self):
ffn = transformer.FeedForwardNetwork(20, 10)
x = tf.random.uniform([4, 5, 10])
Expand Down Expand Up @@ -159,6 +215,14 @@ def testMultiHeadSelfAttentionRelativePositionsWithCache(self):
cache = (tf.zeros([4, 4, 0, 5]), tf.zeros([4, 4, 0, 5]))
_, cache = attention(x, cache=cache)

def testMultiHeadSelfAttentionSparse(self):
attention = transformer.MultiHeadAttention(
4, 20, local_attention_radius=2, max_length_full_attention=3
)
x = tf.random.uniform([2, 9, 10])
mask = tf.sequence_mask([9, 7])
attention(x, mask=mask)

def testMultiHeadSelfAttentionRelativeGradients(self):
attention = transformer.MultiHeadAttention(4, 20, maximum_relative_position=6)

Expand Down