-
Notifications
You must be signed in to change notification settings - Fork 493
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
torch.compile support for single device recipes #596
Conversation
ghstack-source-id: aa906a002fccbc9e80acfe3c4848febe23d5071f Pull Request resolved: #590
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/596
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 2674d59 with merge base c9d1cdc (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
torchtune/utils/_compile_utils.py
Outdated
return model | ||
|
||
|
||
def _consume_torch_compile_prefix(model, state_dict, *args, prefix="_orig_mod.", **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
don't need args or kwargs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Kind of unfortunate but don't know how to work around this. The reason I have it is because this runs as a state_dict hook, so nn.Module implementation invokes these with additional arguments (such as the model, maybe additional metadata) that we don't need or use. So I just hide these in args, kwargs. Maybe could call them (unused_args, unused_kwargs) to make it more explicit?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A couple questions here:
(1) does prefix ever vary? If this is specific to compile do we actually need to parametrize it? (And if we do, are we able to move it out from in between args and kwargs)
(2) Anything else we have to watch out for here in terms of composability with other functionality (specifically other state dict hooks I guess)? I know you mentioned there are still some gaps with distributed, is this one of the rough edges there?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
prefix
should remain constant modulo BC breaking changes from pytorch.
Re: composability with other state_dict hooks, the main thing to watch out for is hook execution order. If there is a wrapper such as DDP, that also adds prefix to modules, and this hook runs before DDP's state_dict hook, it won't remove any of the prefixes since the name begins with DDP's prefix and not torch compile prefix.
If we hit this issue though, checkpointing issues will be raised, so we'd be aware that way.
looks good. Any testing work I can help with? I can follow up on the loss curve comparison if it works for you? |
@@ -70,6 +70,7 @@ loss: | |||
epochs: 1 | |||
max_steps_per_epoch: null | |||
gradient_accumulation_steps: 1 | |||
compile: False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why default to false here? (Same q for QLoRA)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I know I have true for one of the configs but that's mostly for showcase atm. IMO, we should ship compile with False by default. I don't think compile is a fully baked in stable feature yet, i.e. the integration with distributed is not flushed out, there are some concerns around things like long compile time and the like. Even if we comprehensively verify accuracy, I'd be hesitant to ship it on by default as users will likely modify recipes etc and could use features that don't synergize well with compile.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you think it would be to okay to ship local regions with compile turned on by default? In particular QLoRa as is shown in your benchmark very much benefits from torch.compile
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So unfortunately compile + QLoRA only works on latest nightly. The tradeoff here is QLoRA being runnable OOTB with stable version of torch (cannot be done with compile), versus requiring some extra lift to set up to run because compile is on by default. IMO due to our principles around being easy to use and removing all barriers to usage where possible, I'd vote for the former and document the changes required to enable QLoRA (i.e. upgrading to nightly). cc @kartikayk @ebsmothers for their thoughts though
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Totally make sense! I think that is likely the correct trade off
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is awesome! Looking forward to seeing more details on the speedups
@@ -56,6 +56,7 @@ loss: | |||
_component_: torch.nn.CrossEntropyLoss | |||
max_steps_per_epoch: null | |||
gradient_accumulation_steps: 1 | |||
compile: True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Plan to initially ship compile as False
everywhere. See reasoning below.
@@ -176,6 +176,13 @@ Check out `tune --help` for all possible CLI commands and options. | |||
|
|||
--- | |||
|
|||
#### Support for torch.compile |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So I think this is a useful callout, but also idk if putting it as a subheader under "Getting Started" 100% makes sense to me. I almost wonder if we should have a separate section for supported features. Then we can put compile, distributed, quantization, other low-memory features in there, document each and give caveats. Not a blocker for this particular PR anyways, but lmk if that makes sense to you or not.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah agreed - we need to figure this out. README isn't in the best shape right now.
@@ -56,6 +56,7 @@ loss: | |||
_component_: torch.nn.CrossEntropyLoss | |||
max_steps_per_epoch: null | |||
gradient_accumulation_steps: 1 | |||
compile: False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably need to add this to any other configs that will run on single device recipes, right? (I guess that's just Mistral.) Otherwise they will error out due to missing config field
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, can check mistral. Though IMO "a config that we support on single device" --> implies that it runs in CI today (can of course add additional ones to CI).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These speedups are awesome! Couple comments around configs, testing and documentation, but no major concerns from my side. Modulo those comments and green CI this looks good to me!
Context
compile
support is added via a utilitywrap_compile
.wrap_compile
is very simple, it simply calls intotorch.compile
, and also registers astate_dict
post hook on the model to strip off a prefix "orig_mod" that is added by the compile function. This enables compile model produced state_dicts to behave exactly as an eager mode model, and centralizes this logic into a single utility instead of various different checkpointers having to check for this prefix and strip it out.cfg.compile
flag in recipes.Changelog
compile
flag to config to control torch.compilecompile
is true, runstorch.compile
.Speed ups
Compile Overhead
Accuracy
truthfulqa_mc2: {'acc,none': 0.48572842802124694, 'acc_stderr,none': 0.014656380523451883, 'alias': 'truthfulqa_mc2'}
truthfulqa_mc2: {'acc,none': 0.41981779682199993, 'acc_stderr,none': 0.014149664935251433, 'alias': 'truthfulqa_mc2'}
Checkpoint save/load
tune run full_finetune_single_device --config recipes/configs/llama2/7B_full_single_device.yaml device=cuda:4 compile=True max_steps_per_epoch=0 epochs=1
tune run full_finetune_single_device --config recipes/configs/llama2/7B_full_single_device.yaml device=cuda:4 compile=True
(after updating appropriate paths in the config)