Skip to content

Commit

Permalink
update _USE_DISTRIBUTED_STATE_DICT_API version check
Browse files Browse the repository at this point in the history
  • Loading branch information
mori360 committed Dec 20, 2024
1 parent 11e24f1 commit 3d0d26f
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions torchtune/training/_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@

_valid_distributed_single_node_nnodes = ["1:1", "1"]

torch_version = torch.__version__
_USE_DISTRIBUTED_STATE_DICT_API = (
"dev" not in torch_version and torch_version_ge("2.6.0")
) or ("dev" in torch_version and torch_version.split("dev")[1] >= "20241220")


def _get_sharding_strategy(strategy: str) -> ShardingStrategy:
"""Helper function to convert sharding strategy strings to ShardingStrategy enum."""
Expand Down Expand Up @@ -176,7 +181,7 @@ def load_from_full_model_state_dict(
"""
# There are some changes at `set_model_state_dict` to adjust multiple devices from local_state in TorchTune,
# keey version check until PyTorch changes are on stable.
if torch_version_ge("2.6.0.dev20241220"):
if _USE_DISTRIBUTED_STATE_DICT_API:
has_nf4 = any(
hasattr(param, "_local_tensor")
and isinstance(param._local_tensor, NF4Tensor)
Expand Down Expand Up @@ -448,7 +453,7 @@ def load_from_full_optimizer_state_dict(
Converting full optimizer state to sharded state dict
and loading it into optimizer
"""
if torch_version_ge("2.6.0.dev20241220"):
if _USE_DISTRIBUTED_STATE_DICT_API:
options = StateDictOptions(
full_state_dict=True,
broadcast_from_rank0=True,
Expand Down

0 comments on commit 3d0d26f

Please sign in to comment.