-
Notifications
You must be signed in to change notification settings - Fork 493
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RFC] Integration with DCP - Benchmark Results #443
Conversation
✅ Deploy Preview for torchtune-preview ready!
To edit notification comments on pull requests, go to your Netlify site configuration. |
self.log(msg=f"Waited for {time.monotonic() - t0} seconds.") | ||
|
||
t0 = time.monotonic() | ||
with FSDP.state_dict_type( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should use get_state_dict()
, which will avoid future migration when FSDP2 is used.
ckpt_dict[MODEL_KEY] = self._model.state_dict() | ||
|
||
if OPT_KEY in ckpt_dict: | ||
ckpt_dict[OPT_KEY] = FSDP.optim_state_dict(self._model, self._optimizer) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto, we should use get_optimizer_state_dict()
.
self.log(msg=f"creating sharded state_dict {time.monotonic() - t0} seconds.") | ||
|
||
t0 = time.monotonic() | ||
self._checkpoint_future = async_save( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice! Good to see async_save is actually being used.
Superseded by #2006 |
Context
We've pivoted on this PR from an RFC on torchtune + dcp UX, to a PR which showcases async saving benchmarks.
If consensus is reached that we want an
async_checkpointing
integration in torchtune, I'll pull the latest GPU recipes and implement this code in a less-hacky way there. Please don't review this PR for code cleanliness as it's a bit of a dumpster fire in the current state.Results:
full_finetune:
What's shown here is the blocking portion of saving -- or specifically the portion that blocks the training loop. LLM training is a great case for async checkpointing since the training loops are fairly long. This means by the time we reach the end of the epoch, serialization has finished and we never have to wait on the previous checkpoint before staging another one.
In the
full_tune
case, we will also notice an inverse relation to main in terms of ranks. As opposed to the current implementation, which increases latency with number of ranks, we expect the sharded state dict to be smaller per rank as we increase world_size, a fact that is taken advantage of in DCP. Meaning: "more GPUs, faster save"lora_finetune:
In the Lora case, we'll notice performance is actually at parity or worst for sharded state dict + DCP. There's a couple reasons for this:
Most of the time spent saving the lora checkpoints is actually spent in FSDP's
state_dict
calls. Corresponding to the graph above, on the main branch with 2 ranks we see full state dict takes ~0.624
seconds. Sharded state dict can take slightly longer to materialize.The file sizes are really small, looks like 16MB - 48MB. Essentially async overhead will be almost equivalent to the serialization time of just using
torch.save
( at least that's what I'm seeing in this case)Proposal:
For intermediary checkpoints on multi-gpu recipes, we should implement DCP.async_save as the checkpointing solution.
Some decisions points (comments are welcome):
We can go about converting to an un-sharded
torch.save
file in two ways. We can either :a) Go back to the slower
torch.save
+ full state dict implementation on the last call to save in the last epochb) Save in the DCP format always, and ask users to convert themselves using
torch.distributed.checkpoint.format_utils
Using the DCP checkpointer for
lora_finetune
would keep things consistent, but could incur some (pretty minimal) performance hits. Any thoughts?Changelog
Test plan