Skip to content

Commit

Permalink
[torchtune][dcp]: Deprecation of resume_from_checkpoint in favor of s…
Browse files Browse the repository at this point in the history
…hould_load_recipe_state
  • Loading branch information
saumishr committed Dec 13, 2024
1 parent 21b6ded commit c30392c
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 72 deletions.
23 changes: 15 additions & 8 deletions recipes/full_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,14 +279,21 @@ def setup(self, cfg: DictConfig) -> None:
# Therefore the recipe needs to load the distributed checkpoint to restore the training
# progress.
if self._enable_async_checkpointing:
checkpoint_dict = self._checkpoint_client.load_distributed_checkpoint(
self._model,
(
self._optim_ckpt_wrapper
if self._optimizer_in_bwd
else self._optimizer
),
)
try:
checkpoint_dict = (
self._checkpoint_client.load_distributed_checkpoint(
self._model,
(
self._optim_ckpt_wrapper
if self._optimizer_in_bwd
else self._optimizer
),
)
)
except Exception as e:
log.warning(
f"Failed to load distributed checkpoint: {e}. Training will start from the base checkpoint."
)

# Update the recipe state from the checkpoint state dict.
self._update_recipe_state(checkpoint_dict)
Expand Down
1 change: 0 additions & 1 deletion torchtune/training/checkpointing/_checkpoint_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,6 @@ def _get_checkpointer(self):
self._cfg.checkpointer,
should_load_recipe_state=should_load_recipe_state,
)

return self._checkpointer

