Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Attention across documents. #213

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
107 changes: 68 additions & 39 deletions open_lm/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,58 +16,87 @@ def get_rectangular_mask(shape, q_seq_len, k_seq_len, device, dtype):
)[:, :, :, :k_seq_len]


def xformers_attn(queries, keys, values, is_causal):
def xformers_attn(queries, keys, values, is_causal, document_seqlens=None):
# xformers assumes q, k, v are [batch, seq_len, heads, embed_dim]
# We assume that queries match the last part of the key / value sequences
# see (https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.fmha.attn_bias.LowerTriangularFromBottomRightMask)
# we would like to replace the mask generation with: mask = xops.fmha.attn_bias.LowerTriangularFromBottomRightMask()
# sadly we cannot us this because it needs xformers>=0.0.23 and this is not compatible with torch<2.1.1 while llm-foundry requires torch<2.1.1
if document_seqlens is None or all(len(ds) == 1 for ds in document_seqlens):
GeorgiosSmyrnis marked this conversation as resolved.
Show resolved Hide resolved
# In this case, all the tokens inside the sequence (are considered to) come from the same document.
# The attention mask is constructed as a simple causal mask

mask = None
# If queries have shape [batch, 1, heads, dim] it means there is only one query in the sequence.
# In this case, there is no notion of causal masking, so we can just set the mask to None.
# This is actually needed to get the desired behavior with seq_len=1.
if is_causal and queries.shape[1] == keys.shape[1]:
mask = xops.LowerTriangularMask()
elif is_causal and queries.shape[1] > 1:
# Build causal mask that assumes queries are in the end of the sequence.
batch, q_seq_len, heads, _ = queries.shape
k_seq_len = keys.shape[1]
mask = get_rectangular_mask((batch, heads), q_seq_len, k_seq_len, queries.device, queries.dtype)

mask = None
# If queries have shape [batch, 1, heads, dim] it means there is only one query in the sequence.
# In this case, there is no notion of causal masking, so we can just set the mask to None.
# This is actually needed to get the desired behavior with seq_len=1.
if is_causal and queries.shape[1] == keys.shape[1]:
mask = xops.LowerTriangularMask()
elif is_causal and queries.shape[1] > 1:
# Build causal mask that assumes queries are in the end of the sequence.
else:
masks = []
batch, q_seq_len, heads, _ = queries.shape
k_seq_len = keys.shape[1]
mask = get_rectangular_mask((batch, heads), q_seq_len, k_seq_len, queries.device, queries.dtype)
dtype = queries.dtype
device = queries.device
for ds in document_seqlens:
if is_causal and queries.shape[1] == keys.shape[1]:
masks.append(
xops.fmha.attn_bias.BlockDiagonalCausalMask.from_seqlens(ds).materialize(
shape=(1, heads, q_seq_len, k_seq_len), device=device, dtype=dtype
)
)
elif is_causal and queries.shape[1] > 1:
masks.append(
xops.fmha.attn_bias.BlockDiagonalCausalFromBottomRightMask.from_seqlens(ds).materialize(
shape=(1, heads, q_seq_len, k_seq_len), device=device, dtype=dtype
)
)
mask = torch.cat(masks, dim=0)

return xops.memory_efficient_attention(queries, keys, values, attn_bias=mask)


def torch_attn(queries, keys, values, is_causal):
# Need to call contiguous in torch >=2.1, otherwise later calls to .view() fail.
# Possibly related: https://github.com/pytorch/pytorch/issues/110213 - behavior of scaled_dot_product_attention
# changed between 2.0 and 2.1
if is_causal and keys.shape[1] > queries.shape[1] > 1:
q_seq_len = queries.shape[1]
k_seq_len = keys.shape[1]
# Same as above, we would like to use:
# mask = xops.fmha.attn_bias.LowerTriangularFromBottomRightMask().materialize((1, 1, q_seq_len, k_seq_len), queries.dtype, queries.device)
mask = get_rectangular_mask((1, 1), q_seq_len, k_seq_len, queries.device, queries.dtype)
return (
F.scaled_dot_product_attention(
queries.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2), attn_mask=mask
def torch_attn(queries, keys, values, is_causal, document_seqlens=None):
if document_seqlens is None or len(document_seqlens) == 1:
# Need to call contiguous in torch >=2.1, otherwise later calls to .view() fail.
# Possibly related: https://github.com/pytorch/pytorch/issues/110213 - behavior of scaled_dot_product_attention
# changed between 2.0 and 2.1
if is_causal and keys.shape[1] > queries.shape[1] > 1:
q_seq_len = queries.shape[1]
k_seq_len = keys.shape[1]
# Same as above, we would like to use:
# mask = xops.fmha.attn_bias.LowerTriangularFromBottomRightMask().materialize((1, 1, q_seq_len, k_seq_len), queries.dtype, queries.device)
mask = get_rectangular_mask((1, 1), q_seq_len, k_seq_len, queries.device, queries.dtype)
return (
F.scaled_dot_product_attention(
queries.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2), attn_mask=mask
)
.transpose(1, 2)
.contiguous()
)
.transpose(1, 2)
.contiguous()
)
elif queries.shape[1] == 1:
return (
F.scaled_dot_product_attention(queries.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2))
.transpose(1, 2)
.contiguous()
)
else:
return (
F.scaled_dot_product_attention(
queries.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2), is_causal=is_causal
elif queries.shape[1] == 1:
return (
F.scaled_dot_product_attention(queries.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2))
.transpose(1, 2)
.contiguous()
)
.transpose(1, 2)
.contiguous()
)
else:
return (
F.scaled_dot_product_attention(
queries.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2), is_causal=is_causal
)
.transpose(1, 2)
.contiguous()
)

