Skip to content

Commit

Permalink
[MoE][PoC] Expert Parallel: dp2ep
Browse files Browse the repository at this point in the history
ghstack-source-id: 17160930f23950b91faca7b822cd3e7f9d075f7d
Pull Request resolved: #732
  • Loading branch information
tianyu-l committed Dec 12, 2024
1 parent 83d1714 commit 25cfe6d
Show file tree
Hide file tree
Showing 7 changed files with 306 additions and 22 deletions.
18 changes: 13 additions & 5 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,15 +364,23 @@ def __init__(self):
default=1,
help="Context parallelism degree. 1 means disabled.",
)
self.parser.add_argument(
"--experimental.expert_parallel_degree",
type=int,
default=1,
help="""
Expert parallelism degree. 1 means disabled.
When expert_parallel_mode is 'tp' or 'tp2ep', it has to be equal to tensor_parallel_degree.
When expert_parallel_mode is 'dp2ep', it has to be k * context_parallel_degree,
where k >= 1 and k | data_parallel_shard_degree.
""",
)
self.parser.add_argument(
"--experimental.expert_parallel_mode",
type=str,
default="none",
choices=["none", "tp", "tp2ep"],
help="""
Expert Parallel mode.
'tp2ep' would use the entire TP mesh to shard non-shared experts on the num_experts dimension.
""",
choices=["none", "tp", "tp2ep", "dp2ep"],
help="Expert Parallel mode",
)
self.parser.add_argument(
"--training.mixed_precision_param",
Expand Down
2 changes: 1 addition & 1 deletion torchtitan/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def build_optimizers(model_parts, job_config: JobConfig):
"betas": (0.9, 0.95),
"weight_decay": 0.1,
"fused": fused,
"foreach": not fused,
"foreach": False,
}

return (
Expand Down
119 changes: 119 additions & 0 deletions torchtitan/parallelisms/expert_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,3 +325,122 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
self._prepare_output_fn, self.output_layouts, self.use_local_output
),
)


# This class is for dp2ep with TP (without TP we can just use ExpertParallel)
class ExpertTensorParallel(ParallelStyle):
def __init__(
self,
*,
tp_mesh: DeviceMesh,
ep_mesh: DeviceMesh,
):
super().__init__()
# TODO: has to pass in the meshes in addition to device_mesh,
# as there's an issue from DeviceMesh that
# "Cannot create a submesh from a submesh."
self.tp_mesh = tp_mesh
self.ep_mesh = ep_mesh

@staticmethod
def _prepare_input_fn(tp_mesh, ep_mesh, mod, inputs, device_mesh):
input_tensor = inputs[0]
# input_tensor of placements Shard(1) on the tp mesh
assert not isinstance(input_tensor, DTensor)

# a2a(ep)
input_tensor = DTensor.from_local(input_tensor, ep_mesh, (Shard(1),))
input_tensor = input_tensor.redistribute(placements=(Shard(0),)).to_local()
# ag(tp)
input_tensor = DTensor.from_local(input_tensor, tp_mesh, (Shard(1),))
input_tensor = input_tensor.redistribute(placements=(Replicate(),))

return input_tensor

@staticmethod
def _partition_fn(tp_mesh, ep_mesh, name, module, device_mesh):
# TODO: FSDP doesn't support sharding a 2D Tensor into a 3D one yet
# module.register_parameter(
# "gate_proj",
# nn.Parameter(
# distribute_tensor(module.gate_proj, device_mesh, [Shard(0), Shard(2)])
# ),
# ) # Column-wise sharding
# module.register_parameter(
# "down_proj",
# nn.Parameter(
# distribute_tensor(module.down_proj, device_mesh, [Shard(0), Shard(1)])
# ),
# ) # Row-wise sharding
# module.register_parameter(
# "up_proj",
# nn.Parameter(
# distribute_tensor(module.up_proj, device_mesh, [Shard(0), Shard(2)])
# ),
# ) # Column-wise sharding

# TODO: Instead, for MoE experts, we shard on the EP mesh and then "forget" it.
# This would become an issue from DCP resharding perspective.
module.register_parameter(
"gate_proj",
nn.Parameter(
DTensor.from_local(
(
distribute_tensor(
module.gate_proj, device_mesh, [Shard(0), Shard(2)]
).to_local()
),
tp_mesh,
(Shard(2),),
)
),
) # Column-wise sharding
module.register_parameter(
"down_proj",
nn.Parameter(
DTensor.from_local(
(
distribute_tensor(
module.down_proj, device_mesh, [Shard(0), Shard(1)]
).to_local()
),
tp_mesh,
(Shard(1),),
)
),
) # Row-wise sharding
module.register_parameter(
"up_proj",
nn.Parameter(
DTensor.from_local(
(
distribute_tensor(
module.up_proj, device_mesh, [Shard(0), Shard(2)]
).to_local()
),
tp_mesh,
(Shard(2),),
)
),
) # Column-wise sharding

