Skip to content

Commit

Permalink
fix: lint
Browse files Browse the repository at this point in the history
Signed-off-by: Will Johnson <[email protected]>
  • Loading branch information
willmj committed Dec 10, 2024
1 parent 090fcbe commit 78c702d
Showing 1 changed file with 10 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
# Third Party
from accelerate.logging import get_logger
from accelerate.utils.constants import FSDP_MODEL_NAME, OPTIMIZER_NAME
from huggingface_hub import split_torch_state_dict_into_shards
from safetensors.torch import load_file, save_file
from torch.distributed.checkpoint.default_planner import (
DefaultLoadPlanner,
Expand All @@ -32,7 +33,6 @@
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
from transformers import PretrainedConfig
from transformers.utils import CONFIG_NAME, SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME
from huggingface_hub import split_torch_state_dict_into_shards
import torch
import torch.distributed.checkpoint as dcp

Expand Down Expand Up @@ -424,30 +424,31 @@ def _infer_prefixes_and_module_names(


def save_sharded_safetensors(
state_dict: Dict,
save_directory: str,
input_state_dict: Dict,
save_directory: str,
metadata: Dict,
max_shard_size: Union[int, str] = "5GB",
):
filename_pattern = SAFE_WEIGHTS_NAME.replace(".bin", "{suffix}.bin").replace(".safetensors", "{suffix}.safetensors")
filename_pattern = SAFE_WEIGHTS_NAME.replace(".bin", "{suffix}.bin").replace(
".safetensors", "{suffix}.safetensors"
)
state_dict_split = split_torch_state_dict_into_shards(
state_dict, filename_pattern=filename_pattern, max_shard_size=max_shard_size
input_state_dict, filename_pattern=filename_pattern, max_shard_size=max_shard_size
)
index = {
"metadata": state_dict_split.metadata,
"weight_map": state_dict_split.tensor_to_filename,
}
# Save the index
# Save the index
with open(
os.path.join(save_directory, SAFE_WEIGHTS_INDEX_NAME),
"w", encoding="utf-8"
os.path.join(save_directory, SAFE_WEIGHTS_INDEX_NAME), "w", encoding="utf-8"
) as f:
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
f.write(content)

filename_to_tensors = state_dict_split.filename_to_tensors.items()
for shard_file, tensors in filename_to_tensors:
shard = {tensor: state_dict[tensor].contiguous() for tensor in tensors}
shard = {tensor: input_state_dict[tensor].contiguous() for tensor in tensors}
save_file(shard, os.path.join(save_directory, shard_file), metadata=metadata)


Expand Down

0 comments on commit 78c702d

Please sign in to comment.