-
Notifications
You must be signed in to change notification settings - Fork 494
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update torchtune generation to be more flexible #1970
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1970
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 344e99f with merge base 7bfb333 (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This pull request was exported from Phabricator. Differential Revision: D65480353 |
d42f319
to
6cae056
Compare
Summary: The existing softmax sampling trick implementation in the torchtune generator is not flexible enough to deal with vocab pruned models (when the number of logits produced does not match the size of the embedding layer). This is an unnecessary limitation and is easy to fix if we simply create the `q` tensor to match the size of the logits tensor instead of the embedding layer. NOTE: this is just a draft diff to get feedback on possible changes to the OSS torchtune package before submitting a proper pull request Differential Revision: D65480353
This pull request was exported from Phabricator. Differential Revision: D65480353 |
Hey @RylanC24! Thanks for opening this : ) It looks like the main change is to set the default path to sample probs = torch.nn.functional.softmax(logits, dim=-1)
# if q is None, we use the default softmax sampling trick
if q is None: # <---- q is now None by default
q = torch.empty_like(probs).exponential_(1) Is that right? If so that makes sense to me at a high level. Out of curiousity, what's your use case here? Are you adding this change to use with the cc @joecummings |
@SalmanMohammadi yes, that's right. The use-case is a subtle one but comes up anytime you want to trim the embedding and/or output layers to remove unnecessary tokens (e.g., if the output space is constrained and we don't want to keep 128k x 2048 dimensional vectors in our model). The issue comes up when you want to map this trimmed output space back to the original (so we can still use the same tokenizer). In this situation the dimension of the output logits will not match the dimension of the embedding layer, leading to an error when we try to divide the logits by q (which was previously set to the size of the embedding layer). |
@SalmanMohammadi forgot to add that yes, the new generator shouldn't have this issue but this fix will allow us to patch the old one in the meantime :-) |
@@ -67,7 +67,7 @@ def generate_next_token( | |||
model: TransformerDecoder, | |||
input_pos: torch.Tensor, | |||
x: torch.Tensor, | |||
q: torch.Tensor, | |||
q: Optional[torch.Tensor] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need to update the docstring to reflect the new typing, lint is failing
Thanks! I have a couple points:
|
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1970 +/- ##
===========================================
- Coverage 68.40% 25.18% -43.22%
===========================================
Files 311 311
Lines 16973 17038 +65
===========================================
- Hits 11610 4291 -7319
- Misses 5363 12747 +7384 ☔ View full report in Codecov by Sentry. |
Yes, it won't work when an rng is used but I figured these were both pretty niche use-cases that are unlikely to clash. Since there's already a plan to migrate to the new generator where this won't be an issue I think the risk is pretty minimal to ignore this very corner use-case for the time being. wdyt?
This is doable but would be a bit annoying since the vocab pruned model types are not defined in the torchtune repo. The existing tests should validate that the normal generation use-cases are not affected and I've verified with our vocab pruned model definitions that it works as expected. Again, since this is really just a stopgap fix until the new generator is released maybe we can forgo the additional tests? |
Summary: The existing softmax sampling trick implementation in the torchtune generator is not flexible enough to deal with vocab pruned models (when the number of logits produced does not match the size of the embedding layer). This is an unnecessary limitation and is easy to fix if we simply create the `q` tensor to match the size of the logits tensor instead of the embedding layer. Differential Revision: D65480353
6cae056
to
344e99f
Compare
This pull request was exported from Phabricator. Differential Revision: D65480353 |
Yeah makes sense to me. I'll verify it works OK with compile in a follow up :) |
Summary:
The existing softmax sampling trick implementation in the torchtune generator is not flexible enough to deal with vocab pruned models (when the number of logits produced does not match the size of the embedding layer).
This is an unnecessary limitation and is easy to fix if we simply create the
q
tensor to match the size of the logits tensor instead of the embedding layer.Differential Revision: D65480353