Skip to content

Commit

Permalink
remove is_rank_zero from load_from_full_model_state_dict, add torch v…
Browse files Browse the repository at this point in the history
…ersion check
  • Loading branch information
mori360 committed Dec 20, 2024
1 parent f14f90b commit 1e7f47e
Show file tree
Hide file tree
Showing 11 changed files with 139 additions and 47 deletions.
1 change: 0 additions & 1 deletion recipes/dev/early_exit_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,7 +556,6 @@ def _setup_model(
model,
model_state_dict,
self._device,
self._is_rank_zero,
strict=True,
cpu_offload=fsdp_cpu_offload,
)
Expand Down
1 change: 0 additions & 1 deletion recipes/full_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,6 @@ def _setup_model(
model,
model_state_dict,
self._device,
self._is_rank_zero,
strict=True,
cpu_offload=fsdp_cpu_offload,
)
Expand Down
3 changes: 0 additions & 3 deletions recipes/knowledge_distillation_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,6 @@ def _setup_model(
model,
lora_weights_state_dict,
self._device,
self._is_rank_zero,
cpu_offload=fsdp_cpu_offload,
)
else:
Expand All @@ -486,7 +485,6 @@ def _setup_model(
model,
base_model_state_dict,
self._device,
self._is_rank_zero,
cpu_offload=fsdp_cpu_offload,
)
for m in model.modules():
Expand Down Expand Up @@ -574,7 +572,6 @@ def _setup_teacher_model(
model,
model_state_dict,
self._device,
self._is_rank_zero,
strict=True,
cpu_offload=fsdp_cpu_offload,
)
Expand Down
2 changes: 0 additions & 2 deletions recipes/lora_dpo_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,6 @@ def _setup_model(
model,
lora_weights_state_dict,
self._device,
self._is_rank_zero,
cpu_offload=fsdp_cpu_offload,
)
else:
Expand All @@ -410,7 +409,6 @@ def _setup_model(
model,
base_model_state_dict,
self._device,
self._is_rank_zero,
cpu_offload=fsdp_cpu_offload,
)
is_dora = False
Expand Down
2 changes: 0 additions & 2 deletions recipes/lora_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,6 @@ def _setup_model(
model,
lora_weights_state_dict,
self._device,
self._is_rank_zero,
cpu_offload=fsdp_cpu_offload,
)
else:
Expand All @@ -505,7 +504,6 @@ def _setup_model(
model,
base_model_state_dict,
self._device,
self._is_rank_zero,
cpu_offload=fsdp_cpu_offload,
)
for m in model.modules():
Expand Down
2 changes: 0 additions & 2 deletions recipes/lora_finetune_distributed_multi_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,6 @@ def _setup_model(
model,
lora_weights_state_dict,
self._device,
self._is_rank_zero,
cpu_offload=fsdp_cpu_offload,
)
else:
Expand All @@ -500,7 +499,6 @@ def _setup_model(
model,
base_model_state_dict,
self._device,
self._is_rank_zero,
cpu_offload=fsdp_cpu_offload,
)
for m in model.modules():
Expand Down
1 change: 0 additions & 1 deletion recipes/qat_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,7 +508,6 @@ def _setup_model(
model,
model_state_dict,
self._device,
self._is_rank_zero,
strict=True,
cpu_offload=fsdp_cpu_offload,
)
Expand Down
2 changes: 0 additions & 2 deletions recipes/qat_lora_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,6 @@ def _setup_model(
model,
lora_weights_state_dict,
self._device,
self._is_rank_zero,
cpu_offload=fsdp_cpu_offload,
)
else:
Expand All @@ -550,7 +549,6 @@ def _setup_model(
model,
base_model_state_dict,
self._device,
self._is_rank_zero,
cpu_offload=fsdp_cpu_offload,
)
validate_missing_and_unexpected_for_lora(
Expand Down
2 changes: 0 additions & 2 deletions tests/torchtune/modules/peft/test_dora.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,6 @@ def _test_dora_distributed_init(self, load_dora_weights):
ffn,
adapter_state_dict,
device,
is_rank_zero,
)
if is_rank_zero:
for dora_linear in [ffn.w1, ffn.w2, ffn.w3]:
Expand Down Expand Up @@ -377,7 +376,6 @@ def _test_dora_distributed_init(self, load_dora_weights):
ffn,
base_model_state_dict,
device,
is_rank_zero,
)

