diff --git a/docs/source/api_ref_training.rst b/docs/source/api_ref_training.rst index 0f0a392efe..9cba6fb9ea 100644 --- a/docs/source/api_ref_training.rst +++ b/docs/source/api_ref_training.rst @@ -52,7 +52,6 @@ Utilities for enabling and working with distributed training. init_distributed is_distributed - get_world_size_and_rank gather_cpu_state_dict .. _ac_label: diff --git a/docs/source/api_ref_utilities.rst b/docs/source/api_ref_utilities.rst index dd86817281..05c9283ddc 100644 --- a/docs/source/api_ref_utilities.rst +++ b/docs/source/api_ref_utilities.rst @@ -18,3 +18,4 @@ Miscellaneous get_device get_logger torch_version_ge + get_world_size_and_rank diff --git a/recipes/dev/early_exit_finetune_distributed.py b/recipes/dev/early_exit_finetune_distributed.py index aed914a463..642dabb15c 100644 --- a/recipes/dev/early_exit_finetune_distributed.py +++ b/recipes/dev/early_exit_finetune_distributed.py @@ -183,7 +183,7 @@ def __init__(self, cfg: DictConfig) -> None: # _is_rank_zero is used primarily for logging. In the future, the logger # should directly take care of this - _, rank = training.get_world_size_and_rank() + _, rank = utils.get_world_size_and_rank() self._is_rank_zero = rank == 0 # Training cfg @@ -646,7 +646,7 @@ def _setup_data( DistributedSamplers with Map-style Datasets which fit into memory. Other samplers, iterable datasets and streaming datasets are not supported. """ - world_size, rank = training.get_world_size_and_rank() + world_size, rank = utils.get_world_size_and_rank() if isinstance(cfg_dataset, ListConfig): datasets = [ @@ -826,7 +826,7 @@ def train(self) -> None: # clean up before training begins training.cleanup_before_training() - world_size, rank = training.get_world_size_and_rank() + world_size, rank = utils.get_world_size_and_rank() # zero out the gradients before starting training if not self._optimizer_in_bwd: diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index 4a227701d7..01c0607bbf 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -26,6 +26,10 @@ from torchtune.recipe_interfaces import FTRecipeInterface from torchtune.training import DummyProfiler, PROFILER_KEY from torchtune.training.activations import apply_selective_activation_checkpointing +from torchtune.training.checkpointing._checkpoint_client import ( + CheckpointClient, + TrainingProgress, +) from torchtune.training.lr_schedulers import get_lr from tqdm import tqdm @@ -133,14 +137,16 @@ def __init__(self, cfg: DictConfig) -> None: ) self._log_peak_memory_stats = False - _, rank = training.get_world_size_and_rank() + _, rank = utils.get_world_size_and_rank() self._is_rank_zero = rank == 0 # Training cfg self._resume_from_checkpoint = cfg.resume_from_checkpoint + self._enable_async_checkpointing = cfg.get("enable_async_checkpointing", False) self._gradient_accumulation_steps = cfg.gradient_accumulation_steps self._optimizer_in_bwd = cfg.get("optimizer_in_bwd", False) self._clip_grad_norm = cfg.get("clip_grad_norm", None) + self._checkpoint_client = CheckpointClient(cfg) # Optimizer in backward is not compatible with gradient accumulation or gradient clipping if self._optimizer_in_bwd: @@ -189,21 +195,6 @@ def __init__(self, cfg: DictConfig) -> None: self.max_steps_per_epoch = cfg.max_steps_per_epoch self.global_step = 0 - def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]: - """ - Extract the checkpoint state from file and validate. If resume_from_checkpoint - is True, this also includes the recipe state. - """ - self._checkpointer = config.instantiate( - cfg_checkpointer, - resume_from_checkpoint=self._resume_from_checkpoint, - ) - checkpoint_dict = self._checkpointer.load_checkpoint() - - if self._resume_from_checkpoint: - self._update_recipe_state(checkpoint_dict) - return checkpoint_dict - def _update_recipe_state(self, ckpt_dict: Dict[str, Any]) -> None: """ Updates the recipe state from checkpoint. @@ -255,7 +246,8 @@ def setup(self, cfg: DictConfig) -> None: # log config with parameter override self._metric_logger.log_config(cfg) - checkpoint_dict = self.load_checkpoint(cfg_checkpointer=cfg.checkpointer) + # Load the base model + checkpoint_dict = self._checkpoint_client.load_base_checkpoint() self._compile = cfg.get("compile", False) self._model = self._setup_model( @@ -276,11 +268,36 @@ def setup(self, cfg: DictConfig) -> None: optimizer_in_bwd=self._optimizer_in_bwd, opt_state_dict=( checkpoint_dict[training.OPT_KEY] - if self._resume_from_checkpoint + if training.OPT_KEY in checkpoint_dict else None ), ) + if self._resume_from_checkpoint: + # If async checkpointing is enabled, intermediate checkpoints are saved asynchronously + # using the DistributedCheckpointer. + # Therefore the recipe needs to load the distributed checkpoint to restore the training + # progress. + if self._enable_async_checkpointing: + try: + checkpoint_dict = ( + self._checkpoint_client.load_distributed_checkpoint( + self._model, + ( + self._optim_ckpt_wrapper + if self._optimizer_in_bwd + else self._optimizer + ), + ) + ) + except Exception as e: + log.warning( + f"Failed to load distributed checkpoint: {e}. Training will start from the base checkpoint." + ) + + # Update the recipe state from the checkpoint state dict. + self._update_recipe_state(checkpoint_dict) + # initialize loss self._loss_fn = config.instantiate(cfg.loss) @@ -547,6 +564,7 @@ def _setup_model( log, f"Instantiating model and loading checkpoint took {time.perf_counter() - init_start:.2f} secs", ) + if self._is_rank_zero: memory_stats = training.get_memory_stats(device=self._device) training.log_memory_stats(memory_stats) @@ -619,7 +637,7 @@ def _setup_data( DistributedSamplers with Map-style Datasets which fit into memory. Other samplers, iterable datasets and streaming datasets are not supported. """ - world_size, rank = training.get_world_size_and_rank() + world_size, rank = utils.get_world_size_and_rank() if isinstance(cfg_dataset, ListConfig): datasets = [ @@ -661,95 +679,6 @@ def _setup_data( return sampler, dataloader - def save_checkpoint( - self, - epoch: int, - ) -> None: - """ - Checkpoint the state of the recipe. The constructed checkpoint state dict - contains the following information: - - Model weights with key training.MODEL_KEY - - Relevant recipe state if training is not complete - - Checkpointer will save the model weights and recipe state in - different checkpoint files. To correctly resume training from an intermediate checkpoint, - the model weights and recipe state must be provided. - """ - # final dict passed onto the checkpointer - checkpoint_dict = {} - - intermediate_checkpoint = epoch + 1 < self.total_epochs - - utils.log_rank_zero( - log, - "Saving checkpoint. This may take some time. Retrieving full model state dict...", - ) - start = time.perf_counter() - - # To prevent GPU memory from spiking during checkpoint save, - # we consolidate the full model and optim state dicts on CPU for rank 0 - cpu_state_dict = training.gather_cpu_state_dict( - self._model.state_dict(), - self._is_rank_zero, - device=self._device, - ) - - utils.log_rank_zero( - log, - f"Getting full model state dict took {time.perf_counter() - start:.2f} secs", - ) - - if intermediate_checkpoint: - start = time.perf_counter() - utils.log_rank_zero(log, "Getting optimizer state dict...") - if not self._optimizer_in_bwd: - opt_state_dict = training.get_full_optimizer_state_dict( - self._optimizer, - self._is_rank_zero, - device=self._device, - ) - else: - opt_state_dict = {} - for param, opt in self._optim_ckpt_wrapper.optim_map.items(): - opt_state_dict[param] = training.get_full_optimizer_state_dict( - opt, self._is_rank_zero, device=self._device - ) - utils.log_rank_zero( - log, - f"Getting optimizer state dict took {time.perf_counter() - start:.2f} secs", - ) - else: - opt_state_dict = None - - # Now that we have the model and opt state dict, create the actual checkpoint dict - # to be sent to the checkpointer and ultimately written to file - - if self._is_rank_zero: - start = time.perf_counter() - checkpoint_dict.update({training.MODEL_KEY: cpu_state_dict}) - - # if training is in-progress, checkpoint the optimizer state and recipe state - # as well. - if intermediate_checkpoint: - checkpoint_dict.update( - { - training.OPT_KEY: opt_state_dict, - training.SEED_KEY: self.seed, - training.EPOCHS_KEY: self.epochs_run, - training.TOTAL_EPOCHS_KEY: self.total_epochs, - training.MAX_STEPS_KEY: self.max_steps_per_epoch, - } - ) - - self._checkpointer.save_checkpoint( - checkpoint_dict, - epoch=epoch, - intermediate_checkpoint=intermediate_checkpoint, - ) - log.info(f"Saving checkpoint took {time.perf_counter() - start:.2f} secs") - - torch.distributed.barrier() - def train(self) -> None: """ The core training loop. @@ -757,7 +686,7 @@ def train(self) -> None: # clean up before training begins training.cleanup_before_training() - world_size, rank = training.get_world_size_and_rank() + world_size, rank = utils.get_world_size_and_rank() # zero out the gradients before starting training if not self._optimizer_in_bwd: @@ -922,7 +851,21 @@ def train(self) -> None: self._profiler.step() self.epochs_run += 1 - self.save_checkpoint(epoch=curr_epoch) + self._checkpoint_client.save_checkpoint( + model=self._model, + optimizer=( + self._optimizer + if not self._optimizer_in_bwd + else self._optim_ckpt_wrapper + ), + training_progress=TrainingProgress( + seed=self.seed, + epochs_run=self.epochs_run, + total_epochs=self.total_epochs, + max_steps_per_epoch=self.max_steps_per_epoch, + ), + epoch=curr_epoch, + ) self._profiler.stop() diff --git a/recipes/full_finetune_single_device.py b/recipes/full_finetune_single_device.py index 0ab6ff3e63..946e970206 100644 --- a/recipes/full_finetune_single_device.py +++ b/recipes/full_finetune_single_device.py @@ -197,7 +197,7 @@ def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]: """ self._checkpointer = config.instantiate( cfg_checkpointer, - resume_from_checkpoint=self._resume_from_checkpoint, + should_load_recipe_state=self._resume_from_checkpoint, ) checkpoint_dict = self._checkpointer.load_checkpoint() diff --git a/recipes/knowledge_distillation_distributed.py b/recipes/knowledge_distillation_distributed.py index 7bf76b93bf..b7467e3286 100644 --- a/recipes/knowledge_distillation_distributed.py +++ b/recipes/knowledge_distillation_distributed.py @@ -116,7 +116,7 @@ def __init__(self, cfg: DictConfig) -> None: "fp16 precision is not supported in this recipe. Please use fp32 or bf16." ) - _, rank = training.get_world_size_and_rank() + _, rank = utils.get_world_size_and_rank() self._is_rank_zero = rank == 0 @@ -149,7 +149,7 @@ def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]: """ self._checkpointer = config.instantiate( cfg_checkpointer, - resume_from_checkpoint=self._resume_from_checkpoint, + should_load_recipe_state=self._resume_from_checkpoint, ) checkpoint_dict = self._checkpointer.load_checkpoint() @@ -646,7 +646,7 @@ def _setup_data( Map-style Datasets which fit into memory and an option for random shuffling. Samplers, iterable datasets, and streaming datasets are not supported. """ - world_size, rank = training.get_world_size_and_rank() + world_size, rank = utils.get_world_size_and_rank() if isinstance(cfg_dataset, ListConfig): datasets = [ @@ -815,7 +815,7 @@ def train(self) -> None: # clean up before training begins training.cleanup_before_training() - world_size, rank = training.get_world_size_and_rank() + world_size, rank = utils.get_world_size_and_rank() # zero out the gradients before starting training self._optimizer.zero_grad() diff --git a/recipes/knowledge_distillation_single_device.py b/recipes/knowledge_distillation_single_device.py index cd7995267b..71d850d791 100644 --- a/recipes/knowledge_distillation_single_device.py +++ b/recipes/knowledge_distillation_single_device.py @@ -147,7 +147,7 @@ def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]: """ self._checkpointer = config.instantiate( cfg_checkpointer, - resume_from_checkpoint=self._resume_from_checkpoint, + should_load_recipe_state=self._resume_from_checkpoint, ) checkpoint_dict = self._checkpointer.load_checkpoint() diff --git a/recipes/lora_dpo_distributed.py b/recipes/lora_dpo_distributed.py index 96d9b80101..c86d5720b2 100644 --- a/recipes/lora_dpo_distributed.py +++ b/recipes/lora_dpo_distributed.py @@ -131,7 +131,7 @@ def __init__(self, cfg: DictConfig) -> None: "full fp16 training is not supported with this recipe. Please use bf16 or fp32 instead." ) - _, rank = training.get_world_size_and_rank() + _, rank = utils.get_world_size_and_rank() self._is_rank_zero = rank == 0 @@ -188,7 +188,7 @@ def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]: """ self._checkpointer = config.instantiate( cfg_checkpointer, - resume_from_checkpoint=self._resume_from_checkpoint, + should_load_recipe_state=self._resume_from_checkpoint, ) checkpoint_dict = self._checkpointer.load_checkpoint() @@ -492,7 +492,7 @@ def _setup_data( DistributedSamplers with Map-style Datasets which fit into memory. Other samplers, iterable datasets and streaming datasets are not supported. """ - world_size, rank = training.get_world_size_and_rank() + world_size, rank = utils.get_world_size_and_rank() if isinstance(cfg_dataset, ListConfig): datasets = [ @@ -642,7 +642,7 @@ def train(self) -> None: # clean up before training begins training.cleanup_before_training() - _, rank = training.get_world_size_and_rank() + _, rank = utils.get_world_size_and_rank() # zero out the gradients before starting training self._optimizer.zero_grad() diff --git a/recipes/lora_dpo_single_device.py b/recipes/lora_dpo_single_device.py index 17f985e75f..bee78ad0d3 100644 --- a/recipes/lora_dpo_single_device.py +++ b/recipes/lora_dpo_single_device.py @@ -145,7 +145,7 @@ def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]: """ self._checkpointer = config.instantiate( cfg_checkpointer, - resume_from_checkpoint=self._resume_from_checkpoint, + should_load_recipe_state=self._resume_from_checkpoint, ) checkpoint_dict = self._checkpointer.load_checkpoint() diff --git a/recipes/lora_finetune_distributed.py b/recipes/lora_finetune_distributed.py index d71434fb74..74516e7fa2 100644 --- a/recipes/lora_finetune_distributed.py +++ b/recipes/lora_finetune_distributed.py @@ -135,7 +135,7 @@ def __init__(self, cfg: DictConfig) -> None: "full fp16 training is not supported with this recipe. Please use bf16 or fp32 instead." ) - _, rank = training.get_world_size_and_rank() + _, rank = utils.get_world_size_and_rank() self._is_rank_zero = rank == 0 @@ -197,7 +197,7 @@ def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]: """ self._checkpointer = config.instantiate( cfg_checkpointer, - resume_from_checkpoint=self._resume_from_checkpoint, + should_load_recipe_state=self._resume_from_checkpoint, ) checkpoint_dict = self._checkpointer.load_checkpoint() @@ -584,7 +584,7 @@ def _setup_data( DistributedSamplers with Map-style Datasets which fit into memory. Other samplers, iterable datasets and streaming datasets are not supported. """ - world_size, rank = training.get_world_size_and_rank() + world_size, rank = utils.get_world_size_and_rank() if isinstance(cfg_dataset, ListConfig): datasets = [ @@ -746,7 +746,7 @@ def train(self) -> None: # clean up before training begins training.cleanup_before_training() - world_size, rank = training.get_world_size_and_rank() + world_size, rank = utils.get_world_size_and_rank() # zero out the gradients before starting training self._optimizer.zero_grad() diff --git a/recipes/lora_finetune_single_device.py b/recipes/lora_finetune_single_device.py index d1b5e3e421..9a3f3eacfb 100644 --- a/recipes/lora_finetune_single_device.py +++ b/recipes/lora_finetune_single_device.py @@ -188,7 +188,7 @@ def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]: """ self._checkpointer = config.instantiate( cfg_checkpointer, - resume_from_checkpoint=self._resume_from_checkpoint, + should_load_recipe_state=self._resume_from_checkpoint, ) checkpoint_dict = self._checkpointer.load_checkpoint() diff --git a/recipes/ppo_full_finetune_single_device.py b/recipes/ppo_full_finetune_single_device.py index 1030217d74..cb6357c3dc 100644 --- a/recipes/ppo_full_finetune_single_device.py +++ b/recipes/ppo_full_finetune_single_device.py @@ -377,22 +377,22 @@ def _setup_checkpointers( policy_checkpointer = config.instantiate( policy_cfg, - resume_from_checkpoint=self._resume_from_checkpoint, + should_load_recipe_state=self._resume_from_checkpoint, ) ref_policy_checkpointer = config.instantiate( ref_policy_cfg, - resume_from_checkpoint=False, + should_load_recipe_state=False, ) value_checkpointer = config.instantiate( value_cfg, - resume_from_checkpoint=False, + should_load_recipe_state=False, ) reward_checkpointer = config.instantiate( reward_cfg, - resume_from_checkpoint=False, + should_load_recipe_state=False, ) return ( diff --git a/recipes/qat_distributed.py b/recipes/qat_distributed.py index e005dc0247..6c79a6cefa 100644 --- a/recipes/qat_distributed.py +++ b/recipes/qat_distributed.py @@ -144,7 +144,7 @@ def __init__(self, cfg: DictConfig) -> None: ) self._log_peak_memory_stats = False - _, rank = training.get_world_size_and_rank() + _, rank = utils.get_world_size_and_rank() self._is_rank_zero = rank == 0 # Training cfg @@ -209,7 +209,7 @@ def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]: """ self._checkpointer = config.instantiate( cfg_checkpointer, - resume_from_checkpoint=self._resume_from_checkpoint, + should_load_recipe_state=self._resume_from_checkpoint, ) checkpoint_dict = self._checkpointer.load_checkpoint() @@ -591,7 +591,7 @@ def _setup_data( DistributedSamplers with Map-style Datasets which fit into memory. Other samplers, iterable datasets and streaming datasets are not supported. """ - world_size, rank = training.get_world_size_and_rank() + world_size, rank = utils.get_world_size_and_rank() if isinstance(cfg_dataset, ListConfig): datasets = [ @@ -729,7 +729,7 @@ def train(self) -> None: # clean up before training begins training.cleanup_before_training() - world_size, rank = training.get_world_size_and_rank() + world_size, rank = utils.get_world_size_and_rank() # zero out the gradients before starting training if not self._optimizer_in_bwd: diff --git a/recipes/qat_lora_finetune_distributed.py b/recipes/qat_lora_finetune_distributed.py index 6368fffc8e..d047d77d41 100644 --- a/recipes/qat_lora_finetune_distributed.py +++ b/recipes/qat_lora_finetune_distributed.py @@ -149,7 +149,7 @@ def __init__(self, cfg: DictConfig) -> None: "full fp16 training is not supported with this recipe. Please use bf16 or fp32 instead." ) - _, rank = training.get_world_size_and_rank() + _, rank = utils.get_world_size_and_rank() # _is_rank_zero is used primarily for logging. In the future, the logger # should directly take care of this @@ -213,7 +213,7 @@ def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]: """ self._checkpointer = config.instantiate( cfg_checkpointer, - resume_from_checkpoint=self._resume_from_checkpoint, + should_load_recipe_state=self._resume_from_checkpoint, ) checkpoint_dict = self._checkpointer.load_checkpoint() @@ -620,7 +620,7 @@ def _setup_data( DistributedSamplers with Map-style Datasets which fit into memory. Other samplers, iterable datasets and streaming datasets are not supported. """ - world_size, rank = training.get_world_size_and_rank() + world_size, rank = utils.get_world_size_and_rank() if isinstance(cfg_dataset, ListConfig): datasets = [ @@ -784,7 +784,7 @@ def train(self) -> None: # clean up before training begins training.cleanup_before_training() - world_size, rank = training.get_world_size_and_rank() + world_size, rank = utils.get_world_size_and_rank() # zero out the gradients before starting training self._optimizer.zero_grad() diff --git a/tests/recipes/test_full_finetune_distributed.py b/tests/recipes/test_full_finetune_distributed.py index 6db3b0a25d..4cdc42d96b 100644 --- a/tests/recipes/test_full_finetune_distributed.py +++ b/tests/recipes/test_full_finetune_distributed.py @@ -35,27 +35,34 @@ class TestFullFinetuneDistributedRecipe: - def _get_test_config_overrides(self): + def _get_test_config_overrides(self, epochs: int = 2): return [ "dtype=fp32", "enable_activation_checkpointing=False", "enable_activation_offloading=False", "dataset.train_on_input=False", "seed=9", - "epochs=2", + f"epochs={epochs}", "max_steps_per_epoch=2", "optimizer=torch.optim.AdamW", "optimizer.lr=2e-5", "log_every_n_steps=1", ] + dummy_alpaca_dataset_config() - def _fetch_expected_loss_values(self, model_type): + def _fetch_expected_loss_values_multi_rank(self, model_type): loss_values_map = { "llama2": [10.5209, 10.5217, 10.4945, 10.5136], "llama3": [11.9839, 11.9684, 11.9596, 11.93656], } return loss_values_map[model_type] + def _fetch_expected_loss_values_single_rank(self, model_type): + loss_values_map = { + "llama2": [10.5051, 10.5572, 10.4780, 10.5678], + "llama3": [11.9742, 12.0049, 11.9382, 12.0464], + } + return loss_values_map[model_type] + @pytest.mark.integration_test @pytest.mark.parametrize( "config, model_type, ckpt_type, micro_batch_size, gradient_accumulation_steps, optim_in_bwd", @@ -117,7 +124,71 @@ def test_loss( monkeypatch.setattr(sys, "argv", cmd) runpy.run_path(TUNE_PATH, run_name="__main__") loss_values = get_loss_values_from_metric_logger(log_file) - expected_loss_values = self._fetch_expected_loss_values(model_type) + expected_loss_values = self._fetch_expected_loss_values_multi_rank(model_type) + torch.testing.assert_close( + loss_values, expected_loss_values, rtol=1e-4, atol=1e-4 + ) + + @pytest.mark.integration_test + @pytest.mark.parametrize( + "config, model_type, ckpt_type, micro_batch_size, gradient_accumulation_steps, optim_in_bwd", + [ + ("llama2/7B_full", "llama2", "hf", 1, 4, False), + ("llama3/8B_full", "llama3", "tune", 1, 4, False), + ("llama3/8B_full", "llama3", "tune", 4, 1, True), + ], + ) + @gpu_test(gpu_count=1) + def test_loss_single_rank( + self, + micro_batch_size, + gradient_accumulation_steps, + config, + model_type, + ckpt_type, + optim_in_bwd, + tmpdir, + monkeypatch, + ): + ckpt_component = CKPT_COMPONENT_MAP[ckpt_type] + ckpt = model_type + "_" + ckpt_type + ckpt_path = Path(CKPT_MODEL_PATHS[ckpt]) + tokenizer_path = Path(TOKENIZER_PATHS[model_type]) + ckpt_dir = ckpt_path.parent + log_file = gen_log_file_name(tmpdir) + + # Config file needed for model conversion. + write_hf_ckpt_config(ckpt_dir) + + cmd = f""" + tune run --nnodes 1 --nproc_per_node 1 full_finetune_distributed \ + --config {config} \ + batch_size={micro_batch_size} \ + gradient_accumulation_steps={gradient_accumulation_steps} \ + output_dir={tmpdir} \ + checkpointer._component_={ckpt_component} \ + checkpointer.checkpoint_dir='{ckpt_dir}' \ + checkpointer.checkpoint_files=[{ckpt_path}]\ + checkpointer.output_dir={tmpdir} \ + checkpointer.model_type={model_type.upper()} \ + tokenizer.path='{tokenizer_path}' \ + tokenizer.prompt_template=null \ + metric_logger.filename={log_file} \ + """.split() + model_config = MODEL_TEST_CONFIGS[model_type] + cmd = cmd + self._get_test_config_overrides() + model_config + # "optimizer_in_bwd=True" would free gradient info before clip_grad, causing + # wrong grad_norm, so we only test one of them each time. But loss values + # should be the same. + if not optim_in_bwd: + cmd.append("clip_grad_norm=100") + else: + cmd.append("optimizer_in_bwd=True") + + monkeypatch.setattr(sys, "argv", cmd) + runpy.run_path(TUNE_PATH, run_name="__main__") + loss_values = get_loss_values_from_metric_logger(log_file) + expected_loss_values = self._fetch_expected_loss_values_single_rank(model_type) torch.testing.assert_close( loss_values, expected_loss_values, rtol=1e-4, atol=1e-4 ) @@ -206,9 +277,253 @@ def test_training_state_on_resume( monkeypatch.setattr(sys, "argv", cmd_2) runpy.run_path(TUNE_PATH, run_name="__main__") - expected_loss_values = self._fetch_expected_loss_values(model_type)[2:] + expected_loss_values = self._fetch_expected_loss_values_multi_rank(model_type)[ + 2: + ] loss_values = get_loss_values_from_metric_logger(log_file) torch.testing.assert_close( loss_values, expected_loss_values, rtol=1e-4, atol=1e-4 ) + + @pytest.mark.integration_test + @pytest.mark.parametrize( + "config, model_type, ckpt_type, micro_batch_size, gradient_accumulation_steps, optim_in_bwd", + [ + ("llama2/7B_full", "llama2", "hf", 1, 4, False), + ("llama3/8B_full", "llama3", "tune", 1, 4, False), + ("llama3/8B_full", "llama3", "tune", 4, 1, True), + ], + ) + @gpu_test(gpu_count=1) + def test_training_state_on_resume_from_distributed_checkpoint_single_rank( + self, + micro_batch_size, + gradient_accumulation_steps, + config, + model_type, + ckpt_type, + optim_in_bwd, + tmpdir, + monkeypatch, + ): + """Test whether the recipe state is correctly updated on resume. Since this + is model agnostic, we should run this on the small model only. The test + consists of three stages: + - Train a model for 2 epochs + - Resume training after epoch 1 + - Make sure final loss matches the expected value of a model successfully resumed from a ckpt + """ + + ckpt_component = CKPT_COMPONENT_MAP[ckpt_type] + ckpt = model_type + "_" + ckpt_type + ckpt_path = Path(CKPT_MODEL_PATHS[ckpt]) + tokenizer_path = Path(TOKENIZER_PATHS[model_type]) + ckpt_dir = ckpt_path.parent + log_file = gen_log_file_name(tmpdir) + + # Config file needed for model conversion. + # Create a second copy for training resume + write_hf_ckpt_config(ckpt_dir) + write_hf_ckpt_config(tmpdir) + + # Train for two epochs + cmd_1 = f""" + tune run --nnodes 1 --nproc_per_node 1 full_finetune_distributed \ + --config {config} \ + batch_size={micro_batch_size} \ + gradient_accumulation_steps={gradient_accumulation_steps} \ + output_dir={tmpdir} \ + checkpointer._component_={ckpt_component} \ + checkpointer.checkpoint_dir='{ckpt_dir}' \ + checkpointer.checkpoint_files=[{ckpt_path}]\ + checkpointer.output_dir={tmpdir} \ + checkpointer.model_type={model_type.upper()} \ + tokenizer.path='{tokenizer_path}' \ + tokenizer.prompt_template=null \ + metric_logger.filename={log_file} \ + enable_async_checkpointing=True \ + """.split() + + model_config = MODEL_TEST_CONFIGS[model_type] + cmd_1 = cmd_1 + self._get_test_config_overrides() + model_config + # "optimizer_in_bwd=True" would free gradient info before clip_grad, causing + # wrong grad_norm, so we only test one of them each time. But loss values + # should be the same. + if not optim_in_bwd: + cmd_1.append("clip_grad_norm=100") + cmd_1.append("optimizer_in_bwd=False") + else: + cmd_1.append("optimizer_in_bwd=True") + + monkeypatch.setattr(sys, "argv", cmd_1) + runpy.run_path(TUNE_PATH, run_name="__main__") + + expected_loss_values_first_run = get_loss_values_from_metric_logger(log_file) + + resumed_log_dir = (tmpdir / "resumed/").mkdir() + resumed_log_file = gen_log_file_name(resumed_log_dir) + + # Resume training + cmd_2 = f""" + tune run --nnodes 1 --nproc_per_node 1 full_finetune_distributed \ + --config {config} \ + batch_size={micro_batch_size} \ + gradient_accumulation_steps={gradient_accumulation_steps} \ + output_dir={tmpdir} \ + checkpointer._component_={ckpt_component} \ + checkpointer.checkpoint_dir='{ckpt_dir}' \ + checkpointer.checkpoint_files=[{ckpt_path}]\ + checkpointer.output_dir={tmpdir} \ + checkpointer.model_type={model_type.upper()} \ + tokenizer.path='{tokenizer_path}' \ + tokenizer.prompt_template=null \ + metric_logger.filename={resumed_log_file} \ + resume_from_checkpoint=True \ + enable_async_checkpointing=True \ + """.split() + + cmd_2 = cmd_2 + self._get_test_config_overrides(epochs=3) + model_config + + if not optim_in_bwd: + cmd_2.append("clip_grad_norm=100") + cmd_2.append("optimizer_in_bwd=False") + else: + cmd_2.append("optimizer_in_bwd=True") + + monkeypatch.setattr(sys, "argv", cmd_2) + runpy.run_path(TUNE_PATH, run_name="__main__") + + # Validate that the expected loss values are close to the ones observed in the first run + expected_loss_values = self._fetch_expected_loss_values_single_rank(model_type) + torch.testing.assert_close( + expected_loss_values_first_run, expected_loss_values, rtol=1e-4, atol=1e-4 + ) + + # Second epoch only + # Validate that the expected loss values are close to the ones observed after the resume + resumed_loss_values = get_loss_values_from_metric_logger(resumed_log_file) + torch.testing.assert_close( + resumed_loss_values[:2], expected_loss_values[2:], rtol=1e-4, atol=1e-4 + ) + + @pytest.mark.integration_test + @pytest.mark.parametrize( + "config, model_type, ckpt_type, micro_batch_size, gradient_accumulation_steps, optim_in_bwd", + [ + ("llama2/7B_full", "llama2", "hf", 1, 4, False), + ("llama3/8B_full", "llama3", "tune", 1, 4, False), + ("llama3/8B_full", "llama3", "tune", 4, 1, True), + ], + ) + @gpu_test(gpu_count=2) + def test_training_state_on_resume_from_distributed_checkpoint_multi_rank( + self, + micro_batch_size, + gradient_accumulation_steps, + config, + model_type, + ckpt_type, + optim_in_bwd, + tmpdir, + monkeypatch, + ): + """Test whether the recipe state is correctly updated on resume. Since this + is model agnostic, we should run this on the small model only. The test + consists of three stages: + - Train a model for 2 epochs + - Resume training after epoch 1 + - Make sure final loss matches the expected value of a model successfully resumed from a ckpt + """ + + ckpt_component = CKPT_COMPONENT_MAP[ckpt_type] + ckpt = model_type + "_" + ckpt_type + ckpt_path = Path(CKPT_MODEL_PATHS[ckpt]) + tokenizer_path = Path(TOKENIZER_PATHS[model_type]) + ckpt_dir = ckpt_path.parent + log_file = gen_log_file_name(tmpdir) + + # Config file needed for model conversion. + # Create a second copy for training resume + write_hf_ckpt_config(ckpt_dir) + write_hf_ckpt_config(tmpdir) + + # Train for two epochs + cmd_1 = f""" + tune run --nnodes 1 --nproc_per_node 2 full_finetune_distributed \ + --config {config} \ + batch_size={micro_batch_size} \ + gradient_accumulation_steps={gradient_accumulation_steps} \ + output_dir={tmpdir} \ + checkpointer._component_={ckpt_component} \ + checkpointer.checkpoint_dir='{ckpt_dir}' \ + checkpointer.checkpoint_files=[{ckpt_path}]\ + checkpointer.output_dir={tmpdir} \ + checkpointer.model_type={model_type.upper()} \ + tokenizer.path='{tokenizer_path}' \ + tokenizer.prompt_template=null \ + metric_logger.filename={log_file} \ + enable_async_checkpointing=True \ + """.split() + + model_config = MODEL_TEST_CONFIGS[model_type] + cmd_1 = cmd_1 + self._get_test_config_overrides() + model_config + # "optimizer_in_bwd=True" would free gradient info before clip_grad, causing + # wrong grad_norm, so we only test one of them each time. But loss values + # should be the same. + if not optim_in_bwd: + cmd_1.append("clip_grad_norm=100") + cmd_1.append("optimizer_in_bwd=False") + else: + cmd_1.append("optimizer_in_bwd=True") + + monkeypatch.setattr(sys, "argv", cmd_1) + runpy.run_path(TUNE_PATH, run_name="__main__") + + expected_loss_values_first_run = get_loss_values_from_metric_logger(log_file) + + resumed_log_dir = (tmpdir / "resumed/").mkdir() + resumed_log_file = gen_log_file_name(resumed_log_dir) + + # Resume training + cmd_2 = f""" + tune run --nnodes 1 --nproc_per_node 2 full_finetune_distributed \ + --config {config} \ + batch_size={micro_batch_size} \ + gradient_accumulation_steps={gradient_accumulation_steps} \ + output_dir={tmpdir} \ + checkpointer._component_={ckpt_component} \ + checkpointer.checkpoint_dir='{ckpt_dir}' \ + checkpointer.checkpoint_files=[{ckpt_path}]\ + checkpointer.output_dir={tmpdir} \ + checkpointer.model_type={model_type.upper()} \ + tokenizer.path='{tokenizer_path}' \ + tokenizer.prompt_template=null \ + metric_logger.filename={resumed_log_file} \ + resume_from_checkpoint=True \ + enable_async_checkpointing=True \ + """.split() + + cmd_2 = cmd_2 + self._get_test_config_overrides(epochs=3) + model_config + + if not optim_in_bwd: + cmd_2.append("clip_grad_norm=100") + cmd_2.append("optimizer_in_bwd=False") + else: + cmd_2.append("optimizer_in_bwd=True") + + monkeypatch.setattr(sys, "argv", cmd_2) + runpy.run_path(TUNE_PATH, run_name="__main__") + + # Validate that the expected loss values are close to the ones observed in the first run + expected_loss_values = self._fetch_expected_loss_values_multi_rank(model_type) + torch.testing.assert_close( + expected_loss_values_first_run, expected_loss_values, rtol=1e-4, atol=1e-4 + ) + + # Second epoch only + # Validate that the expected loss values are close to the ones observed after the resume + resumed_loss_values = get_loss_values_from_metric_logger(resumed_log_file) + torch.testing.assert_close( + resumed_loss_values[:2], expected_loss_values[2:], rtol=1e-4, atol=1e-4 + ) diff --git a/tests/test_utils.py b/tests/test_utils.py index f7c72965e4..6497539869 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -13,7 +13,7 @@ from functools import partial from io import StringIO from pathlib import Path -from typing import Any, Dict, Generator, List, Mapping, Optional, TextIO, Tuple, Union +from typing import Any, Generator, List, Mapping, Optional, TextIO, Tuple, Union import pytest @@ -308,7 +308,7 @@ def gpu_test(gpu_count: int = 1): return pytest.mark.skipif(local_gpu_count < gpu_count, reason=message) -def get_loss_values_from_metric_logger(log_file_path: str) -> Dict[str, float]: +def get_loss_values_from_metric_logger(log_file_path: str) -> List[float]: """ Given an output directory containing metric logger .txt file, parse the .txt and return a list of losses from each logged iteration. diff --git a/tests/torchtune/training/checkpointing/test_distributed_checkpointer.py b/tests/torchtune/training/checkpointing/test_distributed_checkpointer.py new file mode 100644 index 0000000000..e325499a58 --- /dev/null +++ b/tests/torchtune/training/checkpointing/test_distributed_checkpointer.py @@ -0,0 +1,153 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import os +import shutil +from pathlib import Path + +import pytest +import torch +from torch import randn, zeros + +from torchtune.training.checkpointing import DistributedCheckpointer +from torchtune.training.seed import set_seed + +_VOCAB_SIZE = 100 +_DIM = 64 +_HIDDEN_DIM = 256 + + +@pytest.fixture(autouse=True) +def random(): + set_seed(16) + + +class TestDistributedCheckpointer: + @pytest.fixture + def weight_dtype(self): + return torch.float16 + + @pytest.fixture + def state_dict(self, weight_dtype): + """ + State dict + """ + state_dict = { + "model.embed_tokens.weight": randn(_VOCAB_SIZE, _DIM, dtype=weight_dtype), + "model.layers.0.input_layernorm.weight": randn(_DIM, dtype=weight_dtype), + "model.layers.0.self_attn.q_proj.weight": randn( + _DIM, _DIM, dtype=weight_dtype + ), + "model.layers.0.self_attn.k_proj.weight": randn( + _DIM, _DIM, dtype=weight_dtype + ), + "model.layers.0.self_attn.v_proj.weight": randn( + _DIM, _DIM, dtype=weight_dtype + ), + "model.layers.0.self_attn.o_proj.weight": randn( + _DIM, _DIM, dtype=weight_dtype + ), + "model.layers.0.post_attention_layernorm.weight": randn( + _DIM, dtype=weight_dtype + ), + "model.layers.0.self_attn.rotary_emb.inv_freq": randn( + _DIM, dtype=weight_dtype + ), + "model.layers.0.mlp.gate_proj.weight": randn( + _HIDDEN_DIM, _DIM, dtype=weight_dtype + ), + "model.layers.0.mlp.down_proj.weight": randn( + _DIM, _HIDDEN_DIM, dtype=weight_dtype + ), + "model.layers.0.mlp.up_proj.weight": randn( + _HIDDEN_DIM, _DIM, dtype=weight_dtype + ), + "model.norm.weight": randn(_DIM, dtype=weight_dtype), + "lm_head.weight": randn(_VOCAB_SIZE, _DIM, dtype=weight_dtype), + } + + return state_dict + + @pytest.fixture + def empty_state_dict(self, weight_dtype): + """ + State dict + """ + state_dict = { + "model.embed_tokens.weight": zeros(_VOCAB_SIZE, _DIM, dtype=weight_dtype), + "model.layers.0.input_layernorm.weight": zeros(_DIM, dtype=weight_dtype), + "model.layers.0.self_attn.q_proj.weight": zeros( + _DIM, _DIM, dtype=weight_dtype + ), + "model.layers.0.self_attn.k_proj.weight": zeros( + _DIM, _DIM, dtype=weight_dtype + ), + "model.layers.0.self_attn.v_proj.weight": zeros( + _DIM, _DIM, dtype=weight_dtype + ), + "model.layers.0.self_attn.o_proj.weight": zeros( + _DIM, _DIM, dtype=weight_dtype + ), + "model.layers.0.post_attention_layernorm.weight": zeros( + _DIM, dtype=weight_dtype + ), + "model.layers.0.self_attn.rotary_emb.inv_freq": zeros( + _DIM, dtype=weight_dtype + ), + "model.layers.0.mlp.gate_proj.weight": zeros( + _HIDDEN_DIM, _DIM, dtype=weight_dtype + ), + "model.layers.0.mlp.down_proj.weight": zeros( + _DIM, _HIDDEN_DIM, dtype=weight_dtype + ), + "model.layers.0.mlp.up_proj.weight": zeros( + _HIDDEN_DIM, _DIM, dtype=weight_dtype + ), + "model.norm.weight": zeros(_DIM, dtype=weight_dtype), + "lm_head.weight": zeros(_VOCAB_SIZE, _DIM, dtype=weight_dtype), + } + + return state_dict + + @pytest.fixture + def distributed_checkpointer(self, tmp_path) -> DistributedCheckpointer: + return DistributedCheckpointer( + checkpoint_dir=tmp_path, + output_dir=tmp_path, + ) + + def test_save_load_checkpoint( + self, distributed_checkpointer, state_dict, empty_state_dict + ): + """ + Test ``load_checkpoint`` method within the DistributedCheckpointer. + + We test: + * ``load_checkpoint`` loads the right sets of keys + * Internal state of the checkpointer is correctly updated. + """ + + distributed_checkpointer.save_checkpoint( + state_dict=state_dict, epoch=1, save_async=False + ) + + checkpoint_path = Path.joinpath( + distributed_checkpointer._output_dir, + f"{distributed_checkpointer._checkpoint_dir_prefix}_1", + ) + + assert os.path.exists(checkpoint_path) + + distributed_checkpointer.load_checkpoint( + state_dict=empty_state_dict, + ) + + for key in state_dict.keys(): + assert torch.equal(state_dict[key], empty_state_dict[key]) + + # clean ups + shutil.rmtree(checkpoint_path) diff --git a/tests/torchtune/training/test_distributed.py b/tests/torchtune/training/test_distributed.py index 638e7799a3..87f3656e21 100644 --- a/tests/torchtune/training/test_distributed.py +++ b/tests/torchtune/training/test_distributed.py @@ -56,15 +56,6 @@ def _test_worker_fn(init_pg_explicit: bool) -> None: pg_backend == "gloo" ), f"Expected 'gloo' backend, but received {pg_backend}" - @staticmethod - def _test_world_size_with_cpu_device(expected_world_size: int) -> None: - training.init_distributed(backend="gloo") - world_size, _ = training.get_world_size_and_rank() - if world_size != expected_world_size: - raise AssertionError( - f"Expected different world size: received {world_size}, expected {expected_world_size}" - ) - def _test_launch_worker( self, get_pet_launch_config, @@ -84,13 +75,6 @@ def test_init_from_env_dup(self, get_pet_launch_config) -> None: # trivial test case to ensure test passes with no exceptions assert True - def test_world_size_with_cpu(self, get_pet_launch_config) -> None: - desired_world_size = 4 - lc = get_pet_launch_config(desired_world_size) - launcher.elastic_launch(lc, entrypoint=self._test_world_size_with_cpu_device)( - desired_world_size - ) - def test_validate_no_params_on_meta_device(self) -> None: with torch.device("meta"): model = torch.nn.Linear(3, 3) diff --git a/tests/torchtune/utils/test_device.py b/tests/torchtune/utils/test_device.py index b96eb5ae3b..37d0063828 100644 --- a/tests/torchtune/utils/test_device.py +++ b/tests/torchtune/utils/test_device.py @@ -12,6 +12,8 @@ import pytest import torch + +from torch.distributed import launcher from torchtune.utils._device import ( _get_device_type_from_env, _setup_device, @@ -20,6 +22,7 @@ get_device, get_device_support, get_torch_device_namespace, + get_world_size_and_rank, ) @@ -27,6 +30,24 @@ class TestDevice: cuda_available: bool = torch.cuda.is_available() + def _create_world(self, expected_world_size: int) -> None: + torch.distributed.init_process_group(backend="gloo") + world_size, _ = get_world_size_and_rank() + if world_size != expected_world_size: + raise AssertionError( + f"Expected different world size: received {world_size}, expected {expected_world_size}" + ) + + def test_world_size_with_cpu(self, get_pet_launch_config) -> None: + desired_world_size = 4 + lc = get_pet_launch_config(desired_world_size) + launcher.elastic_launch(lc, entrypoint=self._create_world)(desired_world_size) + + def test_rank_with_cpu_device(self) -> None: + """Very, very basic test""" + _, rank = get_world_size_and_rank() + assert rank == 0 + @patch("torch.cuda.is_available", return_value=False) def test_get_cpu_device(self, mock_cuda): devices = [None, "cpu", "meta"] diff --git a/torchtune/training/__init__.py b/torchtune/training/__init__.py index 5b7ae4d8d3..9dd31246c3 100644 --- a/torchtune/training/__init__.py +++ b/torchtune/training/__init__.py @@ -37,6 +37,7 @@ ADAPTER_CONFIG, ADAPTER_KEY, Checkpointer, + DistributedCheckpointer, EPOCHS_KEY, FormattedCheckpointFiles, FullModelHFCheckpointer, @@ -79,6 +80,7 @@ "validate_expected_param_dtype", "FullModelHFCheckpointer", "FullModelMetaCheckpointer", + "DistributedCheckpointer", "FullModelTorchTuneCheckpointer", "ModelType", "Checkpointer", diff --git a/torchtune/training/_distributed.py b/torchtune/training/_distributed.py index 96c9e6f65b..025b9db159 100644 --- a/torchtune/training/_distributed.py +++ b/torchtune/training/_distributed.py @@ -22,9 +22,8 @@ from torch.optim import Optimizer from torchao.dtypes.nf4tensor import NF4Tensor, to_nf4 from torchtune.modules import TransformerDecoder -from torchtune.utils import get_logger - -from torchtune.utils._device import get_device +from torchtune.utils import get_device, get_logger +from torchtune.utils._logging import deprecated _log: logging.Logger = get_logger() @@ -117,6 +116,10 @@ def set_torch_num_threads() -> None: _log.info(f"Set intra op parallelism no. of threads to {num_threads}") +@deprecated( + msg="`get_world_size_and_rank` will move to `torchtune.utils._device` in future releases. " + "Please use `torchtune.utils.get_world_size_and_rank` instead." +) def get_world_size_and_rank() -> Tuple[int, int]: """Function that gets the current world size (aka total number of ranks) and rank number of the current process in the default process group. diff --git a/torchtune/training/_profiler.py b/torchtune/training/_profiler.py index d296006b5d..5fe3d74b5c 100644 --- a/torchtune/training/_profiler.py +++ b/torchtune/training/_profiler.py @@ -18,9 +18,8 @@ from omegaconf import DictConfig from torch._C._profiler import _ExperimentalConfig from torch.profiler import tensorboard_trace_handler -from torchtune.training import get_world_size_and_rank -from torchtune.utils import get_logger +from torchtune.utils import get_logger, get_world_size_and_rank log = get_logger("INFO") diff --git a/torchtune/training/checkpointing/__init__.py b/torchtune/training/checkpointing/__init__.py index a142856e6b..9db9dca608 100644 --- a/torchtune/training/checkpointing/__init__.py +++ b/torchtune/training/checkpointing/__init__.py @@ -6,6 +6,7 @@ from typing import Union from torchtune.training.checkpointing._checkpointer import ( + DistributedCheckpointer, FullModelHFCheckpointer, FullModelMetaCheckpointer, FullModelTorchTuneCheckpointer, @@ -28,6 +29,7 @@ ) Checkpointer = Union[ + DistributedCheckpointer, FullModelHFCheckpointer, FullModelMetaCheckpointer, FullModelTorchTuneCheckpointer, @@ -37,6 +39,7 @@ "FullModelHFCheckpointer", "FullModelMetaCheckpointer", "FullModelTorchTuneCheckpointer", + "DistributedCheckpointer", "ModelType", "Checkpointer", "update_state_dict_for_classifier", diff --git a/torchtune/training/checkpointing/_checkpoint_client.py b/torchtune/training/checkpointing/_checkpoint_client.py new file mode 100644 index 0000000000..90a4208b6b --- /dev/null +++ b/torchtune/training/checkpointing/_checkpoint_client.py @@ -0,0 +1,354 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import time +from dataclasses import dataclass +from typing import Any, Dict, Union + +import torch +from omegaconf import DictConfig + +from torch.distributed.checkpoint.state_dict import ( + _init_optim_state, + set_model_state_dict, + set_state_dict, +) +from torchtune import config, training, utils +from torchtune.training.checkpointing._checkpointer import DistributedCheckpointer +from torchtune.training.memory import OptimizerInBackwardWrapper + +log = utils.get_logger("DEBUG") + + +@dataclass +class TrainingProgress: + """ + This is training progress metadata. + """ + + seed: int + epochs_run: int + total_epochs: int + max_steps_per_epoch: int + + def state_dict(self) -> Dict[str, object]: + return { + training.SEED_KEY: self.seed, + training.EPOCHS_KEY: self.epochs_run, + training.TOTAL_EPOCHS_KEY: self.total_epochs, + training.MAX_STEPS_KEY: self.max_steps_per_epoch, + } + + +class CheckpointClient: + """ + Stateful checkpointing client for TorchTune recipes. This class is responsible for + saving and loading checkpoints using the user configured checkpointers or distributed + checkpointer if asynchronous checkpointing is enabled. + + Args: + cfg (DictConfig): Configuration object used to instantiate the recipe. + """ + + def __init__( + self, + cfg: DictConfig, + ) -> None: + self._cfg = cfg + + # _checkpointer is the user configured checkpointer + self._checkpointer = None + + # DistributedCheckpointer is used for asynchronous checkpointing, if enabled. + self._dcp_checkpointer = None + + self._resume_from_checkpoint = self._cfg.get("resume_from_checkpoint", False) + self._enable_async_checkpointing = self._cfg.get( + "enable_async_checkpointing", False + ) + self._optimizer_in_bwd = self._cfg.get("optimizer_in_bwd", False) + self._device = utils.get_device(device=self._cfg.device) + + _, self._rank = training.get_world_size_and_rank() + self._is_rank_zero = self._rank == 0 + + def _get_checkpointer(self): + """ + Builds and returns the user configured Checkpointer. + """ + if not self._checkpointer: + should_load_recipe_state: bool = ( + False + if self._enable_async_checkpointing + else self._resume_from_checkpoint + ) + self._checkpointer = config.instantiate( + self._cfg.checkpointer, + should_load_recipe_state=should_load_recipe_state, + ) + return self._checkpointer + + def _get_dcp_checkpointer(self): + """ + Builds and returns the DistributedCheckpointer. + DistributedCheckpointer is used for asynchronous checkpointing, if enabled. + Uses the user configured checkpointer directory and outout directories. + """ + if not self._dcp_checkpointer: + checkpointer = self._get_checkpointer() + + self._dcp_checkpointer = DistributedCheckpointer( + checkpoint_dir=checkpointer._checkpoint_dir, + output_dir=checkpointer._output_dir, + ) + + return self._dcp_checkpointer + + def _save_checkpoint_async( + self, + model: torch.nn.Module, + optimizer: Union[torch.optim.Optimizer, OptimizerInBackwardWrapper], + training_progress: TrainingProgress, + epoch: int, + ) -> None: + """ + Checkpoint the training state asynchronously as a distributed checkpoint. Saving + asnchronously unblocks the training sooner to continue for the next epoch. + The constructed checkpoint state dict contains the following information: + - Model weights with key training.MODEL_KEY + - Relevant recipe state, including optimizer, if training is not complete + + To correctly resume training from a distributed checkpoint, user needs to have both + resume_from_checkpoint and enable_async_checkpointing flags set to True in the config. + User does not need to provide any paths to checkpoint or recipe files. Latest intermediate + and valid checkpoint will be loaded from the output directory and training progress will be + restored automatically. + """ + + if self._is_rank_zero: + log.info("Saving checkpoint asynchronously. Retrieving full state dict...") + cp_start = time.perf_counter() + + # Create the checkpoint dict to be sent to the checkpointer and ultimately persisted to storage + ckpt_dict = {} + ckpt_dict.update(training_progress.state_dict()) + + ckpt_dict[training.MODEL_KEY] = model.state_dict() + ckpt_dict[training.OPT_KEY] = optimizer.state_dict() + + dcp_saver = self._get_dcp_checkpointer() + dcp_saver.save_checkpoint( + ckpt_dict, + epoch=epoch, + save_async=True, + ) + + if self._is_rank_zero: + log.info( + f"Saving asynchronous checkpoint took {time.perf_counter() - cp_start:.2f} secs" + ) + + def _save_checkpoint_sync( + self, + model: torch.nn.Module, + optimizer: Union[torch.optim.Optimizer, OptimizerInBackwardWrapper], + training_progress: TrainingProgress, + epoch: int, + ) -> None: + """ + Checkpoint the training state synchronously. + The constructed checkpoint state dict contains the following information: + - Model weights with key training.MODEL_KEY + - Relevant recipe state, including optimizer, if training is not complete + + To correctly resume training from this checkpoint, user needs to have both + resume_from_checkpoint flag set to True and recipe file paths set in the config. + """ + + intermediate_checkpoint = epoch + 1 < training_progress.total_epochs + checkpointer = self._get_checkpointer() + no_dist = not isinstance(checkpointer, DistributedCheckpointer) + + # final dict passed onto the checkpointer + checkpoint_dict = {} + + if self._is_rank_zero: + log.info( + "Saving checkpoint. This may take some time. Retrieving full model state dict..." + ) + cp_start = time.perf_counter() + + model_state_dict = {} + optim_state_dict = {} + + if no_dist: + # To prevent GPU memory from spiking during checkpoint save, + # we consolidate the full model and optim state dicts on CPU for rank 0 + model_state_dict = training.gather_cpu_state_dict( + model.state_dict(), + self._is_rank_zero, + device=self._device, + ) + + if self._is_rank_zero: + log.info( + f"Getting full model state dict took {time.perf_counter() - cp_start:.2f} secs" + ) + else: + model_state_dict = model.state_dict() + + if intermediate_checkpoint: + if self._is_rank_zero: + log.info("Getting optimizer state dict...") + optim_start = time.perf_counter() + + if no_dist: + if not self._optimizer_in_bwd: + optim_state_dict = training.get_full_optimizer_state_dict( + optimizer, + self._is_rank_zero, + device=self._device, + ) + else: + for param, opt in optimizer.optim_map.items(): + optim_state_dict[ + param + ] = training.get_full_optimizer_state_dict( + opt, self._is_rank_zero, device=self._device + ) + else: + optim_state_dict = optimizer.state_dict() + + if self._is_rank_zero: + log.info( + f"Getting optimizer state dict took {time.perf_counter() - optim_start:.2f} secs" + ) + else: + optim_state_dict = None + + def _save_checkpoint_helper(): + checkpoint_dict.update({training.MODEL_KEY: model_state_dict}) + # if training is in-progress, checkpoint the optimizer state and recipe state + # as well. + if intermediate_checkpoint: + checkpoint_dict.update({training.OPT_KEY: optim_state_dict}) + checkpoint_dict.update(training_progress.state_dict()) + + self._get_checkpointer().save_checkpoint( + checkpoint_dict, + epoch=epoch, + intermediate_checkpoint=intermediate_checkpoint, + ) + + if self._is_rank_zero: + log.info( + f"Saving checkpoint took {time.perf_counter() - cp_start:.2f} secs" + ) + + # Now that we have the model and optim state dict, create the actual checkpoint dict + # to be sent to the checkpointer and ultimately written to file + if no_dist: + if self._is_rank_zero: + _save_checkpoint_helper() + + torch.distributed.barrier() + else: + _save_checkpoint_helper() + + def save_checkpoint( + self, + model: torch.nn.Module, + optimizer: Union[torch.optim.Optimizer, OptimizerInBackwardWrapper], + training_progress: TrainingProgress, + epoch: int, + ) -> None: + """ + Checkpoint the training state. + The constructed checkpoint state dict contains the following information: + - Model weights with key training.MODEL_KEY + - Relevant recipe state, including optimizer state, if training is not complete + + If asynchronous checkpointing is enabled, the checkpoint will be saved asynchronously + as a distributed checkpoint. + Otherwise, the checkpoint will be saved synchronously with the + checkpointer user has configured. + """ + intermediate_checkpoint = epoch + 1 < training_progress.total_epochs + + if intermediate_checkpoint and self._enable_async_checkpointing: + self._save_checkpoint_async(model, optimizer, training_progress, epoch) + else: + self._save_checkpoint_sync(model, optimizer, training_progress, epoch) + + def load_base_checkpoint(self) -> Dict[str, Any]: + """ + This method is used to load the base model from the checkpoint + configured by the user. + """ + return self._get_checkpointer().load_checkpoint() + + def load_distributed_checkpoint( + self, + model: torch.nn.Module, + optimizer: Union[torch.optim.Optimizer, OptimizerInBackwardWrapper], + ) -> Dict[str, Any]: + """ + This method is used to resume training from a distributed checkpoint state. + Due to being distributed, this method is called on every rank. + """ + if self._is_rank_zero: + dcp_load_start = time.perf_counter() + + if not self._optimizer_in_bwd: + _init_optim_state(optimizer) + + # Build the state dict to be loaded from the distributed checkpoint + checkpoint_dict: Dict[str:Any] = {} + model_state_dict = model.state_dict() + optim_state_dict = optimizer.state_dict() + checkpoint_dict.update( + { + training.MODEL_KEY: model_state_dict, + training.OPT_KEY: optim_state_dict, + training.SEED_KEY: 0, + training.EPOCHS_KEY: 0, + training.TOTAL_EPOCHS_KEY: 0, + training.MAX_STEPS_KEY: 0, + } + ) + + # Load the checkpoint state dict from the distributed checkpoint + checkpoint_dict = self._get_dcp_checkpointer().load_checkpoint(checkpoint_dict) + + # Load the checkpoint state dict into model and optimizer + if not self._optimizer_in_bwd: + if training.OPT_KEY in checkpoint_dict: + set_state_dict( + model, + optimizer, + model_state_dict=checkpoint_dict[training.MODEL_KEY], + optim_state_dict=checkpoint_dict[training.OPT_KEY], + ) + else: + set_model_state_dict( + model=model, + model_state_dict=checkpoint_dict[training.MODEL_KEY], + ) + else: + set_model_state_dict( + model=model, + model_state_dict=checkpoint_dict[training.MODEL_KEY], + ) + + if training.OPT_KEY in checkpoint_dict: + optimizer.load_state_dict(checkpoint_dict[training.OPT_KEY]) + + if self._is_rank_zero: + log.info( + f"DistributedCheckpointer loaded the checkpoint in {time.perf_counter() - dcp_load_start:.2f} seconds." + ) + + return checkpoint_dict diff --git a/torchtune/training/checkpointing/_checkpointer.py b/torchtune/training/checkpointing/_checkpointer.py index 559fca84ba..a5d72af320 100644 --- a/torchtune/training/checkpointing/_checkpointer.py +++ b/torchtune/training/checkpointing/_checkpointer.py @@ -7,11 +7,22 @@ import gc import json import os +import re +import time +from concurrent.futures import Future from pathlib import Path from typing import Any, Dict, List, Optional, Protocol, Union import torch +import torch.distributed as dist from safetensors.torch import save_file +from torch.distributed.checkpoint import ( + async_save, + FileSystemReader, + FileSystemWriter, + load, + save, +) from torchtune import training from torchtune.models import convert_weights @@ -121,13 +132,16 @@ class FullModelTorchTuneCheckpointer(_CheckpointerInterface): model_type (str): Model type of the model for which the checkpointer is being loaded, e.g. LLAMA3. output_dir (str): Directory to save the checkpoint files adapter_checkpoint (Optional[str]): Path to the adapter weights. If None, - and `resume_from_checkpoint=True`, then look for adapter_model.pt in output_dir/epoch_{largest_epoch}. + and `should_load_recipe_state=True`, then look for adapter_model.pt in output_dir/epoch_{largest_epoch}. Default is None. recipe_checkpoint (Optional[str]): Path to the recipe state checkpoint file. If None, - and `resume_from_checkpoint=True`, then look for recipe_state.pt in output_dir/RECIPE_STATE_DIRNAME. + and `should_load_recipe_state=True`, then look for recipe_state.pt in output_dir/RECIPE_STATE_DIRNAME. Default is None. - resume_from_checkpoint (bool): If True, the checkpointer will load the additional checkpoint files to - resume training from a previous run. Default is False + resume_from_checkpoint (bool): If True, the checkpointer will load the additional checkpoint files corresponding to + the recipe state from a previous run. Default is False. This flag is deprecated. Please use the + should_load_recipe_state flag instead. + should_load_recipe_state (bool): If True, the checkpointer will load the additional checkpoint files corresponding to + the recipe state from a previous run. Default is False Raises: ValueError: If more than one checkpoint file is provided @@ -142,6 +156,7 @@ def __init__( adapter_checkpoint: Optional[str] = None, recipe_checkpoint: Optional[str] = None, resume_from_checkpoint: bool = False, + should_load_recipe_state: bool = False, ) -> None: # Fail fast if ``checkpoint_files`` is invalid @@ -153,7 +168,14 @@ def __init__( ) self._checkpoint_dir = Path(checkpoint_dir) - self._resume_from_checkpoint = resume_from_checkpoint + self._should_load_recipe_state = should_load_recipe_state + + if resume_from_checkpoint: + self._should_load_recipe_state = resume_from_checkpoint + logger.warning( + "*resume_from_checkpoint is deprecated. Please use the 'should_load_recipe_state' instead" + ) + self._model_type = ModelType[model_type] self._output_dir = Path(output_dir) self._output_dir.mkdir(parents=True, exist_ok=True) @@ -170,7 +192,7 @@ def __init__( self._adapter_checkpoint = get_adapter_checkpoint_path( output_dir=self._output_dir, adapter_checkpoint=adapter_checkpoint, - resume_from_checkpoint=self._resume_from_checkpoint, + should_load_recipe_state=self._should_load_recipe_state, pattern=r"^epoch_(\d+)", ) @@ -178,7 +200,7 @@ def __init__( self._recipe_checkpoint = get_recipe_checkpoint_path( output_dir=self._output_dir, recipe_checkpoint=recipe_checkpoint, - resume_from_checkpoint=self._resume_from_checkpoint, + should_load_recipe_state=self._should_load_recipe_state, ) # get ckpt paths @@ -186,16 +208,16 @@ def __init__( checkpoint_files=checkpoint_files, checkpoint_dir=self._checkpoint_dir, output_dir=self._output_dir, - resume_from_checkpoint=self._resume_from_checkpoint, + should_load_recipe_state=self._should_load_recipe_state, has_adapter_checkpoint=self._adapter_checkpoint is not None, ) # we currently accept only a single file self._checkpoint_path = self._checkpoint_paths[0] - if self._resume_from_checkpoint: + if self._should_load_recipe_state: logger.info( - "Resuming from checkpoint using:" + "Loading the recipe state using: " f"\n\tcheckpoint_paths: {[str(path) for path in self._checkpoint_paths]}" f"\n\trecipe_checkpoint: {self._recipe_checkpoint}" f"\n\tadapter_checkpoint: {self._adapter_checkpoint}" @@ -206,7 +228,7 @@ def load_checkpoint(self, weights_only: bool = True) -> Dict[str, Any]: Load torchtune checkpoint from file. Currently only loading from a single file is supported. The output state_dict has the following format, with keys other than "model" only present if - ``resume_from_checkpoint`` is True: + ``should_load_recipe_state`` is True: >>> { >>> "model": { @@ -233,7 +255,7 @@ def load_checkpoint(self, weights_only: bool = True) -> Dict[str, Any]: adapter_state_dict = safe_torch_load(self._adapter_checkpoint) state_dict[training.ADAPTER_KEY] = adapter_state_dict - if self._resume_from_checkpoint: + if self._should_load_recipe_state: recipe_state = safe_torch_load(self._recipe_checkpoint, mmap=False) state_dict.update(recipe_state) return state_dict @@ -363,15 +385,18 @@ class FullModelHFCheckpointer(_CheckpointerInterface): model_type (str): Model type of the model for which the checkpointer is being loaded, e.g. LLAMA3. output_dir (str): Directory to save the checkpoint files adapter_checkpoint (Optional[str]): Path to the adapter weights. If None, - and `resume_from_checkpoint=True`, then look for adapter_model.pt in output_dir/epoch_{largest_epoch}. + and `should_load_recipe_state=True`, then look for adapter_model.pt in output_dir/epoch_{largest_epoch}. Default is None. recipe_checkpoint (Optional[str]): Path to the recipe state checkpoint file. If None, - and `resume_from_checkpoint=True`, then look for recipe_state.pt in output_dir/RECIPE_STATE_DIRNAME. + and `should_load_recipe_state=True`, then look for recipe_state.pt in output_dir/RECIPE_STATE_DIRNAME. Default is None. - resume_from_checkpoint (bool): If True, the checkpointer will load the additional checkpoint files to - resume training from a previous run. Default is False + resume_from_checkpoint (bool): If True, the checkpointer will load the additional checkpoint files corresponding to + the receipe state from a previous run. Default is False. This flag is deprecated. Please use + the should_load_recipe_state flag instead. safe_serialization (bool): If True, the checkpointer will save the checkpoint file using `safetensors`. Default is True. + should_load_recipe_state (bool): If True, the checkpointer will load the additional checkpoint files corresponding to + the receipe state from a previous run. Default is False """ def __init__( @@ -384,9 +409,16 @@ def __init__( recipe_checkpoint: Optional[str] = None, resume_from_checkpoint: bool = False, safe_serialization: bool = True, + should_load_recipe_state: bool = False, ) -> None: - self._resume_from_checkpoint = resume_from_checkpoint + self._should_load_recipe_state = should_load_recipe_state + if resume_from_checkpoint: + self._should_load_recipe_state = resume_from_checkpoint + logger.warning( + "*resume_from_checkpoint is deprecated. Please use the 'should_load_recipe_state' instead" + ) + self._safe_serialization = safe_serialization self._checkpoint_dir = Path(checkpoint_dir) self._model_type = ModelType[model_type] @@ -427,7 +459,7 @@ def __init__( self._adapter_checkpoint = get_adapter_checkpoint_path( output_dir=self._output_dir, adapter_checkpoint=adapter_checkpoint, - resume_from_checkpoint=self._resume_from_checkpoint, + should_load_recipe_state=self._should_load_recipe_state, pattern=r"^epoch_(\d+)", ) @@ -435,7 +467,7 @@ def __init__( self._recipe_checkpoint = get_recipe_checkpoint_path( output_dir=self._output_dir, recipe_checkpoint=recipe_checkpoint, - resume_from_checkpoint=self._resume_from_checkpoint, + should_load_recipe_state=self._should_load_recipe_state, ) # get ckpt paths @@ -443,13 +475,13 @@ def __init__( checkpoint_files=checkpoint_files, checkpoint_dir=self._checkpoint_dir, output_dir=self._output_dir, - resume_from_checkpoint=self._resume_from_checkpoint, + should_load_recipe_state=self._should_load_recipe_state, has_adapter_checkpoint=self._adapter_checkpoint is not None, ) - if self._resume_from_checkpoint: + if self._should_load_recipe_state: logger.info( - "Resuming from checkpoint using:" + "Loading the recipe state using: " f"\n\tcheckpoint_paths: {[str(path) for path in self._checkpoint_paths]}" f"\n\trecipe_checkpoint: {self._recipe_checkpoint}" f"\n\tadapter_checkpoint: {self._adapter_checkpoint}" @@ -472,6 +504,7 @@ def load_checkpoint(self) -> Dict[str, Any]: Raises: ValueError: If the values in the input state_dict are not Tensors """ + self._weight_map = {} # merged state_dict contains keys and weights from all the checkpoint files @@ -581,7 +614,7 @@ def load_checkpoint(self) -> Dict[str, Any]: adapter_state_dict = safe_torch_load(self._adapter_checkpoint) converted_state_dict[training.ADAPTER_KEY] = adapter_state_dict - if self._resume_from_checkpoint: + if self._should_load_recipe_state: recipe_state = safe_torch_load(self._recipe_checkpoint, mmap=False) converted_state_dict.update(recipe_state) @@ -886,17 +919,20 @@ class FullModelMetaCheckpointer(_CheckpointerInterface): model_type (str): Model type of the model for which the checkpointer is being loaded, e.g. LLAMA3. output_dir (str): Directory to save the checkpoint files adapter_checkpoint (Optional[str]): Path to the adapter weights. If None, - and `resume_from_checkpoint=True`, then look for adapter_model.pt in output_dir/epoch_{largest_epoch}. + and `should_load_recipe_state=True`, then look for adapter_model.pt in output_dir/epoch_{largest_epoch}. Default is None. recipe_checkpoint (Optional[str]): Path to the recipe state checkpoint file. If None, - and `resume_from_checkpoint=True`, then look for recipe_state.pt in output_dir/recipe_state. + and `should_load_recipe_state=True`, then look for recipe_state.pt in output_dir/recipe_state. Default is None. - resume_from_checkpoint (bool): If True, the checkpointer will load the additional checkpoint files to - resume training from a previous run. Default is False + resume_from_checkpoint (bool): If True, the checkpointer will load the additional checkpoint files corresponding to + the recipe state from a previous run. Default is False. This flag is deprecated. Please use the + should_load_recipe_state instead. + should_load_recipe_state (bool): If True, the checkpointer will load the additional checkpoint files corresponding to + the recipe state from a previous run. Default is False Raises: ValueError: If ``checkpoint_files`` is not a list of length 1 - ValueError: If ``resume_from_checkpoint`` is True but ``recipe_checkpoint`` is None + ValueError: If ``should_load_recipe_state`` is True but ``recipe_checkpoint`` is None """ def __init__( @@ -908,6 +944,7 @@ def __init__( adapter_checkpoint: Optional[str] = None, recipe_checkpoint: Optional[str] = None, resume_from_checkpoint: bool = False, + should_load_recipe_state: bool = False, ) -> None: # Fail fast if ``checkpoint_files`` is invalid @@ -919,7 +956,12 @@ def __init__( ) self._checkpoint_dir = Path(checkpoint_dir) - self._resume_from_checkpoint = resume_from_checkpoint + self._should_load_recipe_state = should_load_recipe_state + if resume_from_checkpoint: + self._should_load_recipe_state = resume_from_checkpoint + logger.warning( + "*resume_from_checkpoint is deprecated. Please use the 'should_load_recipe_state' instead" + ) self._model_type = ModelType[model_type] self._output_dir = Path(output_dir) self._output_dir.mkdir(parents=True, exist_ok=True) @@ -936,7 +978,7 @@ def __init__( self._adapter_checkpoint = get_adapter_checkpoint_path( output_dir=self._output_dir, adapter_checkpoint=adapter_checkpoint, - resume_from_checkpoint=self._resume_from_checkpoint, + should_load_recipe_state=self._should_load_recipe_state, pattern=r"^epoch_(\d+)", ) @@ -944,7 +986,7 @@ def __init__( self._recipe_checkpoint = get_recipe_checkpoint_path( output_dir=self._output_dir, recipe_checkpoint=recipe_checkpoint, - resume_from_checkpoint=self._resume_from_checkpoint, + should_load_recipe_state=self._should_load_recipe_state, ) # get ckpt paths @@ -952,16 +994,16 @@ def __init__( checkpoint_files=checkpoint_files, checkpoint_dir=self._checkpoint_dir, output_dir=self._output_dir, - resume_from_checkpoint=self._resume_from_checkpoint, + should_load_recipe_state=self._should_load_recipe_state, has_adapter_checkpoint=self._adapter_checkpoint is not None, ) # we currently accept only a single file self._checkpoint_path = self._checkpoint_paths[0] - if self._resume_from_checkpoint: + if self._should_load_recipe_state: logger.info( - "Resuming from checkpoint using:" + "Loading the recipe state using: " f"\n\tcheckpoint_paths: {[str(path) for path in self._checkpoint_paths]}" f"\n\trecipe_checkpoint: {self._recipe_checkpoint}" f"\n\tadapter_checkpoint: {self._adapter_checkpoint}" @@ -999,7 +1041,7 @@ def load_checkpoint(self) -> Dict[str, Any]: adapter_state_dict = safe_torch_load(self._adapter_checkpoint) state_dict[training.ADAPTER_KEY] = adapter_state_dict - if self._resume_from_checkpoint: + if self._should_load_recipe_state: recipe_state = safe_torch_load(self._recipe_checkpoint, mmap=False) state_dict.update(recipe_state) return state_dict @@ -1113,3 +1155,190 @@ def save_checkpoint( "The full model checkpoint, including all weights and configurations, has been saved successfully." "You can now use this checkpoint for further training or inference." ) + + +class DistributedCheckpointer(_CheckpointerInterface): + """ + Checkpointer which reads and writes checkpoints in the DistributedCheckpointing format. + + Args: + checkpoint_dir (str): Directory containing the checkpoint files + output_dir (str): Directory to save the checkpoint files + process_group (Optional[dist.ProcessGroup]): Optional process group to use + for distributed saving/loading. If None, the default process group will be used. + For checkpointing, gloo CPU-based backend is needed. + """ + + def __init__( + self, + checkpoint_dir: str, + output_dir: str, + process_group: Optional[dist.ProcessGroup] = None, + ) -> None: + self._checkpoint_dir = Path(checkpoint_dir) + self._output_dir = Path(output_dir) + self._checkpoint_future = None + self._checkpoint_dir_prefix = "dist_epoch" + self._metadata_file = ".metadata" + _, self._rank = training.get_world_size_and_rank() + self._process_group: Optional[dist.ProcessGroup] = process_group + + def _get_latest_intermediate_checkpoint(self) -> Optional[str]: + """ + This method iterates over the available intermediate distributed checkpoints and + finds the latest checkpoint to load. + + Returns: + str: The fully qualified path of the checkpoint directory containing the latest and valid + intermediate checkpoint. A valid checkpoint needs to have the metadata file. + """ + + checkpoint_dir_pattern = re.compile(f"{self._checkpoint_dir_prefix}_(\\d+)") + checkpoint_paths = [ + name + for name in os.listdir(self._output_dir) + if re.match(checkpoint_dir_pattern, name) + and os.path.isfile( + os.path.join(self._output_dir, name, self._metadata_file) + ) + ] + + if checkpoint_paths: + latest_checkpoint_dir = sorted( + checkpoint_paths, key=lambda x: int(x.split("_")[-1]) + )[-1] + return os.path.join(self._output_dir, latest_checkpoint_dir) + return None + + def load_checkpoint( + self, state_dict: Dict[str, Any] = None, checkpoint_path: Optional[str] = None + ) -> Dict[str, Any]: + """ + Load a Distributed checkpoint saved at the + If no path is provided, latest intermediate checkpoint is loaded. + """ + + if state_dict is None: + raise ValueError( + "State dict must be provided to load a distributed checkpoint." + ) + + # If no checkpoint path is provided, load the latest intermediate checkpoint. + if checkpoint_path is None: + checkpoint_path = self._get_latest_intermediate_checkpoint() + + if checkpoint_path is None: + raise ValueError( + "No checkpoint path was provided." + "Also, No intermediate checkpoint was found in the output directory." + "Please ensure that a checkpoint exists to load." + ) + + log_rank_zero(logger, msg=f"Loading checkpoint from {checkpoint_path}") + + load( + state_dict=state_dict, + storage_reader=FileSystemReader(checkpoint_path), + process_group=self._process_group, + ) + + return state_dict + + def save_checkpoint( + self, + state_dict: Dict[str, Any], + epoch: int, + save_async: bool = False, + ) -> None: + """ + Save a distributed checkpoint to storage. + If ``save_async`` is True, the save happens asynchronously unblocking the GPUs sooner. This + should only be used for the intermediate checkpoints. Final checkpoint has to be a synchronous + one as the finetuning job can not terminate until the checkpoint gets persisted. + + Args: + state_dict (Dict[str, Any]): Checkpoint state dict to be written out to file + epoch (int): Epoch number. Used to create the checkpoint file name + save_async (bool): If True, save the checkpoint asynchronously + """ + + log_rank_zero( + logger, + msg=f"DistributedCheckpointer is saving a checkpoint for the epoch {epoch}", + ) + + checkpoint_path = Path.joinpath( + self._output_dir, f"{self._checkpoint_dir_prefix}_{epoch}" + ) + + if self._checkpoint_future and not self._checkpoint_future.done(): + # Previous checkpoint needs to finish before saving the next one. + wait_start = time.perf_counter() + + logger.info( + f"Rank {self._rank}: previous checkpoint has not finished. Checkpointing frequency is too high. Waiting...", + ) + + self._checkpoint_future.result() + + logger.info( + f"Rank {self._rank}: waited {time.perf_counter() - wait_start:.2f} seconds for previous checkpoint to finish", + ) + self._checkpoint_future = None + + cp_start = time.perf_counter() + + if save_async: + + def callback( + f: Future, + ) -> None: + if f.exception() is None: + logger.info( + f"Rank {self._rank}: Checkpoint is saved asynchronously to {checkpoint_path} successfully.", + ) + else: + logger.error( + f"Rank {self._rank}: Checkpoint failed to save asynchronously to {checkpoint_path} " + f"with the exception {f.exception()}" + ) + + self._checkpoint_future = async_save( + state_dict=state_dict, + storage_writer=FileSystemWriter( + checkpoint_path, + thread_count=16, + single_file_per_rank=False, + sync_files=False, + ), + process_group=self._process_group, + ) + + logger.info( + f"Rank {self._rank}: Trainer was blocked for {time.perf_counter() - cp_start:.2f} seconds " + "for checkpointing to finish...", + ) + + self._checkpoint_future.add_done_callback(callback) + else: + log_rank_zero( + logger, + msg=f"Saving model checkpoint synchronously to {checkpoint_path}.", + ) + + save( + state_dict=state_dict, + storage_writer=FileSystemWriter( + checkpoint_path, + thread_count=16, + single_file_per_rank=False, + sync_files=False, + ), + process_group=self._process_group, + ) + + log_rank_zero( + logger, + msg="The full model checkpoint, including all the weights and configurations, has been saved successfully " + "by the DistributedCheckpointer. You can now use this checkpoint for further training.", + ) diff --git a/torchtune/training/checkpointing/_utils.py b/torchtune/training/checkpointing/_utils.py index 770a3f889c..1d8a63daab 100644 --- a/torchtune/training/checkpointing/_utils.py +++ b/torchtune/training/checkpointing/_utils.py @@ -387,7 +387,7 @@ def copy_files( def get_recipe_checkpoint_path( output_dir: Path, recipe_checkpoint: Optional[str] = None, - resume_from_checkpoint: bool = False, + should_load_recipe_state: bool = False, ) -> Optional[Path]: """ If recipe_checkpoint is None, look for recipe_state.pt in {output_dir}/{RECIPE_STATE_DIRNAME}/recipe_state.pt. @@ -396,13 +396,13 @@ def get_recipe_checkpoint_path( Args: output_dir (Path): Directory containing the recipe checkpoint. recipe_checkpoint (Optional[str]): Name of the recipe checkpoint file. Defaults to None. - resume_from_checkpoint (bool): Whether to resume from a checkpoint. + should_load_recipe_state (bool): Whether to load the recipe state from the checkpoint. Returns: - Optional[Path]: Path to the recipe checkpoint file if resume_from_checkpoint is True, otherwise None. + Optional[Path]: Path to the recipe checkpoint file if should_load_recipe_state is True, otherwise None. Raises: - ValueError: If resume_from_checkpoint is True and the recipe checkpoint file is missing. + ValueError: If should_load_recipe_state is True and the recipe checkpoint file is missing. """ - if not resume_from_checkpoint: + if not should_load_recipe_state: return None recipe_checkpoint_path = None @@ -416,7 +416,7 @@ def get_recipe_checkpoint_path( # TODO: improve this msg if not recipe_checkpoint_path or not os.path.exists(recipe_checkpoint_path): raise ValueError( - "If resume_from_checkpoint is True, recipe_checkpoint file must be provided." + "If should_load_recipe_state is True, recipe_checkpoint file must be provided." ) return Path(recipe_checkpoint_path) @@ -425,7 +425,7 @@ def get_recipe_checkpoint_path( def get_adapter_checkpoint_path( output_dir: Path, adapter_checkpoint: Optional[str] = None, - resume_from_checkpoint: bool = False, + should_load_recipe_state: bool = False, pattern: str = r"^epoch_(\d+)", ) -> Optional[Path]: r""" @@ -435,13 +435,13 @@ def get_adapter_checkpoint_path( Args: output_dir (Path): Directory containing the adapter checkpoint. adapter_checkpoint (Optional[str]): Name of the adapter checkpoint file. Defaults to None. - resume_from_checkpoint (bool): Whether to resume from a checkpoint. + should_load_recipe_state (bool): Whether to load the recipe state from checkpoint. pattern (str): Regex pattern to match the epoch folder. Defaults to "epoch_(\d+)". Returns: Optional[Path]: Path to the adapter checkpoint file, or None if not applicable. """ - if not resume_from_checkpoint: + if not should_load_recipe_state: return None adapter_checkpoint_path = None @@ -466,7 +466,7 @@ def get_model_checkpoint_path( checkpoint_files: Union[List[str], Dict[str, str]], checkpoint_dir: Union[str, Path], output_dir: Union[str, Path], - resume_from_checkpoint: bool, + should_load_recipe_state: bool, has_adapter_checkpoint: bool, ) -> list[Path]: """ @@ -484,7 +484,7 @@ def get_model_checkpoint_path( it is converted to a list of formatted checkpoint filenames. checkpoint_dir (Union[str, Path]): Directory containing the checkpoint files. output_dir (Union[str, Path]): Directory to use when resuming from a checkpoint. - resume_from_checkpoint (bool): Whether to resume from a checkpoint. + should_load_recipe_state (bool): Whether to resume from a checkpoint. has_adapter_checkpoint (bool): Indicates if there is an adapter checkpoint. Returns: list[Path]: Sorted list of paths to the checkpoint files. @@ -492,13 +492,13 @@ def get_model_checkpoint_path( >>> checkpoint_files = ["model-00001-of-00002.safetensors", "model-00002-of-00002.safetensors"] >>> checkpoint_dir = "/path/to/checkpoints" >>> output_dir = "/path/to/output" - >>> resume_from_checkpoint = True + >>> should_load_recipe_state = True >>> has_adapter_checkpoint = False >>> paths = get_model_checkpoint_path( ... checkpoint_files, ... checkpoint_dir, ... output_dir, - ... resume_from_checkpoint, + ... should_load_recipe_state, ... has_adapter_checkpoint ... ) >>> print(paths) @@ -536,15 +536,15 @@ def validate_checkpoint_files( ) checkpoint_files = formatted_checkpoint_files.build_checkpoint_filenames() - # Case 1: no resuming from ckpt - if not resume_from_checkpoint: + # Case 1: not loading the recipe state + if not should_load_recipe_state: input_dir = checkpoint_dir - # Case 2: Resuming from ckpt, but its full finetuning (no adapter) + # Case 2: Loading the recipe state, but its full finetuning (no adapter) elif not has_adapter_checkpoint: input_dir = output_dir - # Case 3: Resuming from ckpt and has an adapter. + # Case 3: Loading the recipe state and has an adapter. else: # FIXME # TODO: if the model has lora + trained weights, e.g. embeddings, diff --git a/torchtune/training/metric_logging.py b/torchtune/training/metric_logging.py index 42882afa8b..a6189f10e1 100644 --- a/torchtune/training/metric_logging.py +++ b/torchtune/training/metric_logging.py @@ -14,9 +14,8 @@ from numpy import ndarray from omegaconf import DictConfig, OmegaConf -from torchtune.training._distributed import get_world_size_and_rank -from torchtune.utils import get_logger +from torchtune.utils import get_logger, get_world_size_and_rank from typing_extensions import Protocol Scalar = Union[torch.Tensor, ndarray, int, float] diff --git a/torchtune/training/seed.py b/torchtune/training/seed.py index 5c3d8d4db5..a5e2e4b4f8 100644 --- a/torchtune/training/seed.py +++ b/torchtune/training/seed.py @@ -13,8 +13,8 @@ import numpy as np import torch -from torchtune.training._distributed import _broadcast_tensor, get_world_size_and_rank -from torchtune.utils import get_logger +from torchtune.training._distributed import _broadcast_tensor +from torchtune.utils import get_logger, get_world_size_and_rank _log: logging.Logger = get_logger() diff --git a/torchtune/utils/__init__.py b/torchtune/utils/__init__.py index 59de1b5aa7..f7bbf35852 100644 --- a/torchtune/utils/__init__.py +++ b/torchtune/utils/__init__.py @@ -10,12 +10,14 @@ get_device, get_device_support, get_torch_device_namespace, + get_world_size_and_rank, ) from ._logging import get_logger, log_rank_zero from ._version import torch_version_ge __all__ = [ + "get_world_size_and_rank", "batch_to_device", "get_device", "get_logger", diff --git a/torchtune/utils/_device.py b/torchtune/utils/_device.py index d4f84cd63e..10d5e62a05 100644 --- a/torchtune/utils/_device.py +++ b/torchtune/utils/_device.py @@ -6,7 +6,7 @@ import os from enum import Enum -from typing import Optional +from typing import Optional, Tuple import torch @@ -21,6 +21,19 @@ BlockMask = torch.Tensor +def get_world_size_and_rank() -> Tuple[int, int]: + """Function that gets the current world size (aka total number + of ranks) and rank number of the current process in the default process group. + + Returns: + Tuple[int, int]: world size, rank + """ + if torch.distributed.is_available() and torch.distributed.is_initialized(): + return torch.distributed.get_world_size(), torch.distributed.get_rank() + else: + return 1, 0 + + def is_torch_npu_available() -> bool: """Check the availability of NPU""" try: