diff --git a/torchtune/training/_distributed.py b/torchtune/training/_distributed.py index 8ce219d20..340185577 100644 --- a/torchtune/training/_distributed.py +++ b/torchtune/training/_distributed.py @@ -39,6 +39,11 @@ _valid_distributed_single_node_nnodes = ["1:1", "1"] +torch_version = torch.__version__ +_USE_DISTRIBUTED_STATE_DICT_API = ( + "dev" not in torch_version and torch_version_ge("2.6.0") +) or ("dev" in torch_version and torch_version.split("dev")[1] >= "20241220") + def _get_sharding_strategy(strategy: str) -> ShardingStrategy: """Helper function to convert sharding strategy strings to ShardingStrategy enum.""" @@ -176,7 +181,7 @@ def load_from_full_model_state_dict( """ # There are some changes at `set_model_state_dict` to adjust multiple devices from local_state in TorchTune, # keey version check until PyTorch changes are on stable. - if torch_version_ge("2.6.0.dev20241220"): + if _USE_DISTRIBUTED_STATE_DICT_API: has_nf4 = any( hasattr(param, "_local_tensor") and isinstance(param._local_tensor, NF4Tensor) @@ -448,7 +453,7 @@ def load_from_full_optimizer_state_dict( Converting full optimizer state to sharded state dict and loading it into optimizer """ - if torch_version_ge("2.6.0.dev20241220"): + if _USE_DISTRIBUTED_STATE_DICT_API: options = StateDictOptions( full_state_dict=True, broadcast_from_rank0=True,