From d7afc8c3d634644f585ffb79ba186fcd53d33ff6 Mon Sep 17 00:00:00 2001 From: George Smyrnis Date: Sun, 28 Jan 2024 22:55:12 -0600 Subject: [PATCH 01/12] Added extra attention parts. --- open_lm/attention.py | 36 ++++++++++++++++++++++++------------ open_lm/model.py | 12 +++++++----- open_lm/params.py | 5 +++++ 3 files changed, 36 insertions(+), 17 deletions(-) diff --git a/open_lm/attention.py b/open_lm/attention.py index 9709d891..e9c8099c 100644 --- a/open_lm/attention.py +++ b/open_lm/attention.py @@ -16,24 +16,36 @@ def get_rectangular_mask(shape, q_seq_len, k_seq_len, device, dtype): )[:, :, :, :k_seq_len] -def xformers_attn(queries, keys, values, is_causal): +def xformers_attn(queries, keys, values, is_causal, document_seqlens = None): # xformers assumes q, k, v are [batch, seq_len, heads, embed_dim] # We assume that queries match the last part of the key / value sequences # see (https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.fmha.attn_bias.LowerTriangularFromBottomRightMask) # we would like to replace the mask generation with: mask = xops.fmha.attn_bias.LowerTriangularFromBottomRightMask() # sadly we cannot us this because it needs xformers>=0.0.23 and this is not compatible with torch<2.1.1 while llm-foundry requires torch<2.1.1 - mask = None - # If queries have shape [batch, 1, heads, dim] it means there is only one query in the sequence. - # In this case, there is no notion of causal masking, so we can just set the mask to None. - # This is actually needed to get the desired behavior with seq_len=1. - if is_causal and queries.shape[1] == keys.shape[1]: - mask = xops.LowerTriangularMask() - elif is_causal and queries.shape[1] > 1: - # Build causal mask that assumes queries are in the end of the sequence. - batch, q_seq_len, heads, _ = queries.shape - k_seq_len = keys.shape[1] - mask = get_rectangular_mask((batch, heads), q_seq_len, k_seq_len, queries.device, queries.dtype) + if document_seqlens is None or len(document_seqlens) == 1: + # In this case, all the tokens inside the sequence (are considered to) come from the same document. + # The attention mask is constructed as a simple causal mask + + mask = None + # If queries have shape [batch, 1, heads, dim] it means there is only one query in the sequence. + # In this case, there is no notion of causal masking, so we can just set the mask to None. + # This is actually needed to get the desired behavior with seq_len=1. + if is_causal and queries.shape[1] == keys.shape[1]: + mask = xops.LowerTriangularMask() + elif is_causal and queries.shape[1] > 1: + # Build causal mask that assumes queries are in the end of the sequence. + batch, q_seq_len, heads, _ = queries.shape + k_seq_len = keys.shape[1] + mask = get_rectangular_mask((batch, heads), q_seq_len, k_seq_len, queries.device, queries.dtype) + + else: + mask = None + if is_causal and queries.shape[1] == keys.shape[1]: + mask = xops.fmha.attn_bias.BlockDiagonalCausalMask.from_seqlens(document_seqlens) + elif is_causal and queries.shape[1] > 1: + mask = xops.fmha.attn_bias.BlockDiagonalCausalFromBottomRightMask.from_seqlens(document_seqlens) + return xops.memory_efficient_attention(queries, keys, values, attn_bias=mask) diff --git a/open_lm/model.py b/open_lm/model.py index 7484ca5d..453cdb9a 100644 --- a/open_lm/model.py +++ b/open_lm/model.py @@ -98,6 +98,7 @@ class Params: moe_freq: int = 0 positional_embedding_type: str = "rotary" ffn_type: str = "swiglu" + mask_across_documents: bool = False def get_pos_embed(args: Params): @@ -153,7 +154,7 @@ def reset_parameters(self): std = std / math.sqrt(2 * (self.layer_id + 1)) torch.nn.init.trunc_normal_(self.out_proj.weight, std=std, a=-3 * std, b=3 * std) - def forward(self, x: torch.Tensor, is_causal=True, past_key_value=None, use_cache=False): + def forward(self, x: torch.Tensor, is_causal=True, past_key_value=None, use_cache=False, document_seqlens=None): batchsize, q_len, _ = x.shape queries, keys, vals = self.in_proj(x).chunk(3, dim=-1) @@ -250,12 +251,13 @@ def reset_parameters(self): std = std / math.sqrt(2 * (self._layer_id + 1)) torch.nn.init.trunc_normal_(self._ff_w2.weight, std=std, a=-3 * std, b=3 * std) - def forward(self, x, past_key_value=None, use_cache=False): + def forward(self, x, past_key_value=None, use_cache=False, document_seqlens=None): h, past_key_value = self.attention( self.attention_norm(x), is_causal=True, past_key_value=past_key_value, use_cache=use_cache, + document_seqlens=document_seqlens ) h = x + h if self._ffn_type == "moe": @@ -320,7 +322,7 @@ def reset_parameters(self): def set_grad_checkpointing(self, enable=True): self.grad_checkpointing = enable - def forward(self, input, past_key_values=None, use_cache=False): + def forward(self, input, past_key_values=None, use_cache=False, document_seqlens=None): x = self.tok_embeddings(input) x = self.post_embed_norm(x) @@ -330,9 +332,9 @@ def forward(self, input, past_key_values=None, use_cache=False): past_key_values = list(past_key_values) for i, layer in enumerate(self.layers): if self.grad_checkpointing: - x, past_key_values[i] = checkpoint(layer, x, past_key_values[i], use_cache) + x, past_key_values[i] = checkpoint(layer, x, past_key_values[i], use_cache, document_seqlens) else: - x, past_key_values[i] = layer(x, past_key_values[i], use_cache=use_cache) + x, past_key_values[i] = layer(x, past_key_values[i], use_cache=use_cache, document_seqlens=document_seqlens) if past_key_values[0] is None: past_key_values = None x = self.norm(x) diff --git a/open_lm/params.py b/open_lm/params.py index 9d5efa7c..b336e766 100644 --- a/open_lm/params.py +++ b/open_lm/params.py @@ -742,6 +742,11 @@ def parse_args(args): action="store_true", help="If set, allow model to do multiple data passes over our dataset, in order to reach the desired number of tokens.", ) + parser.add_argument( + "--mask-across-documents", + action="store_true", + help="If set, then tokens in the same sequence will be masked across EOT." + ) add_model_args(parser) From 119df7177cca2cdbb1ceabcc430842ad0284fece Mon Sep 17 00:00:00 2001 From: GeorgiosSmyrnis Date: Sun, 28 Jan 2024 21:20:48 -0600 Subject: [PATCH 02/12] Update .gitignore (#208) * Update .gitignore * Fix requirements for env. * Remove test data prep file erroneously committed. * Revert requirements. * Update makefile. --------- Co-authored-by: George Smyrnis --- Makefile | 1 + 1 file changed, 1 insertion(+) diff --git a/Makefile b/Makefile index d9ac59b8..7d0249d4 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,6 @@ install: ## [Local development] Upgrade pip, install requirements, install package. python -m pip install -U pip + python -m pip install -r requirements.txt python -m pip install -e . install-dev: ## [Local development] Install test requirements From e24738c8891072f08062339f32128c8f08a72b3a Mon Sep 17 00:00:00 2001 From: George Smyrnis Date: Mon, 29 Jan 2024 12:30:10 -0600 Subject: [PATCH 03/12] Different mask per element in batch. --- open_lm/attention.py | 78 ++++++++++++++++++++++++-------------------- 1 file changed, 43 insertions(+), 35 deletions(-) diff --git a/open_lm/attention.py b/open_lm/attention.py index e9c8099c..0db42cea 100644 --- a/open_lm/attention.py +++ b/open_lm/attention.py @@ -23,7 +23,7 @@ def xformers_attn(queries, keys, values, is_causal, document_seqlens = None): # we would like to replace the mask generation with: mask = xops.fmha.attn_bias.LowerTriangularFromBottomRightMask() # sadly we cannot us this because it needs xformers>=0.0.23 and this is not compatible with torch<2.1.1 while llm-foundry requires torch<2.1.1 - if document_seqlens is None or len(document_seqlens) == 1: + if document_seqlens is None or all(len(d) == 1 for ds in document_seqlens): # In this case, all the tokens inside the sequence (are considered to) come from the same document. # The attention mask is constructed as a simple causal mask @@ -40,46 +40,54 @@ def xformers_attn(queries, keys, values, is_causal, document_seqlens = None): mask = get_rectangular_mask((batch, heads), q_seq_len, k_seq_len, queries.device, queries.dtype) else: - mask = None - if is_causal and queries.shape[1] == keys.shape[1]: - mask = xops.fmha.attn_bias.BlockDiagonalCausalMask.from_seqlens(document_seqlens) - elif is_causal and queries.shape[1] > 1: - mask = xops.fmha.attn_bias.BlockDiagonalCausalFromBottomRightMask.from_seqlens(document_seqlens) + masks = [] + for ds in document_seqlens: + if is_causal and queries.shape[1] == keys.shape[1]: + masks.append(xops.fmha.attn_bias.BlockDiagonalCausalMask.from_seqlens(document_seqlens).materialize(shape=(1, queries.shape[1], queries.shape[1]))) + elif is_causal and queries.shape[1] > 1: + masks.append(xops.fmha.attn_bias.BlockDiagonalCausalFromBottomRightMask.from_seqlens(document_seqlens).materialize(shape=(1, queries.shape[1], keys.shape[1]))) + mask = torch.cat(masks, dim=0) return xops.memory_efficient_attention(queries, keys, values, attn_bias=mask) -def torch_attn(queries, keys, values, is_causal): - # Need to call contiguous in torch >=2.1, otherwise later calls to .view() fail. - # Possibly related: https://github.com/pytorch/pytorch/issues/110213 - behavior of scaled_dot_product_attention - # changed between 2.0 and 2.1 - if is_causal and keys.shape[1] > queries.shape[1] > 1: - q_seq_len = queries.shape[1] - k_seq_len = keys.shape[1] - # Same as above, we would like to use: - # mask = xops.fmha.attn_bias.LowerTriangularFromBottomRightMask().materialize((1, 1, q_seq_len, k_seq_len), queries.dtype, queries.device) - mask = get_rectangular_mask((1, 1), q_seq_len, k_seq_len, queries.device, queries.dtype) - return ( - F.scaled_dot_product_attention( - queries.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2), attn_mask=mask +def torch_attn(queries, keys, values, is_causal, document_seqlens = None): + + if document_seqlens is None or len(document_seqlens) == 1: + + # Need to call contiguous in torch >=2.1, otherwise later calls to .view() fail. + # Possibly related: https://github.com/pytorch/pytorch/issues/110213 - behavior of scaled_dot_product_attention + # changed between 2.0 and 2.1 + if is_causal and keys.shape[1] > queries.shape[1] > 1: + q_seq_len = queries.shape[1] + k_seq_len = keys.shape[1] + # Same as above, we would like to use: + # mask = xops.fmha.attn_bias.LowerTriangularFromBottomRightMask().materialize((1, 1, q_seq_len, k_seq_len), queries.dtype, queries.device) + mask = get_rectangular_mask((1, 1), q_seq_len, k_seq_len, queries.device, queries.dtype) + return ( + F.scaled_dot_product_attention( + queries.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2), attn_mask=mask + ) + .transpose(1, 2) + .contiguous() ) - .transpose(1, 2) - .contiguous() - ) - elif queries.shape[1] == 1: - return ( - F.scaled_dot_product_attention(queries.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2)) - .transpose(1, 2) - .contiguous() - ) - else: - return ( - F.scaled_dot_product_attention( - queries.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2), is_causal=is_causal + elif queries.shape[1] == 1: + return ( + F.scaled_dot_product_attention(queries.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2)) + .transpose(1, 2) + .contiguous() ) - .transpose(1, 2) - .contiguous() - ) + else: + return ( + F.scaled_dot_product_attention( + queries.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2), is_causal=is_causal + ) + .transpose(1, 2) + .contiguous() + ) + + else: + raise NotImplementedError("Currently supporting --mask-across-documents only with xformers attention.") ATTN_ACTIVATIONS = { From 3e8890786fac070777b64a6e408db8c83ee42da6 Mon Sep 17 00:00:00 2001 From: George Smyrnis Date: Tue, 30 Jan 2024 00:39:11 -0600 Subject: [PATCH 04/12] Add attention calls training. --- open_lm/train.py | 25 ++++++++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/open_lm/train.py b/open_lm/train.py index 4ff5a27e..76b8982b 100644 --- a/open_lm/train.py +++ b/open_lm/train.py @@ -22,11 +22,13 @@ wandb = None from open_lm.data import sample_chunk +from open_lm.datapreprocess.ray.tokenize_shuffle import SpecialTokens from open_lm.distributed import is_master from open_lm.precision import get_autocast from open_lm.meters import AverageMeter + def unwrap_model(model): if hasattr(model, "module"): return model.module @@ -109,13 +111,34 @@ def train_one_epoch(model, data, loss, epoch, step, optimizer, scaler, scheduler (texts,) = batch texts = torch.LongTensor(texts).to(device) + data_time_m.update(time.time() - end) optimizer.zero_grad() if args.accum_freq == 1: with autocast(): inputs, targets = sample_chunk(texts, args) - out, _, _ = model(inputs) + + if args.mask_across_documents: + document_seqlens = [] + for idx in range(inputs.shape[0]): + eot_idx = torch.nonzero(inputs[idx] == SpecialTokens.END_OF_TEXT.value) + if len(eot_idx.shape) == 0: + # Fallback case - an eot token should appear at the end. + document_seqlens.append([args.seq_len + 1]) + else: + start_idx = 0 + seqlens = [] + for eidx in eot_idx: + seqlens.append(eidx - start_idx + 1) + start_idx = eidx + 1 + if start_idx < args.seq_len + 1: + seqlens.append(args.seq_len - start_idx) + document_seqlens.append(seqlens) + else: + document_seqlens = None + + out, _, _ = model(inputs, document_seqlens=document_seqlens) if args.log_logit_mean: logit_m.update(torch.mean(out).item()) From d24b533a6e6a077269a8fce09f7a3a9d11064ece Mon Sep 17 00:00:00 2001 From: George Smyrnis Date: Wed, 31 Jan 2024 13:55:53 -0600 Subject: [PATCH 05/12] Running version. --- open_lm/attention.py | 10 ++++++--- open_lm/model.py | 1 + open_lm/train.py | 51 +++++++++++++++++++++++++++----------------- 3 files changed, 39 insertions(+), 23 deletions(-) diff --git a/open_lm/attention.py b/open_lm/attention.py index 0db42cea..df9de736 100644 --- a/open_lm/attention.py +++ b/open_lm/attention.py @@ -23,7 +23,7 @@ def xformers_attn(queries, keys, values, is_causal, document_seqlens = None): # we would like to replace the mask generation with: mask = xops.fmha.attn_bias.LowerTriangularFromBottomRightMask() # sadly we cannot us this because it needs xformers>=0.0.23 and this is not compatible with torch<2.1.1 while llm-foundry requires torch<2.1.1 - if document_seqlens is None or all(len(d) == 1 for ds in document_seqlens): + if document_seqlens is None or all(len(ds) == 1 for ds in document_seqlens): # In this case, all the tokens inside the sequence (are considered to) come from the same document. # The attention mask is constructed as a simple causal mask @@ -41,11 +41,15 @@ def xformers_attn(queries, keys, values, is_causal, document_seqlens = None): else: masks = [] + batch, q_seq_len, heads, _ = queries.shape + k_seq_len = keys.shape[1] + dtype = queries.dtype + device = queries.device for ds in document_seqlens: if is_causal and queries.shape[1] == keys.shape[1]: - masks.append(xops.fmha.attn_bias.BlockDiagonalCausalMask.from_seqlens(document_seqlens).materialize(shape=(1, queries.shape[1], queries.shape[1]))) + masks.append(xops.fmha.attn_bias.BlockDiagonalCausalMask.from_seqlens(ds).materialize(shape=(1, heads, q_seq_len, k_seq_len), device=device, dtype=dtype)) elif is_causal and queries.shape[1] > 1: - masks.append(xops.fmha.attn_bias.BlockDiagonalCausalFromBottomRightMask.from_seqlens(document_seqlens).materialize(shape=(1, queries.shape[1], keys.shape[1]))) + masks.append(xops.fmha.attn_bias.BlockDiagonalCausalFromBottomRightMask.from_seqlens(ds).materialize(shape=(1, heads, q_seq_len, k_seq_len), device=device, dtype=dtype)) mask = torch.cat(masks, dim=0) return xops.memory_efficient_attention(queries, keys, values, attn_bias=mask) diff --git a/open_lm/model.py b/open_lm/model.py index 453cdb9a..280d3ca7 100644 --- a/open_lm/model.py +++ b/open_lm/model.py @@ -180,6 +180,7 @@ def forward(self, x: torch.Tensor, is_causal=True, past_key_value=None, use_cach keys, vals, is_causal=is_causal, + document_seqlens=document_seqlens ) output = output.view(batchsize, q_len, -1) diff --git a/open_lm/train.py b/open_lm/train.py index 76b8982b..45a2bf2f 100644 --- a/open_lm/train.py +++ b/open_lm/train.py @@ -43,6 +43,34 @@ def backward(total_loss, scaler): total_loss.backward() +def get_document_seqlens(inputs, args): + """Get list of document sequence lengths. + + Return a list of lists. The length of the outer list is equal to the batch size, while the length of the inner list + is equal to the the number of distinct documents (recognized by EOT tokens). Each element of the inner lists is the + length of that corresponding document + """ + if args.mask_across_documents: + document_seqlens = [] + for idx in range(inputs.shape[0]): + eot_idx = torch.nonzero(inputs[idx] == SpecialTokens.END_OF_TEXT.value) + if len(eot_idx.shape) == 0: + # Fallback case - an eot token should appear at the end. + document_seqlens.append([args.seq_len + 1]) + else: + start_idx = 0 + seqlens = [] + for k in range(eot_idx.shape[0]): + seqlens.append(eot_idx[k] - start_idx + 1) + start_idx = eot_idx[k] + 1 + if start_idx < args.seq_len + 1: + seqlens.append(args.seq_len - start_idx) + document_seqlens.append(seqlens) + else: + document_seqlens = None + return document_seqlens + + def train_one_epoch(model, data, loss, epoch, step, optimizer, scaler, scheduler, total_steps, args, tb_writer=None): """Trains model for one epoch on the provided data. @@ -118,25 +146,7 @@ def train_one_epoch(model, data, loss, epoch, step, optimizer, scaler, scheduler if args.accum_freq == 1: with autocast(): inputs, targets = sample_chunk(texts, args) - - if args.mask_across_documents: - document_seqlens = [] - for idx in range(inputs.shape[0]): - eot_idx = torch.nonzero(inputs[idx] == SpecialTokens.END_OF_TEXT.value) - if len(eot_idx.shape) == 0: - # Fallback case - an eot token should appear at the end. - document_seqlens.append([args.seq_len + 1]) - else: - start_idx = 0 - seqlens = [] - for eidx in eot_idx: - seqlens.append(eidx - start_idx + 1) - start_idx = eidx + 1 - if start_idx < args.seq_len + 1: - seqlens.append(args.seq_len - start_idx) - document_seqlens.append(seqlens) - else: - document_seqlens = None + document_seqlens = get_document_seqlens(inputs, args) out, _, _ = model(inputs, document_seqlens=document_seqlens) @@ -170,7 +180,8 @@ def train_one_epoch(model, data, loss, epoch, step, optimizer, scaler, scheduler if inputs_ii.shape[0] == 0: break targets_ii = targets[ii * per_batch : (ii + 1) * per_batch] - out, _, _ = model(inputs_ii) + document_seqlens = get_document_seqlens(inputs_ii, args) + out, _, _ = model(inputs_ii, document_seqlens=document_seqlens) if args.log_logit_mean: logit_m.update(torch.mean(out).item()) From af1a4cd2dfce317375b7a6d9a493fb41e329631f Mon Sep 17 00:00:00 2001 From: George Smyrnis Date: Wed, 31 Jan 2024 15:17:06 -0600 Subject: [PATCH 06/12] Ignore predictions right after EOT. --- open_lm/losses.py | 1 + open_lm/train.py | 10 ++++++++++ 2 files changed, 11 insertions(+) diff --git a/open_lm/losses.py b/open_lm/losses.py index ef3839d6..76805cdd 100644 --- a/open_lm/losses.py +++ b/open_lm/losses.py @@ -18,4 +18,5 @@ def __init__( self.eps = eps def forward(self, input: Tensor, target: Tensor) -> Tensor: + # TODO: should ignore_index be taken into account in the regularization term as well? return super().forward(input, target) + self.eps * torch.square(torch.logsumexp(input, dim=-1)).mean() diff --git a/open_lm/train.py b/open_lm/train.py index 45a2bf2f..385a3bb3 100644 --- a/open_lm/train.py +++ b/open_lm/train.py @@ -147,6 +147,11 @@ def train_one_epoch(model, data, loss, epoch, step, optimizer, scaler, scheduler with autocast(): inputs, targets = sample_chunk(texts, args) document_seqlens = get_document_seqlens(inputs, args) + if args.mask_across_documents: + # Some input samples contain EOT as the final token. The prediction after that is meaningless, so it + # should not contribute to the loss. + ignore_indices = torch.nonzero(inputs == SpecialTokens.END_OF_TEXT, as_tuple=True) + targets[ignore_indices] = loss.ignore_index out, _, _ = model(inputs, document_seqlens=document_seqlens) @@ -168,6 +173,11 @@ def train_one_epoch(model, data, loss, epoch, step, optimizer, scaler, scheduler per_batch = args.per_gpu_batch_size // args.accum_freq inputs, targets = sample_chunk(texts, args) + if args.mask_across_documents: + # Some input samples contain EOT as the final token. The prediction after that is meaningless, so it + # should not contribute to the loss. + ignore_indices = torch.nonzero(inputs == SpecialTokens.END_OF_TEXT, as_tuple=True) + targets[ignore_indices] = loss.ignore_index for ii in range(args.accum_freq): maybe_no_sync = nullcontext From 3e0036b23247cb9e6c6f107b52dae92ff8abc718 Mon Sep 17 00:00:00 2001 From: George Smyrnis Date: Wed, 31 Jan 2024 15:19:16 -0600 Subject: [PATCH 07/12] Formatting. --- open_lm/attention.py | 20 +++++++++++++------- open_lm/model.py | 14 +++++--------- open_lm/params.py | 2 +- open_lm/train.py | 3 +-- 4 files changed, 20 insertions(+), 19 deletions(-) diff --git a/open_lm/attention.py b/open_lm/attention.py index df9de736..783fbb63 100644 --- a/open_lm/attention.py +++ b/open_lm/attention.py @@ -16,7 +16,7 @@ def get_rectangular_mask(shape, q_seq_len, k_seq_len, device, dtype): )[:, :, :, :k_seq_len] -def xformers_attn(queries, keys, values, is_causal, document_seqlens = None): +def xformers_attn(queries, keys, values, is_causal, document_seqlens=None): # xformers assumes q, k, v are [batch, seq_len, heads, embed_dim] # We assume that queries match the last part of the key / value sequences # see (https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.fmha.attn_bias.LowerTriangularFromBottomRightMask) @@ -47,18 +47,24 @@ def xformers_attn(queries, keys, values, is_causal, document_seqlens = None): device = queries.device for ds in document_seqlens: if is_causal and queries.shape[1] == keys.shape[1]: - masks.append(xops.fmha.attn_bias.BlockDiagonalCausalMask.from_seqlens(ds).materialize(shape=(1, heads, q_seq_len, k_seq_len), device=device, dtype=dtype)) + masks.append( + xops.fmha.attn_bias.BlockDiagonalCausalMask.from_seqlens(ds).materialize( + shape=(1, heads, q_seq_len, k_seq_len), device=device, dtype=dtype + ) + ) elif is_causal and queries.shape[1] > 1: - masks.append(xops.fmha.attn_bias.BlockDiagonalCausalFromBottomRightMask.from_seqlens(ds).materialize(shape=(1, heads, q_seq_len, k_seq_len), device=device, dtype=dtype)) + masks.append( + xops.fmha.attn_bias.BlockDiagonalCausalFromBottomRightMask.from_seqlens(ds).materialize( + shape=(1, heads, q_seq_len, k_seq_len), device=device, dtype=dtype + ) + ) mask = torch.cat(masks, dim=0) return xops.memory_efficient_attention(queries, keys, values, attn_bias=mask) -def torch_attn(queries, keys, values, is_causal, document_seqlens = None): - +def torch_attn(queries, keys, values, is_causal, document_seqlens=None): if document_seqlens is None or len(document_seqlens) == 1: - # Need to call contiguous in torch >=2.1, otherwise later calls to .view() fail. # Possibly related: https://github.com/pytorch/pytorch/issues/110213 - behavior of scaled_dot_product_attention # changed between 2.0 and 2.1 @@ -89,7 +95,7 @@ def torch_attn(queries, keys, values, is_causal, document_seqlens = None): .transpose(1, 2) .contiguous() ) - + else: raise NotImplementedError("Currently supporting --mask-across-documents only with xformers attention.") diff --git a/open_lm/model.py b/open_lm/model.py index 280d3ca7..767bd2ad 100644 --- a/open_lm/model.py +++ b/open_lm/model.py @@ -175,13 +175,7 @@ def forward(self, x: torch.Tensor, is_causal=True, past_key_value=None, use_cach if use_cache: past_key_value = [keys, vals] - output = self.attn_fn( - queries, - keys, - vals, - is_causal=is_causal, - document_seqlens=document_seqlens - ) + output = self.attn_fn(queries, keys, vals, is_causal=is_causal, document_seqlens=document_seqlens) output = output.view(batchsize, q_len, -1) @@ -258,7 +252,7 @@ def forward(self, x, past_key_value=None, use_cache=False, document_seqlens=None is_causal=True, past_key_value=past_key_value, use_cache=use_cache, - document_seqlens=document_seqlens + document_seqlens=document_seqlens, ) h = x + h if self._ffn_type == "moe": @@ -335,7 +329,9 @@ def forward(self, input, past_key_values=None, use_cache=False, document_seqlens if self.grad_checkpointing: x, past_key_values[i] = checkpoint(layer, x, past_key_values[i], use_cache, document_seqlens) else: - x, past_key_values[i] = layer(x, past_key_values[i], use_cache=use_cache, document_seqlens=document_seqlens) + x, past_key_values[i] = layer( + x, past_key_values[i], use_cache=use_cache, document_seqlens=document_seqlens + ) if past_key_values[0] is None: past_key_values = None x = self.norm(x) diff --git a/open_lm/params.py b/open_lm/params.py index b336e766..f0d7155c 100644 --- a/open_lm/params.py +++ b/open_lm/params.py @@ -745,7 +745,7 @@ def parse_args(args): parser.add_argument( "--mask-across-documents", action="store_true", - help="If set, then tokens in the same sequence will be masked across EOT." + help="If set, then tokens in the same sequence will be masked across EOT.", ) add_model_args(parser) diff --git a/open_lm/train.py b/open_lm/train.py index 385a3bb3..d1262d3a 100644 --- a/open_lm/train.py +++ b/open_lm/train.py @@ -28,7 +28,6 @@ from open_lm.meters import AverageMeter - def unwrap_model(model): if hasattr(model, "module"): return model.module @@ -45,7 +44,7 @@ def backward(total_loss, scaler): def get_document_seqlens(inputs, args): """Get list of document sequence lengths. - + Return a list of lists. The length of the outer list is equal to the batch size, while the length of the inner list is equal to the the number of distinct documents (recognized by EOT tokens). Each element of the inner lists is the length of that corresponding document From e4a5baca6005da789091474397e9bc91777fe7cf Mon Sep 17 00:00:00 2001 From: Alex Fang Date: Thu, 1 Feb 2024 19:26:36 -0800 Subject: [PATCH 08/12] doc attention eot enum value --- open_lm/train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/open_lm/train.py b/open_lm/train.py index d1262d3a..277830aa 100644 --- a/open_lm/train.py +++ b/open_lm/train.py @@ -149,7 +149,7 @@ def train_one_epoch(model, data, loss, epoch, step, optimizer, scaler, scheduler if args.mask_across_documents: # Some input samples contain EOT as the final token. The prediction after that is meaningless, so it # should not contribute to the loss. - ignore_indices = torch.nonzero(inputs == SpecialTokens.END_OF_TEXT, as_tuple=True) + ignore_indices = torch.nonzero(inputs == SpecialTokens.END_OF_TEXT.value, as_tuple=True) targets[ignore_indices] = loss.ignore_index out, _, _ = model(inputs, document_seqlens=document_seqlens) @@ -175,7 +175,7 @@ def train_one_epoch(model, data, loss, epoch, step, optimizer, scaler, scheduler if args.mask_across_documents: # Some input samples contain EOT as the final token. The prediction after that is meaningless, so it # should not contribute to the loss. - ignore_indices = torch.nonzero(inputs == SpecialTokens.END_OF_TEXT, as_tuple=True) + ignore_indices = torch.nonzero(inputs == SpecialTokens.END_OF_TEXT.value, as_tuple=True) targets[ignore_indices] = loss.ignore_index for ii in range(args.accum_freq): From 7234b3189c9a94a89b57a04dc7f39fd9246156a4 Mon Sep 17 00:00:00 2001 From: George Smyrnis Date: Fri, 2 Feb 2024 17:25:57 -0600 Subject: [PATCH 09/12] Trying to debug. --- open_lm/attention.py | 2 +- open_lm/model.py | 3 ++- open_lm/train.py | 12 ++++++------ 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/open_lm/attention.py b/open_lm/attention.py index 783fbb63..1f0d8693 100644 --- a/open_lm/attention.py +++ b/open_lm/attention.py @@ -22,7 +22,7 @@ def xformers_attn(queries, keys, values, is_causal, document_seqlens=None): # see (https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.fmha.attn_bias.LowerTriangularFromBottomRightMask) # we would like to replace the mask generation with: mask = xops.fmha.attn_bias.LowerTriangularFromBottomRightMask() # sadly we cannot us this because it needs xformers>=0.0.23 and this is not compatible with torch<2.1.1 while llm-foundry requires torch<2.1.1 - + print("attention called") if document_seqlens is None or all(len(ds) == 1 for ds in document_seqlens): # In this case, all the tokens inside the sequence (are considered to) come from the same document. # The attention mask is constructed as a simple causal mask diff --git a/open_lm/model.py b/open_lm/model.py index 767bd2ad..511fd867 100644 --- a/open_lm/model.py +++ b/open_lm/model.py @@ -155,6 +155,7 @@ def reset_parameters(self): torch.nn.init.trunc_normal_(self.out_proj.weight, std=std, a=-3 * std, b=3 * std) def forward(self, x: torch.Tensor, is_causal=True, past_key_value=None, use_cache=False, document_seqlens=None): + print("attention called") batchsize, q_len, _ = x.shape queries, keys, vals = self.in_proj(x).chunk(3, dim=-1) @@ -247,6 +248,7 @@ def reset_parameters(self): torch.nn.init.trunc_normal_(self._ff_w2.weight, std=std, a=-3 * std, b=3 * std) def forward(self, x, past_key_value=None, use_cache=False, document_seqlens=None): + print("block called") h, past_key_value = self.attention( self.attention_norm(x), is_causal=True, @@ -320,7 +322,6 @@ def set_grad_checkpointing(self, enable=True): def forward(self, input, past_key_values=None, use_cache=False, document_seqlens=None): x = self.tok_embeddings(input) x = self.post_embed_norm(x) - if past_key_values is None: past_key_values = [None] * self.n_layers elif isinstance(past_key_values, tuple): diff --git a/open_lm/train.py b/open_lm/train.py index 277830aa..63f3c4e6 100644 --- a/open_lm/train.py +++ b/open_lm/train.py @@ -53,16 +53,16 @@ def get_document_seqlens(inputs, args): document_seqlens = [] for idx in range(inputs.shape[0]): eot_idx = torch.nonzero(inputs[idx] == SpecialTokens.END_OF_TEXT.value) - if len(eot_idx.shape) == 0: - # Fallback case - an eot token should appear at the end. - document_seqlens.append([args.seq_len + 1]) + if eot_idx.shape[0] == 0: + # All tokens come from the same document. + document_seqlens.append([args.seq_len]) else: start_idx = 0 seqlens = [] for k in range(eot_idx.shape[0]): - seqlens.append(eot_idx[k] - start_idx + 1) - start_idx = eot_idx[k] + 1 - if start_idx < args.seq_len + 1: + seqlens.append(eot_idx[k].item() - start_idx + 1) + start_idx = eot_idx[k].item() + 1 + if start_idx < args.seq_len: seqlens.append(args.seq_len - start_idx) document_seqlens.append(seqlens) else: From 427291fc8a352db7026d5141d28f9e730fb078db Mon Sep 17 00:00:00 2001 From: George Smyrnis Date: Fri, 2 Feb 2024 17:30:27 -0600 Subject: [PATCH 10/12] Revert mistake on makefile. --- Makefile | 1 - 1 file changed, 1 deletion(-) diff --git a/Makefile b/Makefile index 7d0249d4..d9ac59b8 100644 --- a/Makefile +++ b/Makefile @@ -1,6 +1,5 @@ install: ## [Local development] Upgrade pip, install requirements, install package. python -m pip install -U pip - python -m pip install -r requirements.txt python -m pip install -e . install-dev: ## [Local development] Install test requirements From f28d9840f0424e89e5b7951757c570604d36ced9 Mon Sep 17 00:00:00 2001 From: George Smyrnis Date: Fri, 2 Feb 2024 17:58:00 -0600 Subject: [PATCH 11/12] Fixed mem sharing. --- open_lm/model.py | 2 -- open_lm/train.py | 1 + 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/open_lm/model.py b/open_lm/model.py index 511fd867..d54c9421 100644 --- a/open_lm/model.py +++ b/open_lm/model.py @@ -155,7 +155,6 @@ def reset_parameters(self): torch.nn.init.trunc_normal_(self.out_proj.weight, std=std, a=-3 * std, b=3 * std) def forward(self, x: torch.Tensor, is_causal=True, past_key_value=None, use_cache=False, document_seqlens=None): - print("attention called") batchsize, q_len, _ = x.shape queries, keys, vals = self.in_proj(x).chunk(3, dim=-1) @@ -248,7 +247,6 @@ def reset_parameters(self): torch.nn.init.trunc_normal_(self._ff_w2.weight, std=std, a=-3 * std, b=3 * std) def forward(self, x, past_key_value=None, use_cache=False, document_seqlens=None): - print("block called") h, past_key_value = self.attention( self.attention_norm(x), is_causal=True, diff --git a/open_lm/train.py b/open_lm/train.py index 63f3c4e6..b1b66b24 100644 --- a/open_lm/train.py +++ b/open_lm/train.py @@ -176,6 +176,7 @@ def train_one_epoch(model, data, loss, epoch, step, optimizer, scaler, scheduler # Some input samples contain EOT as the final token. The prediction after that is meaningless, so it # should not contribute to the loss. ignore_indices = torch.nonzero(inputs == SpecialTokens.END_OF_TEXT.value, as_tuple=True) + targets = targets.detach().clone() # Clone this because it shares mem with input! targets[ignore_indices] = loss.ignore_index for ii in range(args.accum_freq): From 3521ce1df35a76bfb7a1144fef6692a9db921262 Mon Sep 17 00:00:00 2001 From: George Smyrnis Date: Fri, 2 Feb 2024 17:59:30 -0600 Subject: [PATCH 12/12] Remove debug. --- open_lm/attention.py | 1 - 1 file changed, 1 deletion(-) diff --git a/open_lm/attention.py b/open_lm/attention.py index 1f0d8693..7357a512 100644 --- a/open_lm/attention.py +++ b/open_lm/attention.py @@ -22,7 +22,6 @@ def xformers_attn(queries, keys, values, is_causal, document_seqlens=None): # see (https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.fmha.attn_bias.LowerTriangularFromBottomRightMask) # we would like to replace the mask generation with: mask = xops.fmha.attn_bias.LowerTriangularFromBottomRightMask() # sadly we cannot us this because it needs xformers>=0.0.23 and this is not compatible with torch<2.1.1 while llm-foundry requires torch<2.1.1 - print("attention called") if document_seqlens is None or all(len(ds) == 1 for ds in document_seqlens): # In this case, all the tokens inside the sequence (are considered to) come from the same document. # The attention mask is constructed as a simple causal mask