Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: add gemma2b variants #1835
feat: add gemma2b variants #1835
Changes from 1 commit
9685e0b
e87f878
6f89920
0d53660
6b50916
54a237c
2c216de
5127db9
599c828
7adb55a
9d61cd7
53eed40
File filter
Filter by extension
Conversations
Jump to
There are no files selected for viewing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did some quick math, I guess this will take at least 216GB total memory (54GB params + 54GB gradients + 108GB optimizer states for AdamW) , which means to run on 4 devices we'd need people to be using A100s. I wonder whether we can use an 8-bit optimizer + optimizer in backward to get us down to a more reasonable peak VRAM here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does 8bit work with distributed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh yeah duh.. there may be some issues with bitsandbytes optimizers on that front. I just tried out ao low-precision optimizers and it seems to work (though haven't resumed from intermediate checkpoint). Also there may be a compile dep there. Anyways if it's too much hassle we can consider it separately, don't wanna increase the scope of this already substantial PR more than necessary
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What should I do here? Change something or expect users to change parameters according to their hardware ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry missed this comment before now. I think it's fine to leave this as you have it and revisit these details in a later PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry to potentially be a pain in the ass here. We have parallel PR (#1872) which is helping standardize our configs and better expose the features we have. This means we always have
packed: False
indataset
, andlog_peak_memory_stats: True
andcompile: False
below, for every one of our configs.Would it be annoying to ask if we could update these in the same way while we're here, please?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done I have updated all the configs to match the other PR!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are we confident this'll fit on a single device?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed batch size to 2 and accumulation to 8. What is the expected GPU? Is there a CI running everything? Otherwise I guess each user should be responsible to play with the batch to get something suitable for his GPU no ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Generally we try ship configs which we know will work on some common hardware configuration (see examples here https://github.com/pytorch/torchtune?tab=readme-ov-file#memory-and-training-speed), so users can maintain the expectation that they can get started without any painful OOMs. Then they are free to play with the configs. We should make sure this config works with e.g. 1xA1000 - let me know if you need a hand here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@SalmanMohammadi I do not have easy access to a A100, would appreciate if someone could run the code for the 27B params model and let me know what batch size I should set.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll have a quick look when we're ready to land. We can also reasonably mirror the batch size from the config of another similarly sized model already in the codebase.