From c30392ccdca036f033fd28fed84d8d50c29d1be2 Mon Sep 17 00:00:00 2001 From: Saurabh M Date: Tue, 10 Dec 2024 18:28:19 -0800 Subject: [PATCH] [torchtune][dcp]: Deprecation of resume_from_checkpoint in favor of should_load_recipe_state --- recipes/full_finetune_distributed.py | 23 ++-- .../checkpointing/_checkpoint_client.py | 1 - .../training/checkpointing/_checkpointer.py | 114 +++++++++++------- torchtune/training/checkpointing/_utils.py | 34 +++--- 4 files changed, 100 insertions(+), 72 deletions(-) diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index 7399832e9f..df8274f597 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -279,14 +279,21 @@ def setup(self, cfg: DictConfig) -> None: # Therefore the recipe needs to load the distributed checkpoint to restore the training # progress. if self._enable_async_checkpointing: - checkpoint_dict = self._checkpoint_client.load_distributed_checkpoint( - self._model, - ( - self._optim_ckpt_wrapper - if self._optimizer_in_bwd - else self._optimizer - ), - ) + try: + checkpoint_dict = ( + self._checkpoint_client.load_distributed_checkpoint( + self._model, + ( + self._optim_ckpt_wrapper + if self._optimizer_in_bwd + else self._optimizer + ), + ) + ) + except Exception as e: + log.warning( + f"Failed to load distributed checkpoint: {e}. Training will start from the base checkpoint." + ) # Update the recipe state from the checkpoint state dict. self._update_recipe_state(checkpoint_dict) diff --git a/torchtune/training/checkpointing/_checkpoint_client.py b/torchtune/training/checkpointing/_checkpoint_client.py index 8a1109a1e5..90a4208b6b 100644 --- a/torchtune/training/checkpointing/_checkpoint_client.py +++ b/torchtune/training/checkpointing/_checkpoint_client.py @@ -89,7 +89,6 @@ def _get_checkpointer(self): self._cfg.checkpointer, should_load_recipe_state=should_load_recipe_state, ) - return self._checkpointer def _get_dcp_checkpointer(self): diff --git a/torchtune/training/checkpointing/_checkpointer.py b/torchtune/training/checkpointing/_checkpointer.py index 91dc4de3e7..fd9d957b56 100644 --- a/torchtune/training/checkpointing/_checkpointer.py +++ b/torchtune/training/checkpointing/_checkpointer.py @@ -110,9 +110,11 @@ class _CheckpointerInterface(Protocol): """ - def load_checkpoint(self, **kwargs) -> Dict[str, Any]: ... + def load_checkpoint(self, **kwargs) -> Dict[str, Any]: + ... - def save_checkpoint(self, state_dict: Dict[str, Any], **kwargs) -> None: ... + def save_checkpoint(self, state_dict: Dict[str, Any], **kwargs) -> None: + ... class FullModelTorchTuneCheckpointer(_CheckpointerInterface): @@ -130,13 +132,14 @@ class FullModelTorchTuneCheckpointer(_CheckpointerInterface): model_type (str): Model type of the model for which the checkpointer is being loaded, e.g. LLAMA3. output_dir (str): Directory to save the checkpoint files adapter_checkpoint (Optional[str]): Path to the adapter weights. If None, - and `resume_from_checkpoint=True`, then look for adapter_model.pt in output_dir/epoch_{largest_epoch}. + and `should_load_recipe_state=True`, then look for adapter_model.pt in output_dir/epoch_{largest_epoch}. Default is None. recipe_checkpoint (Optional[str]): Path to the recipe state checkpoint file. If None, - and `resume_from_checkpoint=True`, then look for recipe_state.pt in output_dir/RECIPE_STATE_DIRNAME. + and `should_load_recipe_state=True`, then look for recipe_state.pt in output_dir/RECIPE_STATE_DIRNAME. Default is None. - resume_from_checkpoint (bool): If True, the checkpointer will load the additional checkpoint files to - resume training from a previous run. Default is False + resume_from_checkpoint (bool): If True, the checkpointer will load the additional checkpoint files corresponding to + the recipe state from a previous run. Default is False. This flag is deprecated. Please use the + should_load_recipe_state flag instead. should_load_recipe_state (bool): If True, the checkpointer will load the additional checkpoint files corresponding to the recipe state from a previous run. Default is False @@ -152,6 +155,7 @@ def __init__( output_dir: str, adapter_checkpoint: Optional[str] = None, recipe_checkpoint: Optional[str] = None, + resume_from_checkpoint: bool = False, should_load_recipe_state: bool = False, ) -> None: @@ -164,8 +168,14 @@ def __init__( ) self._checkpoint_dir = Path(checkpoint_dir) - self._resume_from_checkpoint = resume_from_checkpoint self._should_load_recipe_state = should_load_recipe_state + + if resume_from_checkpoint: + self._should_load_recipe_state = resume_from_checkpoint + logger.warning( + "*resume_from_checkpoint is deprecated. Please use the 'should_load_recipe_state' instead" + ) + self._model_type = ModelType[model_type] self._output_dir = Path(output_dir) self._output_dir.mkdir(parents=True, exist_ok=True) @@ -182,7 +192,7 @@ def __init__( self._adapter_checkpoint = get_adapter_checkpoint_path( output_dir=self._output_dir, adapter_checkpoint=adapter_checkpoint, - resume_from_checkpoint=self._resume_from_checkpoint, + should_load_recipe_state=self._should_load_recipe_state, pattern=r"^epoch_(\d+)", ) @@ -190,7 +200,7 @@ def __init__( self._recipe_checkpoint = get_recipe_checkpoint_path( output_dir=self._output_dir, recipe_checkpoint=recipe_checkpoint, - resume_from_checkpoint=self._resume_from_checkpoint, + should_load_recipe_state=self._should_load_recipe_state, ) # get ckpt paths @@ -198,16 +208,16 @@ def __init__( checkpoint_files=checkpoint_files, checkpoint_dir=self._checkpoint_dir, output_dir=self._output_dir, - resume_from_checkpoint=self._resume_from_checkpoint, + should_load_recipe_state=self._should_load_recipe_state, has_adapter_checkpoint=self._adapter_checkpoint is not None, ) # we currently accept only a single file self._checkpoint_path = self._checkpoint_paths[0] - if self._resume_from_checkpoint: + if self._should_load_recipe_state: logger.info( - "Resuming from checkpoint using:" + "Loading the recipe state using: " f"\n\tcheckpoint_paths: {[str(path) for path in self._checkpoint_paths]}" f"\n\trecipe_checkpoint: {self._recipe_checkpoint}" f"\n\tadapter_checkpoint: {self._adapter_checkpoint}" @@ -375,13 +385,14 @@ class FullModelHFCheckpointer(_CheckpointerInterface): model_type (str): Model type of the model for which the checkpointer is being loaded, e.g. LLAMA3. output_dir (str): Directory to save the checkpoint files adapter_checkpoint (Optional[str]): Path to the adapter weights. If None, - and `resume_from_checkpoint=True`, then look for adapter_model.pt in output_dir/epoch_{largest_epoch}. + and `should_load_recipe_state=True`, then look for adapter_model.pt in output_dir/epoch_{largest_epoch}. Default is None. recipe_checkpoint (Optional[str]): Path to the recipe state checkpoint file. If None, - and `resume_from_checkpoint=True`, then look for recipe_state.pt in output_dir/RECIPE_STATE_DIRNAME. + and `should_load_recipe_state=True`, then look for recipe_state.pt in output_dir/RECIPE_STATE_DIRNAME. Default is None. - resume_from_checkpoint (bool): If True, the checkpointer will load the additional checkpoint files to - resume training from a previous run. Default is False + resume_from_checkpoint (bool): If True, the checkpointer will load the additional checkpoint files corresponding to + the receipe state from a previous run. Default is False. This flag is deprecated. Please use + the should_load_recipe_state flag instead. safe_serialization (bool): If True, the checkpointer will save the checkpoint file using `safetensors`. Default is True. should_load_recipe_state (bool): If True, the checkpointer will load the additional checkpoint files corresponding to @@ -401,8 +412,13 @@ def __init__( should_load_recipe_state: bool = False, ) -> None: - self._resume_from_checkpoint = resume_from_checkpoint self._should_load_recipe_state = should_load_recipe_state + if resume_from_checkpoint: + self._should_load_recipe_state = resume_from_checkpoint + logger.warning( + "*resume_from_checkpoint is deprecated. Please use the 'should_load_recipe_state' instead" + ) + self._safe_serialization = safe_serialization self._checkpoint_dir = Path(checkpoint_dir) self._model_type = ModelType[model_type] @@ -443,7 +459,7 @@ def __init__( self._adapter_checkpoint = get_adapter_checkpoint_path( output_dir=self._output_dir, adapter_checkpoint=adapter_checkpoint, - resume_from_checkpoint=self._resume_from_checkpoint, + should_load_recipe_state=self._should_load_recipe_state, pattern=r"^epoch_(\d+)", ) @@ -451,7 +467,7 @@ def __init__( self._recipe_checkpoint = get_recipe_checkpoint_path( output_dir=self._output_dir, recipe_checkpoint=recipe_checkpoint, - resume_from_checkpoint=self._resume_from_checkpoint, + should_load_recipe_state=self._should_load_recipe_state, ) # get ckpt paths @@ -459,13 +475,13 @@ def __init__( checkpoint_files=checkpoint_files, checkpoint_dir=self._checkpoint_dir, output_dir=self._output_dir, - resume_from_checkpoint=self._resume_from_checkpoint, + should_load_recipe_state=self._should_load_recipe_state, has_adapter_checkpoint=self._adapter_checkpoint is not None, ) - if self._resume_from_checkpoint: + if self._should_load_recipe_state: logger.info( - "Resuming from checkpoint using:" + "Loading the recipe state using: " f"\n\tcheckpoint_paths: {[str(path) for path in self._checkpoint_paths]}" f"\n\trecipe_checkpoint: {self._recipe_checkpoint}" f"\n\tadapter_checkpoint: {self._adapter_checkpoint}" @@ -796,14 +812,14 @@ def save_checkpoint( "Saving Llama3.2 Vision adapter weights to PEFT format is not supported, saving to torchtune format instead" ) else: - state_dict[training.ADAPTER_KEY] = ( - convert_weights.tune_to_peft_adapter_weights( - state_dict[training.ADAPTER_KEY], - num_heads=self._config["num_attention_heads"], - num_kv_heads=self._config["num_key_value_heads"], - dim=self._config["hidden_size"], - head_dim=self._config.get("head_dim", None), - ) + state_dict[ + training.ADAPTER_KEY + ] = convert_weights.tune_to_peft_adapter_weights( + state_dict[training.ADAPTER_KEY], + num_heads=self._config["num_attention_heads"], + num_kv_heads=self._config["num_key_value_heads"], + dim=self._config["hidden_size"], + head_dim=self._config.get("head_dim", None), ) output_path = Path.joinpath( self._output_dir, f"epoch_{epoch}", ADAPTER_MODEL_FNAME @@ -839,11 +855,11 @@ def save_checkpoint( "PEFT integration for Llama3.2 Vision is not supported, skipping adapter config save" ) else: - state_dict[training.ADAPTER_CONFIG] = ( - convert_weights.tune_to_peft_adapter_config( - adapter_config=state_dict[training.ADAPTER_CONFIG], - base_model_name_or_path=self.repo_id, - ) + state_dict[ + training.ADAPTER_CONFIG + ] = convert_weights.tune_to_peft_adapter_config( + adapter_config=state_dict[training.ADAPTER_CONFIG], + base_model_name_or_path=self.repo_id, ) output_path = Path.joinpath( @@ -903,13 +919,14 @@ class FullModelMetaCheckpointer(_CheckpointerInterface): model_type (str): Model type of the model for which the checkpointer is being loaded, e.g. LLAMA3. output_dir (str): Directory to save the checkpoint files adapter_checkpoint (Optional[str]): Path to the adapter weights. If None, - and `resume_from_checkpoint=True`, then look for adapter_model.pt in output_dir/epoch_{largest_epoch}. + and `should_load_recipe_state=True`, then look for adapter_model.pt in output_dir/epoch_{largest_epoch}. Default is None. recipe_checkpoint (Optional[str]): Path to the recipe state checkpoint file. If None, - and `resume_from_checkpoint=True`, then look for recipe_state.pt in output_dir/recipe_state. + and `should_load_recipe_state=True`, then look for recipe_state.pt in output_dir/recipe_state. Default is None. - resume_from_checkpoint (bool): If True, the checkpointer will load the additional checkpoint files to - resume training from a previous run. Default is False + resume_from_checkpoint (bool): If True, the checkpointer will load the additional checkpoint files corresponding to + the recipe state from a previous run. Default is False. This flag is deprecated. Please use the + should_load_recipe_state instead. should_load_recipe_state (bool): If True, the checkpointer will load the additional checkpoint files corresponding to the recipe state from a previous run. Default is False @@ -926,6 +943,7 @@ def __init__( output_dir: str, adapter_checkpoint: Optional[str] = None, recipe_checkpoint: Optional[str] = None, + resume_from_checkpoint: bool = False, should_load_recipe_state: bool = False, ) -> None: @@ -938,8 +956,12 @@ def __init__( ) self._checkpoint_dir = Path(checkpoint_dir) - self._resume_from_checkpoint = resume_from_checkpoint self._should_load_recipe_state = should_load_recipe_state + if resume_from_checkpoint: + self._should_load_recipe_state = resume_from_checkpoint + logger.warning( + "*resume_from_checkpoint is deprecated. Please use the 'should_load_recipe_state' instead" + ) self._model_type = ModelType[model_type] self._output_dir = Path(output_dir) self._output_dir.mkdir(parents=True, exist_ok=True) @@ -956,7 +978,7 @@ def __init__( self._adapter_checkpoint = get_adapter_checkpoint_path( output_dir=self._output_dir, adapter_checkpoint=adapter_checkpoint, - resume_from_checkpoint=self._resume_from_checkpoint, + should_load_recipe_state=self._should_load_recipe_state, pattern=r"^epoch_(\d+)", ) @@ -964,7 +986,7 @@ def __init__( self._recipe_checkpoint = get_recipe_checkpoint_path( output_dir=self._output_dir, recipe_checkpoint=recipe_checkpoint, - resume_from_checkpoint=self._resume_from_checkpoint, + should_load_recipe_state=self._should_load_recipe_state, ) # get ckpt paths @@ -972,16 +994,16 @@ def __init__( checkpoint_files=checkpoint_files, checkpoint_dir=self._checkpoint_dir, output_dir=self._output_dir, - resume_from_checkpoint=self._resume_from_checkpoint, + should_load_recipe_state=self._should_load_recipe_state, has_adapter_checkpoint=self._adapter_checkpoint is not None, ) # we currently accept only a single file self._checkpoint_path = self._checkpoint_paths[0] - if self._resume_from_checkpoint: + if self._should_load_recipe_state: logger.info( - "Resuming from checkpoint using:" + "Loading the recipe state using: " f"\n\tcheckpoint_paths: {[str(path) for path in self._checkpoint_paths]}" f"\n\trecipe_checkpoint: {self._recipe_checkpoint}" f"\n\tadapter_checkpoint: {self._adapter_checkpoint}" @@ -1156,7 +1178,7 @@ def __init__( self._checkpoint_dir = Path(checkpoint_dir) self._output_dir = Path(output_dir) self._checkpoint_future = None - self._checkpoint_dir_prefix = "checkpoint" + self._checkpoint_dir_prefix = "dist_epoch" self._metadata_file = ".metadata" _, self._rank = training.get_world_size_and_rank() self._process_group: Optional[dist.ProcessGroup] = process_group diff --git a/torchtune/training/checkpointing/_utils.py b/torchtune/training/checkpointing/_utils.py index 770a3f889c..1d8a63daab 100644 --- a/torchtune/training/checkpointing/_utils.py +++ b/torchtune/training/checkpointing/_utils.py @@ -387,7 +387,7 @@ def copy_files( def get_recipe_checkpoint_path( output_dir: Path, recipe_checkpoint: Optional[str] = None, - resume_from_checkpoint: bool = False, + should_load_recipe_state: bool = False, ) -> Optional[Path]: """ If recipe_checkpoint is None, look for recipe_state.pt in {output_dir}/{RECIPE_STATE_DIRNAME}/recipe_state.pt. @@ -396,13 +396,13 @@ def get_recipe_checkpoint_path( Args: output_dir (Path): Directory containing the recipe checkpoint. recipe_checkpoint (Optional[str]): Name of the recipe checkpoint file. Defaults to None. - resume_from_checkpoint (bool): Whether to resume from a checkpoint. + should_load_recipe_state (bool): Whether to load the recipe state from the checkpoint. Returns: - Optional[Path]: Path to the recipe checkpoint file if resume_from_checkpoint is True, otherwise None. + Optional[Path]: Path to the recipe checkpoint file if should_load_recipe_state is True, otherwise None. Raises: - ValueError: If resume_from_checkpoint is True and the recipe checkpoint file is missing. + ValueError: If should_load_recipe_state is True and the recipe checkpoint file is missing. """ - if not resume_from_checkpoint: + if not should_load_recipe_state: return None recipe_checkpoint_path = None @@ -416,7 +416,7 @@ def get_recipe_checkpoint_path( # TODO: improve this msg if not recipe_checkpoint_path or not os.path.exists(recipe_checkpoint_path): raise ValueError( - "If resume_from_checkpoint is True, recipe_checkpoint file must be provided." + "If should_load_recipe_state is True, recipe_checkpoint file must be provided." ) return Path(recipe_checkpoint_path) @@ -425,7 +425,7 @@ def get_recipe_checkpoint_path( def get_adapter_checkpoint_path( output_dir: Path, adapter_checkpoint: Optional[str] = None, - resume_from_checkpoint: bool = False, + should_load_recipe_state: bool = False, pattern: str = r"^epoch_(\d+)", ) -> Optional[Path]: r""" @@ -435,13 +435,13 @@ def get_adapter_checkpoint_path( Args: output_dir (Path): Directory containing the adapter checkpoint. adapter_checkpoint (Optional[str]): Name of the adapter checkpoint file. Defaults to None. - resume_from_checkpoint (bool): Whether to resume from a checkpoint. + should_load_recipe_state (bool): Whether to load the recipe state from checkpoint. pattern (str): Regex pattern to match the epoch folder. Defaults to "epoch_(\d+)". Returns: Optional[Path]: Path to the adapter checkpoint file, or None if not applicable. """ - if not resume_from_checkpoint: + if not should_load_recipe_state: return None adapter_checkpoint_path = None @@ -466,7 +466,7 @@ def get_model_checkpoint_path( checkpoint_files: Union[List[str], Dict[str, str]], checkpoint_dir: Union[str, Path], output_dir: Union[str, Path], - resume_from_checkpoint: bool, + should_load_recipe_state: bool, has_adapter_checkpoint: bool, ) -> list[Path]: """ @@ -484,7 +484,7 @@ def get_model_checkpoint_path( it is converted to a list of formatted checkpoint filenames. checkpoint_dir (Union[str, Path]): Directory containing the checkpoint files. output_dir (Union[str, Path]): Directory to use when resuming from a checkpoint. - resume_from_checkpoint (bool): Whether to resume from a checkpoint. + should_load_recipe_state (bool): Whether to resume from a checkpoint. has_adapter_checkpoint (bool): Indicates if there is an adapter checkpoint. Returns: list[Path]: Sorted list of paths to the checkpoint files. @@ -492,13 +492,13 @@ def get_model_checkpoint_path( >>> checkpoint_files = ["model-00001-of-00002.safetensors", "model-00002-of-00002.safetensors"] >>> checkpoint_dir = "/path/to/checkpoints" >>> output_dir = "/path/to/output" - >>> resume_from_checkpoint = True + >>> should_load_recipe_state = True >>> has_adapter_checkpoint = False >>> paths = get_model_checkpoint_path( ... checkpoint_files, ... checkpoint_dir, ... output_dir, - ... resume_from_checkpoint, + ... should_load_recipe_state, ... has_adapter_checkpoint ... ) >>> print(paths) @@ -536,15 +536,15 @@ def validate_checkpoint_files( ) checkpoint_files = formatted_checkpoint_files.build_checkpoint_filenames() - # Case 1: no resuming from ckpt - if not resume_from_checkpoint: + # Case 1: not loading the recipe state + if not should_load_recipe_state: input_dir = checkpoint_dir - # Case 2: Resuming from ckpt, but its full finetuning (no adapter) + # Case 2: Loading the recipe state, but its full finetuning (no adapter) elif not has_adapter_checkpoint: input_dir = output_dir - # Case 3: Resuming from ckpt and has an adapter. + # Case 3: Loading the recipe state and has an adapter. else: # FIXME # TODO: if the model has lora + trained weights, e.g. embeddings,