Skip to content

Commit

Permalink
[torchtune][dcp] Unit test for the DistributedCheckpointer
Browse files Browse the repository at this point in the history
  • Loading branch information
Saurabh Mishra committed Nov 19, 2024
1 parent daaf4b8 commit 1b1342e
Show file tree
Hide file tree
Showing 2 changed files with 166 additions and 3 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


import os
import shutil
from pathlib import Path

import pytest
import torch
from torch import randn, zeros

from torchtune.training.checkpointing import DistributedCheckpointer
from torchtune.training.seed import set_seed

_VOCAB_SIZE = 100
_DIM = 64
_HIDDEN_DIM = 256


@pytest.fixture(autouse=True)
def random():
set_seed(16)


class TestDistributedCheckpointer:
@pytest.fixture
def weight_dtype(self):
return torch.float16

@pytest.fixture
def state_dict(self, weight_dtype):
"""
State dict
"""
state_dict = {
"model.embed_tokens.weight": randn(_VOCAB_SIZE, _DIM, dtype=weight_dtype),
"model.layers.0.input_layernorm.weight": randn(_DIM, dtype=weight_dtype),
"model.layers.0.self_attn.q_proj.weight": randn(
_DIM, _DIM, dtype=weight_dtype
),
"model.layers.0.self_attn.k_proj.weight": randn(
_DIM, _DIM, dtype=weight_dtype
),
"model.layers.0.self_attn.v_proj.weight": randn(
_DIM, _DIM, dtype=weight_dtype
),
"model.layers.0.self_attn.o_proj.weight": randn(
_DIM, _DIM, dtype=weight_dtype
),
"model.layers.0.post_attention_layernorm.weight": randn(
_DIM, dtype=weight_dtype
),
"model.layers.0.self_attn.rotary_emb.inv_freq": randn(
_DIM, dtype=weight_dtype
),
"model.layers.0.mlp.gate_proj.weight": randn(
_HIDDEN_DIM, _DIM, dtype=weight_dtype
),
"model.layers.0.mlp.down_proj.weight": randn(
_DIM, _HIDDEN_DIM, dtype=weight_dtype
),
"model.layers.0.mlp.up_proj.weight": randn(
_HIDDEN_DIM, _DIM, dtype=weight_dtype
),
"model.norm.weight": randn(_DIM, dtype=weight_dtype),
"lm_head.weight": randn(_VOCAB_SIZE, _DIM, dtype=weight_dtype),
}

return state_dict

@pytest.fixture
def empty_state_dict(self, weight_dtype):
"""
State dict
"""
state_dict = {
"model.embed_tokens.weight": zeros(_VOCAB_SIZE, _DIM, dtype=weight_dtype),
"model.layers.0.input_layernorm.weight": zeros(_DIM, dtype=weight_dtype),
"model.layers.0.self_attn.q_proj.weight": zeros(
_DIM, _DIM, dtype=weight_dtype
),
"model.layers.0.self_attn.k_proj.weight": zeros(
_DIM, _DIM, dtype=weight_dtype
),
"model.layers.0.self_attn.v_proj.weight": zeros(
_DIM, _DIM, dtype=weight_dtype
),
"model.layers.0.self_attn.o_proj.weight": zeros(
_DIM, _DIM, dtype=weight_dtype
),
"model.layers.0.post_attention_layernorm.weight": zeros(
_DIM, dtype=weight_dtype
),
"model.layers.0.self_attn.rotary_emb.inv_freq": zeros(
_DIM, dtype=weight_dtype
),
"model.layers.0.mlp.gate_proj.weight": zeros(
_HIDDEN_DIM, _DIM, dtype=weight_dtype
),
"model.layers.0.mlp.down_proj.weight": zeros(
_DIM, _HIDDEN_DIM, dtype=weight_dtype
),
"model.layers.0.mlp.up_proj.weight": zeros(
_HIDDEN_DIM, _DIM, dtype=weight_dtype
),
"model.norm.weight": zeros(_DIM, dtype=weight_dtype),
"lm_head.weight": zeros(_VOCAB_SIZE, _DIM, dtype=weight_dtype),
}

return state_dict

@pytest.fixture
def distributed_checkpointer(self, tmp_path) -> DistributedCheckpointer:
return DistributedCheckpointer(
checkpoint_dir=tmp_path,
output_dir=tmp_path,
)

def test_save_load_checkpoint(
self, distributed_checkpointer, state_dict, empty_state_dict
):
"""
Test ``load_checkpoint`` method within the DistributedCheckpointer.
We test:
* ``load_checkpoint`` loads the right sets of keys
* Internal state of the checkpointer is correctly updated.
"""

distributed_checkpointer.save_checkpoint(
state_dict=state_dict, epoch=1, save_async=False
)

checkpoint_path = Path.joinpath(
distributed_checkpointer._output_dir,
f"{distributed_checkpointer._checkpoint_dir_prefix}_1",
)

assert os.path.exists(checkpoint_path)

distributed_checkpointer.load_checkpoint(
state_dict=empty_state_dict,
)

for key in state_dict.keys():
assert torch.equal(state_dict[key], empty_state_dict[key])

# clean ups
shutil.rmtree(checkpoint_path)
16 changes: 13 additions & 3 deletions torchtune/training/checkpointing/_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from typing import Any, Dict, List, Optional, Protocol, Union

import torch
import torch.distributed as dist
from safetensors.torch import save_file
from torch.distributed.checkpoint import (
async_save,
Expand Down Expand Up @@ -951,19 +952,25 @@ class DistributedCheckpointer(_CheckpointerInterface):
Args:
checkpoint_dir (str): Directory containing the checkpoint files
output_dir (str): Directory to save the checkpoint files
process_group (Optional[dist.ProcessGroup]): Optional process group to use
for distributed saving/loading.
If None, the default process group will be used.
For checkpointing, gloo CPU-based backend is needed.
"""

def __init__(
self,
checkpoint_dir: str,
output_dir: str,
process_group: Optional[dist.ProcessGroup] = None,
) -> None:
self._checkpoint_dir = Path(checkpoint_dir)
self._output_dir = Path(output_dir)
self._checkpoint_future = None
self._checkpoint_dir_prefix = "checkpoint"
self._metadata_file = ".metadata"
_, self._rank = training.get_world_size_and_rank()
self._process_group: Optional[dist.ProcessGroup] = process_group

def _get_latest_intermediate_checkpoint(self) -> Optional[str]:
"""
Expand Down Expand Up @@ -1012,8 +1019,9 @@ def load_checkpoint(
log_rank_zero(logger, msg=f"Loading checkpoint from {checkpoint_path}")

load(
state_dict,
state_dict=state_dict,
storage_reader=FileSystemReader(checkpoint_path),
process_group=self._process_group,
)

return state_dict
Expand Down Expand Up @@ -1078,13 +1086,14 @@ def callback(
)

self._checkpoint_future = async_save(
state_dict,
state_dict=state_dict,
storage_writer=FileSystemWriter(
checkpoint_path,
thread_count=16,
single_file_per_rank=False,
sync_files=False,
),
process_group=self._process_group,
)

logger.info(
Expand All @@ -1100,13 +1109,14 @@ def callback(
)

save(
state_dict,
state_dict=state_dict,
storage_writer=FileSystemWriter(
checkpoint_path,
thread_count=8,
single_file_per_rank=False,
sync_files=False,
),
process_group=self._process_group,
)

log_rank_zero(
Expand Down

0 comments on commit 1b1342e

Please sign in to comment.