diff --git a/helpers/training/trainer.py b/helpers/training/trainer.py index 31638c93..6de9c34f 100644 --- a/helpers/training/trainer.py +++ b/helpers/training/trainer.py @@ -469,16 +469,17 @@ def init_vae(self, move_to_accelerator: bool = True): from diffusers import AutoencoderDC as AutoencoderClass else: from diffusers import AutoencoderKL as AutoencoderClass + self.vae_cls = AutoencoderClass with ContextManagers(deepspeed_zero_init_disabled_context_manager()): try: - self.vae = AutoencoderClass.from_pretrained(**self.config.vae_kwargs) + self.vae = self.vae_cls.from_pretrained(**self.config.vae_kwargs) except: logger.warning( "Couldn't load VAE with default path. Trying without a subfolder.." ) self.config.vae_kwargs["subfolder"] = None - self.vae = AutoencoderClass.from_pretrained(**self.config.vae_kwargs) + self.vae = self.vae_cls.from_pretrained(**self.config.vae_kwargs) if ( self.vae is not None and self.config.vae_enable_tiling @@ -3233,7 +3234,7 @@ def train(self): tokenizer_3=self.tokenizer_3, vae=self.vae or ( - AutoencoderKL.from_pretrained( + self.vae_cls.from_pretrained( self.config.vae_path, subfolder=( "vae" @@ -3305,7 +3306,7 @@ def train(self): tokenizer=self.tokenizer_1, vae=self.vae or ( - AutoencoderKL.from_pretrained( + self.vae_cls.from_pretrained( self.config.vae_path, subfolder=( "vae" @@ -3339,7 +3340,7 @@ def train(self): tokenizer=self.tokenizer_1, vae=self.vae or ( - AutoencoderKL.from_pretrained( + self.vae_cls.from_pretrained( self.config.vae_path, subfolder=( "vae" @@ -3408,7 +3409,7 @@ def train(self): tokenizer=self.tokenizer_1, tokenizer_2=self.tokenizer_2, vae=StateTracker.get_vae() - or AutoencoderKL.from_pretrained( + or self.vae_cls.from_pretrained( self.config.vae_path, subfolder=( "vae"