Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] Integration with DCP - Benchmark Results #443

Closed
wants to merge 8 commits into from
Closed

Conversation

LucasLLC
Copy link

@LucasLLC LucasLLC commented Mar 4, 2024

Context

We've pivoted on this PR from an RFC on torchtune + dcp UX, to a PR which showcases async saving benchmarks.
If consensus is reached that we want an async_checkpointing integration in torchtune, I'll pull the latest GPU recipes and implement this code in a less-hacky way there. Please don't review this PR for code cleanliness as it's a bit of a dumpster fire in the current state.

Results:

full_finetune:

world_size main (full state dict) dcp (sharded state dict)
2 44.33777084 10.303921
4 68.12809 8.397324
8 89.22356 2.483598

What's shown here is the blocking portion of saving -- or specifically the portion that blocks the training loop. LLM training is a great case for async checkpointing since the training loops are fairly long. This means by the time we reach the end of the epoch, serialization has finished and we never have to wait on the previous checkpoint before staging another one.

In the full_tune case, we will also notice an inverse relation to main in terms of ranks. As opposed to the current implementation, which increases latency with number of ranks, we expect the sharded state dict to be smaller per rank as we increase world_size, a fact that is taken advantage of in DCP. Meaning: "more GPUs, faster save"

lora_finetune:

world_size main (full state dict) dcp (sharded state dict)
2 0.736959 1.0394
4 3.119773 3.502441
8 0.3324818 1.03378

In the Lora case, we'll notice performance is actually at parity or worst for sharded state dict + DCP. There's a couple reasons for this:

  1. Most of the time spent saving the lora checkpoints is actually spent in FSDP's state_dict calls. Corresponding to the graph above, on the main branch with 2 ranks we see full state dict takes ~0.624 seconds. Sharded state dict can take slightly longer to materialize.

  2. The file sizes are really small, looks like 16MB - 48MB. Essentially async overhead will be almost equivalent to the serialization time of just using torch.save ( at least that's what I'm seeing in this case)

Proposal:

For intermediary checkpoints on multi-gpu recipes, we should implement DCP.async_save as the checkpointing solution.

Some decisions points (comments are welcome):

  1. We can go about converting to an un-sharded torch.save file in two ways. We can either :
    a) Go back to the slower torch.save + full state dict implementation on the last call to save in the last epoch
    b) Save in the DCP format always, and ask users to convert themselves using torch.distributed.checkpoint.format_utils

  2. Using the DCP checkpointer for lora_finetune would keep things consistent, but could incur some (pretty minimal) performance hits. Any thoughts?

Changelog

  • ... WIP

Test plan

  • .... WIP

@LucasLLC LucasLLC requested a review from rohan-varma March 4, 2024 16:04
@LucasLLC LucasLLC self-assigned this Mar 4, 2024
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 4, 2024
Copy link

netlify bot commented Mar 4, 2024

Deploy Preview for torchtune-preview ready!

Name Link
🔨 Latest commit 1644396
🔍 Latest deploy log https://app.netlify.com/sites/torchtune-preview/deploys/65ea660bc9f4f7000930dc8c
😎 Deploy Preview https://deploy-preview-443--torchtune-preview.netlify.app
📱 Preview on mobile
Toggle QR Code...

QR Code

Use your smartphone camera to open QR code link.

To edit notification comments on pull requests, go to your Netlify site configuration.

@LucasLLC LucasLLC changed the title [RFC] Integration with DCP [RFC] Integration with DCP - Benchmark Results Mar 12, 2024
@fegin fegin self-requested a review March 12, 2024 22:41
self.log(msg=f"Waited for {time.monotonic() - t0} seconds.")

t0 = time.monotonic()
with FSDP.state_dict_type(
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should use get_state_dict(), which will avoid future migration when FSDP2 is used.

ckpt_dict[MODEL_KEY] = self._model.state_dict()

if OPT_KEY in ckpt_dict:
ckpt_dict[OPT_KEY] = FSDP.optim_state_dict(self._model, self._optimizer)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto, we should use get_optimizer_state_dict().

self.log(msg=f"creating sharded state_dict {time.monotonic() - t0} seconds.")

t0 = time.monotonic()
self._checkpoint_future = async_save(
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice! Good to see async_save is actually being used.

recipes/lora_finetune.py Show resolved Hide resolved
@RdoubleA RdoubleA closed this Nov 18, 2024
@RdoubleA
Copy link
Contributor

Superseded by #2006

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants