Skip to content

Commit

Permalink
fix load state dict for transformers eval (#534) (#535)
Browse files Browse the repository at this point in the history
  • Loading branch information
bfineran authored Jan 28, 2022
1 parent 3463202 commit d1b0622
Showing 1 changed file with 34 additions and 29 deletions.
63 changes: 34 additions & 29 deletions src/sparseml/transformers/utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,46 +137,51 @@ def apply_recipes(self, epoch=0.0):
Applies all recipes from checkpoint_recipes. Runs architecture changing
modifiers to prepare model for state dict loading
"""
# get state dict before recipe application
org_state_dict = self.model.state_dict()

# apply any checkpoint recipes
for checkpoint_recipe in self.checkpoint_recipes:
if checkpoint_recipe is not None:
ScheduledModifierManager.from_yaml(checkpoint_recipe).apply(self.model)

# init current training recipe
if self.manager is not None:
org_state_dict = self.model.state_dict()
self.manager.initialize(
self.model,
epoch=epoch,
distillation_teacher=self.teacher,
loggers=self.loggers,
)
new_state_dict = self.model.state_dict()
new_params = [p for p in new_state_dict.keys() if p not in org_state_dict]

if os.path.isdir(self.model_name_or_path):
if os.path.isfile(os.path.join(self.model_name_or_path, WEIGHTS_NAME)):
archive_file = os.path.join(self.model_name_or_path, WEIGHTS_NAME)
state_dict = torch.load(archive_file, map_location="cpu")
new_params_to_init = [
p for p in new_params if p in state_dict.keys()
]
if new_params_to_init:
# parameters from dict are dependent on recipe
(
_,
missing_keys,
unexpected_keys,
_,
) = self.model._load_state_dict_into_model(
self.model,
state_dict,
self.model_name_or_path,
_fast_init=False,

# if model structure changed, load in new params from state dict
new_state_dict = self.model.state_dict()
new_params = [p for p in new_state_dict.keys() if p not in org_state_dict]

if os.path.isdir(self.model_name_or_path):
if os.path.isfile(os.path.join(self.model_name_or_path, WEIGHTS_NAME)):
archive_file = os.path.join(self.model_name_or_path, WEIGHTS_NAME)
state_dict = torch.load(archive_file, map_location="cpu")
new_params_to_init = [p for p in new_params if p in state_dict.keys()]
if new_params_to_init:
# parameters from dict are dependent on recipe
(
_,
missing_keys,
unexpected_keys,
_,
) = self.model._load_state_dict_into_model(
self.model,
state_dict,
self.model_name_or_path,
_fast_init=False,
)
if missing_keys or unexpected_keys:
raise RuntimeError(
"Unexpected or missing keys detected when applying "
f"recipes to models\nMissing keys: {missing_keys}\n"
f"Unexpected keys: {unexpected_keys}\n"
)
if missing_keys or unexpected_keys:
raise RuntimeError(
"Unexpected or missing keys detected when applying "
f"recipes to models\nMissing keys: {missing_keys}\n"
f"Unexpected keys: {unexpected_keys}\n"
)

def create_optimizer(self):
"""
Expand Down

0 comments on commit d1b0622

Please sign in to comment.