diff --git a/open_lm/utils/transformers/hf_model.py b/open_lm/utils/transformers/hf_model.py index 83353a19..c1766925 100644 --- a/open_lm/utils/transformers/hf_model.py +++ b/open_lm/utils/transformers/hf_model.py @@ -172,11 +172,14 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P checkpoint_path = kwargs["config"].checkpoint_file checkpoint = torch.load(checkpoint_path) - state_dict = checkpoint["state_dict"] - state_dict = {x.replace("module.", ""): y for x, y in state_dict.items()} - state_dict = {f"model.{x}": y for x, y in state_dict.items()} - - return super().from_pretrained(None, state_dict=state_dict, **kwargs) + sd = checkpoint["state_dict"] + if next(iter(sd.items()))[0].startswith("module"): + sd = {k[len("module.") :]: v for k, v in sd.items()} + if "_orig_mod" in next(iter(sd.items()))[0]: + sd = {k.replace("_orig_mod.", ""): v for k, v in sd.items()} + sd = {f"model.{x}": y for x, y in sd.items()} + + return super().from_pretrained(None, state_dict=sd, **kwargs) else: return super().from_pretrained(pretrained_model_name_or_path, **kwargs)