@staticmethod
def _prepare_output_fn(tp_mesh, ep_mesh, mod, outputs, device_mesh):
# outputs of placements Partial() on the tp mesh

# rs(tp)
outputs = outputs.redistribute(placements=(Shard(1),)).to_local()
# a2a(ep)
outputs = DTensor.from_local(outputs, ep_mesh, (Shard(0),))
outputs = outputs.redistribute(placements=(Shard(1),)).to_local()

return outputs

def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
return distribute_module(
module,
device_mesh,
partial(self._partition_fn, self.tp_mesh, self.ep_mesh),
partial(self._prepare_input_fn, self.tp_mesh, self.ep_mesh),
partial(self._prepare_output_fn, self.tp_mesh, self.ep_mesh),
)
74 changes: 72 additions & 2 deletions torchtitan/parallelisms/parallel_dims.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,24 @@ class ParallelDims:
cp: int
tp: int
pp: int
ep: int
ep_mode: str
world_size: int
enable_loss_parallel: bool

def __post_init__(self):
self._validate()

def _validate(self):
dp_replicate, dp_shard, cp, tp, pp = (
dp_replicate, dp_shard, cp, tp, pp, ep = (
self.dp_replicate,
self.dp_shard,
self.cp,
self.tp,
self.pp,
self.ep,
)
for d in (dp_replicate, cp, tp, pp):
for d in (dp_replicate, cp, tp, pp, ep):
assert d >= 1, "Parallelism degree should be >= 1, except for dp_shard"

assert dp_shard == -1 or dp_shard >= 1, " dp_shard must -1 or >=1."
Expand All @@ -45,7 +48,74 @@ def _validate(self):
f"cp({cp}) * tp({tp}) * pp({pp}) != WORLD_SIZE({self.world_size})"
)

if ep > 1:
assert self.ep_mode in ["tp", "tp2ep", "dp2ep"]
if self.ep_mode == "tp" or self.ep_mode == "tp2ep":
assert ep == tp
elif self.ep_mode == "dp2ep":
# EP would borrow all cp and some dp_shard degree
assert ep % cp == 0 and (dp_shard * cp) % ep == 0
else:
self.ep_mode = "none"

def _build_mesh_with_dp2ep(self, device_type):
# In dp2ep, dp_shard and ep are derived submeshes:
# dp_shard = dp_shard_1 * dp_shard_2
# ep = dp_shard_2 * cp
dp_shard_1 = self.dp_shard * self.cp // self.ep
dp_shard_2 = self.ep // self.cp

dims = []
names = []
for d, name in zip(
[self.pp, self.dp_replicate, dp_shard_1, dp_shard_2, self.cp, self.tp],
["pp", "dp_replicate", "dp_shard_1", "dp_shard_2", "cp", "tp"],
):
# dp_shard_1 is needed even if it's 1, whose FSDP wrapping
# helps the MoE layers do mixed precision training
if d > 1 or name == "dp_shard_1":
dims.append(d)
names.append(name)

logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}")
names = tuple(names)
mesh = init_device_mesh(device_type, dims, mesh_dim_names=names)

# Create all the submesh here to ensure all required process groups are
# initialized:
# Mesh for data loading
dp_mesh_dim_names = []
if self.dp_replicate_enabled:
dp_mesh_dim_names.append("dp_replicate")
dp_mesh_dim_names.append("dp_shard_1")
if "dp_shard_2" in names:
dp_mesh_dim_names.append("dp_shard_2")
mesh[tuple(dp_mesh_dim_names)]._flatten(mesh_dim_name="dp")

# Mesh for param sharding
dp_shard_cp_mesh_dim_name = []
dp_shard_cp_mesh_dim_name.append("dp_shard_1")
if "dp_shard_2" in names:
dp_shard_cp_mesh_dim_name.append("dp_shard_2")
if self.cp_enabled:
dp_shard_cp_mesh_dim_name.append("cp")
mesh[tuple(dp_shard_cp_mesh_dim_name)]._flatten(mesh_dim_name="dp_shard_cp")

# Mesh for ep
ep_mesh_dim_names = []
if "dp_shard_2" in names:
ep_mesh_dim_names.append("dp_shard_2")
if self.cp_enabled:
ep_mesh_dim_names.append("cp")
assert len(ep_mesh_dim_names) > 0
mesh[tuple(ep_mesh_dim_names)]._flatten(mesh_dim_name="ep")

return mesh

def build_mesh(self, device_type):
if self.ep_mode == "dp2ep":
return self._build_mesh_with_dp2ep(device_type)

dims = []
names = []
for d, name in zip(
Expand Down
Loading

0 comments on commit 25cfe6d

Please sign in to comment.