Skip to content

Commit

Permalink
modify load_from_full_model_state_dict to optimize memory cost as before
Browse files Browse the repository at this point in the history
  • Loading branch information
mori360 committed Dec 12, 2024
1 parent 23c8e10 commit d243883
Showing 1 changed file with 63 additions and 80 deletions.
143 changes: 63 additions & 80 deletions torchtune/training/_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@
from torch import nn

from torch.distributed._composable.fsdp import CPUOffloadPolicy, fully_shard
from torch.distributed._tensor import distribute_tensor, DTensor
from torch.distributed._state_dict_utils import _broadcast_state_dict
from torch.distributed._tensor import DTensor
from torch.distributed._tensor.placement_types import DTensorSpec, TensorMeta
from torch.distributed.checkpoint.state_dict import (
get_model_state_dict,
get_optimizer_state_dict,
set_model_state_dict,
set_optimizer_state_dict,
StateDictOptions,
)
Expand Down Expand Up @@ -171,84 +171,69 @@ def load_from_full_model_state_dict(
"""
meta_sharded_sd = model.state_dict()
sharded_sd = {}
has_nf4 = any(
hasattr(param, "_local_tensor") and isinstance(param._local_tensor, NF4Tensor)
for param in model.parameters()
)
for param_name in full_sd.keys():
for param_name, full_tensor in full_sd.items():
sharded_meta_param = meta_sharded_sd.get(param_name)
full_sd[param_name] = (
full_sd[param_name].to(sharded_meta_param.dtype).to(device)
)
if has_nf4:
for param_name, full_tensor in full_sd.items():
sharded_meta_param = meta_sharded_sd.get(param_name)
if hasattr(sharded_meta_param, "_local_tensor") and isinstance(
sharded_meta_param._local_tensor, NF4Tensor
):
block_size = sharded_meta_param._local_tensor.block_size
scaler_block_size = sharded_meta_param._local_tensor.scaler_block_size
full_tensor = to_nf4(
full_tensor,
block_size=block_size,
scaler_block_size=scaler_block_size,
)
# replicating logic from `_fsdp_param.py`` `_init_sharded_param`
# otherwise `distribute_tensor(DTensor(local=NF4))`
# requires dispatching `c10d.scatter_``
# long-term solution is `swap_tensor`
mesh = sharded_meta_param.device_mesh
if mesh.ndim > 1:
raise NotImplementedError(
f"only support 1D FSDP but got {mesh.ndim=}"
)
shard_mesh_dim = 0
shard_world_size = mesh.size(shard_mesh_dim)
shard_rank = cast(
torch.distributed.ProcessGroup, mesh.get_group(shard_mesh_dim)
).rank()
chunk = list(torch.chunk(full_tensor, shard_world_size, dim=0))[
shard_rank
]
sharded_param = full_tensor.new_zeros(chunk.size())
sharded_param[: chunk.size(0)].copy_(chunk)

# TODO: change to from_local API (need to add view support for NF4)
sharded_tensor = DTensor(
local_tensor=sharded_param,
spec=DTensorSpec(
mesh=sharded_meta_param.device_mesh,
placements=sharded_meta_param.placements,
tensor_meta=TensorMeta(
shape=sharded_meta_param.size(),
dtype=sharded_meta_param.dtype,
stride=sharded_meta_param.stride(),
),
full_tensor = full_tensor.to(sharded_meta_param.dtype).to(device)
if hasattr(sharded_meta_param, "_local_tensor") and isinstance(
sharded_meta_param._local_tensor, NF4Tensor
):
block_size = sharded_meta_param._local_tensor.block_size
scaler_block_size = sharded_meta_param._local_tensor.scaler_block_size
full_tensor = to_nf4(
full_tensor, block_size=block_size, scaler_block_size=scaler_block_size
)
# replicating logic from `_fsdp_param.py`` `_init_sharded_param`
# otherwise `distribute_tensor(DTensor(local=NF4))`
# requires dispatching `c10d.scatter_``
# long-term solution is `swap_tensor`
mesh = sharded_meta_param.device_mesh
if mesh.ndim > 1:
raise NotImplementedError(f"only support 1D FSDP but got {mesh.ndim=}")
shard_mesh_dim = 0
shard_world_size = mesh.size(shard_mesh_dim)
shard_rank = cast(
torch.distributed.ProcessGroup, mesh.get_group(shard_mesh_dim)
).rank()
chunk = list(torch.chunk(full_tensor, shard_world_size, dim=0))[shard_rank]
sharded_param = full_tensor.new_zeros(chunk.size())
sharded_param[: chunk.size(0)].copy_(chunk)

# TODO: change to from_local API (need to add view support for NF4)
sharded_tensor = DTensor(
local_tensor=sharded_param,
spec=DTensorSpec(
mesh=sharded_meta_param.device_mesh,
placements=sharded_meta_param.placements,
tensor_meta=TensorMeta(
shape=sharded_meta_param.size(),
dtype=sharded_meta_param.dtype,
stride=sharded_meta_param.stride(),
),
requires_grad=sharded_meta_param.requires_grad,
)
elif not hasattr(sharded_meta_param, "device_mesh"):
# In cases where parts of the model aren't sharded, some parameters will be plain tensors
sharded_tensor = full_tensor
else:
sharded_tensor = distribute_tensor(
full_tensor,
sharded_meta_param.device_mesh,
sharded_meta_param.placements,
)
if cpu_offload:
sharded_tensor = sharded_tensor.cpu()
sharded_sd[param_name] = nn.Parameter(sharded_tensor)
return model.load_state_dict(sharded_sd, strict=strict, assign=True)
),
requires_grad=sharded_meta_param.requires_grad,
)

else:
options = StateDictOptions(
full_state_dict=True,
broadcast_from_rank0=True,
strict=strict,
cpu_offload=cpu_offload,
)
set_model_state_dict(model=model, model_state_dict=full_sd, options=options)
elif not hasattr(sharded_meta_param, "device_mesh"):
# In cases where parts of the model aren't sharded, some parameters will be plain tensors
sharded_tensor = full_tensor
else:
local_state_dict = {param_name: sharded_meta_param}
_broadcast_state_dict(
{param_name: full_tensor},
local_state_dict,
device=(
device
if device != torch.device("meta")
else dist.distributed_c10d._get_pg_default_device()
),
strict=strict,
)
sharded_tensor = local_state_dict[param_name]
if cpu_offload:
sharded_tensor = sharded_tensor.cpu()
sharded_sd[param_name] = nn.Parameter(sharded_tensor)
# choose `assign=True` since we cannot call `copy_` on meta tensor
return model.load_state_dict(sharded_sd, strict=strict, assign=True)


def _gather_nf4_tensor(sharded_param: nn.Parameter) -> nn.Parameter:
Expand Down Expand Up @@ -363,9 +348,7 @@ def load_from_full_optimizer_state_dict(
Converting full optimizer state to sharded state dict
and loading it into optimizer
"""
options = StateDictOptions(
full_state_dict=True, broadcast_from_rank0=True, cpu_offload=True
)
options = StateDictOptions(full_state_dict=True, broadcast_from_rank0=True)
set_optimizer_state_dict(
model=model, optimizers=opt, optim_state_dict=full_sd, options=options
)
Expand Down

0 comments on commit d243883

Please sign in to comment.