Skip to content

Commit

Permalink
fix the FSDP name stripping
Browse files Browse the repository at this point in the history
  • Loading branch information
[email protected] committed May 28, 2024
1 parent d6709dd commit eee2526
Showing 1 changed file with 10 additions and 2 deletions.
12 changes: 10 additions & 2 deletions src/sparseml/utils/fsdp/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
"fix_fsdp_module_name",
]

FSDP_WRAPPER_NAME = "_fsdp_wrapped_module."
FSDP_WRAPPER_NAME = "_fsdp_wrapped_module"


def summon_full_params_context(model, offload_to_cpu: bool = False):
Expand Down Expand Up @@ -66,4 +66,12 @@ def fix_fsdp_module_name(name: str) -> str:
:param name: name to strip
:return: stripped name
"""
return name.replace(FSDP_WRAPPER_NAME, "")
if FSDP_WRAPPER_NAME + "." in name:
# accounting for the scenario, where the FSDP_WRAPPER_NAME
# is not the last part of the name
return name.replace(FSDP_WRAPPER_NAME + ".", "")
elif "." + FSDP_WRAPPER_NAME in name:
# accounting for the scenario, where the FSDP_WRAPPER_NAME
# is the last part of the name
return name.replace("." + FSDP_WRAPPER_NAME, "")
return name

0 comments on commit eee2526

Please sign in to comment.