diff --git a/src/sparseml/transformers/utils/trainer.py b/src/sparseml/transformers/utils/trainer.py index 5df07d6e512..b824e6821df 100644 --- a/src/sparseml/transformers/utils/trainer.py +++ b/src/sparseml/transformers/utils/trainer.py @@ -137,46 +137,51 @@ def apply_recipes(self, epoch=0.0): Applies all recipes from checkpoint_recipes. Runs architecture changing modifiers to prepare model for state dict loading """ + # get state dict before recipe application + org_state_dict = self.model.state_dict() + + # apply any checkpoint recipes for checkpoint_recipe in self.checkpoint_recipes: if checkpoint_recipe is not None: ScheduledModifierManager.from_yaml(checkpoint_recipe).apply(self.model) + + # init current training recipe if self.manager is not None: - org_state_dict = self.model.state_dict() self.manager.initialize( self.model, epoch=epoch, distillation_teacher=self.teacher, loggers=self.loggers, ) - new_state_dict = self.model.state_dict() - new_params = [p for p in new_state_dict.keys() if p not in org_state_dict] - - if os.path.isdir(self.model_name_or_path): - if os.path.isfile(os.path.join(self.model_name_or_path, WEIGHTS_NAME)): - archive_file = os.path.join(self.model_name_or_path, WEIGHTS_NAME) - state_dict = torch.load(archive_file, map_location="cpu") - new_params_to_init = [ - p for p in new_params if p in state_dict.keys() - ] - if new_params_to_init: - # parameters from dict are dependent on recipe - ( - _, - missing_keys, - unexpected_keys, - _, - ) = self.model._load_state_dict_into_model( - self.model, - state_dict, - self.model_name_or_path, - _fast_init=False, + + # if model structure changed, load in new params from state dict + new_state_dict = self.model.state_dict() + new_params = [p for p in new_state_dict.keys() if p not in org_state_dict] + + if os.path.isdir(self.model_name_or_path): + if os.path.isfile(os.path.join(self.model_name_or_path, WEIGHTS_NAME)): + archive_file = os.path.join(self.model_name_or_path, WEIGHTS_NAME) + state_dict = torch.load(archive_file, map_location="cpu") + new_params_to_init = [p for p in new_params if p in state_dict.keys()] + if new_params_to_init: + # parameters from dict are dependent on recipe + ( + _, + missing_keys, + unexpected_keys, + _, + ) = self.model._load_state_dict_into_model( + self.model, + state_dict, + self.model_name_or_path, + _fast_init=False, + ) + if missing_keys or unexpected_keys: + raise RuntimeError( + "Unexpected or missing keys detected when applying " + f"recipes to models\nMissing keys: {missing_keys}\n" + f"Unexpected keys: {unexpected_keys}\n" ) - if missing_keys or unexpected_keys: - raise RuntimeError( - "Unexpected or missing keys detected when applying " - f"recipes to models\nMissing keys: {missing_keys}\n" - f"Unexpected keys: {unexpected_keys}\n" - ) def create_optimizer(self): """