-
Notifications
You must be signed in to change notification settings - Fork 5.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix issues for saving checkpointing steps #9891
base: main
Are you sure you want to change the base?
Conversation
@sayakpaul Please take a look at this PR, thanks for your help! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your PR. Can you please modify a single file first and discuss the changes first?
When I did the deambooth flux without Lora, and save the check pointing. It stuck for a while and break. So I think they all need these modifications. I can only do the flux ones if you want |
Yeah let's change a single file first and then we can discuss the changes first. |
Sure |
@sayakpaul I already changed the modifications only on FLUX models |
if global_step % args.checkpointing_steps == 0: | ||
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well, there is a better way to handle it:
if accelerator.is_main_process or accelerator.distributed_type == DistributedType.DEEPSPEED:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thaks for your PR!
You can refer to the following scripts:
- https://github.com/a-r-r-o-w/cogvideox-factory/blob/main/training/cogvideox_image_to_video_lora.py
- https://github.com/a-r-r-o-w/cogvideox-factory/blob/main/training/cogvideox_text_to_video_sft.py
To see how we handle saving and loading from checkpoints when using DeepSpeed.
Search for DistributedType.DEEPSPEED
.
@sayakpaul I've changed it based on the reference |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Left more comments. LMK if they're clear.
model.load_state_dict(load_model.state_dict()) | ||
except Exception: | ||
elif isinstance(unwrap_model(model), (CLIPTextModelWithProjection, T5EncoderModel)): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't support fine-tuning the T5 model. So, this seems wrong. It should just be CLIPTextModelWithProjection
, no?
try: | ||
load_model = T5EncoderModel.from_pretrained(input_dir, subfolder="text_encoder_2") | ||
model(**load_model.config) | ||
model.load_state_dict(load_model.state_dict()) | ||
except Exception: | ||
raise ValueError(f"Couldn't load the model of type: ({type(model)}).") | ||
else: | ||
raise ValueError(f"Unsupported model found: {type(model)=}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same for this.
try: | ||
load_model = CLIPTextModelWithProjection.from_pretrained(input_dir, subfolder="text_encoder") | ||
model(**load_model.config) | ||
if not accelerator.distributed_type == DistributedType.DEEPSPEED: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We also need to handle the case when we're actually doing DeepSpeed training. Similar to:
https://github.com/a-r-r-o-w/cogvideox-factory/blob/d63a826f37758eccf226710f94f6c3a4d4ee7a25/training/cogvideox_text_to_video_sft.py#L385
@@ -1262,15 +1263,16 @@ def load_model_hook(models, input_dir): | |||
transformer_ = None | |||
text_encoder_one_ = None | |||
|
|||
while len(models) > 0: | |||
model = models.pop() | |||
if not accelerator.distributed_type == DistributedType.DEEPSPEED: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same. We need to handle the case when we're doing DeepSpeed training. Reference:
https://github.com/a-r-r-o-w/cogvideox-factory/blob/d63a826f37758eccf226710f94f6c3a4d4ee7a25/training/cogvideox_text_to_video_lora.py#L396
@@ -1187,7 +1187,8 @@ def save_model_hook(models, weights, output_dir): | |||
raise ValueError(f"Wrong model supplied: {type(model)=}.") | |||
|
|||
# make sure to pop weight so that corresponding model is not saved again | |||
weights.pop() | |||
if weights: | |||
weights.pop() | |||
|
|||
def load_model_hook(models, input_dir): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems like we're not handling the loading case appropriately here. I repeated this multiple times now but please refer to the changes here to get an idea of what is required.
In summary, we're not dealing with the changes required to load the state dict in the models being trained when DeepSpeed is enabled.
What does this PR do?
These modification can help to save the checkpoint steps while training. Otherwise it will just stuck for too long and timeout.
Fixes get stuck when save_state using DeepSpeed backend under training train_text_to_image_lora #2606
Bug fix for weight pop from empty list
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.