def _get_dcp_checkpointer(self):
Expand Down
114 changes: 68 additions & 46 deletions torchtune/training/checkpointing/_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,11 @@ class _CheckpointerInterface(Protocol):
"""

def load_checkpoint(self, **kwargs) -> Dict[str, Any]: ...
def load_checkpoint(self, **kwargs) -> Dict[str, Any]:
...

def save_checkpoint(self, state_dict: Dict[str, Any], **kwargs) -> None: ...
def save_checkpoint(self, state_dict: Dict[str, Any], **kwargs) -> None:
...


class FullModelTorchTuneCheckpointer(_CheckpointerInterface):
Expand All @@ -130,13 +132,14 @@ class FullModelTorchTuneCheckpointer(_CheckpointerInterface):
model_type (str): Model type of the model for which the checkpointer is being loaded, e.g. LLAMA3.
output_dir (str): Directory to save the checkpoint files
adapter_checkpoint (Optional[str]): Path to the adapter weights. If None,
and `resume_from_checkpoint=True`, then look for adapter_model.pt in output_dir/epoch_{largest_epoch}.
and `should_load_recipe_state=True`, then look for adapter_model.pt in output_dir/epoch_{largest_epoch}.
Default is None.
recipe_checkpoint (Optional[str]): Path to the recipe state checkpoint file. If None,
and `resume_from_checkpoint=True`, then look for recipe_state.pt in output_dir/RECIPE_STATE_DIRNAME.
and `should_load_recipe_state=True`, then look for recipe_state.pt in output_dir/RECIPE_STATE_DIRNAME.
Default is None.
resume_from_checkpoint (bool): If True, the checkpointer will load the additional checkpoint files to
resume training from a previous run. Default is False
resume_from_checkpoint (bool): If True, the checkpointer will load the additional checkpoint files corresponding to
the recipe state from a previous run. Default is False. This flag is deprecated. Please use the
should_load_recipe_state flag instead.
should_load_recipe_state (bool): If True, the checkpointer will load the additional checkpoint files corresponding to
the recipe state from a previous run. Default is False
Expand All @@ -152,6 +155,7 @@ def __init__(
output_dir: str,
adapter_checkpoint: Optional[str] = None,
recipe_checkpoint: Optional[str] = None,
resume_from_checkpoint: bool = False,
should_load_recipe_state: bool = False,
) -> None:

Expand All @@ -164,8 +168,14 @@ def __init__(
)

self._checkpoint_dir = Path(checkpoint_dir)
self._resume_from_checkpoint = resume_from_checkpoint
self._should_load_recipe_state = should_load_recipe_state

if resume_from_checkpoint:
self._should_load_recipe_state = resume_from_checkpoint
logger.warning(
"*resume_from_checkpoint is deprecated. Please use the 'should_load_recipe_state' instead"
)

self._model_type = ModelType[model_type]
self._output_dir = Path(output_dir)
self._output_dir.mkdir(parents=True, exist_ok=True)
Expand All @@ -182,32 +192,32 @@ def __init__(
self._adapter_checkpoint = get_adapter_checkpoint_path(
output_dir=self._output_dir,
adapter_checkpoint=adapter_checkpoint,
resume_from_checkpoint=self._resume_from_checkpoint,
should_load_recipe_state=self._should_load_recipe_state,
pattern=r"^epoch_(\d+)",
)

# resume recipe_state ckpt
self._recipe_checkpoint = get_recipe_checkpoint_path(
output_dir=self._output_dir,
recipe_checkpoint=recipe_checkpoint,
resume_from_checkpoint=self._resume_from_checkpoint,
should_load_recipe_state=self._should_load_recipe_state,
)

# get ckpt paths
self._checkpoint_paths = get_model_checkpoint_path(
checkpoint_files=checkpoint_files,
checkpoint_dir=self._checkpoint_dir,
output_dir=self._output_dir,
resume_from_checkpoint=self._resume_from_checkpoint,
should_load_recipe_state=self._should_load_recipe_state,
has_adapter_checkpoint=self._adapter_checkpoint is not None,
)

# we currently accept only a single file
self._checkpoint_path = self._checkpoint_paths[0]

if self._resume_from_checkpoint:
if self._should_load_recipe_state:
logger.info(
"Resuming from checkpoint using:"
"Loading the recipe state using: "
f"\n\tcheckpoint_paths: {[str(path) for path in self._checkpoint_paths]}"
f"\n\trecipe_checkpoint: {self._recipe_checkpoint}"
f"\n\tadapter_checkpoint: {self._adapter_checkpoint}"
Expand Down Expand Up @@ -375,13 +385,14 @@ class FullModelHFCheckpointer(_CheckpointerInterface):
model_type (str): Model type of the model for which the checkpointer is being loaded, e.g. LLAMA3.
output_dir (str): Directory to save the checkpoint files
adapter_checkpoint (Optional[str]): Path to the adapter weights. If None,
and `resume_from_checkpoint=True`, then look for adapter_model.pt in output_dir/epoch_{largest_epoch}.
and `should_load_recipe_state=True`, then look for adapter_model.pt in output_dir/epoch_{largest_epoch}.
Default is None.
recipe_checkpoint (Optional[str]): Path to the recipe state checkpoint file. If None,
and `resume_from_checkpoint=True`, then look for recipe_state.pt in output_dir/RECIPE_STATE_DIRNAME.
and `should_load_recipe_state=True`, then look for recipe_state.pt in output_dir/RECIPE_STATE_DIRNAME.
Default is None.
resume_from_checkpoint (bool): If True, the checkpointer will load the additional checkpoint files to
resume training from a previous run. Default is False
resume_from_checkpoint (bool): If True, the checkpointer will load the additional checkpoint files corresponding to
the receipe state from a previous run. Default is False. This flag is deprecated. Please use
the should_load_recipe_state flag instead.
safe_serialization (bool): If True, the checkpointer will save the checkpoint file using `safetensors`.
Default is True.
should_load_recipe_state (bool): If True, the checkpointer will load the additional checkpoint files corresponding to
Expand All @@ -401,8 +412,13 @@ def __init__(
should_load_recipe_state: bool = False,
) -> None:

self._resume_from_checkpoint = resume_from_checkpoint
self._should_load_recipe_state = should_load_recipe_state
if resume_from_checkpoint:
self._should_load_recipe_state = resume_from_checkpoint
logger.warning(
"*resume_from_checkpoint is deprecated. Please use the 'should_load_recipe_state' instead"
)

self._safe_serialization = safe_serialization
self._checkpoint_dir = Path(checkpoint_dir)
self._model_type = ModelType[model_type]
Expand Down Expand Up @@ -443,29 +459,29 @@ def __init__(
self._adapter_checkpoint = get_adapter_checkpoint_path(
output_dir=self._output_dir,
adapter_checkpoint=adapter_checkpoint,
resume_from_checkpoint=self._resume_from_checkpoint,
should_load_recipe_state=self._should_load_recipe_state,
pattern=r"^epoch_(\d+)",
)

# resume recipe_state ckpt
self._recipe_checkpoint = get_recipe_checkpoint_path(
output_dir=self._output_dir,
recipe_checkpoint=recipe_checkpoint,
resume_from_checkpoint=self._resume_from_checkpoint,
should_load_recipe_state=self._should_load_recipe_state,
)

# get ckpt paths
self._checkpoint_paths = get_model_checkpoint_path(
checkpoint_files=checkpoint_files,
checkpoint_dir=self._checkpoint_dir,
output_dir=self._output_dir,
resume_from_checkpoint=self._resume_from_checkpoint,
should_load_recipe_state=self._should_load_recipe_state,
has_adapter_checkpoint=self._adapter_checkpoint is not None,
)

if self._resume_from_checkpoint:
if self._should_load_recipe_state:
logger.info(
"Resuming from checkpoint using:"
"Loading the recipe state using: "
f"\n\tcheckpoint_paths: {[str(path) for path in self._checkpoint_paths]}"
f"\n\trecipe_checkpoint: {self._recipe_checkpoint}"
f"\n\tadapter_checkpoint: {self._adapter_checkpoint}"
Expand Down Expand Up @@ -796,14 +812,14 @@ def save_checkpoint(
"Saving Llama3.2 Vision adapter weights to PEFT format is not supported, saving to torchtune format instead"
)
else:
state_dict[training.ADAPTER_KEY] = (
convert_weights.tune_to_peft_adapter_weights(
state_dict[training.ADAPTER_KEY],
num_heads=self._config["num_attention_heads"],
num_kv_heads=self._config["num_key_value_heads"],
dim=self._config["hidden_size"],
head_dim=self._config.get("head_dim", None),
)
state_dict[
training.ADAPTER_KEY
] = convert_weights.tune_to_peft_adapter_weights(
state_dict[training.ADAPTER_KEY],
num_heads=self._config["num_attention_heads"],
num_kv_heads=self._config["num_key_value_heads"],
dim=self._config["hidden_size"],
head_dim=self._config.get("head_dim", None),
)
output_path = Path.joinpath(
self._output_dir, f"epoch_{epoch}", ADAPTER_MODEL_FNAME
Expand Down Expand Up @@ -839,11 +855,11 @@ def save_checkpoint(
"PEFT integration for Llama3.2 Vision is not supported, skipping adapter config save"
)
else:
state_dict[training.ADAPTER_CONFIG] = (
convert_weights.tune_to_peft_adapter_config(
adapter_config=state_dict[training.ADAPTER_CONFIG],
base_model_name_or_path=self.repo_id,
)
state_dict[
training.ADAPTER_CONFIG
] = convert_weights.tune_to_peft_adapter_config(
adapter_config=state_dict[training.ADAPTER_CONFIG],
base_model_name_or_path=self.repo_id,
)

output_path = Path.joinpath(
Expand Down Expand Up @@ -903,13 +919,14 @@ class FullModelMetaCheckpointer(_CheckpointerInterface):
model_type (str): Model type of the model for which the checkpointer is being loaded, e.g. LLAMA3.
output_dir (str): Directory to save the checkpoint files
adapter_checkpoint (Optional[str]): Path to the adapter weights. If None,
and `resume_from_checkpoint=True`, then look for adapter_model.pt in output_dir/epoch_{largest_epoch}.
and `should_load_recipe_state=True`, then look for adapter_model.pt in output_dir/epoch_{largest_epoch}.
Default is None.
recipe_checkpoint (Optional[str]): Path to the recipe state checkpoint file. If None,
and `resume_from_checkpoint=True`, then look for recipe_state.pt in output_dir/recipe_state.
and `should_load_recipe_state=True`, then look for recipe_state.pt in output_dir/recipe_state.
Default is None.
resume_from_checkpoint (bool): If True, the checkpointer will load the additional checkpoint files to
resume training from a previous run. Default is False
resume_from_checkpoint (bool): If True, the checkpointer will load the additional checkpoint files corresponding to
the recipe state from a previous run. Default is False. This flag is deprecated. Please use the
should_load_recipe_state instead.
should_load_recipe_state (bool): If True, the checkpointer will load the additional checkpoint files corresponding to
the recipe state from a previous run. Default is False
Expand All @@ -926,6 +943,7 @@ def __init__(
output_dir: str,
adapter_checkpoint: Optional[str] = None,
recipe_checkpoint: Optional[str] = None,
resume_from_checkpoint: bool = False,
should_load_recipe_state: bool = False,
) -> None:

Expand All @@ -938,8 +956,12 @@ def __init__(
)

self._checkpoint_dir = Path(checkpoint_dir)
self._resume_from_checkpoint = resume_from_checkpoint
self._should_load_recipe_state = should_load_recipe_state
if resume_from_checkpoint:
self._should_load_recipe_state = resume_from_checkpoint
logger.warning(
"*resume_from_checkpoint is deprecated. Please use the 'should_load_recipe_state' instead"
)
self._model_type = ModelType[model_type]
self._output_dir = Path(output_dir)
self._output_dir.mkdir(parents=True, exist_ok=True)
Expand All @@ -956,32 +978,32 @@ def __init__(
self._adapter_checkpoint = get_adapter_checkpoint_path(
output_dir=self._output_dir,
adapter_checkpoint=adapter_checkpoint,
resume_from_checkpoint=self._resume_from_checkpoint,
should_load_recipe_state=self._should_load_recipe_state,
pattern=r"^epoch_(\d+)",
)

# resume recipe_state ckpt
self._recipe_checkpoint = get_recipe_checkpoint_path(
output_dir=self._output_dir,
recipe_checkpoint=recipe_checkpoint,
resume_from_checkpoint=self._resume_from_checkpoint,
should_load_recipe_state=self._should_load_recipe_state,
)

# get ckpt paths
self._checkpoint_paths = get_model_checkpoint_path(
checkpoint_files=checkpoint_files,
checkpoint_dir=self._checkpoint_dir,
output_dir=self._output_dir,
resume_from_checkpoint=self._resume_from_checkpoint,
should_load_recipe_state=self._should_load_recipe_state,
has_adapter_checkpoint=self._adapter_checkpoint is not None,
)

# we currently accept only a single file
self._checkpoint_path = self._checkpoint_paths[0]

if self._resume_from_checkpoint:
if self._should_load_recipe_state:
logger.info(
"Resuming from checkpoint using:"
"Loading the recipe state using: "
f"\n\tcheckpoint_paths: {[str(path) for path in self._checkpoint_paths]}"
f"\n\trecipe_checkpoint: {self._recipe_checkpoint}"
f"\n\tadapter_checkpoint: {self._adapter_checkpoint}"
Expand Down Expand Up @@ -1156,7 +1178,7 @@ def __init__(
self._checkpoint_dir = Path(checkpoint_dir)
self._output_dir = Path(output_dir)
self._checkpoint_future = None
self._checkpoint_dir_prefix = "checkpoint"
self._checkpoint_dir_prefix = "dist_epoch"
self._metadata_file = ".metadata"
_, self._rank = training.get_world_size_and_rank()
self._process_group: Optional[dist.ProcessGroup] = process_group
Expand Down
Loading

0 comments on commit c30392c

Please sign in to comment.