diff --git a/tests/torchtune/training/checkpointing/test_distributed_checkpointer.py b/tests/torchtune/training/checkpointing/test_distributed_checkpointer.py new file mode 100644 index 0000000000..e325499a58 --- /dev/null +++ b/tests/torchtune/training/checkpointing/test_distributed_checkpointer.py @@ -0,0 +1,153 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import os +import shutil +from pathlib import Path + +import pytest +import torch +from torch import randn, zeros + +from torchtune.training.checkpointing import DistributedCheckpointer +from torchtune.training.seed import set_seed + +_VOCAB_SIZE = 100 +_DIM = 64 +_HIDDEN_DIM = 256 + + +@pytest.fixture(autouse=True) +def random(): + set_seed(16) + + +class TestDistributedCheckpointer: + @pytest.fixture + def weight_dtype(self): + return torch.float16 + + @pytest.fixture + def state_dict(self, weight_dtype): + """ + State dict + """ + state_dict = { + "model.embed_tokens.weight": randn(_VOCAB_SIZE, _DIM, dtype=weight_dtype), + "model.layers.0.input_layernorm.weight": randn(_DIM, dtype=weight_dtype), + "model.layers.0.self_attn.q_proj.weight": randn( + _DIM, _DIM, dtype=weight_dtype + ), + "model.layers.0.self_attn.k_proj.weight": randn( + _DIM, _DIM, dtype=weight_dtype + ), + "model.layers.0.self_attn.v_proj.weight": randn( + _DIM, _DIM, dtype=weight_dtype + ), + "model.layers.0.self_attn.o_proj.weight": randn( + _DIM, _DIM, dtype=weight_dtype + ), + "model.layers.0.post_attention_layernorm.weight": randn( + _DIM, dtype=weight_dtype + ), + "model.layers.0.self_attn.rotary_emb.inv_freq": randn( + _DIM, dtype=weight_dtype + ), + "model.layers.0.mlp.gate_proj.weight": randn( + _HIDDEN_DIM, _DIM, dtype=weight_dtype + ), + "model.layers.0.mlp.down_proj.weight": randn( + _DIM, _HIDDEN_DIM, dtype=weight_dtype + ), + "model.layers.0.mlp.up_proj.weight": randn( + _HIDDEN_DIM, _DIM, dtype=weight_dtype + ), + "model.norm.weight": randn(_DIM, dtype=weight_dtype), + "lm_head.weight": randn(_VOCAB_SIZE, _DIM, dtype=weight_dtype), + } + + return state_dict + + @pytest.fixture + def empty_state_dict(self, weight_dtype): + """ + State dict + """ + state_dict = { + "model.embed_tokens.weight": zeros(_VOCAB_SIZE, _DIM, dtype=weight_dtype), + "model.layers.0.input_layernorm.weight": zeros(_DIM, dtype=weight_dtype), + "model.layers.0.self_attn.q_proj.weight": zeros( + _DIM, _DIM, dtype=weight_dtype + ), + "model.layers.0.self_attn.k_proj.weight": zeros( + _DIM, _DIM, dtype=weight_dtype + ), + "model.layers.0.self_attn.v_proj.weight": zeros( + _DIM, _DIM, dtype=weight_dtype + ), + "model.layers.0.self_attn.o_proj.weight": zeros( + _DIM, _DIM, dtype=weight_dtype + ), + "model.layers.0.post_attention_layernorm.weight": zeros( + _DIM, dtype=weight_dtype + ), + "model.layers.0.self_attn.rotary_emb.inv_freq": zeros( + _DIM, dtype=weight_dtype + ), + "model.layers.0.mlp.gate_proj.weight": zeros( + _HIDDEN_DIM, _DIM, dtype=weight_dtype + ), + "model.layers.0.mlp.down_proj.weight": zeros( + _DIM, _HIDDEN_DIM, dtype=weight_dtype + ), + "model.layers.0.mlp.up_proj.weight": zeros( + _HIDDEN_DIM, _DIM, dtype=weight_dtype + ), + "model.norm.weight": zeros(_DIM, dtype=weight_dtype), + "lm_head.weight": zeros(_VOCAB_SIZE, _DIM, dtype=weight_dtype), + } + + return state_dict + + @pytest.fixture + def distributed_checkpointer(self, tmp_path) -> DistributedCheckpointer: + return DistributedCheckpointer( + checkpoint_dir=tmp_path, + output_dir=tmp_path, + ) + + def test_save_load_checkpoint( + self, distributed_checkpointer, state_dict, empty_state_dict + ): + """ + Test ``load_checkpoint`` method within the DistributedCheckpointer. + + We test: + * ``load_checkpoint`` loads the right sets of keys + * Internal state of the checkpointer is correctly updated. + """ + + distributed_checkpointer.save_checkpoint( + state_dict=state_dict, epoch=1, save_async=False + ) + + checkpoint_path = Path.joinpath( + distributed_checkpointer._output_dir, + f"{distributed_checkpointer._checkpoint_dir_prefix}_1", + ) + + assert os.path.exists(checkpoint_path) + + distributed_checkpointer.load_checkpoint( + state_dict=empty_state_dict, + ) + + for key in state_dict.keys(): + assert torch.equal(state_dict[key], empty_state_dict[key]) + + # clean ups + shutil.rmtree(checkpoint_path) diff --git a/torchtune/training/checkpointing/_checkpointer.py b/torchtune/training/checkpointing/_checkpointer.py index 676707b184..8cb52e991a 100644 --- a/torchtune/training/checkpointing/_checkpointer.py +++ b/torchtune/training/checkpointing/_checkpointer.py @@ -14,6 +14,7 @@ from typing import Any, Dict, List, Optional, Protocol, Union import torch +import torch.distributed as dist from safetensors.torch import save_file from torch.distributed.checkpoint import ( async_save, @@ -951,12 +952,17 @@ class DistributedCheckpointer(_CheckpointerInterface): Args: checkpoint_dir (str): Directory containing the checkpoint files output_dir (str): Directory to save the checkpoint files + process_group (Optional[dist.ProcessGroup]): Optional process group to use + for distributed saving/loading. + If None, the default process group will be used. + For checkpointing, gloo CPU-based backend is needed. """ def __init__( self, checkpoint_dir: str, output_dir: str, + process_group: Optional[dist.ProcessGroup] = None, ) -> None: self._checkpoint_dir = Path(checkpoint_dir) self._output_dir = Path(output_dir) @@ -964,6 +970,7 @@ def __init__( self._checkpoint_dir_prefix = "checkpoint" self._metadata_file = ".metadata" _, self._rank = training.get_world_size_and_rank() + self._process_group: Optional[dist.ProcessGroup] = process_group def _get_latest_intermediate_checkpoint(self) -> Optional[str]: """ @@ -1012,8 +1019,9 @@ def load_checkpoint( log_rank_zero(logger, msg=f"Loading checkpoint from {checkpoint_path}") load( - state_dict, + state_dict=state_dict, storage_reader=FileSystemReader(checkpoint_path), + process_group=self._process_group, ) return state_dict @@ -1078,13 +1086,14 @@ def callback( ) self._checkpoint_future = async_save( - state_dict, + state_dict=state_dict, storage_writer=FileSystemWriter( checkpoint_path, thread_count=16, single_file_per_rank=False, sync_files=False, ), + process_group=self._process_group, ) logger.info( @@ -1100,13 +1109,14 @@ def callback( ) save( - state_dict, + state_dict=state_dict, storage_writer=FileSystemWriter( checkpoint_path, thread_count=8, single_file_per_rank=False, sync_files=False, ), + process_group=self._process_group, ) log_rank_zero(