Skip to content

Commit

Permalink
[torchtune][dcp] Do not expose checkpoint_client as a public API yet
Browse files Browse the repository at this point in the history
  • Loading branch information
saumishr committed Dec 13, 2024
1 parent e745717 commit 21b6ded
Show file tree
Hide file tree
Showing 5 changed files with 5 additions and 17 deletions.
7 changes: 3 additions & 4 deletions recipes/full_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,12 @@
from torchtune.data import padded_collate_packed
from torchtune.datasets import ConcatDataset
from torchtune.recipe_interfaces import FTRecipeInterface
from torchtune.training import (
from torchtune.training import DummyProfiler, PROFILER_KEY
from torchtune.training.activations import apply_selective_activation_checkpointing
from torchtune.training.checkpointing._checkpoint_client import (
CheckpointClient,
DummyProfiler,
PROFILER_KEY,
TrainingProgress,
)
from torchtune.training.activations import apply_selective_activation_checkpointing
from torchtune.training.lr_schedulers import get_lr

from tqdm import tqdm
Expand Down
2 changes: 1 addition & 1 deletion recipes/qat_lora_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]:
"""
self._checkpointer = config.instantiate(
cfg_checkpointer,
resume_from_checkpoint=self._resume_from_checkpoint,
should_load_recipe_state=self._resume_from_checkpoint,
)
checkpoint_dict = self._checkpointer.load_checkpoint()

Expand Down
4 changes: 0 additions & 4 deletions torchtune/training/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
from torchtune.training.checkpointing import (
ADAPTER_CONFIG,
ADAPTER_KEY,
CheckpointClient,
Checkpointer,
DistributedCheckpointer,
EPOCHS_KEY,
Expand All @@ -52,7 +51,6 @@
SEED_KEY,
STEPS_KEY,
TOTAL_EPOCHS_KEY,
TrainingProgress,
update_state_dict_for_classifier,
)
from torchtune.training.lr_schedulers import get_cosine_schedule_with_warmup, get_lr
Expand Down Expand Up @@ -80,7 +78,6 @@
"get_dtype",
"set_default_dtype",
"validate_expected_param_dtype",
"CheckpointClient",
"FullModelHFCheckpointer",
"FullModelMetaCheckpointer",
"DistributedCheckpointer",
Expand Down Expand Up @@ -134,5 +131,4 @@
"OffloadActivations",
"FormattedCheckpointFiles",
"scale_grads",
"TrainingProgress",
]
7 changes: 0 additions & 7 deletions torchtune/training/checkpointing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,6 @@
# LICENSE file in the root directory of this source tree.
from typing import Union

from torchtune.training.checkpointing._checkpoint_client import (
CheckpointClient,
TrainingProgress,
)

from torchtune.training.checkpointing._checkpointer import (
DistributedCheckpointer,
FullModelHFCheckpointer,
Expand Down Expand Up @@ -41,7 +36,6 @@
]

__all__ = [
"CheckpointClient",
"FullModelHFCheckpointer",
"FullModelMetaCheckpointer",
"FullModelTorchTuneCheckpointer",
Expand All @@ -61,5 +55,4 @@
"STEPS_KEY",
"TOTAL_EPOCHS_KEY",
"FormattedCheckpointFiles",
"TrainingProgress",
]
2 changes: 1 addition & 1 deletion torchtune/training/checkpointing/_checkpoint_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ def load_distributed_checkpoint(
) -> Dict[str, Any]:
"""
This method is used to resume training from a distributed checkpoint state.
Due to being disributed, this mehod is called on every rank.
Due to being distributed, this method is called on every rank.
"""
if self._is_rank_zero:
dcp_load_start = time.perf_counter()
Expand Down

0 comments on commit 21b6ded

Please sign in to comment.