diff --git a/torchtune/training/_distributed.py b/torchtune/training/_distributed.py index 2c3ec255f..ff7290933 100644 --- a/torchtune/training/_distributed.py +++ b/torchtune/training/_distributed.py @@ -15,12 +15,12 @@ from torch import nn from torch.distributed._composable.fsdp import CPUOffloadPolicy, fully_shard -from torch.distributed._tensor import distribute_tensor, DTensor +from torch.distributed._state_dict_utils import _broadcast_state_dict +from torch.distributed._tensor import DTensor from torch.distributed._tensor.placement_types import DTensorSpec, TensorMeta from torch.distributed.checkpoint.state_dict import ( get_model_state_dict, get_optimizer_state_dict, - set_model_state_dict, set_optimizer_state_dict, StateDictOptions, ) @@ -171,84 +171,69 @@ def load_from_full_model_state_dict( """ meta_sharded_sd = model.state_dict() sharded_sd = {} - has_nf4 = any( - hasattr(param, "_local_tensor") and isinstance(param._local_tensor, NF4Tensor) - for param in model.parameters() - ) - for param_name in full_sd.keys(): + for param_name, full_tensor in full_sd.items(): sharded_meta_param = meta_sharded_sd.get(param_name) - full_sd[param_name] = ( - full_sd[param_name].to(sharded_meta_param.dtype).to(device) - ) - if has_nf4: - for param_name, full_tensor in full_sd.items(): - sharded_meta_param = meta_sharded_sd.get(param_name) - if hasattr(sharded_meta_param, "_local_tensor") and isinstance( - sharded_meta_param._local_tensor, NF4Tensor - ): - block_size = sharded_meta_param._local_tensor.block_size - scaler_block_size = sharded_meta_param._local_tensor.scaler_block_size - full_tensor = to_nf4( - full_tensor, - block_size=block_size, - scaler_block_size=scaler_block_size, - ) - # replicating logic from `_fsdp_param.py`` `_init_sharded_param` - # otherwise `distribute_tensor(DTensor(local=NF4))` - # requires dispatching `c10d.scatter_`` - # long-term solution is `swap_tensor` - mesh = sharded_meta_param.device_mesh - if mesh.ndim > 1: - raise NotImplementedError( - f"only support 1D FSDP but got {mesh.ndim=}" - ) - shard_mesh_dim = 0 - shard_world_size = mesh.size(shard_mesh_dim) - shard_rank = cast( - torch.distributed.ProcessGroup, mesh.get_group(shard_mesh_dim) - ).rank() - chunk = list(torch.chunk(full_tensor, shard_world_size, dim=0))[ - shard_rank - ] - sharded_param = full_tensor.new_zeros(chunk.size()) - sharded_param[: chunk.size(0)].copy_(chunk) - - # TODO: change to from_local API (need to add view support for NF4) - sharded_tensor = DTensor( - local_tensor=sharded_param, - spec=DTensorSpec( - mesh=sharded_meta_param.device_mesh, - placements=sharded_meta_param.placements, - tensor_meta=TensorMeta( - shape=sharded_meta_param.size(), - dtype=sharded_meta_param.dtype, - stride=sharded_meta_param.stride(), - ), + full_tensor = full_tensor.to(sharded_meta_param.dtype).to(device) + if hasattr(sharded_meta_param, "_local_tensor") and isinstance( + sharded_meta_param._local_tensor, NF4Tensor + ): + block_size = sharded_meta_param._local_tensor.block_size + scaler_block_size = sharded_meta_param._local_tensor.scaler_block_size + full_tensor = to_nf4( + full_tensor, block_size=block_size, scaler_block_size=scaler_block_size + ) + # replicating logic from `_fsdp_param.py`` `_init_sharded_param` + # otherwise `distribute_tensor(DTensor(local=NF4))` + # requires dispatching `c10d.scatter_`` + # long-term solution is `swap_tensor` + mesh = sharded_meta_param.device_mesh + if mesh.ndim > 1: + raise NotImplementedError(f"only support 1D FSDP but got {mesh.ndim=}") + shard_mesh_dim = 0 + shard_world_size = mesh.size(shard_mesh_dim) + shard_rank = cast( + torch.distributed.ProcessGroup, mesh.get_group(shard_mesh_dim) + ).rank() + chunk = list(torch.chunk(full_tensor, shard_world_size, dim=0))[shard_rank] + sharded_param = full_tensor.new_zeros(chunk.size()) + sharded_param[: chunk.size(0)].copy_(chunk) + + # TODO: change to from_local API (need to add view support for NF4) + sharded_tensor = DTensor( + local_tensor=sharded_param, + spec=DTensorSpec( + mesh=sharded_meta_param.device_mesh, + placements=sharded_meta_param.placements, + tensor_meta=TensorMeta( + shape=sharded_meta_param.size(), + dtype=sharded_meta_param.dtype, + stride=sharded_meta_param.stride(), ), - requires_grad=sharded_meta_param.requires_grad, - ) - elif not hasattr(sharded_meta_param, "device_mesh"): - # In cases where parts of the model aren't sharded, some parameters will be plain tensors - sharded_tensor = full_tensor - else: - sharded_tensor = distribute_tensor( - full_tensor, - sharded_meta_param.device_mesh, - sharded_meta_param.placements, - ) - if cpu_offload: - sharded_tensor = sharded_tensor.cpu() - sharded_sd[param_name] = nn.Parameter(sharded_tensor) - return model.load_state_dict(sharded_sd, strict=strict, assign=True) + ), + requires_grad=sharded_meta_param.requires_grad, + ) - else: - options = StateDictOptions( - full_state_dict=True, - broadcast_from_rank0=True, - strict=strict, - cpu_offload=cpu_offload, - ) - set_model_state_dict(model=model, model_state_dict=full_sd, options=options) + elif not hasattr(sharded_meta_param, "device_mesh"): + # In cases where parts of the model aren't sharded, some parameters will be plain tensors + sharded_tensor = full_tensor + else: + local_state_dict = {param_name: sharded_meta_param} + _broadcast_state_dict( + {param_name: full_tensor}, + local_state_dict, + device=( + device + if device != torch.device("meta") + else dist.distributed_c10d._get_pg_default_device() + ), + strict=strict, + ) + sharded_tensor = local_state_dict[param_name] + if cpu_offload: + sharded_tensor = sharded_tensor.cpu() + sharded_sd[param_name] = nn.Parameter(sharded_tensor) + # choose `assign=True` since we cannot call `copy_` on meta tensor + return model.load_state_dict(sharded_sd, strict=strict, assign=True) def _gather_nf4_tensor(sharded_param: nn.Parameter) -> nn.Parameter: @@ -363,9 +348,7 @@ def load_from_full_optimizer_state_dict( Converting full optimizer state to sharded state dict and loading it into optimizer """ - options = StateDictOptions( - full_state_dict=True, broadcast_from_rank0=True, cpu_offload=True - ) + options = StateDictOptions(full_state_dict=True, broadcast_from_rank0=True) set_optimizer_state_dict( model=model, optimizers=opt, optim_state_dict=full_sd, options=options )