Skip to content

Commit

Permalink
feat(moe): impl moe with megablock kernel (#76)
Browse files Browse the repository at this point in the history
  • Loading branch information
blankde authored Mar 15, 2024
1 parent e9fcf55 commit def75dd
Show file tree
Hide file tree
Showing 11 changed files with 1,022 additions and 9 deletions.
11 changes: 10 additions & 1 deletion configs/7B_MoE4_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@
num_chunks=1, # if num_chunks > 1, interleaved pipeline scheduler is used.
num_experts=4,
moe_use_residual=False,
moe_type="GShard",
moe_type="GShard", # Support: "GShard", "MegaBlock", "MegaBlock-D"
)
"""
zero1 parallel (dict):
Expand Down Expand Up @@ -200,6 +200,7 @@
)

# custom moe impl configs
# GShard MoE config
moe = dict(
top_k=2,
capacity_factor=1.0,
Expand All @@ -210,6 +211,14 @@
use_rts=True,
)

# MegaBlock MoE config
# moe = dict(
# top_k=2,
# capacity_factor=1.0, # only used in MegaBlock(non-dmoe)
# drop_tokens=True, # only used in MegaBlock(non-dmoe)
# #parallel_mode="tensor", # only used in MegaBlock-D(dmoe), parallel_mode can be tensor or weight
# )

model_type = "INTERNLM_MoE"

# metric_dtype can be "fp32" or other string
Expand Down
8 changes: 8 additions & 0 deletions internlm/initialize/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from internlm.core.context import Config
from internlm.core.context import global_context as gpc
from internlm.core.context.process_group_initializer import ParallelMode
from internlm.moe.megablock.utils import check_megablock_installed, check_stk_installed
from internlm.monitor import initialize_light_monitor
from internlm.utils.common import get_master_node
from internlm.utils.gputest import warmup_process_group
Expand Down Expand Up @@ -314,6 +315,13 @@ def args_sanity_check():
model._add_item("moe_use_residual", False)
if "moe_type" not in model:
model._add_item("moe_type", "GShard")
# check dependency
if gpc.config.model.moe_type == "MegaBlock":
check_megablock_installed()
if gpc.config.model.moe_type == "MegaBlock-D":
check_megablock_installed()
check_stk_installed()

# process the parallel config
if "sequence_parallel" not in gpc.config.parallel:
gpc.config.parallel._add_item("sequence_parallel", False)
Expand Down
23 changes: 23 additions & 0 deletions internlm/moe/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,26 @@
from .gshard_moe import GShardMOELayer

__all__ = ["GShardMOELayer"]

try:
from megablocks import ops # noqa # pylint: disable=W0611
except ModuleNotFoundError:
pass
else:
from internlm.moe.megablock.megablock_moe import ( # noqa # pylint: disable=W0611
MegaBlockMoE,
)

__all__ += "MegaBlockMoE"

try:
import stk # noqa # pylint: disable=W0611
from megablocks import ops # noqa # pylint: disable=W0611
except ModuleNotFoundError:
pass
else:
from internlm.moe.megablock.megablock_dmoe import ( # noqa # pylint: disable=W0611
MegaBlockdMoE,
)

__all__ += "MegaBlockdMoE"
8 changes: 4 additions & 4 deletions internlm/moe/gshard_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,9 +385,9 @@ class GShardMOELayer(BaseMoELayer):

def __init__(
self,
hidden_size,
hidden_size: int,
num_experts: int,
ep_group,
ep_group: Optional[torch.distributed.ProcessGroup],
ep_size: int,
top_k: int = 1,
capacity_factor: float = 1.0,
Expand All @@ -396,8 +396,8 @@ def __init__(
noisy_gate_policy: str = None,
drop_tokens: bool = True,
use_rts: bool = True,
device=None,
dtype=None,
device: Optional[torch.device] = None,
dtype: Optional[torch.device] = None,
) -> None:
assert noisy_gate_policy is None or noisy_gate_policy in ["None", "Jitter", "RSample"], (
"Unsupported noisy_gate_policy: " + noisy_gate_policy
Expand Down
Empty file.
214 changes: 214 additions & 0 deletions internlm/moe/megablock/megablock_dmoe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
from typing import Optional, Tuple

import numpy as np
import stk
import torch
from megablocks import ops

from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.moe.base_moe import BaseMoELayer
from internlm.moe.megablock.megablock_moe import MegaBlockMoE
from internlm.moe.megablock.mlp import MegaBlockGroupedFeedForward
from internlm.moe.megablock.utils import promote_scalar
from internlm.utils.registry import MODEL_INITIALIZER


@MODEL_INITIALIZER.register_module(module_name="MegaBlock-D")
class MegaBlockdMoE(MegaBlockMoE):
"""
Built on the paper and library Megablocks as described in
https://arxiv.org/abs/2211.15841. This implementation is
strictly equivalent to standard MoE with full capacity (no
dropped tokens). It's faster since it formulates MoE operations
in terms of block-sparse operations to accomodate imbalanced
assignments of tokens to experts, whereas standard MoE either
(1) drop tokens at the cost of reduced performance or (2) set
capacity factor to number of experts and thus waste computation
and memory on padding.
"""

def __init__( # pylint: disable=W0231
self,
hidden_size: int,
ep_group: Optional[torch.distributed.ProcessGroup],
ep_size: int,
num_experts: int,
top_k: int = 1,
parallel_mode: str = "tensor",
device: Optional[torch.device] = None,
dtype: Optional[torch.device] = None,
multiple_of: int = 256,
) -> None:
assert gpc.expert_parallel_size == 1, "do not support expert parallel"
self.top_k = top_k
self.num_experts = num_experts

tp_size = gpc.get_world_size(ParallelMode.TENSOR)
self.ffn_dim = multiple_of * ((int(hidden_size * gpc.config.model.mlp_ratio) + multiple_of - 1) // multiple_of)
assert self.ffn_dim % tp_size == 0
if parallel_mode == "tensor":
self.ffn_dim_per_row = self.ffn_dim // tp_size // ep_size
else:
self.ffn_dim_per_row = self.ffn_dim // ep_size
BaseMoELayer.__init__( # pylint: disable=W0233
self,
torch.nn.Linear(hidden_size, num_experts, bias=False),
MegaBlockGroupedFeedForward(
hidden_size,
(self.ffn_dim // tp_size) * (num_experts // ep_size),
parallel_mode,
device,
dtype,
),
ep_group,
ep_size,
1,
)

# Calculate the number of bits needed to represent the expert indices
# so that we can pass it to radix sort.
self.sort_end_bit = max(int(np.ceil(np.log2(self.num_experts))), 1)
self.blocking = 128
self.quantize_scatter_num_bits = -1

# Calculate the number of bits needed to represent the column indices
# in the intermediate sparse matrix.
max_column_index = (self.ffn_dim * (self.num_experts // ep_size)) // self.blocking
self.transpose_sort_end_bit = max(int(np.ceil(np.log2(max_column_index))), 1)

# re-init the number of experts in each device
self.num_local_experts = num_experts // ep_size

self.forward_fn = self._forward

def sparse_transpose(
self, size: int, row_indices, column_indices
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
block_columns = size[1] // self.blocking

# Sort row indices by column indices to get the transposed matrix's
# column indices.
#
# NOTE: Our sort operation uses the same width indices as the input
# values. To avoid overflow when we have large activation matrices
# we cast to 32-bit before sorting.
_, gather_indices = ops.sort(column_indices.int(), self.transpose_sort_end_bit)

# There are a constant number of blocks in every row of the sparse
# matrix. A blocks offset is:
#
# row_index * blocks_per_row + column_index % blocks_per_row
#
# Once we have the block offsets ordered for transposition we can
# divide by blocks_per_row to get the transposed column indices.
column_indices_t = row_indices.gather(0, gather_indices.long())
block_offsets_t = gather_indices.int()

zero = torch.zeros((1,), dtype=torch.int32, device=row_indices.device)
nnz_per_column = ops.histogram(column_indices, block_columns)
nnz_per_column = ops.inclusive_cumsum(nnz_per_column, 0)
offsets_t = torch.cat([zero, nnz_per_column])
return column_indices_t, offsets_t, block_offsets_t

def topology(self, x: torch.Tensor, padded_bins: torch.Tensor) -> stk.Matrix:
padded_tokens, _ = x.size()
assert padded_tokens % self.blocking == 0
assert self.ffn_dim_per_row % self.blocking == 0

# Offsets for the sparse matrix. All rows have the
# same number of nonzero blocks dictated by the
# dimensionality of a single expert.
block_rows = padded_tokens // self.blocking
blocks_per_row = self.ffn_dim_per_row // self.blocking
offsets = torch.arange(
0,
block_rows * blocks_per_row + 1,
blocks_per_row,
dtype=torch.int32,
device=x.device,
)

# Indices for the sparse matrix. The indices for
# the intermediate matrix are dynamic depending
# on the mapping of tokens to experts.
column_indices = ops.topology(padded_bins, self.blocking, block_rows, blocks_per_row)

# TODO(tgale): This is unused. Remove the need for this in stk.
# For now, use meta init to save the device memory.
data = torch.empty(
column_indices.numel(),
self.blocking,
self.blocking,
dtype=x.dtype,
device="meta",
)
shape = (padded_tokens, self.ffn_dim_per_row * self.num_experts)
row_indices = stk.ops.row_indices(shape, data, offsets, column_indices)
column_indices_t, offsets_t, block_offsets_t = self.sparse_transpose(shape, row_indices, column_indices)
return stk.Matrix(
shape,
data,
row_indices,
column_indices,
offsets,
column_indices_t,
offsets_t,
block_offsets_t,
)

def indices_and_padded_bins(
self, selected_experts: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
# Sort the expert ids to produce the scatter/gather
# indices for the permutation.
selected_experts = selected_experts.int()
bin_ids, indices = ops.sort(selected_experts, self.sort_end_bit)

# Histogram the expert ids to identify the number of
# tokens routed to each expert.
tokens_per_expert = ops.histogram(selected_experts, self.num_experts)

# Round the token counts up to the block size used in
# the matrix muliplications. Caculate the starting
# position of each bin.
padded_tokens_per_expert = ops.round_up(tokens_per_expert, self.blocking)
padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
padded_bins = promote_scalar(padded_bins)

# Calculate the bin bounds for the sorted tokens.
bins = ops.inclusive_cumsum(tokens_per_expert, 0)
bins = promote_scalar(bins)
return indices, bin_ids, bins, padded_bins, tokens_per_expert

def _forward(self, x, expert_weights, top_experts) -> torch.Tensor:
with torch.no_grad():
(indices, bin_ids, bins, padded_bins, tokens_per_expert) = self.indices_and_padded_bins(top_experts)

# Permute tokens and pad to prepare expert computation
# (top_k * sequence_length + padding, model_dim)
x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, self.top_k)

# Create the sparse matrix topology
with torch.no_grad():
topo = self.topology(x, padded_bins)

# Perform the expert computation
# First Dense x Dense -> Sparse for w1 and w3,
# (top_k * sequence_length + padding, ffn_dim * n_experts)
x = self.experts(x, topo=topo)

# Permute back and remove padding
# (top_k * sequence_length, model_dim)
x = ops.padded_scatter(
x,
indices,
bin_ids,
expert_weights,
bins,
padded_bins,
self.top_k,
self.quantize_scatter_num_bits,
)

return x, tokens_per_expert.flatten()
Loading

0 comments on commit def75dd

Please sign in to comment.