diff --git a/open_lm/attention.py b/open_lm/attention.py index 9709d891..7357a512 100644 --- a/open_lm/attention.py +++ b/open_lm/attention.py @@ -16,58 +16,87 @@ def get_rectangular_mask(shape, q_seq_len, k_seq_len, device, dtype): )[:, :, :, :k_seq_len] -def xformers_attn(queries, keys, values, is_causal): +def xformers_attn(queries, keys, values, is_causal, document_seqlens=None): # xformers assumes q, k, v are [batch, seq_len, heads, embed_dim] # We assume that queries match the last part of the key / value sequences # see (https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.fmha.attn_bias.LowerTriangularFromBottomRightMask) # we would like to replace the mask generation with: mask = xops.fmha.attn_bias.LowerTriangularFromBottomRightMask() # sadly we cannot us this because it needs xformers>=0.0.23 and this is not compatible with torch<2.1.1 while llm-foundry requires torch<2.1.1 + if document_seqlens is None or all(len(ds) == 1 for ds in document_seqlens): + # In this case, all the tokens inside the sequence (are considered to) come from the same document. + # The attention mask is constructed as a simple causal mask + + mask = None + # If queries have shape [batch, 1, heads, dim] it means there is only one query in the sequence. + # In this case, there is no notion of causal masking, so we can just set the mask to None. + # This is actually needed to get the desired behavior with seq_len=1. + if is_causal and queries.shape[1] == keys.shape[1]: + mask = xops.LowerTriangularMask() + elif is_causal and queries.shape[1] > 1: + # Build causal mask that assumes queries are in the end of the sequence. + batch, q_seq_len, heads, _ = queries.shape + k_seq_len = keys.shape[1] + mask = get_rectangular_mask((batch, heads), q_seq_len, k_seq_len, queries.device, queries.dtype) - mask = None - # If queries have shape [batch, 1, heads, dim] it means there is only one query in the sequence. - # In this case, there is no notion of causal masking, so we can just set the mask to None. - # This is actually needed to get the desired behavior with seq_len=1. - if is_causal and queries.shape[1] == keys.shape[1]: - mask = xops.LowerTriangularMask() - elif is_causal and queries.shape[1] > 1: - # Build causal mask that assumes queries are in the end of the sequence. + else: + masks = [] batch, q_seq_len, heads, _ = queries.shape k_seq_len = keys.shape[1] - mask = get_rectangular_mask((batch, heads), q_seq_len, k_seq_len, queries.device, queries.dtype) + dtype = queries.dtype + device = queries.device + for ds in document_seqlens: + if is_causal and queries.shape[1] == keys.shape[1]: + masks.append( + xops.fmha.attn_bias.BlockDiagonalCausalMask.from_seqlens(ds).materialize( + shape=(1, heads, q_seq_len, k_seq_len), device=device, dtype=dtype + ) + ) + elif is_causal and queries.shape[1] > 1: + masks.append( + xops.fmha.attn_bias.BlockDiagonalCausalFromBottomRightMask.from_seqlens(ds).materialize( + shape=(1, heads, q_seq_len, k_seq_len), device=device, dtype=dtype + ) + ) + mask = torch.cat(masks, dim=0) + return xops.memory_efficient_attention(queries, keys, values, attn_bias=mask) -def torch_attn(queries, keys, values, is_causal): - # Need to call contiguous in torch >=2.1, otherwise later calls to .view() fail. - # Possibly related: https://github.com/pytorch/pytorch/issues/110213 - behavior of scaled_dot_product_attention - # changed between 2.0 and 2.1 - if is_causal and keys.shape[1] > queries.shape[1] > 1: - q_seq_len = queries.shape[1] - k_seq_len = keys.shape[1] - # Same as above, we would like to use: - # mask = xops.fmha.attn_bias.LowerTriangularFromBottomRightMask().materialize((1, 1, q_seq_len, k_seq_len), queries.dtype, queries.device) - mask = get_rectangular_mask((1, 1), q_seq_len, k_seq_len, queries.device, queries.dtype) - return ( - F.scaled_dot_product_attention( - queries.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2), attn_mask=mask +def torch_attn(queries, keys, values, is_causal, document_seqlens=None): + if document_seqlens is None or len(document_seqlens) == 1: + # Need to call contiguous in torch >=2.1, otherwise later calls to .view() fail. + # Possibly related: https://github.com/pytorch/pytorch/issues/110213 - behavior of scaled_dot_product_attention + # changed between 2.0 and 2.1 + if is_causal and keys.shape[1] > queries.shape[1] > 1: + q_seq_len = queries.shape[1] + k_seq_len = keys.shape[1] + # Same as above, we would like to use: + # mask = xops.fmha.attn_bias.LowerTriangularFromBottomRightMask().materialize((1, 1, q_seq_len, k_seq_len), queries.dtype, queries.device) + mask = get_rectangular_mask((1, 1), q_seq_len, k_seq_len, queries.device, queries.dtype) + return ( + F.scaled_dot_product_attention( + queries.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2), attn_mask=mask + ) + .transpose(1, 2) + .contiguous() ) - .transpose(1, 2) - .contiguous() - ) - elif queries.shape[1] == 1: - return ( - F.scaled_dot_product_attention(queries.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2)) - .transpose(1, 2) - .contiguous() - ) - else: - return ( - F.scaled_dot_product_attention( - queries.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2), is_causal=is_causal + elif queries.shape[1] == 1: + return ( + F.scaled_dot_product_attention(queries.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2)) + .transpose(1, 2) + .contiguous() ) - .transpose(1, 2) - .contiguous() - ) + else: + return ( + F.scaled_dot_product_attention( + queries.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2), is_causal=is_causal + ) + .transpose(1, 2) + .contiguous() + ) + + else: + raise NotImplementedError("Currently supporting --mask-across-documents only with xformers attention.") ATTN_ACTIVATIONS = { diff --git a/open_lm/losses.py b/open_lm/losses.py index ef3839d6..76805cdd 100644 --- a/open_lm/losses.py +++ b/open_lm/losses.py @@ -18,4 +18,5 @@ def __init__( self.eps = eps def forward(self, input: Tensor, target: Tensor) -> Tensor: + # TODO: should ignore_index be taken into account in the regularization term as well? return super().forward(input, target) + self.eps * torch.square(torch.logsumexp(input, dim=-1)).mean() diff --git a/open_lm/model.py b/open_lm/model.py index 7484ca5d..d54c9421 100644 --- a/open_lm/model.py +++ b/open_lm/model.py @@ -98,6 +98,7 @@ class Params: moe_freq: int = 0 positional_embedding_type: str = "rotary" ffn_type: str = "swiglu" + mask_across_documents: bool = False def get_pos_embed(args: Params): @@ -153,7 +154,7 @@ def reset_parameters(self): std = std / math.sqrt(2 * (self.layer_id + 1)) torch.nn.init.trunc_normal_(self.out_proj.weight, std=std, a=-3 * std, b=3 * std) - def forward(self, x: torch.Tensor, is_causal=True, past_key_value=None, use_cache=False): + def forward(self, x: torch.Tensor, is_causal=True, past_key_value=None, use_cache=False, document_seqlens=None): batchsize, q_len, _ = x.shape queries, keys, vals = self.in_proj(x).chunk(3, dim=-1) @@ -174,12 +175,7 @@ def forward(self, x: torch.Tensor, is_causal=True, past_key_value=None, use_cach if use_cache: past_key_value = [keys, vals] - output = self.attn_fn( - queries, - keys, - vals, - is_causal=is_causal, - ) + output = self.attn_fn(queries, keys, vals, is_causal=is_causal, document_seqlens=document_seqlens) output = output.view(batchsize, q_len, -1) @@ -250,12 +246,13 @@ def reset_parameters(self): std = std / math.sqrt(2 * (self._layer_id + 1)) torch.nn.init.trunc_normal_(self._ff_w2.weight, std=std, a=-3 * std, b=3 * std) - def forward(self, x, past_key_value=None, use_cache=False): + def forward(self, x, past_key_value=None, use_cache=False, document_seqlens=None): h, past_key_value = self.attention( self.attention_norm(x), is_causal=True, past_key_value=past_key_value, use_cache=use_cache, + document_seqlens=document_seqlens, ) h = x + h if self._ffn_type == "moe": @@ -320,19 +317,20 @@ def reset_parameters(self): def set_grad_checkpointing(self, enable=True): self.grad_checkpointing = enable - def forward(self, input, past_key_values=None, use_cache=False): + def forward(self, input, past_key_values=None, use_cache=False, document_seqlens=None): x = self.tok_embeddings(input) x = self.post_embed_norm(x) - if past_key_values is None: past_key_values = [None] * self.n_layers elif isinstance(past_key_values, tuple): past_key_values = list(past_key_values) for i, layer in enumerate(self.layers): if self.grad_checkpointing: - x, past_key_values[i] = checkpoint(layer, x, past_key_values[i], use_cache) + x, past_key_values[i] = checkpoint(layer, x, past_key_values[i], use_cache, document_seqlens) else: - x, past_key_values[i] = layer(x, past_key_values[i], use_cache=use_cache) + x, past_key_values[i] = layer( + x, past_key_values[i], use_cache=use_cache, document_seqlens=document_seqlens + ) if past_key_values[0] is None: past_key_values = None x = self.norm(x) diff --git a/open_lm/params.py b/open_lm/params.py index 9d5efa7c..f0d7155c 100644 --- a/open_lm/params.py +++ b/open_lm/params.py @@ -742,6 +742,11 @@ def parse_args(args): action="store_true", help="If set, allow model to do multiple data passes over our dataset, in order to reach the desired number of tokens.", ) + parser.add_argument( + "--mask-across-documents", + action="store_true", + help="If set, then tokens in the same sequence will be masked across EOT.", + ) add_model_args(parser) diff --git a/open_lm/train.py b/open_lm/train.py index 4ff5a27e..b1b66b24 100644 --- a/open_lm/train.py +++ b/open_lm/train.py @@ -22,6 +22,7 @@ wandb = None from open_lm.data import sample_chunk +from open_lm.datapreprocess.ray.tokenize_shuffle import SpecialTokens from open_lm.distributed import is_master from open_lm.precision import get_autocast from open_lm.meters import AverageMeter @@ -41,6 +42,34 @@ def backward(total_loss, scaler): total_loss.backward() +def get_document_seqlens(inputs, args): + """Get list of document sequence lengths. + + Return a list of lists. The length of the outer list is equal to the batch size, while the length of the inner list + is equal to the the number of distinct documents (recognized by EOT tokens). Each element of the inner lists is the + length of that corresponding document + """ + if args.mask_across_documents: + document_seqlens = [] + for idx in range(inputs.shape[0]): + eot_idx = torch.nonzero(inputs[idx] == SpecialTokens.END_OF_TEXT.value) + if eot_idx.shape[0] == 0: + # All tokens come from the same document. + document_seqlens.append([args.seq_len]) + else: + start_idx = 0 + seqlens = [] + for k in range(eot_idx.shape[0]): + seqlens.append(eot_idx[k].item() - start_idx + 1) + start_idx = eot_idx[k].item() + 1 + if start_idx < args.seq_len: + seqlens.append(args.seq_len - start_idx) + document_seqlens.append(seqlens) + else: + document_seqlens = None + return document_seqlens + + def train_one_epoch(model, data, loss, epoch, step, optimizer, scaler, scheduler, total_steps, args, tb_writer=None): """Trains model for one epoch on the provided data. @@ -109,13 +138,21 @@ def train_one_epoch(model, data, loss, epoch, step, optimizer, scaler, scheduler (texts,) = batch texts = torch.LongTensor(texts).to(device) + data_time_m.update(time.time() - end) optimizer.zero_grad() if args.accum_freq == 1: with autocast(): inputs, targets = sample_chunk(texts, args) - out, _, _ = model(inputs) + document_seqlens = get_document_seqlens(inputs, args) + if args.mask_across_documents: + # Some input samples contain EOT as the final token. The prediction after that is meaningless, so it + # should not contribute to the loss. + ignore_indices = torch.nonzero(inputs == SpecialTokens.END_OF_TEXT.value, as_tuple=True) + targets[ignore_indices] = loss.ignore_index + + out, _, _ = model(inputs, document_seqlens=document_seqlens) if args.log_logit_mean: logit_m.update(torch.mean(out).item()) @@ -135,6 +172,12 @@ def train_one_epoch(model, data, loss, epoch, step, optimizer, scaler, scheduler per_batch = args.per_gpu_batch_size // args.accum_freq inputs, targets = sample_chunk(texts, args) + if args.mask_across_documents: + # Some input samples contain EOT as the final token. The prediction after that is meaningless, so it + # should not contribute to the loss. + ignore_indices = torch.nonzero(inputs == SpecialTokens.END_OF_TEXT.value, as_tuple=True) + targets = targets.detach().clone() # Clone this because it shares mem with input! + targets[ignore_indices] = loss.ignore_index for ii in range(args.accum_freq): maybe_no_sync = nullcontext @@ -147,7 +190,8 @@ def train_one_epoch(model, data, loss, epoch, step, optimizer, scaler, scheduler if inputs_ii.shape[0] == 0: break targets_ii = targets[ii * per_batch : (ii + 1) * per_batch] - out, _, _ = model(inputs_ii) + document_seqlens = get_document_seqlens(inputs_ii, args) + out, _, _ = model(inputs_ii, document_seqlens=document_seqlens) if args.log_logit_mean: logit_m.update(torch.mean(out).item())