Skip to content

Commit

Permalink
[bug] fix sharding multimodal (#1889)
Browse files Browse the repository at this point in the history
Co-authored-by: Felipe Mello <[email protected]>
Co-authored-by: ebsmothers <[email protected]>
  • Loading branch information
3 people authored Oct 24, 2024
1 parent 74139c9 commit bc486d4
Show file tree
Hide file tree
Showing 7 changed files with 124 additions and 110 deletions.
2 changes: 1 addition & 1 deletion recipes/configs/llama3_2_vision/11B_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ device: cuda

# Memory management
enable_activation_checkpointing: True
custom_sharded_layers: ['tok_embeddings', 'output']
custom_sharded_layers: ['decoder.tok_embeddings']
dtype: bf16

# Logging
Expand Down
40 changes: 13 additions & 27 deletions recipes/full_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,9 +225,11 @@ def setup(self, cfg: DictConfig) -> None:
self._optimizer = self._setup_optimizer(
cfg_optimizer=cfg.optimizer,
optimizer_in_bwd=self._optimizer_in_bwd,
opt_state_dict=checkpoint_dict[training.OPT_KEY]
if self._resume_from_checkpoint
else None,
opt_state_dict=(
checkpoint_dict[training.OPT_KEY]
if self._resume_from_checkpoint
else None
),
)

# initialize loss
Expand Down Expand Up @@ -350,10 +352,10 @@ def _setup_model(
self,
cfg_model: DictConfig,
enable_activation_checkpointing: bool,
custom_sharded_layers: Optional[List[str]],
fsdp_cpu_offload: bool,
reshard_after_forward: bool,
model_state_dict: Dict[str, Any],
custom_sharded_layers: Optional[List[str]] = None,
ac_mode: Optional[str] = None,
ac_option: Optional[int] = None,
) -> nn.Module:
Expand Down Expand Up @@ -396,29 +398,13 @@ def _setup_model(
model, auto_wrap_policy={modules.TransformerSelfAttentionLayer}
)

# For FSDP sharding, we can condition on either the module or its name
# Shard conditions should be callables taking name (relative to model root)
# and the module itself and returning a bool on whether to shard the given module
fsdp_shard_conditions = []

# Shard transformer decoder layers (or AC-wrapped versions)
# Alternatively we could condition on the module type (TransformerDecoder or CheckpointWrapper)
# But directly using the name is more concise
def _is_layer_fqn(s: str) -> bool:
"""
Return True for layers.i and False for all other module names
Covers sharding for both AC-wrapped and non-AC-wrapped modules in one shot
"""
s_list = s.split(".")
return len(s_list) == 2 and s_list[0] == "layers" and str.isdigit(s_list[1])

fsdp_shard_conditions = [lambda n, m: _is_layer_fqn(n)]

# If wrapping any layers separately, we can add another shard condition
# A layer will be sharded if any of the fsdp_shard_conditions are met
if custom_sharded_layers:
fsdp_shard_conditions += [lambda n, m: n in custom_sharded_layers]

# For FSDP sharding
fsdp_shard_conditions = [
partial(
training.get_shard_conditions,
names_to_match=custom_sharded_layers,
)
]
training.shard_model(
model=model,
shard_conditions=fsdp_shard_conditions,
Expand Down
29 changes: 9 additions & 20 deletions recipes/lora_dpo_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import time

from functools import partial
from typing import Any, Dict, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple
from warnings import warn

import torch
Expand Down Expand Up @@ -290,6 +290,7 @@ def _setup_model(
fsdp_cpu_offload: bool,
reshard_after_forward: bool,
base_model_state_dict: Dict[str, Any],
custom_sharded_layers: Optional[List[str]] = None,
lora_weights_state_dict: Optional[Dict[str, Any]] = None,
) -> nn.Module:
"""
Expand Down Expand Up @@ -323,28 +324,16 @@ def _setup_model(
model, auto_wrap_policy={modules.TransformerSelfAttentionLayer}
)

# For FSDP sharding, we can condition on either the module or its name
# Shard conditions should be callables taking name (relative to model root)
# and the module itself and returning a bool on whether to shard the given module

# Shard transformer decoder layers (or AC-wrapped versions)
# Alternatively we could condition on the module type (TransformerDecoder or CheckpointWrapper)
# But directly using the name is more concise
def _is_layer_name(name: str, module: nn.Module) -> bool:
"""
Return True for layers.i and False for all other module names
Covers sharding for both AC-wrapped and non-AC-wrapped modules in one shot
"""
name_list = name.split(".")
return (
len(name_list) == 2
and name_list[0] == "layers"
and str.isdigit(name_list[1])
# For FSDP sharding
fsdp_shard_conditions = [
partial(
training.get_shard_conditions,
names_to_match=custom_sharded_layers,
)

]
training.shard_model(
model=model,
shard_conditions=[_is_layer_name],
shard_conditions=fsdp_shard_conditions,
cpu_offload=fsdp_cpu_offload,
reshard_after_forward=reshard_after_forward,
)
Expand Down
45 changes: 18 additions & 27 deletions recipes/lora_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import time

from functools import partial
from typing import Any, Dict, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union
from warnings import warn

import torch
Expand Down Expand Up @@ -408,6 +408,7 @@ def _setup_model(
fsdp_cpu_offload: bool,
reshard_after_forward: bool,
base_model_state_dict: Dict[str, Any],
custom_sharded_layers: Optional[List[str]] = None,
lora_weights_state_dict: Optional[Dict[str, Any]] = None,
) -> nn.Module:
"""
Expand Down Expand Up @@ -445,28 +446,16 @@ def _setup_model(
model, auto_wrap_policy={modules.TransformerSelfAttentionLayer}
)

# For FSDP sharding, we can condition on either the module or its name
# Shard conditions should be callables taking name (relative to model root)
# and the module itself and returning a bool on whether to shard the given module

# Shard transformer decoder layers (or AC-wrapped versions)
# Alternatively we could condition on the module type (TransformerDecoder or CheckpointWrapper)
# But directly using the name is more concise
def _is_layer_name(name: str, module: nn.Module) -> bool:
"""
Return True for layers.i and False for all other module names
Covers sharding for both AC-wrapped and non-AC-wrapped modules in one shot
"""
name_list = name.split(".")
return (
len(name_list) == 2
and name_list[0] == "layers"
and str.isdigit(name_list[1])
# For FSDP sharding
fsdp_shard_conditions = [
partial(
training.get_shard_conditions,
names_to_match=custom_sharded_layers,
)

]
training.shard_model(
model=model,
shard_conditions=[_is_layer_name],
shard_conditions=fsdp_shard_conditions,
cpu_offload=fsdp_cpu_offload,
reshard_after_forward=reshard_after_forward,
)
Expand Down Expand Up @@ -624,13 +613,15 @@ def _setup_data(
sampler=sampler,
# dropping last avoids shape issues with compile + flex attention
drop_last=True,
collate_fn=partial(
collate_fn,
padding_idx=self._tokenizer.pad_id,
ignore_idx=self._loss_fn.ignore_index,
)
if not packed
else padded_collate_packed,
collate_fn=(
partial(
collate_fn,
padding_idx=self._tokenizer.pad_id,
ignore_idx=self._loss_fn.ignore_index,
)
if not packed
else padded_collate_packed
),
)

if self._is_rank_zero:
Expand Down
58 changes: 23 additions & 35 deletions recipes/qat_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,9 +233,11 @@ def setup(self, cfg: DictConfig) -> None:

self._optimizer = self._setup_optimizer(
cfg_optimizer=cfg.optimizer,
opt_state_dict=checkpoint_dict[training.OPT_KEY]
if self._resume_from_checkpoint
else None,
opt_state_dict=(
checkpoint_dict[training.OPT_KEY]
if self._resume_from_checkpoint
else None
),
)

# initialize loss
Expand Down Expand Up @@ -363,10 +365,10 @@ def _setup_model(
self,
cfg_model: DictConfig,
enable_activation_checkpointing: bool,
custom_sharded_layers: Optional[List[str]],
fsdp_cpu_offload: bool,
reshard_after_forward: bool,
model_state_dict: Dict[str, Any],
custom_sharded_layers: Optional[List[str]] = None,
ac_mode: Optional[str] = None,
ac_option: Optional[int] = None,
quantizer_cfg: Optional[DictConfig] = None,
Expand Down Expand Up @@ -420,29 +422,13 @@ def _setup_model(
self._quantizer_mode = quantizer_mode
model = quantizer.prepare(model)

# For FSDP sharding, we can condition on either the module or its name
# Shard conditions should be callables taking name (relative to model root)
# and the module itself and returning a bool on whether to shard the given module
fsdp_shard_conditions = []

# Shard transformer decoder layers (or AC-wrapped versions)
# Alternatively we could condition on the module type (TransformerDecoder or CheckpointWrapper)
# But directly using the name is more concise
def _is_layer_fqn(s: str) -> bool:
"""
Return True for layers.i and False for all other module names
Covers sharding for both AC-wrapped and non-AC-wrapped modules in one shot
"""
s_list = s.split(".")
return len(s_list) == 2 and s_list[0] == "layers" and str.isdigit(s_list[1])

fsdp_shard_conditions = [lambda n, m: _is_layer_fqn(n)]

# If wrapping any layers separately, we can add another shard condition
# A layer will be sharded if any of the fsdp_shard_conditions are met
if custom_sharded_layers:
fsdp_shard_conditions += [lambda n, m: n in custom_sharded_layers]

# For FSDP sharding
fsdp_shard_conditions = [
partial(
training.get_shard_conditions,
names_to_match=custom_sharded_layers,
)
]
training.shard_model(
model=model,
shard_conditions=fsdp_shard_conditions,
Expand Down Expand Up @@ -525,14 +511,16 @@ def _setup_data(
sampler=sampler,
# dropping last avoids shape issues with compile + flex attention
drop_last=True,
collate_fn=partial(
padded_collate_sft,
padding_idx=self._tokenizer.pad_id,
ignore_idx=self._loss_fn.ignore_index,
)
if not packed
else partial(
padded_collate_packed,
collate_fn=(
partial(
padded_collate_sft,
padding_idx=self._tokenizer.pad_id,
ignore_idx=self._loss_fn.ignore_index,
)
if not packed
else partial(
padded_collate_packed,
)
),
)

Expand Down
2 changes: 2 additions & 0 deletions torchtune/training/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
get_full_finetune_fsdp_wrap_policy,
get_full_model_state_dict,
get_full_optimizer_state_dict,
get_shard_conditions,
get_world_size_and_rank,
init_distributed,
is_distributed,
Expand Down Expand Up @@ -106,6 +107,7 @@
"get_world_size_and_rank",
"set_torch_num_threads",
"shard_model",
"get_shard_conditions",
"prepare_model_for_fsdp_with_meta_device",
"validate_no_params_on_meta_device",
"contains_fsdp",
Expand Down
58 changes: 58 additions & 0 deletions torchtune/training/_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,6 +583,55 @@ def llama3_wrap(module: nn.Module, recurse: bool, **kwargs):
return llama3_wrap


def get_shard_conditions(
name: str,
module: nn.Module,
names_to_match: Optional[List[str]] = None,
*args,
**kwargs,
) -> bool:
"""
Returs True for layers named {}.layers.i or layers that exactly match names_to_match, otherwise,
returns False. This is a helper function for sharding a model with FSDP.
In :func:`~torchtune.training.shard_model`, we iterate over the model's named modules
and apply fully_shard using this condition.
As part of our sharding strategy, we want each layer to be sharded separately, as this is
generally efficient. We may also want to shard certain modules that are not layers, such as
the embedding module.
#TODO: a more robust way would be to shard on the module type, not the name.
Args:
name (str): Name of the module.
module (nn.Module): Module to be sharded.
names_to_match (Optional[List[str]]): List of names to match, if any.
*args: Variable length argument list to be passed to the Embedding module.
**kwargs: Arbitrary keyword arguments to be passed to the Embedding module.
Returns:
bool: True if the module name matches the condition, False otherwise.
Examples:
>>> names_to_match = ["embedding"]
>>> layer_names = ["layers.0", "decoder.layers.1", "encoder.layers.2.attention",
"my_wrapper.layer.1.something", "embedding"]
>>> matches = []
>>> for name in layer_names:
>>> if shard_condition_is_layer_or_match(name, None): matches.append(name)
>>> print(matches)
>>> ["layers.0", "decoder.layers.1", "embedding"]
"""
if names_to_match and name in names_to_match:
return True

name_list = name.split(".")
if len(name_list) >= 2:
return name_list[-2] == "layers" and str.isdigit(name_list[-1])

return False


def shard_model(
model: TransformerDecoder,
shard_conditions: List[Callable[[str, nn.Module], bool]],
Expand All @@ -608,16 +657,25 @@ def shard_model(
the forward pass. Setting this to True corresponds to the FULL_SHARD sharding strategy
from FSDP1, while setting it to False corresponds to the SHARD_GRAD_OP sharding strategy.
Raises:
ValueError: If no layer modules were sharded, indicating that no shard_condition was triggered.
"""
fsdp_kwargs = {"reshard_after_forward": reshard_after_forward}
if cpu_offload:
fsdp_kwargs["offload_policy"] = CPUOffloadPolicy()

# Shard the model with FSDP, iterating in reverse to start with
# lowest-level modules first
num_layers_sharded = 0
for n, m in reversed(list(model.named_modules())):
if any([shard_condition(n, m) for shard_condition in shard_conditions]):
fully_shard(m, **fsdp_kwargs)
num_layers_sharded += 1

if num_layers_sharded == 0:
raise ValueError(
"No layer modules were sharded. Please check if shard conditions are working as expected."
)

# Finally shard the entire model to account for any stragglers
fully_shard(model, **fsdp_kwargs)

0 comments on commit bc486d4

Please sign in to comment.