Skip to content

Commit

Permalink
add new num_logits_to_keep arg for llama.forward()
Browse files Browse the repository at this point in the history
Signed-off-by: Anh Uong <[email protected]>
  • Loading branch information
anhuong committed Oct 16, 2024
1 parent 6adfc2d commit a1f74b6
Showing 1 changed file with 14 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def fused_linear_cross_entropy_forward(

# gradient of logits_chunk is computed in-place by the above triton kernel.
# Following HuggingFace model source code, we do the forward and backward
# w.r.t. logits in fp32 for numerical stability especially as the num classes (vocab size) os huge.
# w.r.t. logits in fp32 for numerical stability especially as the num classes (vocab size) is huge.
# (reference: https://github.com/huggingface/transformers/blob/v4.42.4/src/transformers/models/llama/modeling_llama.py#L1194)
# Propagating to lm_head's backward, we'll switch back to the original dtype.
logits_chunk = logits_chunk.to(dtype)
Expand Down Expand Up @@ -306,6 +306,7 @@ def lce_forward(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Copy paste llama forward but replace torch cross entropy with liger fused linear cross entropy
Expand All @@ -317,6 +318,11 @@ def lce_forward(
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
num_logits_to_keep (`int`, *optional*):
Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
Returns:
Example:
Expand Down Expand Up @@ -390,9 +396,14 @@ def lce_forward(
]
logits = torch.cat(logits, dim=-1)
else:
logits = self.lm_head(hidden_states)
logits = logits.float()
# TODO: differing line below in granite models compared to llama/mistral model type
# logits = logits / self.config.logits_scaling
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])

if labels is not None:
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
Expand Down

0 comments on commit a1f74b6

Please sign in to comment.