From 623daf71f42fd5bf39d2310b082a618125225abf Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Tue, 26 Nov 2024 12:46:34 -0800 Subject: [PATCH] add base_model_name_or_path --- torchtune/_cli/download.py | 51 ++++++++----------- torchtune/models/convert_weights.py | 5 +- torchtune/training/__init__.py | 22 ++++---- torchtune/training/checkpointing/__init__.py | 22 ++++---- .../training/checkpointing/_checkpointer.py | 33 ++++++++---- torchtune/training/checkpointing/_utils.py | 11 ++-- 6 files changed, 78 insertions(+), 66 deletions(-) diff --git a/torchtune/_cli/download.py b/torchtune/_cli/download.py index 55f5ef4ab1..8a2d4a3334 100644 --- a/torchtune/_cli/download.py +++ b/torchtune/_cli/download.py @@ -5,13 +5,14 @@ # LICENSE file in the root directory of this source tree. import argparse + +import json import os import textwrap import traceback from http import HTTPStatus from pathlib import Path -from typing import Literal, Union from warnings import warn from huggingface_hub import snapshot_download @@ -21,6 +22,7 @@ from kagglehub.auth import set_kaggle_credentials from kagglehub.exceptions import KaggleApiHTTPError from kagglehub.handle import parse_model_handle +from torchtune import training from torchtune._cli.subcommand import Subcommand @@ -85,18 +87,6 @@ def _add_arguments(self) -> None: default=None, help="Directory in which to save the model. Defaults to `/tmp/`.", ) - self._parser.add_argument( - "--output-dir-use-symlinks", - type=str, - required=False, - default="auto", - help=( - "To be used with `output-dir`. If set to 'auto', the cache directory will be used and the file will be" - " either duplicated or symlinked to the local directory depending on its size. It set to `True`, a" - " symlink will be created, no matter the file size. If set to `False`, the file will either be" - " duplicated from cache (if already exists) or downloaded from the Hub and not cached." - ), - ) self._parser.add_argument( "--hf-token", type=str, @@ -150,27 +140,11 @@ def _download_from_huggingface(self, args: argparse.Namespace) -> None: model_name = args.repo_id.split("/")[-1] output_dir = Path("/tmp") / model_name - # Raise if local_dir_use_symlinks is invalid - output_dir_use_symlinks: Union[Literal["auto"], bool] - use_symlinks_lowercase = args.output_dir_use_symlinks.lower() - if use_symlinks_lowercase == "true": - output_dir_use_symlinks = True - elif use_symlinks_lowercase == "false": - output_dir_use_symlinks = False - elif use_symlinks_lowercase == "auto": - output_dir_use_symlinks = "auto" - else: - self._parser.error( - f"'{args.output_dir_use_symlinks}' is not a valid value for `--output-dir-use-symlinks`. It must be either" - " 'auto', 'True' or 'False'." - ) - print(f"Ignoring files matching the following patterns: {args.ignore_patterns}") try: true_output_dir = snapshot_download( args.repo_id, local_dir=output_dir, - local_dir_use_symlinks=output_dir_use_symlinks, ignore_patterns=args.ignore_patterns, token=args.hf_token, ) @@ -196,6 +170,15 @@ def _download_from_huggingface(self, args: argparse.Namespace) -> None: msg = f"Failed to download {args.repo_id} with error: '{e}' and traceback: {tb}" self._parser.error(msg) + # save the repo_id. This is necessary because the download step is a separate command + # from the rest of the CLI. When saving a model adapter, we have to add the repo_id + # to the adapter config. + file_path = os.path.join(true_output_dir, training.REPO_ID_FNAME).with_suffix( + ".json" + ) + with open(file_path, "w") as json_file: + json.dump({"repo_id": args.repo_id}, json_file, indent=4) + print( "Successfully downloaded model repo and wrote to the following locations:", *list(Path(true_output_dir).iterdir()), @@ -224,6 +207,16 @@ def _download_from_kaggle(self, args: argparse.Namespace) -> None: try: output_dir = model_download(model_handle) + + # save the repo_id. This is necessary because the download step is a separate command + # from the rest of the CLI. When saving a model adapter, we have to add the repo_id + # to the adapter config. + file_path = os.path.join(output_dir, training.REPO_ID_FNAME).with_suffix( + ".json" + ) + with open(file_path, "w") as json_file: + json.dump({"repo_id": args.repo_id}, json_file, indent=4) + print( "Successfully downloaded model repo and wrote to the following locations:", *list(Path(output_dir).iterdir()), diff --git a/torchtune/models/convert_weights.py b/torchtune/models/convert_weights.py index b96006d33a..b86b899647 100644 --- a/torchtune/models/convert_weights.py +++ b/torchtune/models/convert_weights.py @@ -10,7 +10,6 @@ import torch - # state dict key mappings from Meta's format to torchtune's format _FROM_META = { "tok_embeddings.weight": "tok_embeddings.weight", @@ -231,6 +230,7 @@ def _permute(t, n_heads): def tune_to_peft_adapter_config( adapter_config: Dict[str, Any], + base_model_name_or_path: Optional[str] = None, ): if not all([x in adapter_config.keys() for x in _PEFT_CONFIG_EXPECTED_KEYS]): raise ValueError( @@ -244,6 +244,9 @@ def tune_to_peft_adapter_config( map(_TO_PEFT_TARGET_MODULES.get, adapter_config["target_modules"]) ) + if base_model_name_or_path: + adapter_config["base_model_name_or_path"] = base_model_name_or_path + return adapter_config diff --git a/torchtune/training/__init__.py b/torchtune/training/__init__.py index 504350941e..5e9394f5d4 100644 --- a/torchtune/training/__init__.py +++ b/torchtune/training/__init__.py @@ -35,9 +35,9 @@ from torchtune.training.activations import apply_selective_activation_checkpointing from torchtune.training.checkpointing import ( ADAPTER_CONFIG, - ADAPTER_CONFIG_FILENAME, + ADAPTER_CONFIG_FNAME, ADAPTER_KEY, - ADAPTER_MODEL_FILENAME, + ADAPTER_MODEL_FNAME, Checkpointer, EPOCHS_KEY, FormattedCheckpointFiles, @@ -48,12 +48,13 @@ MODEL_KEY, ModelType, OPT_KEY, + REPO_ID_FNAME, RNG_KEY, - SAFETENSOR_INDEX_FILENAME, + SAFETENSOR_INDEX_FNAME, SEED_KEY, - SHARD_FILENAME, + SHARD_FNAME, STEPS_KEY, - TORCHTUNE_INDEX_FILENAME, + TORCHTUNE_INDEX_FNAME, TOTAL_EPOCHS_KEY, update_state_dict_for_classifier, ) @@ -89,11 +90,12 @@ "Checkpointer", "update_state_dict_for_classifier", "ADAPTER_CONFIG", - "ADAPTER_CONFIG_FILENAME", - "ADAPTER_MODEL_FILENAME", - "SHARD_FILENAME", - "SAFETENSOR_INDEX_FILENAME", - "TORCHTUNE_INDEX_FILENAME", + "ADAPTER_CONFIG_FNAME", + "ADAPTER_MODEL_FNAME", + "SHARD_FNAME", + "SAFETENSOR_INDEX_FNAME", + "TORCHTUNE_INDEX_FNAME", + "REPO_ID_FNAME", "ADAPTER_KEY", "EPOCHS_KEY", "MAX_STEPS_KEY", diff --git a/torchtune/training/checkpointing/__init__.py b/torchtune/training/checkpointing/__init__.py index 8a4f2d137f..c8bc111068 100644 --- a/torchtune/training/checkpointing/__init__.py +++ b/torchtune/training/checkpointing/__init__.py @@ -12,21 +12,22 @@ ) from torchtune.training.checkpointing._utils import ( ADAPTER_CONFIG, - ADAPTER_CONFIG_FILENAME, + ADAPTER_CONFIG_FNAME, ADAPTER_KEY, - ADAPTER_MODEL_FILENAME, + ADAPTER_MODEL_FNAME, EPOCHS_KEY, FormattedCheckpointFiles, MAX_STEPS_KEY, MODEL_KEY, ModelType, OPT_KEY, + REPO_ID_FNAME, RNG_KEY, - SAFETENSOR_INDEX_FILENAME, + SAFETENSOR_INDEX_FNAME, SEED_KEY, - SHARD_FILENAME, + SHARD_FNAME, STEPS_KEY, - TORCHTUNE_INDEX_FILENAME, + TORCHTUNE_INDEX_FNAME, TOTAL_EPOCHS_KEY, update_state_dict_for_classifier, ) @@ -45,11 +46,12 @@ "Checkpointer", "update_state_dict_for_classifier", "ADAPTER_CONFIG", - "ADAPTER_CONFIG_FILENAME", - "ADAPTER_MODEL_FILENAME", - "SHARD_FILENAME", - "SAFETENSOR_INDEX_FILENAME", - "TORCHTUNE_INDEX_FILENAME", + "ADAPTER_CONFIG_FNAME", + "ADAPTER_MODEL_FNAME", + "SHARD_FNAME", + "SAFETENSOR_INDEX_FNAME", + "TORCHTUNE_INDEX_FNAME", + "REPO_ID_FNAME", "ADAPTER_KEY", "EPOCHS_KEY", "MAX_STEPS_KEY", diff --git a/torchtune/training/checkpointing/_checkpointer.py b/torchtune/training/checkpointing/_checkpointer.py index c54b81c5c2..82297d0494 100644 --- a/torchtune/training/checkpointing/_checkpointer.py +++ b/torchtune/training/checkpointing/_checkpointer.py @@ -347,6 +347,17 @@ def __init__( ) -> None: self._checkpoint_dir = Path(checkpoint_dir) + # repo_id is necessary because, when saving a model adapter, we have to add it + # to the adapter config. This json file is produced and saved in the download step. + repo_id_path = Path.joinpath( + self._checkpoint_dir, training.REPO_ID_FNAME + ).with_suffix(".json") + self.repo_id = None + if repo_id_path.exists(): + with open(repo_id_path, "r") as json_file: + data = json.load(json_file) + self.repo_id = data.get("repo_id") + # e.g. # checkpoint_files: # filename_format: model-{}-of-{}.safetensors @@ -645,7 +656,7 @@ def save_checkpoint( num_shards = len(split_state_dicts) map_original_name_to_new_name = {} for cpt_idx, model_state_dict in split_state_dicts.items(): - shard_name = training.SHARD_FILENAME.format( + shard_name = training.SHARD_FNAME.format( cpt_idx=cpt_idx, num_shards=num_shards ) map_original_name_to_new_name[cpt_idx] = shard_name @@ -672,13 +683,13 @@ def save_checkpoint( k: map_original_name_to_new_name[int(cpt_idx)] + ".safetensors" for k, cpt_idx in self._weight_map.items() } - index_file_name = training.SAFETENSOR_INDEX_FILENAME + index_file_name = training.SAFETENSOR_INDEX_FNAME else: weight_map = { k: map_original_name_to_new_name[int(cpt_idx)] + ".bin" for k, cpt_idx in self._weight_map.items() } - index_file_name = training.TORCHTUNE_INDEX_FILENAME + index_file_name = training.TORCHTUNE_INDEX_FNAME index_path = Path.joinpath( self._output_dir, @@ -727,7 +738,7 @@ def save_checkpoint( ) # TODO: add "if self._safe_serialization:" peft_output_path = Path.joinpath( - self._output_dir, f"epoch_{epoch}", training.ADAPTER_MODEL_FILENAME + self._output_dir, f"epoch_{epoch}", training.ADAPTER_MODEL_FNAME ).with_suffix(".bin") torch.save(state_dict[training.ADAPTER_KEY], peft_output_path) logger.info( @@ -753,10 +764,12 @@ def save_checkpoint( state_dict[ training.ADAPTER_CONFIG ] = convert_weights.tune_to_peft_adapter_config( - state_dict[training.ADAPTER_CONFIG] + state_dict[training.ADAPTER_CONFIG], + base_model_name_or_path=self.repo_id, ) + output_path = Path.joinpath( - self._output_dir, f"epoch_{epoch}", training.ADAPTER_CONFIG_FILENAME + self._output_dir, f"epoch_{epoch}", training.ADAPTER_CONFIG_FNAME ).with_suffix(".json") with open(output_path, "w") as f: json.dump(state_dict[training.ADAPTER_CONFIG], f) @@ -945,11 +958,9 @@ def save_checkpoint( model_state_dict ) - # TODO: We should consider adding adapter/model config - # like we do for HF. - + # TODO: We should consider adding adapter/model config, like we do for HF. # Output file is always a .pt - model_filename = training.SHARD_FILENAME.format(cpt_idx=1, num_shards=1) + model_filename = training.SHARD_FNAME.format(cpt_idx=1, num_shards=1) checkpoint_file = Path.joinpath( self._output_dir, f"epoch_{epoch}", model_filename ).with_suffix(".pt") @@ -962,7 +973,7 @@ def save_checkpoint( if training.ADAPTER_KEY in state_dict: output_path = Path.joinpath( - self._output_dir, f"epoch_{epoch}", training.ADAPTER_MODEL_FILENAME + self._output_dir, f"epoch_{epoch}", training.ADAPTER_MODEL_FNAME ).with_suffix(".pt") torch.save(state_dict[training.ADAPTER_KEY], output_path) logger.info( diff --git a/torchtune/training/checkpointing/_utils.py b/torchtune/training/checkpointing/_utils.py index 5c480c8f66..7988037e16 100644 --- a/torchtune/training/checkpointing/_utils.py +++ b/torchtune/training/checkpointing/_utils.py @@ -23,11 +23,12 @@ # default used by huggingface when looking for saved adapters # https://github.com/huggingface/peft/blob/d13d7a401ccf4808aaaf76480fea09a4cf4ac1f5/src/peft/config.py#L259C21-L259C32 -ADAPTER_CONFIG_FILENAME = "adapter_config" -ADAPTER_MODEL_FILENAME = "adapter_model" -SHARD_FILENAME = "model-{int(cpt_idx):05d}-of-{int(num_shards):05d}" -SAFETENSOR_INDEX_FILENAME = "model.safetensors.index.json" -TORCHTUNE_INDEX_FILENAME = "pytorch_model.bin.index.json" +ADAPTER_CONFIG_FNAME = "adapter_config" +ADAPTER_MODEL_FNAME = "adapter_model" +SHARD_FNAME = "model-{int(cpt_idx):05d}-of-{int(num_shards):05d}" +SAFETENSOR_INDEX_FNAME = "model.safetensors.index" +TORCHTUNE_INDEX_FNAME = "pytorch_model.bin.index" +REPO_ID_FNAME = "repo_id" # key used for adapter weights such as LoRA weights ADAPTER_KEY = "adapter"