diff --git a/helpers/training/save_hooks.py b/helpers/training/save_hooks.py index 876d4200..dbb33b68 100644 --- a/helpers/training/save_hooks.py +++ b/helpers/training/save_hooks.py @@ -215,12 +215,20 @@ def _save_lora(self, models, weights, output_dir): ] self.ema_model.store(trainable_parameters) self.ema_model.copy_to(trainable_parameters) - self.pipeline_class.save_lora_weights( - os.path.join(output_dir, "ema"), - transformer_lora_layers=convert_state_dict_to_diffusers( - get_peft_model_state_dict(self._primary_model()) - ), - ) + if self.transformer is not None: + self.pipeline_class.save_lora_weights( + os.path.join(output_dir, "ema"), + transformer_lora_layers=convert_state_dict_to_diffusers( + get_peft_model_state_dict(self._primary_model()) + ), + ) + elif self.unet is not None: + self.pipeline_class.save_lora_weights( + os.path.join(output_dir, "ema"), + unet_lora_layers=convert_state_dict_to_diffusers( + get_peft_model_state_dict(self._primary_model()) + ), + ) self.ema_model.restore(trainable_parameters) for model in models: