Skip to content

Commit

Permalink
Move get_world_size_and_rank to utils (#2155)
Browse files Browse the repository at this point in the history
  • Loading branch information
joecummings authored Dec 13, 2024
1 parent cdaece1 commit 096881d
Show file tree
Hide file tree
Showing 17 changed files with 69 additions and 48 deletions.
1 change: 0 additions & 1 deletion docs/source/api_ref_training.rst
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ Utilities for enabling and working with distributed training.

init_distributed
is_distributed
get_world_size_and_rank
gather_cpu_state_dict

.. _ac_label:
Expand Down
1 change: 1 addition & 0 deletions docs/source/api_ref_utilities.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@ Miscellaneous
get_device
get_logger
torch_version_ge
get_world_size_and_rank
6 changes: 3 additions & 3 deletions recipes/dev/early_exit_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def __init__(self, cfg: DictConfig) -> None:

# _is_rank_zero is used primarily for logging. In the future, the logger
# should directly take care of this
_, rank = training.get_world_size_and_rank()
_, rank = utils.get_world_size_and_rank()
self._is_rank_zero = rank == 0

# Training cfg
Expand Down Expand Up @@ -646,7 +646,7 @@ def _setup_data(
DistributedSamplers with Map-style Datasets which fit into memory. Other samplers,
iterable datasets and streaming datasets are not supported.
"""
world_size, rank = training.get_world_size_and_rank()
world_size, rank = utils.get_world_size_and_rank()

if isinstance(cfg_dataset, ListConfig):
datasets = [
Expand Down Expand Up @@ -826,7 +826,7 @@ def train(self) -> None:
# clean up before training begins
training.cleanup_before_training()

world_size, rank = training.get_world_size_and_rank()
world_size, rank = utils.get_world_size_and_rank()

# zero out the gradients before starting training
if not self._optimizer_in_bwd:
Expand Down
6 changes: 3 additions & 3 deletions recipes/full_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def __init__(self, cfg: DictConfig) -> None:
)
self._log_peak_memory_stats = False

_, rank = training.get_world_size_and_rank()
_, rank = utils.get_world_size_and_rank()
self._is_rank_zero = rank == 0

# Training cfg
Expand Down Expand Up @@ -619,7 +619,7 @@ def _setup_data(
DistributedSamplers with Map-style Datasets which fit into memory. Other samplers,
iterable datasets and streaming datasets are not supported.
"""
world_size, rank = training.get_world_size_and_rank()
world_size, rank = utils.get_world_size_and_rank()

if isinstance(cfg_dataset, ListConfig):
datasets = [
Expand Down Expand Up @@ -757,7 +757,7 @@ def train(self) -> None:
# clean up before training begins
training.cleanup_before_training()

world_size, rank = training.get_world_size_and_rank()
world_size, rank = utils.get_world_size_and_rank()

# zero out the gradients before starting training
if not self._optimizer_in_bwd:
Expand Down
6 changes: 3 additions & 3 deletions recipes/knowledge_distillation_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def __init__(self, cfg: DictConfig) -> None:
"fp16 precision is not supported in this recipe. Please use fp32 or bf16."
)

_, rank = training.get_world_size_and_rank()
_, rank = utils.get_world_size_and_rank()

self._is_rank_zero = rank == 0

Expand Down Expand Up @@ -646,7 +646,7 @@ def _setup_data(
Map-style Datasets which fit into memory and an option for random shuffling.
Samplers, iterable datasets, and streaming datasets are not supported.
"""
world_size, rank = training.get_world_size_and_rank()
world_size, rank = utils.get_world_size_and_rank()

if isinstance(cfg_dataset, ListConfig):
datasets = [
Expand Down Expand Up @@ -815,7 +815,7 @@ def train(self) -> None:
# clean up before training begins
training.cleanup_before_training()

world_size, rank = training.get_world_size_and_rank()
world_size, rank = utils.get_world_size_and_rank()

# zero out the gradients before starting training
self._optimizer.zero_grad()
Expand Down
6 changes: 3 additions & 3 deletions recipes/lora_dpo_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def __init__(self, cfg: DictConfig) -> None:
"full fp16 training is not supported with this recipe. Please use bf16 or fp32 instead."
)

_, rank = training.get_world_size_and_rank()
_, rank = utils.get_world_size_and_rank()

self._is_rank_zero = rank == 0

Expand Down Expand Up @@ -492,7 +492,7 @@ def _setup_data(
DistributedSamplers with Map-style Datasets which fit into memory. Other samplers,
iterable datasets and streaming datasets are not supported.
"""
world_size, rank = training.get_world_size_and_rank()
world_size, rank = utils.get_world_size_and_rank()

if isinstance(cfg_dataset, ListConfig):
datasets = [
Expand Down Expand Up @@ -642,7 +642,7 @@ def train(self) -> None:
# clean up before training begins
training.cleanup_before_training()

_, rank = training.get_world_size_and_rank()
_, rank = utils.get_world_size_and_rank()

# zero out the gradients before starting training
self._optimizer.zero_grad()
Expand Down
6 changes: 3 additions & 3 deletions recipes/lora_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def __init__(self, cfg: DictConfig) -> None:
"full fp16 training is not supported with this recipe. Please use bf16 or fp32 instead."
)

_, rank = training.get_world_size_and_rank()
_, rank = utils.get_world_size_and_rank()

self._is_rank_zero = rank == 0

Expand Down Expand Up @@ -584,7 +584,7 @@ def _setup_data(
DistributedSamplers with Map-style Datasets which fit into memory. Other samplers,
iterable datasets and streaming datasets are not supported.
"""
world_size, rank = training.get_world_size_and_rank()
world_size, rank = utils.get_world_size_and_rank()

if isinstance(cfg_dataset, ListConfig):
datasets = [
Expand Down Expand Up @@ -746,7 +746,7 @@ def train(self) -> None:
# clean up before training begins
training.cleanup_before_training()

world_size, rank = training.get_world_size_and_rank()
world_size, rank = utils.get_world_size_and_rank()

# zero out the gradients before starting training
self._optimizer.zero_grad()
Expand Down
6 changes: 3 additions & 3 deletions recipes/qat_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def __init__(self, cfg: DictConfig) -> None:
)
self._log_peak_memory_stats = False

_, rank = training.get_world_size_and_rank()
_, rank = utils.get_world_size_and_rank()
self._is_rank_zero = rank == 0

# Training cfg
Expand Down Expand Up @@ -591,7 +591,7 @@ def _setup_data(
DistributedSamplers with Map-style Datasets which fit into memory. Other samplers,
iterable datasets and streaming datasets are not supported.
"""
world_size, rank = training.get_world_size_and_rank()
world_size, rank = utils.get_world_size_and_rank()

if isinstance(cfg_dataset, ListConfig):
datasets = [
Expand Down Expand Up @@ -729,7 +729,7 @@ def train(self) -> None:
# clean up before training begins
training.cleanup_before_training()

world_size, rank = training.get_world_size_and_rank()
world_size, rank = utils.get_world_size_and_rank()

# zero out the gradients before starting training
if not self._optimizer_in_bwd:
Expand Down
6 changes: 3 additions & 3 deletions recipes/qat_lora_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def __init__(self, cfg: DictConfig) -> None:
"full fp16 training is not supported with this recipe. Please use bf16 or fp32 instead."
)

_, rank = training.get_world_size_and_rank()
_, rank = utils.get_world_size_and_rank()

# _is_rank_zero is used primarily for logging. In the future, the logger
# should directly take care of this
Expand Down Expand Up @@ -620,7 +620,7 @@ def _setup_data(
DistributedSamplers with Map-style Datasets which fit into memory. Other samplers,
iterable datasets and streaming datasets are not supported.
"""
world_size, rank = training.get_world_size_and_rank()
world_size, rank = utils.get_world_size_and_rank()

if isinstance(cfg_dataset, ListConfig):
datasets = [
Expand Down Expand Up @@ -784,7 +784,7 @@ def train(self) -> None:
# clean up before training begins
training.cleanup_before_training()

world_size, rank = training.get_world_size_and_rank()
world_size, rank = utils.get_world_size_and_rank()

# zero out the gradients before starting training
self._optimizer.zero_grad()
Expand Down
16 changes: 0 additions & 16 deletions tests/torchtune/training/test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,6 @@ def _test_worker_fn(init_pg_explicit: bool) -> None:
pg_backend == "gloo"
), f"Expected 'gloo' backend, but received {pg_backend}"

@staticmethod
def _test_world_size_with_cpu_device(expected_world_size: int) -> None:
training.init_distributed(backend="gloo")
world_size, _ = training.get_world_size_and_rank()
if world_size != expected_world_size:
raise AssertionError(
f"Expected different world size: received {world_size}, expected {expected_world_size}"
)

def _test_launch_worker(
self,
get_pet_launch_config,
Expand All @@ -84,13 +75,6 @@ def test_init_from_env_dup(self, get_pet_launch_config) -> None:
# trivial test case to ensure test passes with no exceptions
assert True

def test_world_size_with_cpu(self, get_pet_launch_config) -> None:
desired_world_size = 4
lc = get_pet_launch_config(desired_world_size)
launcher.elastic_launch(lc, entrypoint=self._test_world_size_with_cpu_device)(
desired_world_size
)

def test_validate_no_params_on_meta_device(self) -> None:
with torch.device("meta"):
model = torch.nn.Linear(3, 3)
Expand Down
21 changes: 21 additions & 0 deletions tests/torchtune/utils/test_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
import pytest

import torch

from torch.distributed import launcher
from torchtune.utils._device import (
_get_device_type_from_env,
_setup_device,
Expand All @@ -20,13 +22,32 @@
get_device,
get_device_support,
get_torch_device_namespace,
get_world_size_and_rank,
)


class TestDevice:

cuda_available: bool = torch.cuda.is_available()

def _create_world(self, expected_world_size: int) -> None:
torch.distributed.init_process_group(backend="gloo")
world_size, _ = get_world_size_and_rank()
if world_size != expected_world_size:
raise AssertionError(
f"Expected different world size: received {world_size}, expected {expected_world_size}"
)

def test_world_size_with_cpu(self, get_pet_launch_config) -> None:
desired_world_size = 4
lc = get_pet_launch_config(desired_world_size)
launcher.elastic_launch(lc, entrypoint=self._create_world)(desired_world_size)

def test_rank_with_cpu_device(self) -> None:
"""Very, very basic test"""
_, rank = get_world_size_and_rank()
assert rank == 0

@patch("torch.cuda.is_available", return_value=False)
def test_get_cpu_device(self, mock_cuda):
devices = [None, "cpu", "meta"]
Expand Down
9 changes: 6 additions & 3 deletions torchtune/training/_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,8 @@
from torch.optim import Optimizer
from torchao.dtypes.nf4tensor import NF4Tensor, to_nf4
from torchtune.modules import TransformerDecoder
from torchtune.utils import get_logger

from torchtune.utils._device import get_device
from torchtune.utils import get_device, get_logger
from torchtune.utils._logging import deprecated

_log: logging.Logger = get_logger()

Expand Down Expand Up @@ -117,6 +116,10 @@ def set_torch_num_threads() -> None:
_log.info(f"Set intra op parallelism no. of threads to {num_threads}")


@deprecated(
msg="`get_world_size_and_rank` will move to `torchtune.utils._device` in future releases. "
"Please use `torchtune.utils.get_world_size_and_rank` instead."
)
def get_world_size_and_rank() -> Tuple[int, int]:
"""Function that gets the current world size (aka total number
of ranks) and rank number of the current process in the default process group.
Expand Down
3 changes: 1 addition & 2 deletions torchtune/training/_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,8 @@
from omegaconf import DictConfig
from torch._C._profiler import _ExperimentalConfig
from torch.profiler import tensorboard_trace_handler
from torchtune.training import get_world_size_and_rank

from torchtune.utils import get_logger
from torchtune.utils import get_logger, get_world_size_and_rank

log = get_logger("INFO")

Expand Down
3 changes: 1 addition & 2 deletions torchtune/training/metric_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,8 @@

from numpy import ndarray
from omegaconf import DictConfig, OmegaConf
from torchtune.training._distributed import get_world_size_and_rank

from torchtune.utils import get_logger
from torchtune.utils import get_logger, get_world_size_and_rank
from typing_extensions import Protocol

Scalar = Union[torch.Tensor, ndarray, int, float]
Expand Down
4 changes: 2 additions & 2 deletions torchtune/training/seed.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
import numpy as np
import torch

from torchtune.training._distributed import _broadcast_tensor, get_world_size_and_rank
from torchtune.utils import get_logger
from torchtune.training._distributed import _broadcast_tensor
from torchtune.utils import get_logger, get_world_size_and_rank

_log: logging.Logger = get_logger()

Expand Down
2 changes: 2 additions & 0 deletions torchtune/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,14 @@
get_device,
get_device_support,
get_torch_device_namespace,
get_world_size_and_rank,
)
from ._logging import get_logger, log_rank_zero

from ._version import torch_version_ge

__all__ = [
"get_world_size_and_rank",
"batch_to_device",
"get_device",
"get_logger",
Expand Down
15 changes: 14 additions & 1 deletion torchtune/utils/_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import os
from enum import Enum
from typing import Optional
from typing import Optional, Tuple

import torch

Expand All @@ -21,6 +21,19 @@
BlockMask = torch.Tensor


def get_world_size_and_rank() -> Tuple[int, int]:
"""Function that gets the current world size (aka total number
of ranks) and rank number of the current process in the default process group.
Returns:
Tuple[int, int]: world size, rank
"""
if torch.distributed.is_available() and torch.distributed.is_initialized():
return torch.distributed.get_world_size(), torch.distributed.get_rank()
else:
return 1, 0


def is_torch_npu_available() -> bool:
"""Check the availability of NPU"""
try:
Expand Down

0 comments on commit 096881d

Please sign in to comment.