else:
raise NotImplementedError("Currently supporting --mask-across-documents only with xformers attention.")


ATTN_ACTIVATIONS = {
Expand Down
1 change: 1 addition & 0 deletions open_lm/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,5 @@ def __init__(
self.eps = eps

def forward(self, input: Tensor, target: Tensor) -> Tensor:
# TODO: should ignore_index be taken into account in the regularization term as well?
return super().forward(input, target) + self.eps * torch.square(torch.logsumexp(input, dim=-1)).mean()
22 changes: 10 additions & 12 deletions open_lm/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ class Params:
moe_freq: int = 0
positional_embedding_type: str = "rotary"
ffn_type: str = "swiglu"
mask_across_documents: bool = False


def get_pos_embed(args: Params):
Expand Down Expand Up @@ -153,7 +154,7 @@ def reset_parameters(self):
std = std / math.sqrt(2 * (self.layer_id + 1))
torch.nn.init.trunc_normal_(self.out_proj.weight, std=std, a=-3 * std, b=3 * std)

def forward(self, x: torch.Tensor, is_causal=True, past_key_value=None, use_cache=False):
def forward(self, x: torch.Tensor, is_causal=True, past_key_value=None, use_cache=False, document_seqlens=None):
batchsize, q_len, _ = x.shape
queries, keys, vals = self.in_proj(x).chunk(3, dim=-1)

Expand All @@ -174,12 +175,7 @@ def forward(self, x: torch.Tensor, is_causal=True, past_key_value=None, use_cach
if use_cache:
past_key_value = [keys, vals]

output = self.attn_fn(
queries,
keys,
vals,
is_causal=is_causal,
)
output = self.attn_fn(queries, keys, vals, is_causal=is_causal, document_seqlens=document_seqlens)

output = output.view(batchsize, q_len, -1)

Expand Down Expand Up @@ -250,12 +246,13 @@ def reset_parameters(self):
std = std / math.sqrt(2 * (self._layer_id + 1))
torch.nn.init.trunc_normal_(self._ff_w2.weight, std=std, a=-3 * std, b=3 * std)

def forward(self, x, past_key_value=None, use_cache=False):
def forward(self, x, past_key_value=None, use_cache=False, document_seqlens=None):
h, past_key_value = self.attention(
self.attention_norm(x),
is_causal=True,
past_key_value=past_key_value,
use_cache=use_cache,
document_seqlens=document_seqlens,
)
h = x + h
if self._ffn_type == "moe":
Expand Down Expand Up @@ -320,19 +317,20 @@ def reset_parameters(self):
def set_grad_checkpointing(self, enable=True):
self.grad_checkpointing = enable

def forward(self, input, past_key_values=None, use_cache=False):
def forward(self, input, past_key_values=None, use_cache=False, document_seqlens=None):
x = self.tok_embeddings(input)
x = self.post_embed_norm(x)

if past_key_values is None:
past_key_values = [None] * self.n_layers
elif isinstance(past_key_values, tuple):
past_key_values = list(past_key_values)
for i, layer in enumerate(self.layers):
if self.grad_checkpointing:
x, past_key_values[i] = checkpoint(layer, x, past_key_values[i], use_cache)
x, past_key_values[i] = checkpoint(layer, x, past_key_values[i], use_cache, document_seqlens)
else:
x, past_key_values[i] = layer(x, past_key_values[i], use_cache=use_cache)
x, past_key_values[i] = layer(
x, past_key_values[i], use_cache=use_cache, document_seqlens=document_seqlens
)
if past_key_values[0] is None:
past_key_values = None
x = self.norm(x)
Expand Down
5 changes: 5 additions & 0 deletions open_lm/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -742,6 +742,11 @@ def parse_args(args):
action="store_true",
help="If set, allow model to do multiple data passes over our dataset, in order to reach the desired number of tokens.",
)
parser.add_argument(
"--mask-across-documents",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think this should be an int not a bool so that a user can specify their EOT token

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense - will update the parameter.

action="store_true",
help="If set, then tokens in the same sequence will be masked across EOT.",
)

add_model_args(parser)

Expand Down
48 changes: 46 additions & 2 deletions open_lm/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
wandb = None

from open_lm.data import sample_chunk
from open_lm.datapreprocess.ray.tokenize_shuffle import SpecialTokens
from open_lm.distributed import is_master
from open_lm.precision import get_autocast
from open_lm.meters import AverageMeter
Expand All @@ -41,6 +42,34 @@ def backward(total_loss, scaler):
total_loss.backward()


def get_document_seqlens(inputs, args):
"""Get list of document sequence lengths.

Return a list of lists. The length of the outer list is equal to the batch size, while the length of the inner list
is equal to the the number of distinct documents (recognized by EOT tokens). Each element of the inner lists is the
length of that corresponding document
"""
if args.mask_across_documents:
document_seqlens = []
for idx in range(inputs.shape[0]):
eot_idx = torch.nonzero(inputs[idx] == SpecialTokens.END_OF_TEXT.value)
if eot_idx.shape[0] == 0:
# All tokens come from the same document.
document_seqlens.append([args.seq_len])
else:
start_idx = 0
seqlens = []
for k in range(eot_idx.shape[0]):
seqlens.append(eot_idx[k].item() - start_idx + 1)
start_idx = eot_idx[k].item() + 1
if start_idx < args.seq_len:
seqlens.append(args.seq_len - start_idx)
document_seqlens.append(seqlens)
else:
document_seqlens = None
return document_seqlens


def train_one_epoch(model, data, loss, epoch, step, optimizer, scaler, scheduler, total_steps, args, tb_writer=None):
"""Trains model for one epoch on the provided data.

Expand Down Expand Up @@ -109,13 +138,21 @@ def train_one_epoch(model, data, loss, epoch, step, optimizer, scaler, scheduler

(texts,) = batch
texts = torch.LongTensor(texts).to(device)

data_time_m.update(time.time() - end)
optimizer.zero_grad()

if args.accum_freq == 1:
with autocast():
inputs, targets = sample_chunk(texts, args)
out, _, _ = model(inputs)
document_seqlens = get_document_seqlens(inputs, args)
if args.mask_across_documents:
# Some input samples contain EOT as the final token. The prediction after that is meaningless, so it
# should not contribute to the loss.
ignore_indices = torch.nonzero(inputs == SpecialTokens.END_OF_TEXT.value, as_tuple=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i prefer not to hard code our EOT to keep open_lm tokenizer agnostic

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed - I'll change it so that it uses the user defined EOT token.

targets[ignore_indices] = loss.ignore_index

out, _, _ = model(inputs, document_seqlens=document_seqlens)

if args.log_logit_mean:
logit_m.update(torch.mean(out).item())
Expand All @@ -135,6 +172,12 @@ def train_one_epoch(model, data, loss, epoch, step, optimizer, scaler, scheduler
per_batch = args.per_gpu_batch_size // args.accum_freq

inputs, targets = sample_chunk(texts, args)
if args.mask_across_documents:
# Some input samples contain EOT as the final token. The prediction after that is meaningless, so it
# should not contribute to the loss.
ignore_indices = torch.nonzero(inputs == SpecialTokens.END_OF_TEXT.value, as_tuple=True)
targets = targets.detach().clone() # Clone this because it shares mem with input!
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting, is the detach necessary here? When args.mask_across_documents is False, should we also a detach()?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Detach is not necessary, but clone is - because the targets and the input share the underlying tensor, if the target is explicitly set then the input is also affected.

When args.mask_across_documents is False, this is not an issue - neither the target nor the input are explicitly changed.

targets[ignore_indices] = loss.ignore_index

for ii in range(args.accum_freq):
maybe_no_sync = nullcontext
Expand All @@ -147,7 +190,8 @@ def train_one_epoch(model, data, loss, epoch, step, optimizer, scaler, scheduler
if inputs_ii.shape[0] == 0:
break
targets_ii = targets[ii * per_batch : (ii + 1) * per_batch]
out, _, _ = model(inputs_ii)
document_seqlens = get_document_seqlens(inputs_ii, args)
out, _, _ = model(inputs_ii, document_seqlens=document_seqlens)

if args.log_logit_mean:
logit_m.update(torch.mean(out).item())
Expand Down
Loading