Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch.compile support for single device recipes #596

Merged
merged 34 commits into from
Apr 10, 2024
Merged

torch.compile support for single device recipes #596

merged 34 commits into from
Apr 10, 2024

Conversation

rohan-varma
Copy link
Member

@rohan-varma rohan-varma commented Mar 26, 2024

Context

  • Adds torch.compile support for speed ups for single device recipes (lora, qlora, full finetunes) only tested for llama-7b currently.
  • compile support is added via a utility wrap_compile. wrap_compile is very simple, it simply calls into torch.compile, and also registers a state_dict post hook on the model to strip off a prefix "orig_mod" that is added by the compile function. This enables compile model produced state_dicts to behave exactly as an eager mode model, and centralizes this logic into a single utility instead of various different checkpointers having to check for this prefix and strip it out.
  • compile is configured via a cfg.compile flag in recipes.
  • This PR is limited to compile support on single device for llama7b. The interaction with distributed training, such as FSDP, and torch.compile is very nuanced and fast-evolving - discussions around this will be punted to a future PR.

Changelog

  • Adds a compile flag to config to control torch.compile
  • If compile is true, runs torch.compile.
  • Checkpointing changes are appropriately made.

Speed ups

  • All reported speed ups are measured for 100 iterations and averaged over, beginning at iteration 300 (allowed a couple hundred iters to warm up).
  • Full finetune no compile: 3.022900000000003, Full finetune w/torch compile: 4.889200000000003 (+61%)
  • LoRA finetune no compile: 3.0531000000000006, LoRA finetune w/compile: 4.123499999999996 (+35%)
  • QLoRA finetune no compile: 1.1966666666666668, QLoRA finetune w/compile: 3.6063636363636378 (+201%)
  • QLoRA finetune only supported w/latest pytorch nightly. In CI, only runs with the nightlies.
  • Averages computed with this script: https://gist.github.com/rohan-varma/a33e75471faccdad7f0b6dfeb82a678c

Compile Overhead

  • Full finetune's first iteration is 108.48s, reflecting the overhead in compile time. Similarly, lora finetune's is 122.85s. QLoRA compile time overhead is ~300s.

Accuracy

  • LoRA w/compile: truthfulqa_mc2: {'acc,none': 0.48572842802124694, 'acc_stderr,none': 0.014656380523451883, 'alias': 'truthfulqa_mc2'}
  • Full w/compile: truthfulqa_mc2: {'acc,none': 0.41981779682199993, 'acc_stderr,none': 0.014149664935251433, 'alias': 'truthfulqa_mc2'}
  • These are at par with what we've seen with non-compile flows.

Checkpoint save/load

  • Save a compiled checkpoint: tune run full_finetune_single_device --config recipes/configs/llama2/7B_full_single_device.yaml device=cuda:4 compile=True max_steps_per_epoch=0 epochs=1
  • Load compiled checkpoint back: tune run full_finetune_single_device --config recipes/configs/llama2/7B_full_single_device.yaml device=cuda:4 compile=True (after updating appropriate paths in the config)

ghstack-source-id: aa906a002fccbc9e80acfe3c4848febe23d5071f
Pull Request resolved: #590
Copy link