# After this, everything should be off meta device
Expand Down
3 changes: 1 addition & 2 deletions tests/torchtune/training/test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,6 @@ def test_lora_state_dict(self):
fsdp_model_to_load,
copy.deepcopy(base_model.state_dict()),
torch.device("cuda"),
is_rank_zero,
)
fsdp_optim_to_load = torch.optim.Adam(
fsdp_model_to_load.parameters(), weight_decay=0.01, lr=0.01
Expand Down Expand Up @@ -355,7 +354,7 @@ def _test_qlora_state_dict(self, enable_activation_checkpointing: bool):
fully_shard(m)
fully_shard(fsdp_model_to_load)
training.load_from_full_model_state_dict(
fsdp_model_to_load, expected_model_sd, torch.device("cuda"), is_rank_zero
fsdp_model_to_load, expected_model_sd, torch.device("cuda")
)
fsdp_model_to_load(inp)
sharded_model_sd = fsdp_model_to_load.state_dict()
Expand Down
167 changes: 138 additions & 29 deletions torchtune/training/_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from torchtune.modules.peft import get_adapter_state_dict
from torchtune.utils import get_device, get_logger
from torchtune.utils._logging import deprecated
from torchtune.utils._version import torch_version_ge

_log: logging.Logger = get_logger()

Expand Down Expand Up @@ -161,7 +162,6 @@ def load_from_full_model_state_dict(
model: "FSDPModule", # noqa
full_sd: Dict[str, Any],
device: torch.device,
is_rank_zero: bool,
strict: bool = False,
cpu_offload: bool = False,
):
Expand All @@ -173,13 +173,97 @@ def load_from_full_model_state_dict(
- `is_rank_zero` matters if only rank 0 pass in non-empty `full_sd` and
we need to broadcast from rank 0
"""
has_nf4 = any(
hasattr(param, "_local_tensor") and isinstance(param._local_tensor, NF4Tensor)
for param in model.parameters()
)
# has_nf4 = True
meta_sharded_sd = model.state_dict()
if has_nf4:
# There are some changes at `set_model_state_dict` to adjust multiple devices from local_state in TorchTune,
# keey version check until PyTorch changes are on stable.
if torch_version_ge("2.6.0.dev20241220"):
has_nf4 = any(
hasattr(param, "_local_tensor")
and isinstance(param._local_tensor, NF4Tensor)
for param in model.parameters()
)
# has_nf4 = True
meta_sharded_sd = model.state_dict()
if has_nf4:
sharded_sd = {}
for param_name, full_tensor in full_sd.items():
sharded_meta_param = meta_sharded_sd.get(param_name)
full_tensor = full_tensor.to(sharded_meta_param.dtype).to(device)
if hasattr(sharded_meta_param, "_local_tensor") and isinstance(
sharded_meta_param._local_tensor, NF4Tensor
):
block_size = sharded_meta_param._local_tensor.block_size
scaler_block_size = (
sharded_meta_param._local_tensor.scaler_block_size
)
full_tensor = to_nf4(
full_tensor,
block_size=block_size,
scaler_block_size=scaler_block_size,
)
# replicating logic from `_fsdp_param.py`` `_init_sharded_param`
# otherwise `distribute_tensor(DTensor(local=NF4))`
# requires dispatching `c10d.scatter_``
# long-term solution is `swap_tensor`
mesh = sharded_meta_param.device_mesh
if mesh.ndim > 1:
raise NotImplementedError(
f"only support 1D FSDP but got {mesh.ndim=}"
)
shard_mesh_dim = 0
shard_world_size = mesh.size(shard_mesh_dim)
shard_rank = cast(
torch.distributed.ProcessGroup, mesh.get_group(shard_mesh_dim)
).rank()
chunk = list(torch.chunk(full_tensor, shard_world_size, dim=0))[
shard_rank
]
sharded_param = full_tensor.new_zeros(chunk.size())
sharded_param[: chunk.size(0)].copy_(chunk)

# TODO: change to from_local API (need to add view support for NF4)
sharded_tensor = DTensor(
local_tensor=sharded_param,
spec=DTensorSpec(
mesh=sharded_meta_param.device_mesh,
placements=sharded_meta_param.placements,
tensor_meta=TensorMeta(
shape=sharded_meta_param.size(),
dtype=sharded_meta_param.dtype,
stride=sharded_meta_param.stride(),
),
),
requires_grad=sharded_meta_param.requires_grad,
)

elif not hasattr(sharded_meta_param, "device_mesh"):
# In cases where parts of the model aren't sharded, some parameters will be plain tensors
sharded_tensor = full_tensor
else:
sharded_tensor = distribute_tensor(
full_tensor,
sharded_meta_param.device_mesh,
sharded_meta_param.placements,
)
if cpu_offload:
sharded_tensor = sharded_tensor.cpu()
sharded_sd[param_name] = nn.Parameter(sharded_tensor)
# choose `assign=True` since we cannot call `copy_` on meta tensor
return model.load_state_dict(sharded_sd, strict=strict, assign=True)
else:
for param_name in full_sd.keys():
sharded_meta_param = meta_sharded_sd.get(param_name)
full_sd[param_name] = full_sd[param_name].to(sharded_meta_param.dtype)
options = StateDictOptions(
full_state_dict=True,
broadcast_from_rank0=True,
strict=strict,
cpu_offload=cpu_offload,
)
return set_model_state_dict(
model=model, model_state_dict=full_sd, options=options
)
else:
meta_sharded_sd = model.state_dict()
sharded_sd = {}
for param_name, full_tensor in full_sd.items():
sharded_meta_param = meta_sharded_sd.get(param_name)
Expand Down Expand Up @@ -243,19 +327,6 @@ def load_from_full_model_state_dict(
sharded_sd[param_name] = nn.Parameter(sharded_tensor)
# choose `assign=True` since we cannot call `copy_` on meta tensor
return model.load_state_dict(sharded_sd, strict=strict, assign=True)
else:
for param_name in full_sd.keys():
sharded_meta_param = meta_sharded_sd.get(param_name)
full_sd[param_name] = full_sd[param_name].to(sharded_meta_param.dtype)
options = StateDictOptions(
full_state_dict=True,
broadcast_from_rank0=True,
strict=strict,
cpu_offload=cpu_offload,
)
return set_model_state_dict(
model=model, model_state_dict=full_sd, options=options
)


def _gather_nf4_tensor(sharded_param: nn.Parameter) -> nn.Parameter:
Expand Down Expand Up @@ -376,14 +447,52 @@ def load_from_full_optimizer_state_dict(
Converting full optimizer state to sharded state dict
and loading it into optimizer
"""
options = StateDictOptions(
full_state_dict=True,
broadcast_from_rank0=True,
cpu_offload=device is torch.device("cpu"),
)
set_optimizer_state_dict(
model=model, optimizers=opt, optim_state_dict=full_sd, options=options
)
if torch_version_ge("2.6.0.dev20241220"):
options = StateDictOptions(
full_state_dict=True,
broadcast_from_rank0=True,
cpu_offload=device is torch.device("cpu"),
)
set_optimizer_state_dict(
model=model, optimizers=opt, optim_state_dict=full_sd, options=options
)
else:
PARAMS = "params" # noqa: N806
_init_optim_state(opt)
param_groups = opt.state_dict()["param_groups"]
state = opt.state_dict()["state"]

full_param_groups = full_sd["param_groups"]
full_state = full_sd["state"]

for param_group, full_param_group in zip(param_groups, full_param_groups):
for key, value in full_param_group.items():
if key == PARAMS:
continue
param_group[key] = value
for pid, full_pid in zip(param_group[PARAMS], full_param_group[PARAMS]):
if pid not in state:
continue
param_state = state[pid]
full_param_state = full_state[full_pid]
for attr, full_tensor in full_param_state.items():
sharded_tensor = param_state[attr]
if isinstance(sharded_tensor, DTensor):
# exp_avg is DTensor
param_state[attr] = distribute_tensor(
full_tensor,
sharded_tensor.device_mesh,
sharded_tensor.placements,
)
else:
# step is plain tensor
param_state[attr] = full_tensor
opt.load_state_dict(
{
"param_groups": param_groups,
"state": state,
}
)


def get_shard_conditions(
Expand Down

0 comments on commit 1e7f47e

Please sign in to comment.