Skip to content

Commit

Permalink
add base_model_name_or_path
Browse files Browse the repository at this point in the history
  • Loading branch information
Felipe Mello committed Nov 26, 2024
1 parent 1c907df commit 623daf7
Show file tree
Hide file tree
Showing 6 changed files with 78 additions and 66 deletions.
51 changes: 22 additions & 29 deletions torchtune/_cli/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@
# LICENSE file in the root directory of this source tree.

import argparse

import json
import os
import textwrap
import traceback

from http import HTTPStatus
from pathlib import Path
from typing import Literal, Union
from warnings import warn

from huggingface_hub import snapshot_download
Expand All @@ -21,6 +22,7 @@
from kagglehub.auth import set_kaggle_credentials
from kagglehub.exceptions import KaggleApiHTTPError
from kagglehub.handle import parse_model_handle
from torchtune import training
from torchtune._cli.subcommand import Subcommand


Expand Down Expand Up @@ -85,18 +87,6 @@ def _add_arguments(self) -> None:
default=None,
help="Directory in which to save the model. Defaults to `/tmp/<model_name>`.",
)
self._parser.add_argument(
"--output-dir-use-symlinks",
type=str,
required=False,
default="auto",
help=(
"To be used with `output-dir`. If set to 'auto', the cache directory will be used and the file will be"
" either duplicated or symlinked to the local directory depending on its size. It set to `True`, a"
" symlink will be created, no matter the file size. If set to `False`, the file will either be"
" duplicated from cache (if already exists) or downloaded from the Hub and not cached."
),
)
self._parser.add_argument(
"--hf-token",
type=str,
Expand Down Expand Up @@ -150,27 +140,11 @@ def _download_from_huggingface(self, args: argparse.Namespace) -> None:
model_name = args.repo_id.split("/")[-1]
output_dir = Path("/tmp") / model_name

# Raise if local_dir_use_symlinks is invalid
output_dir_use_symlinks: Union[Literal["auto"], bool]
use_symlinks_lowercase = args.output_dir_use_symlinks.lower()
if use_symlinks_lowercase == "true":
output_dir_use_symlinks = True
elif use_symlinks_lowercase == "false":
output_dir_use_symlinks = False
elif use_symlinks_lowercase == "auto":
output_dir_use_symlinks = "auto"
else:
self._parser.error(
f"'{args.output_dir_use_symlinks}' is not a valid value for `--output-dir-use-symlinks`. It must be either"
" 'auto', 'True' or 'False'."
)

print(f"Ignoring files matching the following patterns: {args.ignore_patterns}")
try:
true_output_dir = snapshot_download(
args.repo_id,
local_dir=output_dir,
local_dir_use_symlinks=output_dir_use_symlinks,
ignore_patterns=args.ignore_patterns,
token=args.hf_token,
)
Expand All @@ -196,6 +170,15 @@ def _download_from_huggingface(self, args: argparse.Namespace) -> None:
msg = f"Failed to download {args.repo_id} with error: '{e}' and traceback: {tb}"
self._parser.error(msg)

# save the repo_id. This is necessary because the download step is a separate command
# from the rest of the CLI. When saving a model adapter, we have to add the repo_id
# to the adapter config.
file_path = os.path.join(true_output_dir, training.REPO_ID_FNAME).with_suffix(
".json"
)
with open(file_path, "w") as json_file:
json.dump({"repo_id": args.repo_id}, json_file, indent=4)

print(
"Successfully downloaded model repo and wrote to the following locations:",
*list(Path(true_output_dir).iterdir()),
Expand Down Expand Up @@ -224,6 +207,16 @@ def _download_from_kaggle(self, args: argparse.Namespace) -> None:

try:
output_dir = model_download(model_handle)

# save the repo_id. This is necessary because the download step is a separate command
# from the rest of the CLI. When saving a model adapter, we have to add the repo_id
# to the adapter config.
file_path = os.path.join(output_dir, training.REPO_ID_FNAME).with_suffix(
".json"
)
with open(file_path, "w") as json_file:
json.dump({"repo_id": args.repo_id}, json_file, indent=4)

print(
"Successfully downloaded model repo and wrote to the following locations:",
*list(Path(output_dir).iterdir()),
Expand Down
5 changes: 4 additions & 1 deletion torchtune/models/convert_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

import torch


# state dict key mappings from Meta's format to torchtune's format
_FROM_META = {
"tok_embeddings.weight": "tok_embeddings.weight",
Expand Down Expand Up @@ -231,6 +230,7 @@ def _permute(t, n_heads):

def tune_to_peft_adapter_config(
adapter_config: Dict[str, Any],
base_model_name_or_path: Optional[str] = None,
):
if not all([x in adapter_config.keys() for x in _PEFT_CONFIG_EXPECTED_KEYS]):
raise ValueError(
Expand All @@ -244,6 +244,9 @@ def tune_to_peft_adapter_config(
map(_TO_PEFT_TARGET_MODULES.get, adapter_config["target_modules"])
)

if base_model_name_or_path:
adapter_config["base_model_name_or_path"] = base_model_name_or_path

return adapter_config


Expand Down
22 changes: 12 additions & 10 deletions torchtune/training/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@
from torchtune.training.activations import apply_selective_activation_checkpointing
from torchtune.training.checkpointing import (
ADAPTER_CONFIG,
ADAPTER_CONFIG_FILENAME,
ADAPTER_CONFIG_FNAME,
ADAPTER_KEY,
ADAPTER_MODEL_FILENAME,
ADAPTER_MODEL_FNAME,
Checkpointer,
EPOCHS_KEY,
FormattedCheckpointFiles,
Expand All @@ -48,12 +48,13 @@
MODEL_KEY,
ModelType,
OPT_KEY,
REPO_ID_FNAME,
RNG_KEY,
SAFETENSOR_INDEX_FILENAME,
SAFETENSOR_INDEX_FNAME,
SEED_KEY,
SHARD_FILENAME,
SHARD_FNAME,
STEPS_KEY,
TORCHTUNE_INDEX_FILENAME,
TORCHTUNE_INDEX_FNAME,
TOTAL_EPOCHS_KEY,
update_state_dict_for_classifier,
)
Expand Down Expand Up @@ -89,11 +90,12 @@
"Checkpointer",
"update_state_dict_for_classifier",
"ADAPTER_CONFIG",
"ADAPTER_CONFIG_FILENAME",
"ADAPTER_MODEL_FILENAME",
"SHARD_FILENAME",
"SAFETENSOR_INDEX_FILENAME",
"TORCHTUNE_INDEX_FILENAME",
"ADAPTER_CONFIG_FNAME",
"ADAPTER_MODEL_FNAME",
"SHARD_FNAME",
"SAFETENSOR_INDEX_FNAME",
"TORCHTUNE_INDEX_FNAME",
"REPO_ID_FNAME",
"ADAPTER_KEY",
"EPOCHS_KEY",
"MAX_STEPS_KEY",
Expand Down
22 changes: 12 additions & 10 deletions torchtune/training/checkpointing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,22 @@
)
from torchtune.training.checkpointing._utils import (
ADAPTER_CONFIG,
ADAPTER_CONFIG_FILENAME,
ADAPTER_CONFIG_FNAME,
ADAPTER_KEY,
ADAPTER_MODEL_FILENAME,
ADAPTER_MODEL_FNAME,
EPOCHS_KEY,
FormattedCheckpointFiles,
MAX_STEPS_KEY,
MODEL_KEY,
ModelType,
OPT_KEY,
REPO_ID_FNAME,
RNG_KEY,
SAFETENSOR_INDEX_FILENAME,
SAFETENSOR_INDEX_FNAME,
SEED_KEY,
SHARD_FILENAME,
SHARD_FNAME,
STEPS_KEY,
TORCHTUNE_INDEX_FILENAME,
TORCHTUNE_INDEX_FNAME,
TOTAL_EPOCHS_KEY,
update_state_dict_for_classifier,
)
Expand All @@ -45,11 +46,12 @@
"Checkpointer",
"update_state_dict_for_classifier",
"ADAPTER_CONFIG",
"ADAPTER_CONFIG_FILENAME",
"ADAPTER_MODEL_FILENAME",
"SHARD_FILENAME",
"SAFETENSOR_INDEX_FILENAME",
"TORCHTUNE_INDEX_FILENAME",
"ADAPTER_CONFIG_FNAME",
"ADAPTER_MODEL_FNAME",
"SHARD_FNAME",
"SAFETENSOR_INDEX_FNAME",
"TORCHTUNE_INDEX_FNAME",
"REPO_ID_FNAME",
"ADAPTER_KEY",
"EPOCHS_KEY",
"MAX_STEPS_KEY",
Expand Down
33 changes: 22 additions & 11 deletions torchtune/training/checkpointing/_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,17 @@ def __init__(
) -> None:
self._checkpoint_dir = Path(checkpoint_dir)

# repo_id is necessary because, when saving a model adapter, we have to add it
# to the adapter config. This json file is produced and saved in the download step.
repo_id_path = Path.joinpath(
self._checkpoint_dir, training.REPO_ID_FNAME
).with_suffix(".json")
self.repo_id = None
if repo_id_path.exists():
with open(repo_id_path, "r") as json_file:
data = json.load(json_file)
self.repo_id = data.get("repo_id")

# e.g.
# checkpoint_files:
# filename_format: model-{}-of-{}.safetensors
Expand Down Expand Up @@ -645,7 +656,7 @@ def save_checkpoint(
num_shards = len(split_state_dicts)
map_original_name_to_new_name = {}
for cpt_idx, model_state_dict in split_state_dicts.items():
shard_name = training.SHARD_FILENAME.format(
shard_name = training.SHARD_FNAME.format(
cpt_idx=cpt_idx, num_shards=num_shards
)
map_original_name_to_new_name[cpt_idx] = shard_name
Expand All @@ -672,13 +683,13 @@ def save_checkpoint(
k: map_original_name_to_new_name[int(cpt_idx)] + ".safetensors"
for k, cpt_idx in self._weight_map.items()
}
index_file_name = training.SAFETENSOR_INDEX_FILENAME
index_file_name = training.SAFETENSOR_INDEX_FNAME
else:
weight_map = {
k: map_original_name_to_new_name[int(cpt_idx)] + ".bin"
for k, cpt_idx in self._weight_map.items()
}
index_file_name = training.TORCHTUNE_INDEX_FILENAME
index_file_name = training.TORCHTUNE_INDEX_FNAME

index_path = Path.joinpath(
self._output_dir,
Expand Down Expand Up @@ -727,7 +738,7 @@ def save_checkpoint(
)
# TODO: add "if self._safe_serialization:"
peft_output_path = Path.joinpath(
self._output_dir, f"epoch_{epoch}", training.ADAPTER_MODEL_FILENAME
self._output_dir, f"epoch_{epoch}", training.ADAPTER_MODEL_FNAME
).with_suffix(".bin")
torch.save(state_dict[training.ADAPTER_KEY], peft_output_path)
logger.info(
Expand All @@ -753,10 +764,12 @@ def save_checkpoint(
state_dict[
training.ADAPTER_CONFIG
] = convert_weights.tune_to_peft_adapter_config(
state_dict[training.ADAPTER_CONFIG]
state_dict[training.ADAPTER_CONFIG],
base_model_name_or_path=self.repo_id,
)

output_path = Path.joinpath(
self._output_dir, f"epoch_{epoch}", training.ADAPTER_CONFIG_FILENAME
self._output_dir, f"epoch_{epoch}", training.ADAPTER_CONFIG_FNAME
).with_suffix(".json")
with open(output_path, "w") as f:
json.dump(state_dict[training.ADAPTER_CONFIG], f)
Expand Down Expand Up @@ -945,11 +958,9 @@ def save_checkpoint(
model_state_dict
)

# TODO: We should consider adding adapter/model config
# like we do for HF.

# TODO: We should consider adding adapter/model config, like we do for HF.
# Output file is always a .pt
model_filename = training.SHARD_FILENAME.format(cpt_idx=1, num_shards=1)
model_filename = training.SHARD_FNAME.format(cpt_idx=1, num_shards=1)
checkpoint_file = Path.joinpath(
self._output_dir, f"epoch_{epoch}", model_filename
).with_suffix(".pt")
Expand All @@ -962,7 +973,7 @@ def save_checkpoint(

if training.ADAPTER_KEY in state_dict:
output_path = Path.joinpath(
self._output_dir, f"epoch_{epoch}", training.ADAPTER_MODEL_FILENAME
self._output_dir, f"epoch_{epoch}", training.ADAPTER_MODEL_FNAME
).with_suffix(".pt")
torch.save(state_dict[training.ADAPTER_KEY], output_path)
logger.info(
Expand Down
11 changes: 6 additions & 5 deletions torchtune/training/checkpointing/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,12 @@

# default used by huggingface when looking for saved adapters
# https://github.com/huggingface/peft/blob/d13d7a401ccf4808aaaf76480fea09a4cf4ac1f5/src/peft/config.py#L259C21-L259C32
ADAPTER_CONFIG_FILENAME = "adapter_config"
ADAPTER_MODEL_FILENAME = "adapter_model"
SHARD_FILENAME = "model-{int(cpt_idx):05d}-of-{int(num_shards):05d}"
SAFETENSOR_INDEX_FILENAME = "model.safetensors.index.json"
TORCHTUNE_INDEX_FILENAME = "pytorch_model.bin.index.json"
ADAPTER_CONFIG_FNAME = "adapter_config"
ADAPTER_MODEL_FNAME = "adapter_model"
SHARD_FNAME = "model-{int(cpt_idx):05d}-of-{int(num_shards):05d}"
SAFETENSOR_INDEX_FNAME = "model.safetensors.index"
TORCHTUNE_INDEX_FNAME = "pytorch_model.bin.index"
REPO_ID_FNAME = "repo_id"

# key used for adapter weights such as LoRA weights
ADAPTER_KEY = "adapter"
Expand Down

0 comments on commit 623daf7

Please sign in to comment.