Skip to content

Commit

Permalink
Update torchtune generation to be more flexible
Browse files Browse the repository at this point in the history
Summary:
The existing softmax sampling trick implementation in the torchtune generator is not flexible enough to deal with vocab pruned models (when the number of logits produced does not match the size of the embedding layer). 

This is an unnecessary limitation and is easy to fix if we simply create the `q` tensor to match the size of the logits tensor instead of the embedding layer.

NOTE: this is just a draft diff to get feedback on possible changes to the OSS torchtune package before submitting a proper pull request

Differential Revision: D65480353
  • Loading branch information
RylanC24 authored and facebook-github-bot committed Nov 8, 2024
1 parent 506e099 commit d42f319
Showing 1 changed file with 11 additions and 7 deletions.
18 changes: 11 additions & 7 deletions torchtune/generation/_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def generate_next_token(
model: TransformerDecoder,
input_pos: torch.Tensor,
x: torch.Tensor,
q: torch.Tensor,
q: Optional[torch.Tensor] = None,
*,
mask: Optional[torch.Tensor] = None,
temperature: float = 1.0,
Expand Down Expand Up @@ -302,9 +302,11 @@ def generate(
# tensors are of identical shape to the prompt
curr_masks = masks[:, :prompt_length, :prompt_length]

q = torch.empty(
(bsz, model.tok_embeddings.num_embeddings), device=prompt.device
).exponential_(1, generator=rng)
q = None
if rng is not None:
q = torch.empty(
(bsz, model.tok_embeddings.num_embeddings), device=prompt.device
).exponential_(1, generator=rng)
tokens, generated_logits = generate_next_token(
model,
input_pos=input_pos[:, :prompt_length].squeeze(),
Expand Down Expand Up @@ -360,9 +362,11 @@ def generate(
curr_input_pos = input_pos[:, : curr_pos + 1]
curr_masks = masks[:, : curr_pos + 1, : curr_pos + 1]

q = torch.empty(
(bsz, model.tok_embeddings.num_embeddings), device=prompt.device
).exponential_(1, generator=rng)
q = None
if rng is not None:
q = torch.empty(
(bsz, model.tok_embeddings.num_embeddings), device=prompt.device
).exponential_(1, generator=rng)
tokens, logits = custom_generate_next_token(
model,
input_pos=curr_input_pos,
Expand Down

0 comments on commit d42f319

Please sign in to comment.