Skip to content

Commit

Permalink
Faster intermediate checkpoints with DCP async save in TorchTune (pyt…
Browse files Browse the repository at this point in the history
…orch#2006)

Co-authored-by: Saurabh Mishra <[email protected]>
  • Loading branch information
2 people authored and rahul-sarvam committed Dec 23, 2024
1 parent 4a0a53d commit e6d7ef5
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 47 deletions.
88 changes: 58 additions & 30 deletions torchtune/training/checkpointing/_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,13 +132,16 @@ class FullModelTorchTuneCheckpointer(_CheckpointerInterface):
model_type (str): Model type of the model for which the checkpointer is being loaded, e.g. LLAMA3.
output_dir (str): Directory to save the checkpoint files
adapter_checkpoint (Optional[str]): Path to the adapter weights. If None,
and `resume_from_checkpoint=True`, then look for adapter_model.pt in output_dir/epoch_{largest_epoch}.
and `should_load_recipe_state=True`, then look for adapter_model.pt in output_dir/epoch_{largest_epoch}.
Default is None.
recipe_checkpoint (Optional[str]): Path to the recipe state checkpoint file. If None,
and `resume_from_checkpoint=True`, then look for recipe_state.pt in output_dir/RECIPE_STATE_DIRNAME.
and `should_load_recipe_state=True`, then look for recipe_state.pt in output_dir/RECIPE_STATE_DIRNAME.
Default is None.
resume_from_checkpoint (bool): If True, the checkpointer will load the additional checkpoint files to
resume training from a previous run. Default is False
resume_from_checkpoint (bool): If True, the checkpointer will load the additional checkpoint files corresponding to
the recipe state from a previous run. Default is False. This flag is deprecated. Please use the
should_load_recipe_state flag instead.
should_load_recipe_state (bool): If True, the checkpointer will load the additional checkpoint files corresponding to
the recipe state from a previous run. Default is False
Raises:
ValueError: If more than one checkpoint file is provided
Expand All @@ -165,7 +168,14 @@ def __init__(
)

self._checkpoint_dir = Path(checkpoint_dir)
self._resume_from_checkpoint = resume_from_checkpoint
self._should_load_recipe_state = should_load_recipe_state

if resume_from_checkpoint:
self._should_load_recipe_state = resume_from_checkpoint
logger.warning(
"*resume_from_checkpoint is deprecated. Please use the 'should_load_recipe_state' instead"
)

self._model_type = ModelType[model_type]
self._output_dir = Path(output_dir)
self._output_dir.mkdir(parents=True, exist_ok=True)
Expand All @@ -182,32 +192,32 @@ def __init__(
self._adapter_checkpoint = get_adapter_checkpoint_path(
output_dir=self._output_dir,
adapter_checkpoint=adapter_checkpoint,
resume_from_checkpoint=self._resume_from_checkpoint,
should_load_recipe_state=self._should_load_recipe_state,
pattern=r"^epoch_(\d+)",
)

# resume recipe_state ckpt
self._recipe_checkpoint = get_recipe_checkpoint_path(
output_dir=self._output_dir,
recipe_checkpoint=recipe_checkpoint,
resume_from_checkpoint=self._resume_from_checkpoint,
should_load_recipe_state=self._should_load_recipe_state,
)

# get ckpt paths
self._checkpoint_paths = get_model_checkpoint_path(
checkpoint_files=checkpoint_files,
checkpoint_dir=self._checkpoint_dir,
output_dir=self._output_dir,
resume_from_checkpoint=self._resume_from_checkpoint,
should_load_recipe_state=self._should_load_recipe_state,
has_adapter_checkpoint=self._adapter_checkpoint is not None,
)

# we currently accept only a single file
self._checkpoint_path = self._checkpoint_paths[0]

if self._resume_from_checkpoint:
if self._should_load_recipe_state:
logger.info(
"Resuming from checkpoint using:"
"Loading the recipe state using: "
f"\n\tcheckpoint_paths: {[str(path) for path in self._checkpoint_paths]}"
f"\n\trecipe_checkpoint: {self._recipe_checkpoint}"
f"\n\tadapter_checkpoint: {self._adapter_checkpoint}"
Expand Down Expand Up @@ -375,15 +385,18 @@ class FullModelHFCheckpointer(_CheckpointerInterface):
model_type (str): Model type of the model for which the checkpointer is being loaded, e.g. LLAMA3.
output_dir (str): Directory to save the checkpoint files
adapter_checkpoint (Optional[str]): Path to the adapter weights. If None,
and `resume_from_checkpoint=True`, then look for adapter_model.pt in output_dir/epoch_{largest_epoch}.
and `should_load_recipe_state=True`, then look for adapter_model.pt in output_dir/epoch_{largest_epoch}.
Default is None.
recipe_checkpoint (Optional[str]): Path to the recipe state checkpoint file. If None,
and `resume_from_checkpoint=True`, then look for recipe_state.pt in output_dir/RECIPE_STATE_DIRNAME.
and `should_load_recipe_state=True`, then look for recipe_state.pt in output_dir/RECIPE_STATE_DIRNAME.
Default is None.
resume_from_checkpoint (bool): If True, the checkpointer will load the additional checkpoint files to
resume training from a previous run. Default is False
resume_from_checkpoint (bool): If True, the checkpointer will load the additional checkpoint files corresponding to
the receipe state from a previous run. Default is False. This flag is deprecated. Please use
the should_load_recipe_state flag instead.
safe_serialization (bool): If True, the checkpointer will save the checkpoint file using `safetensors`.
Default is True.
should_load_recipe_state (bool): If True, the checkpointer will load the additional checkpoint files corresponding to
the receipe state from a previous run. Default is False
"""

def __init__(
Expand All @@ -396,9 +409,16 @@ def __init__(
recipe_checkpoint: Optional[str] = None,
resume_from_checkpoint: bool = False,
safe_serialization: bool = True,
should_load_recipe_state: bool = False,
) -> None:

self._resume_from_checkpoint = resume_from_checkpoint
self._should_load_recipe_state = should_load_recipe_state
if resume_from_checkpoint:
self._should_load_recipe_state = resume_from_checkpoint
logger.warning(
"*resume_from_checkpoint is deprecated. Please use the 'should_load_recipe_state' instead"
)

self._safe_serialization = safe_serialization
self._checkpoint_dir = Path(checkpoint_dir)
self._model_type = ModelType[model_type]
Expand Down Expand Up @@ -439,29 +459,29 @@ def __init__(
self._adapter_checkpoint = get_adapter_checkpoint_path(
output_dir=self._output_dir,
adapter_checkpoint=adapter_checkpoint,
resume_from_checkpoint=self._resume_from_checkpoint,
should_load_recipe_state=self._should_load_recipe_state,
pattern=r"^epoch_(\d+)",
)

# resume recipe_state ckpt
self._recipe_checkpoint = get_recipe_checkpoint_path(
output_dir=self._output_dir,
recipe_checkpoint=recipe_checkpoint,
resume_from_checkpoint=self._resume_from_checkpoint,
should_load_recipe_state=self._should_load_recipe_state,
)

# get ckpt paths
self._checkpoint_paths = get_model_checkpoint_path(
checkpoint_files=checkpoint_files,
checkpoint_dir=self._checkpoint_dir,
output_dir=self._output_dir,
resume_from_checkpoint=self._resume_from_checkpoint,
should_load_recipe_state=self._should_load_recipe_state,
has_adapter_checkpoint=self._adapter_checkpoint is not None,
)

if self._resume_from_checkpoint:
if self._should_load_recipe_state:
logger.info(
"Resuming from checkpoint using:"
"Loading the recipe state using: "
f"\n\tcheckpoint_paths: {[str(path) for path in self._checkpoint_paths]}"
f"\n\trecipe_checkpoint: {self._recipe_checkpoint}"
f"\n\tadapter_checkpoint: {self._adapter_checkpoint}"
Expand Down Expand Up @@ -899,13 +919,16 @@ class FullModelMetaCheckpointer(_CheckpointerInterface):
model_type (str): Model type of the model for which the checkpointer is being loaded, e.g. LLAMA3.
output_dir (str): Directory to save the checkpoint files
adapter_checkpoint (Optional[str]): Path to the adapter weights. If None,
and `resume_from_checkpoint=True`, then look for adapter_model.pt in output_dir/epoch_{largest_epoch}.
and `should_load_recipe_state=True`, then look for adapter_model.pt in output_dir/epoch_{largest_epoch}.
Default is None.
recipe_checkpoint (Optional[str]): Path to the recipe state checkpoint file. If None,
and `resume_from_checkpoint=True`, then look for recipe_state.pt in output_dir/recipe_state.
and `should_load_recipe_state=True`, then look for recipe_state.pt in output_dir/recipe_state.
Default is None.
resume_from_checkpoint (bool): If True, the checkpointer will load the additional checkpoint files to
resume training from a previous run. Default is False
resume_from_checkpoint (bool): If True, the checkpointer will load the additional checkpoint files corresponding to
the recipe state from a previous run. Default is False. This flag is deprecated. Please use the
should_load_recipe_state instead.
should_load_recipe_state (bool): If True, the checkpointer will load the additional checkpoint files corresponding to
the recipe state from a previous run. Default is False
Raises:
ValueError: If ``checkpoint_files`` is not a list of length 1
Expand Down Expand Up @@ -933,7 +956,12 @@ def __init__(
)

self._checkpoint_dir = Path(checkpoint_dir)
self._resume_from_checkpoint = resume_from_checkpoint
self._should_load_recipe_state = should_load_recipe_state
if resume_from_checkpoint:
self._should_load_recipe_state = resume_from_checkpoint
logger.warning(
"*resume_from_checkpoint is deprecated. Please use the 'should_load_recipe_state' instead"
)
self._model_type = ModelType[model_type]
self._output_dir = Path(output_dir)
self._output_dir.mkdir(parents=True, exist_ok=True)
Expand All @@ -950,32 +978,32 @@ def __init__(
self._adapter_checkpoint = get_adapter_checkpoint_path(
output_dir=self._output_dir,
adapter_checkpoint=adapter_checkpoint,
resume_from_checkpoint=self._resume_from_checkpoint,
should_load_recipe_state=self._should_load_recipe_state,
pattern=r"^epoch_(\d+)",
)

# resume recipe_state ckpt
self._recipe_checkpoint = get_recipe_checkpoint_path(
output_dir=self._output_dir,
recipe_checkpoint=recipe_checkpoint,
resume_from_checkpoint=self._resume_from_checkpoint,
should_load_recipe_state=self._should_load_recipe_state,
)

# get ckpt paths
self._checkpoint_paths = get_model_checkpoint_path(
checkpoint_files=checkpoint_files,
checkpoint_dir=self._checkpoint_dir,
output_dir=self._output_dir,
resume_from_checkpoint=self._resume_from_checkpoint,
should_load_recipe_state=self._should_load_recipe_state,
has_adapter_checkpoint=self._adapter_checkpoint is not None,
)

# we currently accept only a single file
self._checkpoint_path = self._checkpoint_paths[0]

if self._resume_from_checkpoint:
if self._should_load_recipe_state:
logger.info(
"Resuming from checkpoint using:"
"Loading the recipe state using: "
f"\n\tcheckpoint_paths: {[str(path) for path in self._checkpoint_paths]}"
f"\n\trecipe_checkpoint: {self._recipe_checkpoint}"
f"\n\tadapter_checkpoint: {self._adapter_checkpoint}"
Expand Down
34 changes: 17 additions & 17 deletions torchtune/training/checkpointing/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ def copy_files(
def get_recipe_checkpoint_path(
output_dir: Path,
recipe_checkpoint: Optional[str] = None,
resume_from_checkpoint: bool = False,
should_load_recipe_state: bool = False,
) -> Optional[Path]:
"""
If recipe_checkpoint is None, look for recipe_state.pt in {output_dir}/{RECIPE_STATE_DIRNAME}/recipe_state.pt.
Expand All @@ -400,13 +400,13 @@ def get_recipe_checkpoint_path(
Args:
output_dir (Path): Directory containing the recipe checkpoint.
recipe_checkpoint (Optional[str]): Name of the recipe checkpoint file. Defaults to None.
resume_from_checkpoint (bool): Whether to resume from a checkpoint.
should_load_recipe_state (bool): Whether to load the recipe state from the checkpoint.
Returns:
Optional[Path]: Path to the recipe checkpoint file if resume_from_checkpoint is True, otherwise None.
Optional[Path]: Path to the recipe checkpoint file if should_load_recipe_state is True, otherwise None.
Raises:
ValueError: If resume_from_checkpoint is True and the recipe checkpoint file is missing.
ValueError: If should_load_recipe_state is True and the recipe checkpoint file is missing.
"""
if not resume_from_checkpoint:
if not should_load_recipe_state:
return None

recipe_checkpoint_path = None
Expand All @@ -420,7 +420,7 @@ def get_recipe_checkpoint_path(
# TODO: improve this msg
if not recipe_checkpoint_path or not os.path.exists(recipe_checkpoint_path):
raise ValueError(
"If resume_from_checkpoint is True, recipe_checkpoint file must be provided."
"If should_load_recipe_state is True, recipe_checkpoint file must be provided."
)

return Path(recipe_checkpoint_path)
Expand All @@ -429,7 +429,7 @@ def get_recipe_checkpoint_path(
def get_adapter_checkpoint_path(
output_dir: Path,
adapter_checkpoint: Optional[str] = None,
resume_from_checkpoint: bool = False,
should_load_recipe_state: bool = False,
pattern: str = r"^epoch_(\d+)",
) -> Optional[Path]:
r"""
Expand All @@ -439,13 +439,13 @@ def get_adapter_checkpoint_path(
Args:
output_dir (Path): Directory containing the adapter checkpoint.
adapter_checkpoint (Optional[str]): Name of the adapter checkpoint file. Defaults to None.
resume_from_checkpoint (bool): Whether to resume from a checkpoint.
should_load_recipe_state (bool): Whether to load the recipe state from checkpoint.
pattern (str): Regex pattern to match the epoch folder. Defaults to "epoch_(\d+)".
Returns:
Optional[Path]: Path to the adapter checkpoint file, or None if not applicable.
"""
if not resume_from_checkpoint:
if not should_load_recipe_state:
return None

adapter_checkpoint_path = None
Expand All @@ -470,7 +470,7 @@ def get_model_checkpoint_path(
checkpoint_files: Union[List[str], Dict[str, str]],
checkpoint_dir: Union[str, Path],
output_dir: Union[str, Path],
resume_from_checkpoint: bool,
should_load_recipe_state: bool,
has_adapter_checkpoint: bool,
) -> list[Path]:
"""
Expand All @@ -488,21 +488,21 @@ def get_model_checkpoint_path(
it is converted to a list of formatted checkpoint filenames.
checkpoint_dir (Union[str, Path]): Directory containing the checkpoint files.
output_dir (Union[str, Path]): Directory to use when resuming from a checkpoint.
resume_from_checkpoint (bool): Whether to resume from a checkpoint.
should_load_recipe_state (bool): Whether to resume from a checkpoint.
has_adapter_checkpoint (bool): Indicates if there is an adapter checkpoint.
Returns:
list[Path]: Sorted list of paths to the checkpoint files.
Example:
>>> checkpoint_files = ["model-00001-of-00002.safetensors", "model-00002-of-00002.safetensors"]
>>> checkpoint_dir = "/path/to/checkpoints"
>>> output_dir = "/path/to/output"
>>> resume_from_checkpoint = True
>>> should_load_recipe_state = True
>>> has_adapter_checkpoint = False
>>> paths = get_model_checkpoint_path(
... checkpoint_files,
... checkpoint_dir,
... output_dir,
... resume_from_checkpoint,
... should_load_recipe_state,
... has_adapter_checkpoint
... )
>>> print(paths)
Expand Down Expand Up @@ -540,15 +540,15 @@ def validate_checkpoint_files(
)
checkpoint_files = formatted_checkpoint_files.build_checkpoint_filenames()

# Case 1: no resuming from ckpt
if not resume_from_checkpoint:
# Case 1: not loading the recipe state
if not should_load_recipe_state:
input_dir = checkpoint_dir

# Case 2: Resuming from ckpt, but its full finetuning (no adapter)
# Case 2: Loading the recipe state, but its full finetuning (no adapter)
elif not has_adapter_checkpoint:
input_dir = output_dir

# Case 3: Resuming from ckpt and has an adapter.
# Case 3: Loading the recipe state and has an adapter.
else:
# FIXME
# TODO: if the model has lora + trained weights, e.g. embeddings,
Expand Down

0 comments on commit e6d7ef5

Please sign in to comment.