Skip to content

Commit

Permalink
ema lora save for sdxl unet
Browse files Browse the repository at this point in the history
  • Loading branch information
bghira committed Nov 19, 2024
1 parent 2118dcc commit 6790c25
Showing 1 changed file with 14 additions and 6 deletions.
20 changes: 14 additions & 6 deletions helpers/training/save_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,12 +215,20 @@ def _save_lora(self, models, weights, output_dir):
]
self.ema_model.store(trainable_parameters)
self.ema_model.copy_to(trainable_parameters)
self.pipeline_class.save_lora_weights(
os.path.join(output_dir, "ema"),
transformer_lora_layers=convert_state_dict_to_diffusers(
get_peft_model_state_dict(self._primary_model())
),
)
if self.transformer is not None:
self.pipeline_class.save_lora_weights(
os.path.join(output_dir, "ema"),
transformer_lora_layers=convert_state_dict_to_diffusers(
get_peft_model_state_dict(self._primary_model())
),
)
elif self.unet is not None:
self.pipeline_class.save_lora_weights(
os.path.join(output_dir, "ema"),
unet_lora_layers=convert_state_dict_to_diffusers(
get_peft_model_state_dict(self._primary_model())
),
)
self.ema_model.restore(trainable_parameters)

for model in models:
Expand Down

0 comments on commit 6790c25

Please sign in to comment.