From 090fcbe6aac7653ff9f545cee70f01e048908155 Mon Sep 17 00:00:00 2001 From: Will Johnson Date: Mon, 9 Dec 2024 20:19:10 -0500 Subject: [PATCH] fix: sharded safetensors save Signed-off-by: Will Johnson --- .../utils/checkpoint_utils.py | 35 ++++++++++++++----- 1 file changed, 26 insertions(+), 9 deletions(-) diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py index 40355c3e..c5198708 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py @@ -14,7 +14,7 @@ # Standard from collections import defaultdict -from typing import Dict, List +from typing import Dict, List, Union import json import os import re @@ -32,6 +32,7 @@ from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType from transformers import PretrainedConfig from transformers.utils import CONFIG_NAME, SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME +from huggingface_hub import split_torch_state_dict_into_shards import torch import torch.distributed.checkpoint as dcp @@ -422,16 +423,32 @@ def _infer_prefixes_and_module_names( return sd -def save_single_safetensor( - sd: Dict, - save_directory: str, +def save_sharded_safetensors( + state_dict: Dict, + save_directory: str, metadata: Dict, + max_shard_size: Union[int, str] = "5GB", ): - save_file( - sd, - os.path.join(save_directory, SAFE_WEIGHTS_NAME), - metadata, + filename_pattern = SAFE_WEIGHTS_NAME.replace(".bin", "{suffix}.bin").replace(".safetensors", "{suffix}.safetensors") + state_dict_split = split_torch_state_dict_into_shards( + state_dict, filename_pattern=filename_pattern, max_shard_size=max_shard_size ) + index = { + "metadata": state_dict_split.metadata, + "weight_map": state_dict_split.tensor_to_filename, + } + # Save the index + with open( + os.path.join(save_directory, SAFE_WEIGHTS_INDEX_NAME), + "w", encoding="utf-8" + ) as f: + content = json.dumps(index, indent=2, sort_keys=True) + "\n" + f.write(content) + + filename_to_tensors = state_dict_split.filename_to_tensors.items() + for shard_file, tensors in filename_to_tensors: + shard = {tensor: state_dict[tensor].contiguous() for tensor in tensors} + save_file(shard, os.path.join(save_directory, shard_file), metadata=metadata) # --------------------------- SCRIPT ------------------------- @@ -522,7 +539,7 @@ def save_single_safetensor( state_dict = recover_original_state_dict_from_checkpoint(state_dict, _name_or_path) # save it as a safetensors file - save_single_safetensor( + save_sharded_safetensors( {k: v.contiguous() for k, v in state_dict.items()}, args.output_dir, metadata={"format": "pt"},