Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix grad accum + FSDP CPU offload, pass None via CLI #1941

Merged
merged 1 commit into from
Nov 1, 2024

Conversation

ebsmothers
Copy link
Contributor

@ebsmothers ebsmothers commented Nov 1, 2024

Fixes #1939

Two small fixes in this PR.

The first one is due to not moving our grad scaler to CPU when CPU offloading is enabled. Imo it's cleanest to just do this directly in the utility by inferring the right device from the first parameter we see, rather than relying on the FSDP CPU offload flag.

The second one is a bit of a hack but makes it possible to pass some_config_field=None from CLI and have it mean None in the Python sense. This means we can't ever use "None" as a string in configs or CLI overrides, but it eliminates some confusion around the fact that OmegaConf would expect some_config_field=null and instead parses None as a string.

Test plan:

Fix 1:

tune run --nproc_per_node 2 full_finetune_distributed --config llama3/8B_full \
gradient_accumulation_steps=2 fsdp_cpu_offload=True

On main

...
[rank0]: RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

On this PR

...
1|1|Loss: 3.1665189266204834:   0%|                                                                                                                                                                 | 1/6500 [00:25<46:53:08, 25.97s/it]

Fix 2:

tune run full_finetune_single_device --config llama3/8B_full_single_device \
clip_grad_norm=None

On main:

...
    raise RuntimeError(
RuntimeError: Gradient clipping is not supported with optimizer in bwd.Please set clip_grad_norm=None, or optimizer_in_bwd=False.

On this PR:

...
1|5|Loss: 2.147994041442871:   0%|                                                                                                                                                                  | 5/26001 [00:08<7:58:13,  1.10s/it]

Copy link

pytorch-bot bot commented Nov 1, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1941

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 216d1c3 with merge base f560cbb (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Nov 1, 2024
@codecov-commenter
Copy link

Codecov Report

Attention: Patch coverage is 16.66667% with 5 lines in your changes missing coverage. Please review.

Project coverage is 65.98%. Comparing base (f560cbb) to head (216d1c3).

Files with missing lines Patch % Lines
torchtune/training/_grad_scaler.py 0.00% 4 Missing ⚠️
torchtune/config/_utils.py 50.00% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1941      +/-   ##
==========================================
- Coverage   68.39%   65.98%   -2.42%     
==========================================
  Files         311      311              
  Lines       16901    16907       +6     
==========================================
- Hits        11560    11156     -404     
- Misses       5341     5751     +410     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@@ -173,6 +173,11 @@ def _merge_yaml_and_cli_args(yaml_args: Namespace, cli_args: List[str]) -> DictC
# key string to reflect this
if k in yaml_kwargs and _has_component(yaml_kwargs[k]):
k += "._component_"

# None passed via CLI will be parsed as string, but we really want OmegaConf null
if v == "None":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should it be v.lower() == “none”? To avoid the None/none case

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I might leave it as is.. I know it's not a heavily-used API in the library as of today, but e.g.

ac_mode (str): Activation checkpointing mode. ['none', 'full', 'selective']

Since we already have a case that's using 'none' as a string, I don't wanna mess with that

Copy link
Contributor

@felipemello1 felipemello1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lgtm! Just a small comment

@ebsmothers ebsmothers merged commit bc4acc1 into pytorch:main Nov 1, 2024
17 checks passed
@ebsmothers ebsmothers mentioned this pull request Nov 26, 2024
44 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

clip_grad_norm=None doesn't work
4 participants