Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix lora single device fine tune checkpoint saving & nan loss when use_dora=True #1909

Merged
merged 5 commits into from
Oct 31, 2024

Conversation

mirceamironenco
Copy link
Contributor

Context

What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (please add here)

Fixes #1903 .

Changelog

What are the changes made in this PR?

  • Account for the magnitude parameter when constructing the state mapping dict (hf/peft).
  • Initialize dora magnitude in the single device lora finetune recipe.

Test plan

Please make sure to do each of the following if applicable to your PR. If you're unsure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.

  • run pre-commit hooks and linters (make sure you've first installed via pre-commit install)
  • add unit tests for any new functionality
  • update docstrings for any new or updated methods or classes
  • run unit tests via pytest tests
  • run recipe tests via pytest tests -m integration_test
  • manually run any new or modified recipes with sufficient proof of correctness
  • include relevant commands and any other artifacts in this summary (pastes of loss curves, eval results, etc.)

UX

If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Here is a docstring example
and a tutorial example

  • I did not change any public API
  • I have added an example to docs or docstrings

Copy link

pytorch-bot bot commented Oct 27, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1909

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit b6e7239 with merge base e99b890 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 27, 2024
@codecov-commenter
Copy link

Codecov Report

Attention: Patch coverage is 66.66667% with 3 lines in your changes missing coverage. Please review.

Project coverage is 70.31%. Comparing base (23c8829) to head (eadd920).
Report is 2 commits behind head on main.

Files with missing lines Patch % Lines
recipes/lora_finetune_single_device.py 0.00% 3 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1909      +/-   ##
==========================================
- Coverage   70.44%   70.31%   -0.14%     
==========================================
  Files         308      308              
  Lines       16270    16285      +15     
==========================================
- Hits        11462    11450      -12     
- Misses       4808     4835      +27     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@SalmanMohammadi
Copy link
Collaborator

SalmanMohammadi commented Oct 30, 2024

Thank you so much for contributing this fix @mirceamironenco. Would you be able to share a small training run showing that this works OK?

Another thing that would be great to see would be modifying one of our LoRA recipe tests to make sure the SD loading logic works as expected for DoRA. You could add this here

def test_training_state_on_resume(

And even just add a parameterize to the test with use_dora=True, and use_dora=False.

but I won't block on this - happy to file an issue for a follow up (and to add a unit test). Evan will have some thoughts here too so we should make sure he's happy.

@mirceamironenco
Copy link
Contributor Author

Thank you so much for contributing this fix @mirceamironenco. Would you be able to share a small training run showing that this works OK?

Running tune run lora_finetune_single_device --config llama3_2/1B_lora_single_device \ gradient_accumulation_steps=1 max_steps_per_epoch=5 model.use_dora=True:

(config here)
DEBUG:torchtune.utils._logging:Setting manual seed to local seed 3727881134. Local seed is seed + rank = 3727881134 + 0
Writing logs to /tmp/lora_finetune_output/log_1730294718.txt
INFO:torchtune.utils._logging:Model is initialized with precision torch.bfloat16.
INFO:torchtune.utils._logging:Memory stats after model init:
        GPU peak memory allocation: 2.52 GiB
        GPU peak memory reserved: 2.56 GiB
        GPU peak memory active: 2.52 GiB
INFO:torchtune.utils._logging:Tokenizer is initialized from file.
INFO:torchtune.utils._logging:Optimizer and loss are initialized.
INFO:torchtune.utils._logging:Loss is initialized.
INFO:torchtune.utils._logging:Dataset and Sampler are initialized.
INFO:torchtune.utils._logging:Learning rate scheduler is initialized.
WARNING:torchtune.utils._logging: Profiling disabled.
INFO:torchtune.utils._logging: Profiler config after instantiation: {'enabled': False}
1|5|Loss: 1.7691251039505005: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:01<00:00,  2.80it/s]
INFO:torchtune.utils._logging:Starting checkpoint save...
INFO:torchtune.utils._logging:Model checkpoint of size 2.47 GB saved to /tmp/Llama-3.2-1B-Instruct/hf_model_0001_0.pt
INFO:torchtune.utils._logging:Adapter checkpoint of size 0.09 GB saved to /tmp/Llama-3.2-1B-Instruct/adapter_0.pt
INFO:torchtune.utils._logging:Adapter checkpoint of size 0.09 GB saved to /tmp/Llama-3.2-1B-Instruct/adapter_model.bin
INFO:torchtune.utils._logging:Adapter checkpoint of size 0.00 GB saved to /tmp/Llama-3.2-1B-Instruct/adapter_config.json
INFO:torchtune.utils._logging:Saving final epoch checkpoint.
INFO:torchtune.utils._logging:The full model checkpoint, including all weights and configurations, has been saved successfully.You can now use this checkpoint for further training or inference.
INFO:torchtune.utils._logging:Checkpoint saved in 222.20 seconds.
1|5|Loss: 1.7691251039505005: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [03:43<00:00, 44.78s/it]

Log file:

Step 1 | loss:2.0422861576080322 lr:2.9999999999999997e-06 tokens_per_second_per_gpu:1470.76318359375 peak_memory_active:5.774061679840088 peak_memory_alloc:5.774061679840088 peak_memory_reserved:6.09765625
Step 2 | loss:2.2708561420440674 lr:5.999999999999999e-06 tokens_per_second_per_gpu:2653.373291015625 peak_memory_active:5.9765753746032715 peak_memory_alloc:5.9765753746032715 peak_memory_reserved:6.216796875
Step 3 | loss:1.9193693399429321 lr:8.999999999999999e-06 tokens_per_second_per_gpu:2741.960693359375 peak_memory_active:7.017248630523682 peak_memory_alloc:7.017248630523682 peak_memory_reserved:7.466796875
Step 4 | loss:1.6952054500579834 lr:1.1999999999999999e-05 tokens_per_second_per_gpu:2495.660400390625 peak_memory_active:7.290591716766357 peak_memory_alloc:7.290591716766357 peak_memory_reserved:7.84765625
Step 5 | loss:1.7691251039505005 lr:1.4999999999999999e-05 tokens_per_second_per_gpu:2274.15478515625 peak_memory_active:10.266940593719482 peak_memory_alloc:10.266940593719482 peak_memory_reserved:10.51171875

@SalmanMohammadi
Copy link
Collaborator

Step 1 | loss:2.0422861576080322 lr:2.9999999999999997e-06 tokens_per_second_per_gpu:1470.76318359375 peak_memory_active:5.774061679840088 peak_memory_alloc:5.774061679840088 peak_memory_reserved:6.09765625
Step 2 | loss:2.2708561420440674 lr:5.999999999999999e-06 tokens_per_second_per_gpu:2653.373291015625 peak_memory_active:5.9765753746032715 peak_memory_alloc:5.9765753746032715 peak_memory_reserved:6.216796875
Step 3 | loss:1.9193693399429321 lr:8.999999999999999e-06 tokens_per_second_per_gpu:2741.960693359375 peak_memory_active:7.017248630523682 peak_memory_alloc:7.017248630523682 peak_memory_reserved:7.466796875
Step 4 | loss:1.6952054500579834 lr:1.1999999999999999e-05 tokens_per_second_per_gpu:2495.660400390625 peak_memory_active:7.290591716766357 peak_memory_alloc:7.290591716766357 peak_memory_reserved:7.84765625
Step 5 | loss:1.7691251039505005 lr:1.4999999999999999e-05 tokens_per_second_per_gpu:2274.15478515625 peak_memory_active:10.266940593719482 peak_memory_alloc:10.266940593719482 peak_memory_reserved:10.51171875

Thanks for running this! I'll try out the vision config.
One thing - what's up with the memory usage? 5GB going up to 10GB in 5 steps - is that normal?

@mirceamironenco
Copy link
Contributor Author

Thanks for running this! I'll try out the vision config. One thing - what's up with the memory usage? 5GB going up to 10GB in 5 steps - is that normal?

My logs with use_dora=False:

Step 1 | loss:2.0422861576080322 lr:2.9999999999999997e-06 tokens_per_second_per_gpu:1277.9771728515625 peak_memory_active:4.533076763153076 peak_memory_alloc:4.533076763153076 peak_memory_reserved:4.98828125
Step 2 | loss:2.2708561420440674 lr:5.999999999999999e-06 tokens_per_second_per_gpu:3626.0888671875 peak_memory_active:4.7215142250061035 peak_memory_alloc:4.7215142250061035 peak_memory_reserved:5.044921875
Step 3 | loss:1.920769214630127 lr:8.999999999999999e-06 tokens_per_second_per_gpu:3585.767822265625 peak_memory_active:5.397563457489014 peak_memory_alloc:5.397563457489014 peak_memory_reserved:6.015625
Step 4 | loss:1.695857048034668 lr:1.1999999999999999e-05 tokens_per_second_per_gpu:3310.01806640625 peak_memory_active:5.569389820098877 peak_memory_alloc:5.569389820098877 peak_memory_reserved:6.779296875
Step 5 | loss:1.7684353590011597 lr:1.4999999999999999e-05 tokens_per_second_per_gpu:2942.122802734375 peak_memory_active:7.440155506134033 peak_memory_alloc:7.440155506134033 peak_memory_reserved:8.814453125

Still seems to go up just not as much. Note that we are processing more tokens so it's not that unusual, but I'm not familiar enough with how to recipe should behave to say more.

@ebsmothers
Copy link
Contributor

Thanks for working on this! Looks like CI is red though. Re @SalmanMohammadi's question about memory consumption, we can pull some profiles to figure out what's going on. In my experience DoRA is definitely more memory hungry, but a 2 GB difference seems high for a model of this size. It's not a blocker to land the bug fix, but I'm curious whether the increase in memory scales with model size (if so we could run into problems with 7B+ models). As for the fact that it increases over the first N steps that could be OK (and expected) given batch-dependent sequence lengths, as long as it stabilizes.

@mirceamironenco
Copy link
Contributor Author

I've added a fix that should address the failing CI for dora. This changes the initialization of the magnitude parameter so that a new Parameter isn't being created:

@torch.no_grad()
def initialize_dora_magnitude(self):
"""
DoRA initializes the magnitude vector such that its outputs are initially
identical to standard LoRA's outputs.
"""
base_weight = self.weight.to(self.lora_a.weight.dtype)
lora_weight = self.lora_b.weight @ self.lora_a.weight
weight_norm = self._get_weight_norm(base_weight, lora_weight)
self.magnitude.copy_(weight_norm)

Currently the recipes construct the adapter params dict before the initialization:

self.adapter_params = get_adapter_params(model)

Because self.magnitude was being redeclared that dict held a reference to the randomly initialized magnitude:

self.magnitude = nn.Parameter(torch.empty(out_dim))

Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @mirceamironenco for fixing this! This looks great to me

@ebsmothers ebsmothers merged commit eab21f0 into pytorch:main Oct 31, 2024
17 checks passed
@mirceamironenco mirceamironenco deleted the fix-dora branch October 31, 2024 19:30
@ebsmothers
Copy link
Contributor

(Very) belated discovery: this breaks DoRA distributed training. E.g.

tune run --nnodes 1 --nproc_per_node 2 lora_finetune_distributed --config llama3/8B_dora

fails after this commit due to magnitude still being on meta device. 22 @SLR722

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[bug] DoRA is broken
5 participants