diff --git a/helpers/configuration/cmd_args.py b/helpers/configuration/cmd_args.py index ebeb4235..ab41b5a0 100644 --- a/helpers/configuration/cmd_args.py +++ b/helpers/configuration/cmd_args.py @@ -1338,7 +1338,7 @@ def get_argument_parser(): help=( "Validations must be enabled for model evaluation to function. The default is to use no evaluator," " and 'clip' will use a CLIP model to evaluate the resulting model's performance during validations." - ) + ), ) parser.add_argument( "--pretrained_evaluation_model_name_or_path", @@ -1348,7 +1348,7 @@ def get_argument_parser(): "Optionally provide a custom model to use for ViT evaluations." " The default is currently clip-vit-large-patch14-336, allowing for lower patch sizes (greater accuracy)" " and an input resolution of 336x336." - ) + ), ) parser.add_argument( "--validation_on_startup", @@ -2351,13 +2351,9 @@ def parse_cmdline_args(input_args=None): ) args.gradient_precision = "fp32" - if args.use_ema: - if args.model_family == "sd3": - raise ValueError( - "Using EMA is not currently supported for Stable Diffusion 3 training." - ) - if "lora" in args.model_type: - raise ValueError("Using EMA is not currently supported for LoRA training.") + # if args.use_ema: + # if "lora" in args.model_type: + # raise ValueError("Using EMA is not currently supported for LoRA training.") args.logging_dir = os.path.join(args.output_dir, args.logging_dir) args.accelerator_project_config = ProjectConfiguration( project_dir=args.output_dir, logging_dir=args.logging_dir diff --git a/helpers/training/ema.py b/helpers/training/ema.py index 9fdb5f2d..d2b59a9a 100644 --- a/helpers/training/ema.py +++ b/helpers/training/ema.py @@ -7,9 +7,10 @@ from typing import Any, Dict, Iterable, Optional, Union from diffusers.utils.deprecation_utils import deprecate from diffusers.utils import is_transformers_available +from helpers.training.state_tracker import StateTracker logger = logging.getLogger("EMAModel") -logger.setLevel(os.environ.get("SIMPLETUNER_LOG_LEVEL", "INFO")) +logger.setLevel(os.environ.get("SIMPLETUNER_LOG_LEVEL", "WARNING")) def should_update_ema(args, step): @@ -119,6 +120,62 @@ def __init__( self.model_config = model_config self.args = args self.accelerator = accelerator + self.training = True # To emulate nn.Module's training mode + + def save_state_dict(self, path: str) -> None: + """ + Save the EMA model's state directly to a file. + + Args: + path (str): The file path where the EMA state will be saved. + """ + # if the folder containing the path does not exist, create it + os.makedirs(os.path.dirname(path), exist_ok=True) + # grab state dict + state_dict = self.state_dict() + # save it using torch.save + torch.save(state_dict, path) + logger.info(f"EMA model state saved to {path}") + + def load_state_dict(self, path: str) -> None: + """ + Load the EMA model's state from a file and apply it to this instance. + + Args: + path (str): The file path from where the EMA state will be loaded. + """ + state_dict = torch.load(path, map_location="cpu", weights_only=True) + + # Load metadata + self.decay = state_dict.get("decay", self.decay) + self.min_decay = state_dict.get("min_decay", self.min_decay) + self.optimization_step = state_dict.get( + "optimization_step", self.optimization_step + ) + self.update_after_step = state_dict.get( + "update_after_step", self.update_after_step + ) + self.use_ema_warmup = state_dict.get("use_ema_warmup", self.use_ema_warmup) + self.inv_gamma = state_dict.get("inv_gamma", self.inv_gamma) + self.power = state_dict.get("power", self.power) + + # Load shadow parameters + shadow_params = [] + idx = 0 + while f"shadow_params.{idx}" in state_dict: + shadow_params.append(state_dict[f"shadow_params.{idx}"]) + idx += 1 + + if len(shadow_params) != len(self.shadow_params): + raise ValueError( + f"Mismatch in number of shadow parameters: expected {len(self.shadow_params)}, " + f"but found {len(shadow_params)} in the state dict." + ) + + for current_param, loaded_param in zip(self.shadow_params, shadow_params): + current_param.data.copy_(loaded_param.data) + + logger.info(f"EMA model state loaded from {path}") @classmethod def from_pretrained(cls, path, model_cls) -> "EMAModel": @@ -176,7 +233,6 @@ def get_decay(self, optimization_step: int = None) -> float: @torch.no_grad() def step(self, parameters: Iterable[torch.nn.Parameter], global_step: int = None): if not should_update_ema(self.args, global_step): - return if self.args.ema_device == "cpu" and not self.args.ema_cpu_only: @@ -290,6 +346,7 @@ def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None: ) else: for s_param, param in zip(self.shadow_params, parameters): + print(f"From shape: {s_param.shape}, to shape: {param.shape}") param.data.copy_(s_param.to(param.device).data) def pin_memory(self) -> None: @@ -307,31 +364,22 @@ def pin_memory(self) -> None: # This probably won't work, but we'll do it anyway. self.shadow_params = [p.pin_memory() for p in self.shadow_params] - def to(self, device=None, dtype=None, non_blocking=False) -> None: - r"""Move internal buffers of the ExponentialMovingAverage to `device`. + def to(self, *args, **kwargs): + for param in self.shadow_params: + param.data = param.data.to(*args, **kwargs) + return self - Args: - device: like `device` argument to `torch.Tensor.to` - """ - # .to() on the tensors handles None correctly - self.shadow_params = [ - ( - p.to(device=device, dtype=dtype, non_blocking=non_blocking) - if p.is_floating_point() - else p.to(device=device, non_blocking=non_blocking) - ) - for p in self.shadow_params - ] + def cuda(self, device=None): + return self.to(device="cuda" if device is None else f"cuda:{device}") + + def cpu(self): + return self.to(device="cpu") - def state_dict(self) -> dict: + def state_dict(self, destination=None, prefix="", keep_vars=False): r""" - Returns the state of the ExponentialMovingAverage as a dict. This method is used by accelerate during - checkpointing to save the ema state dict. + Returns a dictionary containing a whole state of the EMA model. """ - # Following PyTorch conventions, references to tensors are returned: - # "returns a reference to the state and not its copy!" - - # https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict - return { + state_dict = { "decay": self.decay, "min_decay": self.min_decay, "optimization_step": self.optimization_step, @@ -339,27 +387,48 @@ def state_dict(self) -> dict: "use_ema_warmup": self.use_ema_warmup, "inv_gamma": self.inv_gamma, "power": self.power, - "shadow_params": self.shadow_params, } + for idx, param in enumerate(self.shadow_params): + state_dict[f"{prefix}shadow_params.{idx}"] = ( + param if keep_vars else param.detach() + ) + return state_dict + + # def load_state_dict(self, state_dict: dict, strict=True) -> None: + # r""" + # Loads the EMA model's state. + # """ + # self.decay = state_dict.get("decay", self.decay) + # self.min_decay = state_dict.get("min_decay", self.min_decay) + # self.optimization_step = state_dict.get( + # "optimization_step", self.optimization_step + # ) + # self.update_after_step = state_dict.get( + # "update_after_step", self.update_after_step + # ) + # self.use_ema_warmup = state_dict.get("use_ema_warmup", self.use_ema_warmup) + # self.inv_gamma = state_dict.get("inv_gamma", self.inv_gamma) + # self.power = state_dict.get("power", self.power) + + # # Load shadow parameters + # shadow_params = [] + # idx = 0 + # while f"shadow_params.{idx}" in state_dict: + # shadow_params.append(state_dict[f"shadow_params.{idx}"]) + # idx += 1 + # if len(shadow_params) != len(self.shadow_params): + # raise ValueError("Mismatch in number of shadow parameters") + # self.shadow_params = shadow_params def store(self, parameters: Iterable[torch.nn.Parameter]) -> None: r""" - Args: Save the current parameters for restoring later. - parameters: Iterable of `torch.nn.Parameter`; the parameters to be - temporarily stored. """ self.temp_stored_params = [param.detach().cpu().clone() for param in parameters] def restore(self, parameters: Iterable[torch.nn.Parameter]) -> None: r""" - Args: - Restore the parameters stored with the `store` method. Useful to validate the model with EMA parameters without: - affecting the original optimization process. Store the parameters before the `copy_to()` method. After - validation (or model saving), use this to restore the former parameters. - parameters: Iterable of `torch.nn.Parameter`; the parameters to be - updated with the stored parameters. If `None`, the parameters with which this - `ExponentialMovingAverage` was initialized will be used. + Restore the parameters stored with the `store` method. """ if self.temp_stored_params is None: raise RuntimeError( @@ -378,53 +447,45 @@ def restore(self, parameters: Iterable[torch.nn.Parameter]) -> None: # Better memory-wise. self.temp_stored_params = None - def load_state_dict(self, state_dict: dict) -> None: - r""" - Args: - Loads the ExponentialMovingAverage state. This method is used by accelerate during checkpointing to save the - ema state dict. - state_dict (dict): EMA state. Should be an object returned - from a call to :meth:`state_dict`. - """ - # deepcopy, to be consistent with module API - state_dict = copy.deepcopy(state_dict) + def parameter_count(self) -> int: + return sum(p.numel() for p in self.shadow_params) - self.decay = state_dict.get("decay", self.decay) - if self.decay < 0.0 or self.decay > 1.0: - raise ValueError("Decay must be between 0 and 1") + # Implementing nn.Module methods to emulate its behavior - self.min_decay = state_dict.get("min_decay", self.min_decay) - if not isinstance(self.min_decay, float): - raise ValueError("Invalid min_decay") + def named_children(self): + # No child modules + return iter([]) - self.optimization_step = state_dict.get( - "optimization_step", self.optimization_step - ) - if not isinstance(self.optimization_step, int): - raise ValueError("Invalid optimization_step") + def children(self): + return iter([]) - self.update_after_step = state_dict.get( - "update_after_step", self.update_after_step - ) - if not isinstance(self.update_after_step, int): - raise ValueError("Invalid update_after_step") + def modules(self): + yield self - self.use_ema_warmup = state_dict.get("use_ema_warmup", self.use_ema_warmup) - if not isinstance(self.use_ema_warmup, bool): - raise ValueError("Invalid use_ema_warmup") + def named_modules(self, memo=None, prefix=""): + yield prefix, self - self.inv_gamma = state_dict.get("inv_gamma", self.inv_gamma) - if not isinstance(self.inv_gamma, (float, int)): - raise ValueError("Invalid inv_gamma") + def parameters(self, recurse=True): + return iter(self.shadow_params) - self.power = state_dict.get("power", self.power) - if not isinstance(self.power, (float, int)): - raise ValueError("Invalid power") - - shadow_params = state_dict.get("shadow_params", None) - if shadow_params is not None: - self.shadow_params = shadow_params - if not isinstance(self.shadow_params, list): - raise ValueError("shadow_params must be a list") - if not all(isinstance(p, torch.Tensor) for p in self.shadow_params): - raise ValueError("shadow_params must all be Tensors") + def named_parameters(self, prefix="", recurse=True): + for i, param in enumerate(self.shadow_params): + name = f"{prefix}shadow_params.{i}" + yield name, param + + def buffers(self, recurse=True): + return iter([]) + + def named_buffers(self, prefix="", recurse=True): + return iter([]) + + def train(self, mode=True): + self.training = mode + return self + + def eval(self): + return self.train(False) + + def zero_grad(self): + # No gradients to zero in EMA model + pass diff --git a/helpers/training/quantisation/__init__.py b/helpers/training/quantisation/__init__.py index ac96770b..f39c21fe 100644 --- a/helpers/training/quantisation/__init__.py +++ b/helpers/training/quantisation/__init__.py @@ -206,7 +206,15 @@ def get_quant_fn(base_model_precision): def quantise_model( - unet, transformer, text_encoder_1, text_encoder_2, text_encoder_3, controlnet, args + unet=None, + transformer=None, + text_encoder_1=None, + text_encoder_2=None, + text_encoder_3=None, + controlnet=None, + ema=None, + args=None, + return_dict: bool = False, ): """ Quantizes the provided models using the specified precision settings. @@ -218,6 +226,7 @@ def quantise_model( text_encoder_2: The second text encoder to quantize. text_encoder_3: The third text encoder to quantize. controlnet: The ControlNet model to quantize. + ema: An EMAModel to quantize. args: An object containing precision settings and other arguments. Returns: @@ -273,6 +282,14 @@ def quantise_model( "base_model_precision": args.base_model_precision, }, ), + ( + ema, + { + "quant_fn": get_quant_fn(args.base_model_precision), + "model_precision": args.base_model_precision, + "quantize_activations": args.quantize_activations, + }, + ), ] # Iterate over the models and apply quantization if the model is not None @@ -293,8 +310,33 @@ def quantise_model( models[i] = (quant_fn(model, **quant_args_combined), quant_args) # Unpack the quantized models - transformer, unet, controlnet, text_encoder_1, text_encoder_2, text_encoder_3 = [ - model for model, _ in models - ] - - return unet, transformer, text_encoder_1, text_encoder_2, text_encoder_3, controlnet + ( + transformer, + unet, + controlnet, + text_encoder_1, + text_encoder_2, + text_encoder_3, + ema, + ) = [model for model, _ in models] + + if return_dict: + return { + "unet": unet, + "transformer": transformer, + "text_encoder_1": text_encoder_1, + "text_encoder_2": text_encoder_2, + "text_encoder_3": text_encoder_3, + "controlnet": controlnet, + "ema": ema, + } + + return ( + unet, + transformer, + text_encoder_1, + text_encoder_2, + text_encoder_3, + controlnet, + ema, + ) diff --git a/helpers/training/save_hooks.py b/helpers/training/save_hooks.py index f6b75807..695b678b 100644 --- a/helpers/training/save_hooks.py +++ b/helpers/training/save_hooks.py @@ -1,4 +1,5 @@ -from diffusers.training_utils import EMAModel, _set_state_dict_into_text_encoder +from diffusers.training_utils import _set_state_dict_into_text_encoder +from helpers.training.ema import EMAModel from helpers.training.wrappers import unwrap_model from helpers.training.multi_process import _get_rank as get_rank from diffusers.utils import ( @@ -22,7 +23,7 @@ logger = logging.getLogger("SaveHookManager") -logger.setLevel(os.environ.get("SIMPLETUNER_LOG_LEVEL") or "INFO") +logger.setLevel(os.environ.get("SIMPLETUNER_LOG_LEVEL", "WARNING")) try: from diffusers import ( @@ -176,13 +177,10 @@ def __init__( self.ema_model_subdir = None if unet is not None: self.ema_model_subdir = "unet_ema" - self.ema_model_cls = UNet2DConditionModel + self.ema_model_cls = unet.__class__ if transformer is not None: self.ema_model_subdir = "transformer_ema" - if self.args.model_family == "sd3": - self.ema_model_cls = SD3Transformer2DModel - elif self.args.model_family == "pixart_sigma": - self.ema_model_cls = PixArtTransformer2DModel + self.ema_model_cls = transformer.__class__ self.training_state_path = "training_state.json" if self.accelerator is not None: rank = get_rank() @@ -295,11 +293,14 @@ def _save_full_model(self, models, weights, output_dir): os.makedirs(temporary_dir, exist_ok=True) if self.args.use_ema: - tqdm.write("Saving EMA model") - self.ema_model.save_pretrained( - os.path.join(temporary_dir, self.ema_model_subdir), - max_shard_size="10GB", + ema_model_path = os.path.join( + temporary_dir, self.ema_model_subdir, "ema_model.pt" ) + logger.info(f"Saving EMA model to {ema_model_path}") + try: + self.ema_model.save_state_dict(ema_model_path) + except Exception as e: + logger.error(f"Error saving EMA model: {e}") if self.unet is not None: sub_dir = "unet" @@ -334,6 +335,15 @@ def save_model_hook(self, models, weights, output_dir): ) if not self.accelerator.is_main_process: return + if self.args.use_ema: + ema_model_path = os.path.join( + output_dir, self.ema_model_subdir, "ema_model.pt" + ) + logger.info(f"Saving EMA model to {ema_model_path}") + try: + self.ema_model.save_state_dict(ema_model_path) + except Exception as e: + logger.error(f"Error saving EMA model: {e}") if "lora" in self.args.model_type and self.args.lora_type == "standard": self._save_lora(models=models, weights=weights, output_dir=output_dir) return @@ -455,13 +465,6 @@ def _load_lycoris(self, models, input_dir): lycoris_logger.setLevel(logging.ERROR) def _load_full_model(self, models, input_dir): - if self.args.use_ema: - load_model = EMAModel.from_pretrained( - os.path.join(input_dir, self.ema_model_subdir), self.ema_model_cls - ) - self.ema_model.load_state_dict(load_model.state_dict()) - self.ema_model.to(self.accelerator.device) - del load_model if self.args.model_type == "full": return_exception = False for i in range(len(models)): @@ -508,6 +511,14 @@ def load_model_hook(self, models, input_dir): logger.warning( f"Could not find {training_state_path} in checkpoint dir {input_dir}" ) + if self.args.use_ema: + try: + self.ema_model.load_state_dict( + os.path.join(input_dir, self.ema_model_subdir, "ema_model.pt") + ) + # self.ema_model.to(self.accelerator.device) + except Exception as e: + logger.error(f"Could not load EMA model: {e}") if "lora" in self.args.model_type and self.args.lora_type == "standard": self._load_lora(models=models, input_dir=input_dir) elif "lora" in self.args.model_type and self.args.lora_type == "lycoris": diff --git a/helpers/training/trainer.py b/helpers/training/trainer.py index b1525391..bc144212 100644 --- a/helpers/training/trainer.py +++ b/helpers/training/trainer.py @@ -176,6 +176,7 @@ def __init__( self.text_encoder_2 = None self.text_encoder_3 = None self.controlnet = None + self.ema_model = None self.validation = None def _config_to_obj(self, config): @@ -350,11 +351,7 @@ def _misc_init(self): self.config.is_torchao = True elif "bnb" in self.config.base_model_precision: self.config.is_bnb = True - if self.config.is_quanto: - from helpers.training.quantisation import quantise_model - - self.quantise_model = quantise_model - elif self.config.is_torchao: + if self.config.is_quanto or self.config.is_torchao: from helpers.training.quantisation import quantise_model self.quantise_model = quantise_model @@ -756,7 +753,9 @@ def init_unload_text_encoder(self): " The real memories were the friends we trained a model on along the way." ) - def init_precision(self, preprocessing_models_only: bool = False): + def init_precision( + self, preprocessing_models_only: bool = False, ema_only: bool = False + ): self.config.enable_adamw_bf16 = ( True if self.config.weight_dtype == torch.bfloat16 else False ) @@ -765,7 +764,7 @@ def init_precision(self, preprocessing_models_only: bool = False): ) if "bnb" in self.config.base_model_precision: - # can't cast or move bitsandbytes modelsthis + # can't cast or move bitsandbytes models return if not self.config.disable_accelerator and self.config.is_quantized: @@ -793,6 +792,10 @@ def init_precision(self, preprocessing_models_only: bool = False): if self.config.is_quanto: with self.accelerator.local_main_process_first(): + if ema_only: + self.quantise_model(ema=self.ema_model, args=self.config) + + return self.quantise_model( unet=self.unet if not preprocessing_models_only else None, transformer=( @@ -802,10 +805,17 @@ def init_precision(self, preprocessing_models_only: bool = False): text_encoder_2=self.text_encoder_2, text_encoder_3=self.text_encoder_3, controlnet=None, + ema=self.ema_model, args=self.config, ) elif self.config.is_torchao: with self.accelerator.local_main_process_first(): + if ema_only: + self.ema_model = self.quantise_model( + ema=self.ema_model, args=self.config, return_dict=True + )["ema"] + + return ( self.unet, self.transformer, @@ -813,6 +823,7 @@ def init_precision(self, preprocessing_models_only: bool = False): self.text_encoder_2, self.text_encoder_3, self.controlnet, + self.ema_model, ) = self.quantise_model( unet=self.unet if not preprocessing_models_only else None, transformer=( @@ -822,6 +833,7 @@ def init_precision(self, preprocessing_models_only: bool = False): text_encoder_2=self.text_encoder_2, text_encoder_3=self.text_encoder_3, controlnet=None, + ema=self.ema_model, args=self.config, ) @@ -1010,6 +1022,22 @@ def init_post_load_freeze(self): self.accelerator, self.text_encoder_2 ).gradient_checkpointing_enable() + def _get_trainable_parameters(self): + # Return just a list of the currently trainable parameters. + if self.config.model_type == "lora": + if self.config.lora_type == "lycoris": + return self.lycoris_wrapped_network.parameters() + if self.config.controlnet: + return [ + param for param in self.controlnet.parameters() if param.requires_grad + ] + if self.unet is not None: + return [param for param in self.unet.parameters() if param.requires_grad] + if self.transformer is not None: + return [ + param for param in self.transformer.parameters() if param.requires_grad + ] + def _recalculate_training_steps(self): # Scheduler and math around the number of training steps. if not hasattr(self.config, "overrode_max_train_steps"): @@ -1190,37 +1218,33 @@ def init_ema_model(self): logger.info("Using EMA. Creating EMAModel.") ema_model_cls = None - if self.unet is not None: - ema_model_cls = UNet2DConditionModel - elif self.config.model_family == "pixart_sigma": - ema_model_cls = PixArtTransformer2DModel - elif self.config.model_family == "flux": - ema_model_cls = FluxTransformer2DModel - else: - raise ValueError( - f"Please open a bug report or disable EMA. Unknown EMA model family: {self.config.model_family}" - ) - ema_model_config = None - if self.unet is not None: + if self.config.controlnet: + ema_model_cls = self.controlnet.__class__ + ema_model_config = self.controlnet.config + elif self.unet is not None: + ema_model_cls = self.unet.__class__ ema_model_config = self.unet.config elif self.transformer is not None: + ema_model_cls = self.transformer.__class__ ema_model_config = self.transformer.config + else: + raise ValueError( + f"Please open a bug report or disable EMA. Unknown EMA model family: {self.config.model_family}" + ) self.ema_model = EMAModel( self.config, self.accelerator, - parameters=( - self.unet.parameters() - if self.unet is not None - else self.transformer.parameters() - ), + parameters=self._get_trainable_parameters(), model_cls=ema_model_cls, model_config=ema_model_config, decay=self.config.ema_decay, foreach=not self.config.ema_foreach_disable, ) - logger.info("EMA model creation complete.") + logger.info( + f"EMA model creation completed with {self.ema_model.parameter_count():,} parameters" + ) self.accelerator.wait_for_everyone() @@ -1296,6 +1320,7 @@ def init_prepare_models(self, lr_scheduler): if self.config.use_ema and self.ema_model is not None: if self.config.ema_device == "accelerator": logger.info("Moving EMA model weights to accelerator...") + print(f"EMA model: {self.ema_model}") self.ema_model.to( ( self.accelerator.device @@ -2624,11 +2649,7 @@ def train(self): if self.ema_model is not None: training_logger.debug("Stepping EMA forward") self.ema_model.step( - parameters=( - self.unet.parameters() - if self.unet is not None - else self.transformer.parameters() - ), + parameters=self._get_trainable_parameters(), global_step=self.state["global_step"], ) wandb_logs["ema_decay_value"] = self.ema_model.get_decay() diff --git a/tests/test_ema.py b/tests/test_ema.py new file mode 100644 index 00000000..a05099e9 --- /dev/null +++ b/tests/test_ema.py @@ -0,0 +1,106 @@ +import unittest +import torch +import tempfile +import os +from helpers.training.ema import EMAModel + + +class TestEMAModel(unittest.TestCase): + def setUp(self): + # Set up a simple model and its parameters + self.model = torch.nn.Linear(10, 5) # Simple linear model + self.args = type( + "Args", + (), + {"ema_update_interval": None, "ema_device": "cpu", "ema_cpu_only": True}, + ) + self.accelerator = None # For simplicity, assuming no accelerator in tests + self.ema_model = EMAModel( + args=self.args, + accelerator=self.accelerator, + parameters=self.model.parameters(), + decay=0.999, + min_decay=0.999, # Force decay to be 0.999 + update_after_step=-1, # Ensure decay is used from step 1 + use_ema_warmup=False, # Disable EMA warmup + foreach=False, + ) + + def test_ema_initialization(self): + """Test that the EMA model initializes correctly.""" + self.assertEqual( + len(self.ema_model.shadow_params), len(list(self.model.parameters())) + ) + for shadow_param, model_param in zip( + self.ema_model.shadow_params, self.model.parameters() + ): + self.assertTrue(torch.equal(shadow_param, model_param)) + + def test_ema_step(self): + """Test that the EMA model updates correctly after a step.""" + # Perform a model parameter update + optimizer = torch.optim.SGD(self.model.parameters(), lr=0.01) + dummy_input = torch.randn(1, 10) # Adjust to match input size + dummy_output = self.model(dummy_input) + loss = dummy_output.sum() # A dummy loss function + loss.backward() + optimizer.step() + + # Save a copy of the model parameters after the update but before the EMA update. + model_params = [p.clone() for p in self.model.parameters()] + # Save a copy of the shadow parameters before the EMA update. + shadow_params_before = [p.clone() for p in self.ema_model.shadow_params] + + # Perform an EMA update + self.ema_model.step(self.model.parameters(), global_step=1) + decay = self.ema_model.cur_decay_value # This should be 0.999 + + # Verify that the decay used is as expected + self.assertAlmostEqual( + decay, 0.999, places=6, msg="Decay value is not as expected." + ) + + # Verify shadow parameters have changed + for shadow_param, shadow_param_before in zip( + self.ema_model.shadow_params, shadow_params_before + ): + self.assertFalse( + torch.equal(shadow_param, shadow_param_before), + "Shadow parameters did not update correctly.", + ) + + # Compute and check expected shadow parameter values + for shadow_param, shadow_param_before, model_param in zip( + self.ema_model.shadow_params, shadow_params_before, self.model.parameters() + ): + expected_shadow = decay * shadow_param_before + (1 - decay) * model_param + self.assertTrue( + torch.allclose(shadow_param, expected_shadow, atol=1e-6), + f"Shadow parameter does not match expected value.", + ) + + def test_save_and_load_state_dict(self): + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = os.path.join(temp_dir, "ema_model_state.pth") + + # Save the state + self.ema_model.save_state_dict(temp_path) + + # Create a new EMA model and load the state + new_ema_model = EMAModel( + args=self.args, + accelerator=self.accelerator, + parameters=self.model.parameters(), + decay=0.999, + ) + new_ema_model.load_state_dict(temp_path) + + # Check that the new EMA model's shadow parameters match the saved state + for shadow_param, new_shadow_param in zip( + self.ema_model.shadow_params, new_ema_model.shadow_params + ): + self.assertTrue(torch.equal(shadow_param, new_shadow_param)) + + +if __name__ == "__main__": + unittest.main() diff --git a/train.py b/train.py index 7e00e46b..13642a4d 100644 --- a/train.py +++ b/train.py @@ -27,6 +27,7 @@ trainer.init_huggingface_hub() trainer.init_preprocessing_models() + trainer.init_precision(preprocessing_models_only=True) trainer.init_data_backend() trainer.init_validation_prompts() trainer.init_unload_text_encoder() @@ -38,6 +39,8 @@ trainer.init_freeze_models() trainer.init_trainable_peft_adapter() trainer.init_ema_model() + # EMA must be quantised if the base model is as well. + trainer.init_precision(ema_only=True) trainer.move_models(destination="accelerator") trainer.init_validations()