diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py index c5198708..c43ba1c2 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py @@ -23,6 +23,7 @@ # Third Party from accelerate.logging import get_logger from accelerate.utils.constants import FSDP_MODEL_NAME, OPTIMIZER_NAME +from huggingface_hub import split_torch_state_dict_into_shards from safetensors.torch import load_file, save_file from torch.distributed.checkpoint.default_planner import ( DefaultLoadPlanner, @@ -32,7 +33,6 @@ from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType from transformers import PretrainedConfig from transformers.utils import CONFIG_NAME, SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME -from huggingface_hub import split_torch_state_dict_into_shards import torch import torch.distributed.checkpoint as dcp @@ -424,30 +424,31 @@ def _infer_prefixes_and_module_names( def save_sharded_safetensors( - state_dict: Dict, - save_directory: str, + input_state_dict: Dict, + save_directory: str, metadata: Dict, max_shard_size: Union[int, str] = "5GB", ): - filename_pattern = SAFE_WEIGHTS_NAME.replace(".bin", "{suffix}.bin").replace(".safetensors", "{suffix}.safetensors") + filename_pattern = SAFE_WEIGHTS_NAME.replace(".bin", "{suffix}.bin").replace( + ".safetensors", "{suffix}.safetensors" + ) state_dict_split = split_torch_state_dict_into_shards( - state_dict, filename_pattern=filename_pattern, max_shard_size=max_shard_size + input_state_dict, filename_pattern=filename_pattern, max_shard_size=max_shard_size ) index = { "metadata": state_dict_split.metadata, "weight_map": state_dict_split.tensor_to_filename, } - # Save the index + # Save the index with open( - os.path.join(save_directory, SAFE_WEIGHTS_INDEX_NAME), - "w", encoding="utf-8" + os.path.join(save_directory, SAFE_WEIGHTS_INDEX_NAME), "w", encoding="utf-8" ) as f: content = json.dumps(index, indent=2, sort_keys=True) + "\n" f.write(content) filename_to_tensors = state_dict_split.filename_to_tensors.items() for shard_file, tensors in filename_to_tensors: - shard = {tensor: state_dict[tensor].contiguous() for tensor in tensors} + shard = {tensor: input_state_dict[tensor].contiguous() for tensor in tensors} save_file(shard, os.path.join(save_directory, shard_file), metadata=metadata)