pytorch-bot bot commented Mar 26, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/596

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 2674d59 with merge base c9d1cdc (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@rohan-varma rohan-varma marked this pull request as draft March 26, 2024 21:36
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 26, 2024
@rohan-varma rohan-varma requested a review from ebsmothers March 28, 2024 20:11
@rohan-varma rohan-varma marked this pull request as ready for review March 28, 2024 20:11
return model


def _consume_torch_compile_prefix(model, state_dict, *args, prefix="_orig_mod.", **kwargs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't need args or kwargs?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Kind of unfortunate but don't know how to work around this. The reason I have it is because this runs as a state_dict hook, so nn.Module implementation invokes these with additional arguments (such as the model, maybe additional metadata) that we don't need or use. So I just hide these in args, kwargs. Maybe could call them (unused_args, unused_kwargs) to make it more explicit?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A couple questions here:

(1) does prefix ever vary? If this is specific to compile do we actually need to parametrize it? (And if we do, are we able to move it out from in between args and kwargs)
(2) Anything else we have to watch out for here in terms of composability with other functionality (specifically other state dict hooks I guess)? I know you mentioned there are still some gaps with distributed, is this one of the rough edges there?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

prefix should remain constant modulo BC breaking changes from pytorch.

Re: composability with other state_dict hooks, the main thing to watch out for is hook execution order. If there is a wrapper such as DDP, that also adds prefix to modules, and this hook runs before DDP's state_dict hook, it won't remove any of the prefixes since the name begins with DDP's prefix and not torch compile prefix.

If we hit this issue though, checkpointing issues will be raised, so we'd be aware that way.

torchtune/utils/_compile_utils.py Outdated Show resolved Hide resolved
@skcoirz
Copy link

skcoirz commented Mar 29, 2024

looks good. Any testing work I can help with? I can follow up on the loss curve comparison if it works for you?

@@ -70,6 +70,7 @@ loss:
epochs: 1
max_steps_per_epoch: null
gradient_accumulation_steps: 1
compile: False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why default to false here? (Same q for QLoRA)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know I have true for one of the configs but that's mostly for showcase atm. IMO, we should ship compile with False by default. I don't think compile is a fully baked in stable feature yet, i.e. the integration with distributed is not flushed out, there are some concerns around things like long compile time and the like. Even if we comprehensively verify accuracy, I'd be hesitant to ship it on by default as users will likely modify recipes etc and could use features that don't synergize well with compile.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think it would be to okay to ship local regions with compile turned on by default? In particular QLoRa as is shown in your benchmark very much benefits from torch.compile

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So unfortunately compile + QLoRA only works on latest nightly. The tradeoff here is QLoRA being runnable OOTB with stable version of torch (cannot be done with compile), versus requiring some extra lift to set up to run because compile is on by default. IMO due to our principles around being easy to use and removing all barriers to usage where possible, I'd vote for the former and document the changes required to enable QLoRA (i.e. upgrading to nightly). cc @kartikayk @ebsmothers for their thoughts though

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Totally make sense! I think that is likely the correct trade off

Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is awesome! Looking forward to seeing more details on the speedups

@@ -56,6 +56,7 @@ loss:
_component_: torch.nn.CrossEntropyLoss
max_steps_per_epoch: null
gradient_accumulation_steps: 1
compile: True
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Plan to initially ship compile as False everywhere. See reasoning below.

@rohan-varma rohan-varma requested a review from ebsmothers April 3, 2024 12:55
@rohan-varma rohan-varma requested a review from RdoubleA April 3, 2024 21:39
@rohan-varma rohan-varma changed the title [RFC][WIP] torch.compile support for single device recipes torch.compile support for single device recipes Apr 3, 2024
@@ -176,6 +176,13 @@ Check out `tune --help` for all possible CLI commands and options.

---

#### Support for torch.compile
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I think this is a useful callout, but also idk if putting it as a subheader under "Getting Started" 100% makes sense to me. I almost wonder if we should have a separate section for supported features. Then we can put compile, distributed, quantization, other low-memory features in there, document each and give caveats. Not a blocker for this particular PR anyways, but lmk if that makes sense to you or not.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah agreed - we need to figure this out. README isn't in the best shape right now.

@@ -56,6 +56,7 @@ loss:
_component_: torch.nn.CrossEntropyLoss
max_steps_per_epoch: null
gradient_accumulation_steps: 1
compile: False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably need to add this to any other configs that will run on single device recipes, right? (I guess that's just Mistral.) Otherwise they will error out due to missing config field

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, can check mistral. Though IMO "a config that we support on single device" --> implies that it runs in CI today (can of course add additional ones to CI).

Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These speedups are awesome! Couple comments around configs, testing and documentation, but no major concerns from my side. Modulo those comments and green CI this looks good to me!

@rohan-varma rohan-varma merged commit d252375 into main Apr 10, 2024
26 checks passed
@joecummings joecummings deleted the compile branch April 11, 2024 15:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants