Skip to content

Commit

Permalink
Merge pull request #1246 from bghira/bugfix/export-autoencoderkl-clas…
Browse files Browse the repository at this point in the history
…s-name

use correct autoencoder class for different model exports (#1245)
  • Loading branch information
bghira authored Dec 25, 2024
2 parents 933789c + ff924cb commit 75ecc79
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions helpers/training/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,16 +469,17 @@ def init_vae(self, move_to_accelerator: bool = True):
from diffusers import AutoencoderDC as AutoencoderClass
else:
from diffusers import AutoencoderKL as AutoencoderClass
self.vae_cls = AutoencoderClass

with ContextManagers(deepspeed_zero_init_disabled_context_manager()):
try:
self.vae = AutoencoderClass.from_pretrained(**self.config.vae_kwargs)
self.vae = self.vae_cls.from_pretrained(**self.config.vae_kwargs)
except:
logger.warning(
"Couldn't load VAE with default path. Trying without a subfolder.."
)
self.config.vae_kwargs["subfolder"] = None
self.vae = AutoencoderClass.from_pretrained(**self.config.vae_kwargs)
self.vae = self.vae_cls.from_pretrained(**self.config.vae_kwargs)
if (
self.vae is not None
and self.config.vae_enable_tiling
Expand Down Expand Up @@ -3233,7 +3234,7 @@ def train(self):
tokenizer_3=self.tokenizer_3,
vae=self.vae
or (
AutoencoderKL.from_pretrained(
self.vae_cls.from_pretrained(
self.config.vae_path,
subfolder=(
"vae"
Expand Down Expand Up @@ -3305,7 +3306,7 @@ def train(self):
tokenizer=self.tokenizer_1,
vae=self.vae
or (
AutoencoderKL.from_pretrained(
self.vae_cls.from_pretrained(
self.config.vae_path,
subfolder=(
"vae"
Expand Down Expand Up @@ -3339,7 +3340,7 @@ def train(self):
tokenizer=self.tokenizer_1,
vae=self.vae
or (
AutoencoderKL.from_pretrained(
self.vae_cls.from_pretrained(
self.config.vae_path,
subfolder=(
"vae"
Expand Down Expand Up @@ -3408,7 +3409,7 @@ def train(self):
tokenizer=self.tokenizer_1,
tokenizer_2=self.tokenizer_2,
vae=StateTracker.get_vae()
or AutoencoderKL.from_pretrained(
or self.vae_cls.from_pretrained(
self.config.vae_path,
subfolder=(
"vae"
Expand Down

0 comments on commit 75ecc79

Please sign in to comment.