-
Notifications
You must be signed in to change notification settings - Fork 471
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix grad accum + FSDP CPU offload, pass None via CLI #1941
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1941
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 216d1c3 with merge base f560cbb (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1941 +/- ##
==========================================
- Coverage 68.39% 65.98% -2.42%
==========================================
Files 311 311
Lines 16901 16907 +6
==========================================
- Hits 11560 11156 -404
- Misses 5341 5751 +410 ☔ View full report in Codecov by Sentry. |
@@ -173,6 +173,11 @@ def _merge_yaml_and_cli_args(yaml_args: Namespace, cli_args: List[str]) -> DictC | |||
# key string to reflect this | |||
if k in yaml_kwargs and _has_component(yaml_kwargs[k]): | |||
k += "._component_" | |||
|
|||
# None passed via CLI will be parsed as string, but we really want OmegaConf null | |||
if v == "None": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should it be v.lower() == “none”? To avoid the None/none case
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I might leave it as is.. I know it's not a heavily-used API in the library as of today, but e.g.
ac_mode (str): Activation checkpointing mode. ['none', 'full', 'selective'] |
Since we already have a case that's using 'none' as a string, I don't wanna mess with that
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lgtm! Just a small comment
Fixes #1939
Two small fixes in this PR.
The first one is due to not moving our grad scaler to CPU when CPU offloading is enabled. Imo it's cleanest to just do this directly in the utility by inferring the right device from the first parameter we see, rather than relying on the FSDP CPU offload flag.
The second one is a bit of a hack but makes it possible to pass
some_config_field=None
from CLI and have it mean None in the Python sense. This means we can't ever use "None" as a string in configs or CLI overrides, but it eliminates some confusion around the fact that OmegaConf would expectsome_config_field=null
and instead parses None as a string.Test plan:
Fix 1:
On main
On this PR
Fix 2:
On main:
On this PR: