Skip to content

Commit

Permalink
fix: sharded safetensors save
Browse files Browse the repository at this point in the history
Signed-off-by: Will Johnson <[email protected]>
  • Loading branch information
willmj authored and fabianlim committed Dec 10, 2024
1 parent c84d6f9 commit 090fcbe
Showing 1 changed file with 26 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

# Standard
from collections import defaultdict
from typing import Dict, List
from typing import Dict, List, Union
import json
import os
import re
Expand All @@ -32,6 +32,7 @@
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
from transformers import PretrainedConfig
from transformers.utils import CONFIG_NAME, SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME
from huggingface_hub import split_torch_state_dict_into_shards
import torch
import torch.distributed.checkpoint as dcp

Expand Down Expand Up @@ -422,16 +423,32 @@ def _infer_prefixes_and_module_names(
return sd


def save_single_safetensor(
sd: Dict,
save_directory: str,
def save_sharded_safetensors(
state_dict: Dict,
save_directory: str,
metadata: Dict,
max_shard_size: Union[int, str] = "5GB",
):
save_file(
sd,
os.path.join(save_directory, SAFE_WEIGHTS_NAME),
metadata,
filename_pattern = SAFE_WEIGHTS_NAME.replace(".bin", "{suffix}.bin").replace(".safetensors", "{suffix}.safetensors")
state_dict_split = split_torch_state_dict_into_shards(
state_dict, filename_pattern=filename_pattern, max_shard_size=max_shard_size
)
index = {
"metadata": state_dict_split.metadata,
"weight_map": state_dict_split.tensor_to_filename,
}
# Save the index
with open(
os.path.join(save_directory, SAFE_WEIGHTS_INDEX_NAME),
"w", encoding="utf-8"
) as f:
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
f.write(content)

filename_to_tensors = state_dict_split.filename_to_tensors.items()
for shard_file, tensors in filename_to_tensors:
shard = {tensor: state_dict[tensor].contiguous() for tensor in tensors}
save_file(shard, os.path.join(save_directory, shard_file), metadata=metadata)


# --------------------------- SCRIPT -------------------------
Expand Down Expand Up @@ -522,7 +539,7 @@ def save_single_safetensor(
state_dict = recover_original_state_dict_from_checkpoint(state_dict, _name_or_path)

# save it as a safetensors file
save_single_safetensor(
save_sharded_safetensors(
{k: v.contiguous() for k, v in state_dict.items()},
args.output_dir,
metadata={"format": "pt"},
Expand Down

0 comments on commit 090fcbe

Please sign in to comment.