Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update torchtune generation to be more flexible #1970

Merged
merged 1 commit into from
Nov 8, 2024

Conversation

RylanC24
Copy link
Contributor

@RylanC24 RylanC24 commented Nov 8, 2024

Summary:
The existing softmax sampling trick implementation in the torchtune generator is not flexible enough to deal with vocab pruned models (when the number of logits produced does not match the size of the embedding layer).

This is an unnecessary limitation and is easy to fix if we simply create the q tensor to match the size of the logits tensor instead of the embedding layer.

Differential Revision: D65480353

Copy link

pytorch-bot bot commented Nov 8, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1970

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 344e99f with merge base 7bfb333 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Nov 8, 2024
@facebook-github-bot
Copy link

This pull request was exported from Phabricator. Differential Revision: D65480353

RylanC24 added a commit to RylanC24/torchtune that referenced this pull request Nov 8, 2024
Summary:

The existing softmax sampling trick implementation in the torchtune generator is not flexible enough to deal with vocab pruned models (when the number of logits produced does not match the size of the embedding layer). 

This is an unnecessary limitation and is easy to fix if we simply create the `q` tensor to match the size of the logits tensor instead of the embedding layer.

NOTE: this is just a draft diff to get feedback on possible changes to the OSS torchtune package before submitting a proper pull request

Differential Revision: D65480353
@facebook-github-bot
Copy link

This pull request was exported from Phabricator. Differential Revision: D65480353

@SalmanMohammadi
Copy link
Collaborator

Hey @RylanC24! Thanks for opening this : )

It looks like the main change is to set the default path to sample q using

    probs = torch.nn.functional.softmax(logits, dim=-1)

    # if q is None, we use the default softmax sampling trick
    if q is None: # <---- q is now None by default
        q = torch.empty_like(probs).exponential_(1)

Is that right? If so that makes sense to me at a high level.

Out of curiousity, what's your use case here? Are you adding this change to use with the generate.py recipe? FWIW we'll eventually be deprecating this recipe (I think) to use the dev/generate_v2.py recipe which is significantly neater and uses this proposed behaviour by default since it calls sample directly without going through generate_next_token. I think this change makes sense to fix the existing generation utils, though.

cc @joecummings

@RylanC24
Copy link
Contributor Author

RylanC24 commented Nov 8, 2024

@SalmanMohammadi yes, that's right. The use-case is a subtle one but comes up anytime you want to trim the embedding and/or output layers to remove unnecessary tokens (e.g., if the output space is constrained and we don't want to keep 128k x 2048 dimensional vectors in our model). The issue comes up when you want to map this trimmed output space back to the original (so we can still use the same tokenizer). In this situation the dimension of the output logits will not match the dimension of the embedding layer, leading to an error when we try to divide the logits by q (which was previously set to the size of the embedding layer).

@RylanC24
Copy link
Contributor Author

RylanC24 commented Nov 8, 2024

@SalmanMohammadi forgot to add that yes, the new generator shouldn't have this issue but this fix will allow us to patch the old one in the meantime :-)

@RdoubleA RdoubleA requested a review from joecummings November 8, 2024 14:00
@@ -67,7 +67,7 @@ def generate_next_token(
model: TransformerDecoder,
input_pos: torch.Tensor,
x: torch.Tensor,
q: torch.Tensor,
q: Optional[torch.Tensor] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to update the docstring to reflect the new typing, lint is failing

@SalmanMohammadi
Copy link
Collaborator

SalmanMohammadi commented Nov 8, 2024

@SalmanMohammadi yes, that's right. The use-case is a subtle one but comes up anytime you want to trim the embedding and/or output layers to remove unnecessary tokens (e.g., if the output space is constrained and we don't want to keep 128k x 2048 dimensional vectors in our model). The issue comes up when you want to map this trimmed output space back to the original (so we can still use the same tokenizer). In this situation the dimension of the output logits will not match the dimension of the embedding layer, leading to an error when we try to divide the logits by q (which was previously set to the size of the embedding layer).

Thanks! I have a couple points:

  1. This fix won't actually work for when we have rng right? I'm not sure I see an immediate neat solution here though, is there a way to infer the size of the output space here? rng is just used for PPO so it'd be a very rare interaction.

  2. How annoying would it be to add a test for this? We have some tests in tests/torchtune/generation/test_generation.py which build some dummy models. Would it be simple enough to create another dummy model fixture which has the embedding replaced with a trimmed embedding, and ensures that we can correctly generate without any issues?

@codecov-commenter
Copy link

Codecov Report

Attention: Patch coverage is 0% with 6 lines in your changes missing coverage. Please review.

Project coverage is 25.18%. Comparing base (9eced21) to head (6cae056).
Report is 9 commits behind head on main.

Files with missing lines Patch % Lines
torchtune/generation/_generation.py 0.00% 6 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##             main    #1970       +/-   ##
===========================================
- Coverage   68.40%   25.18%   -43.22%     
===========================================
  Files         311      311               
  Lines       16973    17038       +65     
===========================================
- Hits        11610     4291     -7319     
- Misses       5363    12747     +7384     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@RylanC24
Copy link
Contributor Author

RylanC24 commented Nov 8, 2024

Thanks! I have a couple points:

  1. This fix won't actually work for when we have rng right? I'm not sure I see an immediate neat solution here though, is there >a way to infer the size of the output space here? rng is just used for PPO so it'd be a very rare interaction.

Yes, it won't work when an rng is used but I figured these were both pretty niche use-cases that are unlikely to clash. Since there's already a plan to migrate to the new generator where this won't be an issue I think the risk is pretty minimal to ignore this very corner use-case for the time being. wdyt?

  1. How annoying would it be to add a test for this? We have some tests in tests/torchtune/generation/test_generation.py >which build some dummy models. Would it be simple enough to create another dummy model fixture which has the >embedding replaced with a trimmed embedding, and ensures that we can correctly generate without any issues?

This is doable but would be a bit annoying since the vocab pruned model types are not defined in the torchtune repo. The existing tests should validate that the normal generation use-cases are not affected and I've verified with our vocab pruned model definitions that it works as expected. Again, since this is really just a stopgap fix until the new generator is released maybe we can forgo the additional tests?

Summary:

The existing softmax sampling trick implementation in the torchtune generator is not flexible enough to deal with vocab pruned models (when the number of logits produced does not match the size of the embedding layer). 

This is an unnecessary limitation and is easy to fix if we simply create the `q` tensor to match the size of the logits tensor instead of the embedding layer.

Differential Revision: D65480353
@facebook-github-bot
Copy link

This pull request was exported from Phabricator. Differential Revision: D65480353

@SalmanMohammadi
Copy link
Collaborator

Yeah makes sense to me. I'll verify it works OK with compile in a follow up :)

@facebook-github-bot facebook-github-bot merged commit eb67cc5 into pytorch:main Nov 8, 2024
18 of 19 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. fb-exported
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants