Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Checkpoint utils safetensors #116

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,25 @@

# Standard
from collections import defaultdict
from typing import List
from typing import Dict, List, Union
import json
import os
import re
import shutil

# Third Party
from accelerate.logging import get_logger
from accelerate.utils.constants import FSDP_MODEL_NAME, OPTIMIZER_NAME
from huggingface_hub import split_torch_state_dict_into_shards
from safetensors.torch import load_file, save_file
from torch.distributed.checkpoint.default_planner import (
DefaultLoadPlanner,
DefaultSavePlanner,
)
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
from transformers import PretrainedConfig
from transformers.utils import CONFIG_NAME, SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME
import torch
import torch.distributed.checkpoint as dcp

Expand Down Expand Up @@ -213,24 +217,10 @@ def _dict_from_json_file(resolved_config_file):
return os.path.dirname(result)


# function to get the ScatterMoE state dict from its DCP checkpoint
# - if the original pretrained_model_name_or_path is specified, will use the checkpoint as hints
# to map the ScatterMoE checkpoint to that of the original model. This is useful so that we
# can restore the checkpoint to be loaded by the original architecture.
def recover_original_state_dict_from_dcp_checkpoint(
# function to get the state dict from dcp_checkpoint
def get_state_dict_from_dcp_checkpoint(
dcp_checkpoint_dir: str,
pretrained_model_name_or_path: str = None,
):
"""
Parameters:
dcp_checkpoint_dir (str): the DCP to be converted.
pretrained_model_name_or_path (str): Optional, if provided we will
use the hints to remap the
"""

# reference dcp_to_torch_save from torch.distributed.checkpoint.format_utils.py
# - strategy is to use _EmptyStateDictLoadPlanner to populate the state dict, then we remap

# guarded, load some internal functions
# pylint: disable=import-outside-toplevel
# Third Party
Expand All @@ -245,11 +235,46 @@ def recover_original_state_dict_from_dcp_checkpoint(
planner=_EmptyStateDictLoadPlanner(),
no_dist=True,
)
sd = sd[KEY_MODEL]
return [KEY_MODEL]


# function to get state dict from regular checkoint
# - note this assumes sharded safetensors, we do not support
# the non-sharded case for now
def get_state_dict_from_safe_checkpoint(
safe_checkpoint_dir: str,
):
# Load the index
safe_index_file = os.path.join(safe_checkpoint_dir, SAFE_WEIGHTS_INDEX_NAME)
with open(safe_index_file, "r", encoding="utf-8") as f:
index = json.load(f)

sd = {}
shard_files = list(set(index["weight_map"].values()))
for shard_file in shard_files:
for key, v in load_file(os.path.join(safe_checkpoint_dir, shard_file)).items():
sd[key] = v

return sd

# if not provided
if pretrained_model_name_or_path is None:
return sd

# function to get the ScatterMoE state dict from its DCP checkpoint
# - if the original pretrained_model_name_or_path is specified, will use the checkpoint as hints
# to map the ScatterMoE checkpoint to that of the original model. This is useful so that we
# can restore the checkpoint to be loaded by the original architecture.
def recover_original_state_dict_from_checkpoint(
sd: Dict,
pretrained_model_name_or_path: str = None,
):
"""
Parameters:
dcp_checkpoint_dir (str): the DCP to be converted.
pretrained_model_name_or_path (str): Optional, if provided we will
use the hints to remap the
"""

# reference dcp_to_torch_save from torch.distributed.checkpoint.format_utils.py
# - strategy is to use _EmptyStateDictLoadPlanner to populate the state dict, then we remap

# now do the remap
loc = get_resolved_checkpoint_location(pretrained_model_name_or_path)
Expand Down Expand Up @@ -398,6 +423,37 @@ def _infer_prefixes_and_module_names(
return sd


def save_sharded_safetensors(
input_state_dict: Dict,
save_directory: str,
metadata: Dict,
max_shard_size: Union[int, str] = "5GB",
):
filename_pattern = SAFE_WEIGHTS_NAME.replace(".bin", "{suffix}.bin").replace(
".safetensors", "{suffix}.safetensors"
)
state_dict_split = split_torch_state_dict_into_shards(
input_state_dict,
filename_pattern=filename_pattern,
max_shard_size=max_shard_size,
)
index = {
"metadata": state_dict_split.metadata,
"weight_map": state_dict_split.tensor_to_filename,
}
# Save the index
with open(
os.path.join(save_directory, SAFE_WEIGHTS_INDEX_NAME), "w", encoding="utf-8"
) as f:
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
f.write(content)

filename_to_tensors = state_dict_split.filename_to_tensors.items()
for shard_file, tensors in filename_to_tensors:
shard = {tensor: input_state_dict[tensor].contiguous() for tensor in tensors}
save_file(shard, os.path.join(save_directory, shard_file), metadata=metadata)


# --------------------------- SCRIPT -------------------------


Expand All @@ -417,8 +473,8 @@ def _infer_prefixes_and_module_names(
)

parser.add_argument(
"dcp_checkpoint_dir",
help="Path to the distributed checkpoint.",
"checkpoint_dir",
help="Path to the checkpoint.",
)

parser.add_argument(
Expand All @@ -432,37 +488,62 @@ def _infer_prefixes_and_module_names(
"the original pretrained model checkpoint (from which this "
"checkpoint is obtained)."
),
default=None,
)

args = parser.parse_args()

# search for the checkpint. By the code above, it must
# search for an FSDP checkpoint. If it is an FSDP checkpoint, it must
# start with FSDP_MODEL_NAME
if args.dcp_checkpoint_dir.startswith(FSDP_MODEL_NAME):
checkpoint_dir = args.dcp_checkpoint_dir
if args.checkpoint_dir.startswith(FSDP_MODEL_NAME):
checkpoint_dir = args.checkpoint_dir
loader = get_state_dict_from_dcp_checkpoint
else:
checkpoint_dir = [
x
for x in os.listdir(args.dcp_checkpoint_dir)
if os.path.isdir(os.path.join(args.dcp_checkpoint_dir, x))
for x in os.listdir(args.checkpoint_dir)
if os.path.isdir(os.path.join(args.checkpoint_dir, x))
and x.startswith(FSDP_MODEL_NAME)
]
if len(checkpoint_dir) > 1:
if len(checkpoint_dir) == 1:
checkpoint_dir = os.path.join(args.checkpoint_dir, checkpoint_dir[0])
loader = get_state_dict_from_dcp_checkpoint
elif len(checkpoint_dir) > 1:
raise ValueError(
f"Found > 1 dirs in dcp checkpoint dir {args.dcp_checkpoint_dir} "
f"Found > 1 dirs in dcp checkpoint dir {args.checkpoint_dir} "
f"that starts with {FSDP_MODEL_NAME}. Please spectify the exact dir."
)
if len(checkpoint_dir) == 0:
raise ValueError(
f"Found no dirs in dcp checkpoint dir {args.dcp_checkpoint_dir} "
f"that starts with {FSDP_MODEL_NAME}. Nothing to convert"
)
checkpoint_dir = os.path.join(args.dcp_checkpoint_dir, checkpoint_dir[0])

# get the converted statedict
state_dict = recover_original_state_dict_from_dcp_checkpoint(
checkpoint_dir, args.pretrained_model_name_or_path
else:
# then take it as a safetensors checkpoint
# - do not support .bin checkpoints
checkpoint_dir = args.checkpoint_dir
loader = get_state_dict_from_safe_checkpoint

# - pretrained model name
_name_or_path = args.pretrained_model_name_or_path

# assume output directory exists, we do not create it
# - copy the config file if exists
config_file = os.path.join(checkpoint_dir, CONFIG_NAME)
target_config_file = os.path.join(args.output_dir, CONFIG_NAME)
if os.path.exists(config_file):
shutil.copyfile(config_file, target_config_file)

# try to populate pretrained_model_name_or_path from the config path
# if it was None
if not _name_or_path:
with open(target_config_file, "r", encoding="utf-8") as file:
_name_or_path = json.load(file).get("_name_or_path")

# get the state_dict
state_dict = loader(checkpoint_dir)

# recover the original state dict
state_dict = recover_original_state_dict_from_checkpoint(state_dict, _name_or_path)

# save it as a safetensors file
save_sharded_safetensors(
{k: v.contiguous() for k, v in state_dict.items()},
args.output_dir,
metadata={"format": "pt"},
)

# save it
torch.save(state_dict, args.output_dir)
Loading