From ecd8e5fcc62147b38d80a7d0a5bc98b50266c13a Mon Sep 17 00:00:00 2001 From: JessicaZhong Date: Thu, 24 Oct 2024 08:52:27 -0700 Subject: [PATCH 01/13] add moe design --- RFC.md | 207 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 207 insertions(+) create mode 100644 RFC.md diff --git a/RFC.md b/RFC.md new file mode 100644 index 0000000000..26b20ecc27 --- /dev/null +++ b/RFC.md @@ -0,0 +1,207 @@ +# [RFC] MOE design in Torchtune + +## Background +This RFC proposes adding the MOE support in Torchtune. We want to design in a general way so that components can be easily swapped when implementing different MOE models. An MOE layer directly replaces the dense FFN layer in the transformer decoder layer and has two main components: router and experts. + +## Expert +An expert is essentially an FFN layer similar to the original dense FFN layer in the transformer decoder layer. There are two kinds of experts: routed experts and shared experts. Each expert in the routed experts specializes in learning certain patterns/aspects, and only part of the routed experts will be activated. On the other hand, shared experts are always activated, aiming at capturing and consolidating common knowledge across varying contexts. + +**Here's the proposed Experts design in torchtune:** +```python +class Experts(nn.Module): + def __init__(self, dim_in, dim_out, nonlinearity, num_experts=1, swiglu=True): + self.gate_proj = nn.Parameter(torch.empty(num_experts, dim_in, dim_out)) + self.down_proj = nn.Parameter(torch.empty(num_experts, dim_out, dim_in)) + if swiglu: + self.up_proj = nn.Parameter(torch.empty(num_experts, dim_in, dim_out)) + self.act_fn = F.silu() + else: + self.up_proj = None + self.act_fn = nonlinearity + + def forward(self, x): + x = x.view(num_experts, -1, dim_in) + h = self.act_fn(torch.bmm(x, self.gate_proj)) + if self.up_proj is not None: + h = h * torch.bmm(x, self.up_proj) + h = torch.bmm(x, self.down_proj).view(-1, dim_in) + return h + +# Routed Experts(num_experts) +def moe_experts(hidden_dim, model_dim, num_experts, swiglu, nonlinearity) -> FeedForward: + return Experts(dim_in=hidden_dim, dim_out=model_dim, nonlinearity=nonlinearity, num_experts=num_experts, swiglu=swiglu) + +# Shared expert(single) +def moe_expert(hidden_dim, model_dim, swiglu, nonlinearity) -> FeedForward: + return Experts(dim_in=hidden_dim, dim_out=model_dim, nonlinearity=nonlinearity, num_experts=1, swiglu=swiglu) + +# For example, the Mixtral expert could be implemented like this +def mixtral_expert(hidden_dim, model_dim, nonlinearity) -> FeedForward: + return Experts(dim_in=hidden_dim, dim_out=model_dim, nonlinearity=nonlinearity, num_experts=1, swiglu=True) + +``` + +## Router and Moe Layer +Router is a gating network that calculates router scores and learns token-to-expert affinity, and an MOE layer consists of experts and routers. There are two types of routing: token choice routing and expert choice routing. + +Mixtral uses *token choice* topK routing, which means each token will select its topK experts. The router is implemented through a learnable gate function, whose outputs will go through softmax and topK. The TokenChoiceMoeLayer class then defines how tokens select experts based on router scores. + +**Here's the proposed Token Choice Routing and TokenChoiceMoeLayer design in torchtune:** +```python +class TokenChoiceTopKRouter(nn.Module): + def __init__(self, hidden_dim, num_experts, experts_per_token): + self.gate = nn.Linear(hidden_dim, num_experts) + self.experts_per_token = experts_per_token + + def forward(self, x): + ''' + input: + x shape [bs*slen, hidden_dim] + outputs: + top_scores shape [bs*slen, experts_per_token] + top_indices shape [bs*slen, experts_per_token] + ''' + # scores shape [bs*slen, num_experts] + scores = self.gate(x) + scores = F.softmax(scores, dim=1) + top_scores, top_indices = torch.topk(scores, k=self.experts_per_token, dim=1) + top_scores /= top_scores.sum(dim=-1, keep_dim=True).to(x.dtype) + return top_scores, top_indices + +# For example, Mixtral uses TokenChoiceMoeLayer +class TokenChoiceMoeLayer(nn.Module): + def __init__(self): + self.experts = nn.ModuleList(moe_expert() for _ in range(num_experts)) + self.router = TokenChoiceTopKRouter(hidden_dim, num_experts, experts_per_token) + + def forward(self, x): + # x shape [bs*slen, hidden_dim] + # router scores/indices shape [bs*slen, experts_per_token] + top_scores, selected_experts_indices = self.router(x) + + # expert_mask shape [num_experts, experts_per_token, bs*slen] + expert_mask = torch.nn.functional.one_hot(selected_experts_indices, num_class=num_experts).permute(2,1,0) + out = torch.zeros((batch_size * seq_len, hidden_dim)) + for i in range(num_experts): + expert = self.experts[i] + expert_idx, token_idx = torch.where(expert_mask[i]) + # compute hidden state for the each selected expert and multiply by the routing weights + hidden_states = expert(x[token_idx]) * top_scores[token_idx, expert_idx] + out.index_add_(0, token_idx, hidden_states) + return out + ``` + +However, token choice routing has several pitfalls according to the expert choice [paper](https://arxiv.org/pdf/2002.05202). +1. Poor load balance. Experts can become under or over-specialized. Load imbalance can hurt step latency / inference time. +2. Experts under specialization. Ideally the gating network will learn token-to-expert affinity such that similar or relevant tokens are routed to the same expert. However, a sub-optimal strategy can produce redundant experts and/or experts that are not sufficiently specialized. +3. Same compute for each token. Token choice will allocate a fixed number of experts to each token regardless of the importance of different tokens. Ideally an MOE model should flexibly allocate compute resources based on the complexity of the input. + +Compared to **token choice**, **expert choice** topK routing lets experts select its top-k tokens. The ExpertChoiceMoeLayer class routes input tokens to different experts based on the routing algorithm, processes them through the experts and the shared expert, and then combines the output. + +**Here's the proposed Expert Choice Routing and ExpertChoiceMoeLayer design in torchtune:** +```python +class ExpertChoiceTopKRouter(nn.Module): + def __init__(self, hidden_dim, num_experts): + self.gate = nn.Linear(hidden_dim, num_experts) + self.tokens_per_expert = tokens_per_expert + + def forward(self, x): + ''' + input: + x shape [bs*slen, hidden_dim] + outputs: + top_scores shape [num_experts, tokens_per_expert] + top_indices shape [num_experts, tokens_per_expert] + ''' + # scores shape [num_experts, bs*slen] + scores = self.gate(x).transpose(0,1) + scores = F.softmax(scores.to(softmax_dtype), dim=0).to(scores.dtype) + # [num_experts, tokens_per_expert] + top_scores, top_indices = torch.topk(scores, k=self.tokens_per_expert, dim=1) + return top_scores, top_indices + + +class ExpertChoiceMoeLayer(nn.Module): + def __init__(self): + self.experts = moe_experts(hidden_dim, model_dim, num_experts) + self.shared_expert = moe_shared_expert(hidden_dim, model_dim) + self.router = ExpertChoiceTopKRouter(hidden_dim, num_experts, tokens_per_expert) + + def forward(self, x): + # x shape [bs*slen, hidden_dim] + # router scores/indices shape [num_experts, tokens_per_expert] + top_scores, top_indices = self.router(x) + # apply the token preprocess function and then run experts forward + top_indices_expanded = top_indices.reshape(-1, 1).expand(-1, D) + # routed input shape [num_experts*tokens_per_expert, hidden_dim] + routed_input = torch.gather(x, dim=0, index=top_indices_expanded) + routed_input = routed_input * top_scores.reshape(-1, 1) + # routed output shape [num_experts*tokens_per_expert, hidden_dim] + routed_output = self.experts(routed_input) + + # shared expert + if use_shared_expert: + out = self.shared_expert(x) + else: + out = torch.zeros_like(x) + + # add experts output + out.data = scatter_add_( + # [bs*slen, hidden_dim] + out.data, + # [num_experts*tokens_per_expert, hidden_dim] + routed_output, + # [num_experts*tokens_per_expert, hidden_dim] + top_indices_expanded, + ) + return out + ``` + +## Model builder +Besides the above components: experts, routers, and MOE layers, we would need a model builder to pull all pieces together to form the Transformer decoder layer and then Transformer decoder: + +**Here's the proposed MOE model builder design in torchtune:** +```python +def moe(...) -> TransformerDecoder: + # Build the decoder associated with the moe model. This includes + # - Token embeddings + # - num_layers number of TransfomerDecoderLayer block + # - RMS Norm layer applied to the ouput of the transfomer + # - Final projection into the token space' + token_embeddings = nn.Embedding(vocab_size, embed_dim) + self_attn = MultiHeadAttention() + moe_layer = ExpertsChoiceMoeLayer() # or TokenChoiceMoeLayer() + norm = RMSNorm(dim=embed_dim) + layer = TransformerSelfAttentionLayer(attn=self_attn, mlp=moe_layer, sa_norm=norm, mlp_norm=norm) + output_proj = nn.Linear(embed_dim, vocab_size) + return TransformerDecoder( + tok_embeddings=tok_embeddings, + layers=layer, + num_layers=num_layers, + max_seq_len=max_seq_len, + num_heads=num_heads, + head_dim=head_dim, + norm=RMSNorm(dim=embed_dim), + output=output_proj, + ) +``` + +**File changes for new modules/functions** +``` +torchtune/ + modules/ + moe/ + moe_layers.py + TokenChoiceTopKRouter() + ExpertChoiceTopKRouter() + TokenChoiceMoeLayer() + ExpertChoiceMoeLayer() + experts.py + Experts() + models/ + moe/ + _component_builders.py + moe() + moe_experts() + moe_expert() +``` From 9c6cc1cd8aac8f94ce5a7fbcae7bddd28a7052c6 Mon Sep 17 00:00:00 2001 From: JessicaZhong Date: Thu, 24 Oct 2024 17:27:38 -0700 Subject: [PATCH 02/13] draft --- RFC.md | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/RFC.md b/RFC.md index 26b20ecc27..42c8c8666c 100644 --- a/RFC.md +++ b/RFC.md @@ -89,6 +89,34 @@ class TokenChoiceMoeLayer(nn.Module): hidden_states = expert(x[token_idx]) * top_scores[token_idx, expert_idx] out.index_add_(0, token_idx, hidden_states) return out + + def forward(self, x): + # x shape [bs*slen, hidden_dim] + # router scores/indices shape [bs*slen, experts_per_token] + top_scores, selected_experts_indices = self.router(x) + # [bs*slen*experts_per_token, hidden_dim] + selected_experts_indices_expanded = selected_experts_indices.reshape(-1, 1).expand(-1, D) + # [bs*slen*experts_per_token, hidden_dim] + routed_input = torch.gather(x, dim=0, index=selected_experts_indices_expanded) + routed_input = routed_input * top_scores.reshape(-1, 1) + # [bs*slen*experts_per_token, hidden_dim] + routed_output = self.experts(routed_input) + + # shared expert + if use_shared_expert: + out = self.shared_expert(x) + else: + out = torch.zeros_like(x) + + # add experts output + out.data = scatter_add_( + # [bs*slen, hidden_dim] + out.data, + # [bs*slen, hidden_dim] + routed_output.reshape(-1, experts_per_token, hidden_dim).sum(dim=1), + # [bs*slen*experts_per_token, hidden_dim] + selected_experts_indices_expanded, + ) ``` However, token choice routing has several pitfalls according to the expert choice [paper](https://arxiv.org/pdf/2002.05202). From 8da01d849c9f17544500fef414497f7d689425be Mon Sep 17 00:00:00 2001 From: JessicaZhong Date: Thu, 24 Oct 2024 22:50:34 -0700 Subject: [PATCH 03/13] debug --- RFC.md | 43 +++++++++++++++++++++++++++---------------- 1 file changed, 27 insertions(+), 16 deletions(-) diff --git a/RFC.md b/RFC.md index 42c8c8666c..70e3186c56 100644 --- a/RFC.md +++ b/RFC.md @@ -19,12 +19,18 @@ class Experts(nn.Module): self.up_proj = None self.act_fn = nonlinearity - def forward(self, x): - x = x.view(num_experts, -1, dim_in) - h = self.act_fn(torch.bmm(x, self.gate_proj)) - if self.up_proj is not None: - h = h * torch.bmm(x, self.up_proj) - h = torch.bmm(x, self.down_proj).view(-1, dim_in) + def forward(self, x, use_token_choice=False): + # token choice forward, token choose topK experts + if use_token_choice: + # TODO: + # expert choice forward, expert choice topK tokens + else: + # TODO: implement clamp() + x = x.view(num_experts, -1, dim_in) + h = self.act_fn(torch.bmm(x, self.gate_proj)) + if self.up_proj is not None: + h = h * torch.bmm(x, self.up_proj) + h = torch.bmm(h, self.down_proj).view(-1, dim_in) return h # Routed Experts(num_experts) @@ -68,12 +74,15 @@ class TokenChoiceTopKRouter(nn.Module): top_scores /= top_scores.sum(dim=-1, keep_dim=True).to(x.dtype) return top_scores, top_indices + # For example, Mixtral uses TokenChoiceMoeLayer class TokenChoiceMoeLayer(nn.Module): def __init__(self): - self.experts = nn.ModuleList(moe_expert() for _ in range(num_experts)) + # self.experts = nn.ModuleList(moe_expert() for _ in range(num_experts)) + self.experts = moe_experts(hidden_dim, model_dim, num_experts) self.router = TokenChoiceTopKRouter(hidden_dim, num_experts, experts_per_token) + # Mixtral token choice forward implementation assuming we have a list of moe_expert() def forward(self, x): # x shape [bs*slen, hidden_dim] # router scores/indices shape [bs*slen, experts_per_token] @@ -90,17 +99,19 @@ class TokenChoiceMoeLayer(nn.Module): out.index_add_(0, token_idx, hidden_states) return out + # unifying TC and EC MoeLayer forward function def forward(self, x): # x shape [bs*slen, hidden_dim] # router scores/indices shape [bs*slen, experts_per_token] top_scores, selected_experts_indices = self.router(x) # [bs*slen*experts_per_token, hidden_dim] selected_experts_indices_expanded = selected_experts_indices.reshape(-1, 1).expand(-1, D) - # [bs*slen*experts_per_token, hidden_dim] + + # [bs*slen*experts_per_token, hidden_dim], why do we gather x based on selected expert indices? routed_input = torch.gather(x, dim=0, index=selected_experts_indices_expanded) routed_input = routed_input * top_scores.reshape(-1, 1) - # [bs*slen*experts_per_token, hidden_dim] - routed_output = self.experts(routed_input) + # [bs*slen*experts_per_token, hidden_dim], only experts_per_token will be activated for each token + routed_output = self.experts(routed_input, use_token_choice=True) # shared expert if use_shared_expert: @@ -112,8 +123,8 @@ class TokenChoiceMoeLayer(nn.Module): out.data = scatter_add_( # [bs*slen, hidden_dim] out.data, - # [bs*slen, hidden_dim] - routed_output.reshape(-1, experts_per_token, hidden_dim).sum(dim=1), + # [bs*slen*experts_per_token, hidden_dim] + routed_output, # [bs*slen*experts_per_token, hidden_dim] selected_experts_indices_expanded, ) @@ -158,11 +169,11 @@ class ExpertChoiceMoeLayer(nn.Module): def forward(self, x): # x shape [bs*slen, hidden_dim] # router scores/indices shape [num_experts, tokens_per_expert] - top_scores, top_indices = self.router(x) + top_scores, selected_token_indices = self.router(x) # apply the token preprocess function and then run experts forward - top_indices_expanded = top_indices.reshape(-1, 1).expand(-1, D) + selected_token_indices_expanded = selected_token_indices.reshape(-1, 1).expand(-1, D) # routed input shape [num_experts*tokens_per_expert, hidden_dim] - routed_input = torch.gather(x, dim=0, index=top_indices_expanded) + routed_input = torch.gather(x, dim=0, index=selected_token_indices_expanded) routed_input = routed_input * top_scores.reshape(-1, 1) # routed output shape [num_experts*tokens_per_expert, hidden_dim] routed_output = self.experts(routed_input) @@ -180,7 +191,7 @@ class ExpertChoiceMoeLayer(nn.Module): # [num_experts*tokens_per_expert, hidden_dim] routed_output, # [num_experts*tokens_per_expert, hidden_dim] - top_indices_expanded, + selected_token_indices_expanded, ) return out ``` From 6422e12fdddf34bd595cd3e9818e84717c0563a5 Mon Sep 17 00:00:00 2001 From: JessicaZhong Date: Fri, 25 Oct 2024 16:51:28 -0700 Subject: [PATCH 04/13] finish first draft --- RFC.md | 108 +++++++++++++++++++++++++++++++++++++++++++++------------ 1 file changed, 85 insertions(+), 23 deletions(-) diff --git a/RFC.md b/RFC.md index 70e3186c56..3590b63f30 100644 --- a/RFC.md +++ b/RFC.md @@ -19,13 +19,33 @@ class Experts(nn.Module): self.up_proj = None self.act_fn = nonlinearity - def forward(self, x, use_token_choice=False): - # token choice forward, token choose topK experts + def forward(self, x, use_token_choice=False, num_local_tokens_per_expert=None): + # token choice forward, token choose topK experts, TODO: use cutlass groupGEMM instead of torch.matmul() if use_token_choice: - # TODO: + assert num_local_tokens_per_expert is not None, "num_local_tokens_per_expert is needed for token choice expert forward" + # x shape [bs*slen*experts_per_expert, hidden_dim] + # x_e_GG_D_splits shape [num_experts, tokens_per_expert(varying), hidden_dim] + x_e_GG_D_splits = torch.split(x, split_size_or_sections=num_local_tokens_per_expert, dim=0) + out_GG_D_splits = [] + for expert_index, x_GG_D in enumerate(x_e_GG_D_splits): + gate_proj = self.gate_proj[expert_index] + down_proj = self.down_proj[expert_index] + up_proj = None + if self.up_proj is not None: + up_proj = self.up_proj[expert_index] + + h = self.act_fn(torch.matmul(x_GG_D, gate_proj)) + if up_proj is not None: + h = h * torch.matmul(x_GG_D, down_proj) + # [tokens_per_expert, hidden_dim] + h = torch.matmul(h, down_proj) + out_GG_D_splits.append(h) + # shape [num_experts * tokens_per_expert(varying), hidden_dim] + out_EGG_D = torch.cat(out_GG_D_splits, dim=0) + return out_EGG_D # expert choice forward, expert choice topK tokens else: - # TODO: implement clamp() + # x shape [num_experts, tokens_per_expert, hidden_dim] x = x.view(num_experts, -1, dim_in) h = self.act_fn(torch.bmm(x, self.gate_proj)) if self.up_proj is not None: @@ -37,14 +57,9 @@ class Experts(nn.Module): def moe_experts(hidden_dim, model_dim, num_experts, swiglu, nonlinearity) -> FeedForward: return Experts(dim_in=hidden_dim, dim_out=model_dim, nonlinearity=nonlinearity, num_experts=num_experts, swiglu=swiglu) -# Shared expert(single) +# Shared expert / single expert def moe_expert(hidden_dim, model_dim, swiglu, nonlinearity) -> FeedForward: return Experts(dim_in=hidden_dim, dim_out=model_dim, nonlinearity=nonlinearity, num_experts=1, swiglu=swiglu) - -# For example, the Mixtral expert could be implemented like this -def mixtral_expert(hidden_dim, model_dim, nonlinearity) -> FeedForward: - return Experts(dim_in=hidden_dim, dim_out=model_dim, nonlinearity=nonlinearity, num_experts=1, swiglu=True) - ``` ## Router and Moe Layer @@ -75,11 +90,11 @@ class TokenChoiceTopKRouter(nn.Module): return top_scores, top_indices -# For example, Mixtral uses TokenChoiceMoeLayer +# Option 1: Least efficient approach: looping over experts class TokenChoiceMoeLayer(nn.Module): def __init__(self): - # self.experts = nn.ModuleList(moe_expert() for _ in range(num_experts)) - self.experts = moe_experts(hidden_dim, model_dim, num_experts) + self.experts = nn.ModuleList(moe_expert() for _ in range(num_experts)) + self.shared_expert = moe_expert(hidden_dim, model_dim) self.router = TokenChoiceTopKRouter(hidden_dim, num_experts, experts_per_token) # Mixtral token choice forward implementation assuming we have a list of moe_expert() @@ -97,21 +112,41 @@ class TokenChoiceMoeLayer(nn.Module): # compute hidden state for the each selected expert and multiply by the routing weights hidden_states = expert(x[token_idx]) * top_scores[token_idx, expert_idx] out.index_add_(0, token_idx, hidden_states) + + out += self.shared_expert(x) return out + +# Option 2: More efficient approach using Cutlass Grouped GEMM +class TokenChoiceMoeLayer(nn.Module): + def __init__(self): + self.experts = moe_experts(hidden_dim, model_dim, num_experts) + self.shared_expert = moe_expert(hidden_dim, model_dim) + self.router = TokenChoiceTopKRouter(hidden_dim, num_experts, experts_per_token) + # unifying TC and EC MoeLayer forward function def forward(self, x): # x shape [bs*slen, hidden_dim] # router scores/indices shape [bs*slen, experts_per_token] top_scores, selected_experts_indices = self.router(x) - # [bs*slen*experts_per_token, hidden_dim] - selected_experts_indices_expanded = selected_experts_indices.reshape(-1, 1).expand(-1, D) - # [bs*slen*experts_per_token, hidden_dim], why do we gather x based on selected expert indices? - routed_input = torch.gather(x, dim=0, index=selected_experts_indices_expanded) - routed_input = routed_input * top_scores.reshape(-1, 1) - # [bs*slen*experts_per_token, hidden_dim], only experts_per_token will be activated for each token - routed_output = self.experts(routed_input, use_token_choice=True) + # arg sort experts, and group together tokens for each expert + # num_local_tokens_per_expert shape [num_experts,]: how many tokens for each expert + num_local_tokens_per_expert = torch.histc(selected_expert_indices.view(-1), bins=num_experts, min=0, max=num_experts) + + # shape [bs*slen*experts_per_token,] + token_indices_experts_sorted = torch.argsort(selected_experts_indices.view(-1), stable=True) + # top_scores shape [bs*slen*experts_per_token,] + top_scores = top_scores.view(-1)[token_indices_sorted] + # token_indices_experts_sorted_expanded shape [bs*slen*experts_per_token, hidden_dim] + token_indices_experts_sorted_expanded = token_indices_experts_sorted.reshape(-1, 1).expand(-1, hidden_dim) + + # routed_input shape [bs*slen*experts_per_token, hidden_dim] + routed_input = torch.gather(x, dim=0, index=token_indices_experts_sorted_expanded) + routed_input = routed_input * top_scores + + # [bs*slen*experts_per_token, hidden_dim], only experts_per_token experts will be activated for each token + routed_output = self.experts(routed_input, use_token_choice=True, num_local_tokens_per_expert=num_local_tokens_per_expert) # shared expert if use_shared_expert: @@ -126,7 +161,7 @@ class TokenChoiceMoeLayer(nn.Module): # [bs*slen*experts_per_token, hidden_dim] routed_output, # [bs*slen*experts_per_token, hidden_dim] - selected_experts_indices_expanded, + token_indices_experts_sorted_expanded, ) ``` @@ -160,10 +195,37 @@ class ExpertChoiceTopKRouter(nn.Module): return top_scores, top_indices +# Option 1: Least efficient approach: looping over experts similar to TokenChoiceMoeLayer +class ExpertChoiceMoeLayer(nn.Module): + def __init__(self): + self.experts = nn.ModuleList(moe_expert() for _ in range(num_experts)) + self.shared_expert = moe_expert(hidden_dim, model_dim) + self.router = ExpertChoiceTopKRouter(hidden_dim, num_experts, tokens_per_expert) + + def forward(self, x): + # x shape [bs*slen, hidden_dim] + # router scores/indices shape [num_experts, tokens_per_expert] + top_scores, selected_token_indices = self.router(x) + + # out shape [bs*slen, hidden_dim] + out = torch.zeros((batch_size * seq_len, hidden_dim)) + for i in range(num_experts): + expert = self.experts[i] + # selected tokens [tokens_per_expert, hidden_dim] + selected_tokens = x[selected_token_indices[i]] + # compute hidden state for the each selected expert and multiply by the routing weights [tokens_per_expert, hidden_dim] + hidden_states = expert(selected_tokens) * top_scores[i] + out.index_add_(0, selected_token_indices[i], hidden_states) + + out += self.shared_expert(x) + return out + + +# Option 2: More efficient approach with GEMM class ExpertChoiceMoeLayer(nn.Module): def __init__(self): self.experts = moe_experts(hidden_dim, model_dim, num_experts) - self.shared_expert = moe_shared_expert(hidden_dim, model_dim) + self.shared_expert = moe_expert(hidden_dim, model_dim) self.router = ExpertChoiceTopKRouter(hidden_dim, num_experts, tokens_per_expert) def forward(self, x): @@ -176,7 +238,7 @@ class ExpertChoiceMoeLayer(nn.Module): routed_input = torch.gather(x, dim=0, index=selected_token_indices_expanded) routed_input = routed_input * top_scores.reshape(-1, 1) # routed output shape [num_experts*tokens_per_expert, hidden_dim] - routed_output = self.experts(routed_input) + routed_output = self.experts(routed_input, use_token_choice=False) # shared expert if use_shared_expert: From 321e2629440d92b0c608033ba6298ed76240a77f Mon Sep 17 00:00:00 2001 From: JessicaZhong Date: Fri, 25 Oct 2024 16:56:07 -0700 Subject: [PATCH 05/13] finish first draft --- RFC.md | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/RFC.md b/RFC.md index 3590b63f30..80c8ce4bb1 100644 --- a/RFC.md +++ b/RFC.md @@ -39,6 +39,7 @@ class Experts(nn.Module): h = h * torch.matmul(x_GG_D, down_proj) # [tokens_per_expert, hidden_dim] h = torch.matmul(h, down_proj) + out_GG_D_splits.append(h) # shape [num_experts * tokens_per_expert(varying), hidden_dim] out_EGG_D = torch.cat(out_GG_D_splits, dim=0) @@ -57,7 +58,7 @@ class Experts(nn.Module): def moe_experts(hidden_dim, model_dim, num_experts, swiglu, nonlinearity) -> FeedForward: return Experts(dim_in=hidden_dim, dim_out=model_dim, nonlinearity=nonlinearity, num_experts=num_experts, swiglu=swiglu) -# Shared expert / single expert +# Shared expert / Single expert def moe_expert(hidden_dim, model_dim, swiglu, nonlinearity) -> FeedForward: return Experts(dim_in=hidden_dim, dim_out=model_dim, nonlinearity=nonlinearity, num_experts=1, swiglu=swiglu) ``` @@ -117,7 +118,7 @@ class TokenChoiceMoeLayer(nn.Module): return out -# Option 2: More efficient approach using Cutlass Grouped GEMM +# Option 2: More efficient approach: without looping over experts, using bmm class TokenChoiceMoeLayer(nn.Module): def __init__(self): self.experts = moe_experts(hidden_dim, model_dim, num_experts) @@ -195,7 +196,7 @@ class ExpertChoiceTopKRouter(nn.Module): return top_scores, top_indices -# Option 1: Least efficient approach: looping over experts similar to TokenChoiceMoeLayer +# Option 1: Least efficient approach: looping over experts class ExpertChoiceMoeLayer(nn.Module): def __init__(self): self.experts = nn.ModuleList(moe_expert() for _ in range(num_experts)) @@ -221,7 +222,7 @@ class ExpertChoiceMoeLayer(nn.Module): return out -# Option 2: More efficient approach with GEMM +# Option 2: More efficient approach: without looping over experts, using bmm class ExpertChoiceMoeLayer(nn.Module): def __init__(self): self.experts = moe_experts(hidden_dim, model_dim, num_experts) @@ -238,7 +239,7 @@ class ExpertChoiceMoeLayer(nn.Module): routed_input = torch.gather(x, dim=0, index=selected_token_indices_expanded) routed_input = routed_input * top_scores.reshape(-1, 1) # routed output shape [num_experts*tokens_per_expert, hidden_dim] - routed_output = self.experts(routed_input, use_token_choice=False) + routed_output = self.experts(routed_input) # shared expert if use_shared_expert: From cc4446b6761cfcdd8d4a9414e495a766dbcee3b8 Mon Sep 17 00:00:00 2001 From: JessicaZhong Date: Fri, 25 Oct 2024 17:04:15 -0700 Subject: [PATCH 06/13] finish first draft --- RFC.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/RFC.md b/RFC.md index 80c8ce4bb1..b79ad20030 100644 --- a/RFC.md +++ b/RFC.md @@ -118,7 +118,7 @@ class TokenChoiceMoeLayer(nn.Module): return out -# Option 2: More efficient approach: without looping over experts, using bmm +# Option 2: More efficient approach: without looping over experts class TokenChoiceMoeLayer(nn.Module): def __init__(self): self.experts = moe_experts(hidden_dim, model_dim, num_experts) @@ -222,7 +222,7 @@ class ExpertChoiceMoeLayer(nn.Module): return out -# Option 2: More efficient approach: without looping over experts, using bmm +# Option 2: More efficient approach: without looping over experts, using bmm class ExpertChoiceMoeLayer(nn.Module): def __init__(self): self.experts = moe_experts(hidden_dim, model_dim, num_experts) From 8190a627655569a52b95b06623b70985298598a2 Mon Sep 17 00:00:00 2001 From: JessicaZhong Date: Sat, 26 Oct 2024 13:58:04 -0700 Subject: [PATCH 07/13] finish first draft --- RFC.md | 79 +++++++++++++++++++++++++++------------------------------- 1 file changed, 37 insertions(+), 42 deletions(-) diff --git a/RFC.md b/RFC.md index b79ad20030..fd7259bb08 100644 --- a/RFC.md +++ b/RFC.md @@ -20,47 +20,49 @@ class Experts(nn.Module): self.act_fn = nonlinearity def forward(self, x, use_token_choice=False, num_local_tokens_per_expert=None): - # token choice forward, token choose topK experts, TODO: use cutlass groupGEMM instead of torch.matmul() + ''' + inputs: + x: input tokens, shape [bs*slen*experts_per_token, hidden_dim] for TC, shape [num_experts*tokens_per_expert, hidden_dim] for EC + use_token_choice: True only for TokenChoiceMoeLayer's expert forward + num_local_tokens_per_expert: not None when use_token_choice=True, number of tokens for each expert + outputs: + out: shape [bs*slen*experts_per_token, hidden_dim] for TC, shape [num_experts*tokens_per_expert, hidden_dim] for EC + ''' + # token choice forward TODO: use cutlass groupGEMM instead of torch.matmul() to optimize performance if use_token_choice: assert num_local_tokens_per_expert is not None, "num_local_tokens_per_expert is needed for token choice expert forward" - # x shape [bs*slen*experts_per_expert, hidden_dim] - # x_e_GG_D_splits shape [num_experts, tokens_per_expert(varying), hidden_dim] - x_e_GG_D_splits = torch.split(x, split_size_or_sections=num_local_tokens_per_expert, dim=0) - out_GG_D_splits = [] - for expert_index, x_GG_D in enumerate(x_e_GG_D_splits): + # x_expert_splits shape [num_experts, tokens_per_expert(varying), hidden_dim] + x_expert_splits = torch.split(x, split_size_or_sections=num_local_tokens_per_expert, dim=0) + out_expert_splits = [] + for expert_index, x_expert_split in enumerate(x_expert_splits): gate_proj = self.gate_proj[expert_index] down_proj = self.down_proj[expert_index] up_proj = None if self.up_proj is not None: up_proj = self.up_proj[expert_index] - h = self.act_fn(torch.matmul(x_GG_D, gate_proj)) + h = self.act_fn(torch.matmul(x_expert_split, gate_proj)) if up_proj is not None: - h = h * torch.matmul(x_GG_D, down_proj) + h = h * torch.matmul(x_expert_split, up_proj) # [tokens_per_expert, hidden_dim] h = torch.matmul(h, down_proj) - out_GG_D_splits.append(h) - # shape [num_experts * tokens_per_expert(varying), hidden_dim] - out_EGG_D = torch.cat(out_GG_D_splits, dim=0) - return out_EGG_D - # expert choice forward, expert choice topK tokens + out_expert_splits.append(h) + # shape [num_experts * tokens_per_expert(varying), hidden_dim] = [bs*slen*experts_per_token, hidden_dim] + out = torch.cat(out_expert_splits, dim=0) else: + # expert choice forward # x shape [num_experts, tokens_per_expert, hidden_dim] x = x.view(num_experts, -1, dim_in) h = self.act_fn(torch.bmm(x, self.gate_proj)) if self.up_proj is not None: h = h * torch.bmm(x, self.up_proj) - h = torch.bmm(h, self.down_proj).view(-1, dim_in) + out = torch.bmm(h, self.down_proj).view(-1, dim_in) return h -# Routed Experts(num_experts) -def moe_experts(hidden_dim, model_dim, num_experts, swiglu, nonlinearity) -> FeedForward: +# Expert builder for both routed experts and shared expert +def moe_experts(hidden_dim, model_dim, num_experts=1, swiglu=True, nonlinearity=None) -> FeedForward: return Experts(dim_in=hidden_dim, dim_out=model_dim, nonlinearity=nonlinearity, num_experts=num_experts, swiglu=swiglu) - -# Shared expert / Single expert -def moe_expert(hidden_dim, model_dim, swiglu, nonlinearity) -> FeedForward: - return Experts(dim_in=hidden_dim, dim_out=model_dim, nonlinearity=nonlinearity, num_experts=1, swiglu=swiglu) ``` ## Router and Moe Layer @@ -94,11 +96,10 @@ class TokenChoiceTopKRouter(nn.Module): # Option 1: Least efficient approach: looping over experts class TokenChoiceMoeLayer(nn.Module): def __init__(self): - self.experts = nn.ModuleList(moe_expert() for _ in range(num_experts)) - self.shared_expert = moe_expert(hidden_dim, model_dim) + self.experts = nn.ModuleList(moe_experts(hidden_dim, model_dim, num_experts=1) for _ in range(num_experts)) + self.shared_expert = moe_experts(hidden_dim, model_dim, num_experts=1) self.router = TokenChoiceTopKRouter(hidden_dim, num_experts, experts_per_token) - # Mixtral token choice forward implementation assuming we have a list of moe_expert() def forward(self, x): # x shape [bs*slen, hidden_dim] # router scores/indices shape [bs*slen, experts_per_token] @@ -118,35 +119,31 @@ class TokenChoiceMoeLayer(nn.Module): return out -# Option 2: More efficient approach: without looping over experts +# Option 2: More efficient approach: without explicitly looping over experts, use_token_choice=True for expert's forward class TokenChoiceMoeLayer(nn.Module): def __init__(self): - self.experts = moe_experts(hidden_dim, model_dim, num_experts) - self.shared_expert = moe_expert(hidden_dim, model_dim) + self.experts = moe_experts(hidden_dim, model_dim, num_experts=num_experts) + self.shared_expert = moe_experts(hidden_dim, model_dim, num_experts=1) self.router = TokenChoiceTopKRouter(hidden_dim, num_experts, experts_per_token) - # unifying TC and EC MoeLayer forward function def forward(self, x): # x shape [bs*slen, hidden_dim] # router scores/indices shape [bs*slen, experts_per_token] top_scores, selected_experts_indices = self.router(x) - # arg sort experts, and group together tokens for each expert - # num_local_tokens_per_expert shape [num_experts,]: how many tokens for each expert + # shape [num_experts,]: how many tokens for each expert num_local_tokens_per_expert = torch.histc(selected_expert_indices.view(-1), bins=num_experts, min=0, max=num_experts) - # shape [bs*slen*experts_per_token,] token_indices_experts_sorted = torch.argsort(selected_experts_indices.view(-1), stable=True) # top_scores shape [bs*slen*experts_per_token,] - top_scores = top_scores.view(-1)[token_indices_sorted] + top_scores = top_scores.view(-1)[token_indices_experts_sorted] + # token_indices_experts_sorted_expanded shape [bs*slen*experts_per_token, hidden_dim] token_indices_experts_sorted_expanded = token_indices_experts_sorted.reshape(-1, 1).expand(-1, hidden_dim) - # routed_input shape [bs*slen*experts_per_token, hidden_dim] routed_input = torch.gather(x, dim=0, index=token_indices_experts_sorted_expanded) routed_input = routed_input * top_scores - - # [bs*slen*experts_per_token, hidden_dim], only experts_per_token experts will be activated for each token + # output [bs*slen*experts_per_token, hidden_dim] routed_output = self.experts(routed_input, use_token_choice=True, num_local_tokens_per_expert=num_local_tokens_per_expert) # shared expert @@ -191,7 +188,6 @@ class ExpertChoiceTopKRouter(nn.Module): # scores shape [num_experts, bs*slen] scores = self.gate(x).transpose(0,1) scores = F.softmax(scores.to(softmax_dtype), dim=0).to(scores.dtype) - # [num_experts, tokens_per_expert] top_scores, top_indices = torch.topk(scores, k=self.tokens_per_expert, dim=1) return top_scores, top_indices @@ -199,8 +195,8 @@ class ExpertChoiceTopKRouter(nn.Module): # Option 1: Least efficient approach: looping over experts class ExpertChoiceMoeLayer(nn.Module): def __init__(self): - self.experts = nn.ModuleList(moe_expert() for _ in range(num_experts)) - self.shared_expert = moe_expert(hidden_dim, model_dim) + self.experts = nn.ModuleList(moe_experts(hidden_dim, model_dim, num_experts=1) for _ in range(num_experts)) + self.shared_expert = moe_experts(hidden_dim, model_dim, num_experts=1) self.router = ExpertChoiceTopKRouter(hidden_dim, num_experts, tokens_per_expert) def forward(self, x): @@ -212,7 +208,7 @@ class ExpertChoiceMoeLayer(nn.Module): out = torch.zeros((batch_size * seq_len, hidden_dim)) for i in range(num_experts): expert = self.experts[i] - # selected tokens [tokens_per_expert, hidden_dim] + # selected_tokens [tokens_per_expert, hidden_dim] selected_tokens = x[selected_token_indices[i]] # compute hidden state for the each selected expert and multiply by the routing weights [tokens_per_expert, hidden_dim] hidden_states = expert(selected_tokens) * top_scores[i] @@ -222,11 +218,11 @@ class ExpertChoiceMoeLayer(nn.Module): return out -# Option 2: More efficient approach: without looping over experts, using bmm +# Option 2: More efficient approach: without looping over experts using torch.bmm class ExpertChoiceMoeLayer(nn.Module): def __init__(self): - self.experts = moe_experts(hidden_dim, model_dim, num_experts) - self.shared_expert = moe_expert(hidden_dim, model_dim) + self.experts = moe_experts(hidden_dim, model_dim, num_experts=num_experts) + self.shared_expert = moe_expert(hidden_dim, model_dim, num_experts=1) self.router = ExpertChoiceTopKRouter(hidden_dim, num_experts, tokens_per_expert) def forward(self, x): @@ -305,5 +301,4 @@ torchtune/ _component_builders.py moe() moe_experts() - moe_expert() ``` From 0c2060a38b0236211cd421bee341472f652c4628 Mon Sep 17 00:00:00 2001 From: JessicaZhong Date: Sat, 26 Oct 2024 14:20:01 -0700 Subject: [PATCH 08/13] misc --- RFC.md | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/RFC.md b/RFC.md index fd7259bb08..df1dca3205 100644 --- a/RFC.md +++ b/RFC.md @@ -93,7 +93,7 @@ class TokenChoiceTopKRouter(nn.Module): return top_scores, top_indices -# Option 1: Least efficient approach: looping over experts +# Implementation 1: Least efficient approach: looping over experts class TokenChoiceMoeLayer(nn.Module): def __init__(self): self.experts = nn.ModuleList(moe_experts(hidden_dim, model_dim, num_experts=1) for _ in range(num_experts)) @@ -119,7 +119,7 @@ class TokenChoiceMoeLayer(nn.Module): return out -# Option 2: More efficient approach: without explicitly looping over experts, use_token_choice=True for expert's forward +# Implementation 2: More efficient approach: without explicitly looping over experts, use_token_choice=True for expert's forward class TokenChoiceMoeLayer(nn.Module): def __init__(self): self.experts = moe_experts(hidden_dim, model_dim, num_experts=num_experts) @@ -192,7 +192,7 @@ class ExpertChoiceTopKRouter(nn.Module): return top_scores, top_indices -# Option 1: Least efficient approach: looping over experts +# Implementation 1: Least efficient approach: looping over experts class ExpertChoiceMoeLayer(nn.Module): def __init__(self): self.experts = nn.ModuleList(moe_experts(hidden_dim, model_dim, num_experts=1) for _ in range(num_experts)) @@ -218,14 +218,14 @@ class ExpertChoiceMoeLayer(nn.Module): return out -# Option 2: More efficient approach: without looping over experts using torch.bmm +# Implementation 2: More efficient approach: without looping over experts using torch.bmm class ExpertChoiceMoeLayer(nn.Module): def __init__(self): self.experts = moe_experts(hidden_dim, model_dim, num_experts=num_experts) - self.shared_expert = moe_expert(hidden_dim, model_dim, num_experts=1) + self.shared_expert = moe_experts(hidden_dim, model_dim, num_experts=1) self.router = ExpertChoiceTopKRouter(hidden_dim, num_experts, tokens_per_expert) - def forward(self, x): + def forward(self, x, infernece=False): # x shape [bs*slen, hidden_dim] # router scores/indices shape [num_experts, tokens_per_expert] top_scores, selected_token_indices = self.router(x) From e12dab9530c6df06279b1b67f30670bc0b67ca32 Mon Sep 17 00:00:00 2001 From: JessicaZhong Date: Mon, 28 Oct 2024 13:12:04 -0700 Subject: [PATCH 09/13] separate TC and EC experts --- RFC.md | 116 +++++++++++++++++++++++++++++++++------------------------ 1 file changed, 68 insertions(+), 48 deletions(-) diff --git a/RFC.md b/RFC.md index df1dca3205..4a9168506e 100644 --- a/RFC.md +++ b/RFC.md @@ -9,7 +9,7 @@ An expert is essentially an FFN layer similar to the original dense FFN layer in **Here's the proposed Experts design in torchtune:** ```python class Experts(nn.Module): - def __init__(self, dim_in, dim_out, nonlinearity, num_experts=1, swiglu=True): + def __init__(self, dim_in, dim_out, num_experts=1, swiglu=True, nonlinearity=None): self.gate_proj = nn.Parameter(torch.empty(num_experts, dim_in, dim_out)) self.down_proj = nn.Parameter(torch.empty(num_experts, dim_out, dim_in)) if swiglu: @@ -19,50 +19,67 @@ class Experts(nn.Module): self.up_proj = None self.act_fn = nonlinearity - def forward(self, x, use_token_choice=False, num_local_tokens_per_expert=None): + def forward(self, x): + raise NotImplementedError("Subclasses must implement their own forward method.") + + +class TokenChoiceExperts(Experts): + def forward(self, x, num_local_tokens_per_expert): ''' inputs: - x: input tokens, shape [bs*slen*experts_per_token, hidden_dim] for TC, shape [num_experts*tokens_per_expert, hidden_dim] for EC - use_token_choice: True only for TokenChoiceMoeLayer's expert forward - num_local_tokens_per_expert: not None when use_token_choice=True, number of tokens for each expert + x: input tokens, shape [bs*slen*experts_per_token, hidden_dim] + num_local_tokens_per_expert: number of tokens for each expert outputs: - out: shape [bs*slen*experts_per_token, hidden_dim] for TC, shape [num_experts*tokens_per_expert, hidden_dim] for EC + out: output tokens, shape [bs*slen*experts_per_token, hidden_dim] ''' - # token choice forward TODO: use cutlass groupGEMM instead of torch.matmul() to optimize performance - if use_token_choice: - assert num_local_tokens_per_expert is not None, "num_local_tokens_per_expert is needed for token choice expert forward" - # x_expert_splits shape [num_experts, tokens_per_expert(varying), hidden_dim] - x_expert_splits = torch.split(x, split_size_or_sections=num_local_tokens_per_expert, dim=0) - out_expert_splits = [] - for expert_index, x_expert_split in enumerate(x_expert_splits): - gate_proj = self.gate_proj[expert_index] - down_proj = self.down_proj[expert_index] - up_proj = None - if self.up_proj is not None: - up_proj = self.up_proj[expert_index] - - h = self.act_fn(torch.matmul(x_expert_split, gate_proj)) - if up_proj is not None: - h = h * torch.matmul(x_expert_split, up_proj) - # [tokens_per_expert, hidden_dim] - h = torch.matmul(h, down_proj) - - out_expert_splits.append(h) - # shape [num_experts * tokens_per_expert(varying), hidden_dim] = [bs*slen*experts_per_token, hidden_dim] - out = torch.cat(out_expert_splits, dim=0) - else: - # expert choice forward - # x shape [num_experts, tokens_per_expert, hidden_dim] - x = x.view(num_experts, -1, dim_in) - h = self.act_fn(torch.bmm(x, self.gate_proj)) + # TODO: use cutlass groupGEMM instead of torch.matmul() to optimize performance + assert num_local_tokens_per_expert is not None, "num_local_tokens_per_expert is needed for token choice expert forward" + # x_expert_splits shape [num_experts, tokens_per_expert(varying), hidden_dim] + x_expert_splits = torch.split(x, split_size_or_sections=num_local_tokens_per_expert, dim=0) + out_expert_splits = [] + for expert_index, x_expert_split in enumerate(x_expert_splits): + gate_proj = self.gate_proj[expert_index] + down_proj = self.down_proj[expert_index] + up_proj = None if self.up_proj is not None: - h = h * torch.bmm(x, self.up_proj) - out = torch.bmm(h, self.down_proj).view(-1, dim_in) - return h + up_proj = self.up_proj[expert_index] + + h = self.act_fn(torch.matmul(x_expert_split, gate_proj)) + if up_proj is not None: + h = h * torch.matmul(x_expert_split, up_proj) + # [tokens_per_expert, hidden_dim] + h = torch.matmul(h, down_proj) + + out_expert_splits.append(h) + # shape [num_experts * tokens_per_expert(varying), hidden_dim] = [bs*slen*experts_per_token, hidden_dim] + return torch.cat(out_expert_splits, dim=0) -# Expert builder for both routed experts and shared expert -def moe_experts(hidden_dim, model_dim, num_experts=1, swiglu=True, nonlinearity=None) -> FeedForward: - return Experts(dim_in=hidden_dim, dim_out=model_dim, nonlinearity=nonlinearity, num_experts=num_experts, swiglu=swiglu) + +class ExpertChoiceExperts(Experts): + def forward(self, x): + ''' + inputs: + x: input tokens, shape [num_experts*tokens_per_expert, hidden_dim] + outputs: + out: output tokens, shape [num_experts*tokens_per_expert, hidden_dim] + ''' + # x shape [num_experts, tokens_per_expert, hidden_dim] + x = x.view(num_experts, -1, dim_in) + h = self.act_fn(torch.bmm(x, self.gate_proj)) + if self.up_proj is not None: + h = h * torch.bmm(x, self.up_proj) + return torch.bmm(h, self.down_proj).view(-1, dim_in) + +# Expert builder for routed experts +def moe_experts(hidden_dim, model_dim, num_experts, swiglu=True, nonlinearity=None, expert_choice=True): + if expert_choice: + return ExpertChoiceExperts(dim_in=hidden_dim, dim_out=model_dim, num_experts=num_experts, swiglu=swiglu, nonlinearity=nonlinearity) + else: + return TokenChoiceExperts(dim_in=hidden_dim, dim_out=model_dim, num_experts=num_experts, swiglu=swiglu, nonlinearity=nonlinearity) + +# Single expert / shared expert +def moe_expert(hidden_dim, model_dim, swiglu=True, nonlinearity=None): + return ExpertChoiceExperts(dim_in=hidden_dim, dim_out=model_dim, num_experts=1, swiglu=swiglu, nonlinearity=nonlinearity) ``` ## Router and Moe Layer @@ -96,8 +113,8 @@ class TokenChoiceTopKRouter(nn.Module): # Implementation 1: Least efficient approach: looping over experts class TokenChoiceMoeLayer(nn.Module): def __init__(self): - self.experts = nn.ModuleList(moe_experts(hidden_dim, model_dim, num_experts=1) for _ in range(num_experts)) - self.shared_expert = moe_experts(hidden_dim, model_dim, num_experts=1) + self.experts = nn.ModuleList(moe_expert(hidden_dim, model_dim) for _ in range(num_experts)) + self.shared_expert = moe_expert(hidden_dim, model_dim) self.router = TokenChoiceTopKRouter(hidden_dim, num_experts, experts_per_token) def forward(self, x): @@ -122,8 +139,8 @@ class TokenChoiceMoeLayer(nn.Module): # Implementation 2: More efficient approach: without explicitly looping over experts, use_token_choice=True for expert's forward class TokenChoiceMoeLayer(nn.Module): def __init__(self): - self.experts = moe_experts(hidden_dim, model_dim, num_experts=num_experts) - self.shared_expert = moe_experts(hidden_dim, model_dim, num_experts=1) + self.experts = moe_experts(hidden_dim, model_dim, num_experts=num_experts, expert_choice=False) + self.shared_expert = moe_expert(hidden_dim, model_dim) self.router = TokenChoiceTopKRouter(hidden_dim, num_experts, experts_per_token) def forward(self, x): @@ -144,7 +161,7 @@ class TokenChoiceMoeLayer(nn.Module): routed_input = torch.gather(x, dim=0, index=token_indices_experts_sorted_expanded) routed_input = routed_input * top_scores # output [bs*slen*experts_per_token, hidden_dim] - routed_output = self.experts(routed_input, use_token_choice=True, num_local_tokens_per_expert=num_local_tokens_per_expert) + routed_output = self.experts(routed_input, num_local_tokens_per_expert=num_local_tokens_per_expert) # shared expert if use_shared_expert: @@ -195,8 +212,8 @@ class ExpertChoiceTopKRouter(nn.Module): # Implementation 1: Least efficient approach: looping over experts class ExpertChoiceMoeLayer(nn.Module): def __init__(self): - self.experts = nn.ModuleList(moe_experts(hidden_dim, model_dim, num_experts=1) for _ in range(num_experts)) - self.shared_expert = moe_experts(hidden_dim, model_dim, num_experts=1) + self.experts = nn.ModuleList(moe_expert(hidden_dim, model_dim) for _ in range(num_experts)) + self.shared_expert = moe_expert(hidden_dim, model_dim) self.router = ExpertChoiceTopKRouter(hidden_dim, num_experts, tokens_per_expert) def forward(self, x): @@ -221,8 +238,8 @@ class ExpertChoiceMoeLayer(nn.Module): # Implementation 2: More efficient approach: without looping over experts using torch.bmm class ExpertChoiceMoeLayer(nn.Module): def __init__(self): - self.experts = moe_experts(hidden_dim, model_dim, num_experts=num_experts) - self.shared_expert = moe_experts(hidden_dim, model_dim, num_experts=1) + self.experts = moe_experts(hidden_dim, model_dim, num_experts=num_experts, expert_choice=True) + self.shared_expert = moe_expert(hidden_dim, model_dim) self.router = ExpertChoiceTopKRouter(hidden_dim, num_experts, tokens_per_expert) def forward(self, x, infernece=False): @@ -296,9 +313,12 @@ torchtune/ ExpertChoiceMoeLayer() experts.py Experts() + TokenChoiceExperts + ExpertChoiceExperts models/ moe/ _component_builders.py moe() + moe_expert() moe_experts() ``` From cfd37643da9385e3b0682fcbb5a3f2753e162318 Mon Sep 17 00:00:00 2001 From: JessicaZhong Date: Mon, 28 Oct 2024 14:57:13 -0700 Subject: [PATCH 10/13] add sigmoid --- RFC.md | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/RFC.md b/RFC.md index 4a9168506e..276b99040b 100644 --- a/RFC.md +++ b/RFC.md @@ -30,7 +30,7 @@ class TokenChoiceExperts(Experts): x: input tokens, shape [bs*slen*experts_per_token, hidden_dim] num_local_tokens_per_expert: number of tokens for each expert outputs: - out: output tokens, shape [bs*slen*experts_per_token, hidden_dim] + out: output tokens, shape [bs*slen*experts_per_token, hidden_dim] ''' # TODO: use cutlass groupGEMM instead of torch.matmul() to optimize performance assert num_local_tokens_per_expert is not None, "num_local_tokens_per_expert is needed for token choice expert forward" @@ -94,7 +94,7 @@ class TokenChoiceTopKRouter(nn.Module): self.gate = nn.Linear(hidden_dim, num_experts) self.experts_per_token = experts_per_token - def forward(self, x): + def forward(self, x, use_sigmoid=False): ''' input: x shape [bs*slen, hidden_dim] @@ -104,7 +104,10 @@ class TokenChoiceTopKRouter(nn.Module): ''' # scores shape [bs*slen, num_experts] scores = self.gate(x) - scores = F.softmax(scores, dim=1) + if use_sigmoid: + scores = torch.sigmoid(scores.to(sigmoid_dtype)).to(x.dtype) + else: + scores = F.softmax(scores.to(softmax_dtype), dim=1).to(x.dtype) top_scores, top_indices = torch.topk(scores, k=self.experts_per_token, dim=1) top_scores /= top_scores.sum(dim=-1, keep_dim=True).to(x.dtype) return top_scores, top_indices @@ -194,7 +197,7 @@ class ExpertChoiceTopKRouter(nn.Module): self.gate = nn.Linear(hidden_dim, num_experts) self.tokens_per_expert = tokens_per_expert - def forward(self, x): + def forward(self, x, use_sigmoid=False): ''' input: x shape [bs*slen, hidden_dim] @@ -204,7 +207,10 @@ class ExpertChoiceTopKRouter(nn.Module): ''' # scores shape [num_experts, bs*slen] scores = self.gate(x).transpose(0,1) - scores = F.softmax(scores.to(softmax_dtype), dim=0).to(scores.dtype) + if use_sigmoid: + scores = torch.sigmoid(scores.to(sigmoid_dtype)).to(x.dtype) + else: + scores = F.softmax(scores.to(softmax_dtype), dim=0).to(x.dtype) top_scores, top_indices = torch.topk(scores, k=self.tokens_per_expert, dim=1) return top_scores, top_indices From c76aef3daef5ae98c45df738c34dd386c629ec3b Mon Sep 17 00:00:00 2001 From: JessicaZhong Date: Mon, 28 Oct 2024 17:12:41 -0700 Subject: [PATCH 11/13] unify Experts class and MoeLayer --- RFC.md | 241 ++++++++++++++++++--------------------------------------- 1 file changed, 77 insertions(+), 164 deletions(-) diff --git a/RFC.md b/RFC.md index 276b99040b..75b2a2d30a 100644 --- a/RFC.md +++ b/RFC.md @@ -19,73 +19,59 @@ class Experts(nn.Module): self.up_proj = None self.act_fn = nonlinearity - def forward(self, x): - raise NotImplementedError("Subclasses must implement their own forward method.") - - -class TokenChoiceExperts(Experts): - def forward(self, x, num_local_tokens_per_expert): + def forward(self, x, use_token_choice=False, num_local_tokens_per_expert=None): ''' inputs: - x: input tokens, shape [bs*slen*experts_per_token, hidden_dim] - num_local_tokens_per_expert: number of tokens for each expert + x: input tokens, shape [bs*slen*experts_per_token, hidden_dim] for TC or [num_experts*tokens_per_expert, hidden_dim] for EC + use_token_choice: if we use token choice forward + num_local_tokens_per_expert: number of tokens for each expert, only used for token choice forward outputs: out: output tokens, shape [bs*slen*experts_per_token, hidden_dim] ''' - # TODO: use cutlass groupGEMM instead of torch.matmul() to optimize performance - assert num_local_tokens_per_expert is not None, "num_local_tokens_per_expert is needed for token choice expert forward" - # x_expert_splits shape [num_experts, tokens_per_expert(varying), hidden_dim] - x_expert_splits = torch.split(x, split_size_or_sections=num_local_tokens_per_expert, dim=0) - out_expert_splits = [] - for expert_index, x_expert_split in enumerate(x_expert_splits): - gate_proj = self.gate_proj[expert_index] - down_proj = self.down_proj[expert_index] - up_proj = None + if use_token_choice: + # TODO: use cutlass groupGEMM instead of torch.matmul() to optimize performance + assert num_local_tokens_per_expert is not None, "num_local_tokens_per_expert is needed for token choice expert forward" + # x_expert_splits shape [num_experts, tokens_per_expert(varying), hidden_dim] + x_expert_splits = torch.split(x, split_size_or_sections=num_local_tokens_per_expert, dim=0) + out_expert_splits = [] + for expert_index, x_expert_split in enumerate(x_expert_splits): + gate_proj = self.gate_proj[expert_index] + down_proj = self.down_proj[expert_index] + up_proj = None + if self.up_proj is not None: + up_proj = self.up_proj[expert_index] + + h = self.act_fn(torch.matmul(x_expert_split, gate_proj)) + if up_proj is not None: + h = h * torch.matmul(x_expert_split, up_proj) + # [tokens_per_expert, hidden_dim] + h = torch.matmul(h, down_proj) + + out_expert_splits.append(h) + # shape [num_experts * tokens_per_expert(varying), hidden_dim] = [bs*slen*experts_per_token, hidden_dim] + out = torch.cat(out_expert_splits, dim=0) + else: + # x shape [num_experts, tokens_per_expert, hidden_dim] + x = x.view(num_experts, -1, dim_in) + h = self.act_fn(torch.bmm(x, self.gate_proj)) if self.up_proj is not None: - up_proj = self.up_proj[expert_index] - - h = self.act_fn(torch.matmul(x_expert_split, gate_proj)) - if up_proj is not None: - h = h * torch.matmul(x_expert_split, up_proj) - # [tokens_per_expert, hidden_dim] - h = torch.matmul(h, down_proj) - - out_expert_splits.append(h) - # shape [num_experts * tokens_per_expert(varying), hidden_dim] = [bs*slen*experts_per_token, hidden_dim] - return torch.cat(out_expert_splits, dim=0) - - -class ExpertChoiceExperts(Experts): - def forward(self, x): - ''' - inputs: - x: input tokens, shape [num_experts*tokens_per_expert, hidden_dim] - outputs: - out: output tokens, shape [num_experts*tokens_per_expert, hidden_dim] - ''' - # x shape [num_experts, tokens_per_expert, hidden_dim] - x = x.view(num_experts, -1, dim_in) - h = self.act_fn(torch.bmm(x, self.gate_proj)) - if self.up_proj is not None: - h = h * torch.bmm(x, self.up_proj) - return torch.bmm(h, self.down_proj).view(-1, dim_in) + h = h * torch.bmm(x, self.up_proj) + out = torch.bmm(h, self.down_proj).view(-1, dim_in) + return out # Expert builder for routed experts -def moe_experts(hidden_dim, model_dim, num_experts, swiglu=True, nonlinearity=None, expert_choice=True): - if expert_choice: - return ExpertChoiceExperts(dim_in=hidden_dim, dim_out=model_dim, num_experts=num_experts, swiglu=swiglu, nonlinearity=nonlinearity) - else: - return TokenChoiceExperts(dim_in=hidden_dim, dim_out=model_dim, num_experts=num_experts, swiglu=swiglu, nonlinearity=nonlinearity) +def moe_experts(hidden_dim, model_dim, num_experts, swiglu=True, nonlinearity=None): + return Experts(dim_in=hidden_dim, dim_out=model_dim, num_experts=num_experts, swiglu=swiglu, nonlinearity=nonlinearity) # Single expert / shared expert def moe_expert(hidden_dim, model_dim, swiglu=True, nonlinearity=None): - return ExpertChoiceExperts(dim_in=hidden_dim, dim_out=model_dim, num_experts=1, swiglu=swiglu, nonlinearity=nonlinearity) + return Experts(dim_in=hidden_dim, dim_out=model_dim, num_experts=1, swiglu=swiglu, nonlinearity=nonlinearity) ``` -## Router and Moe Layer -Router is a gating network that calculates router scores and learns token-to-expert affinity, and an MOE layer consists of experts and routers. There are two types of routing: token choice routing and expert choice routing. +## Router +Router is a gating network that calculates router scores and learns token-to-expert affinity. There are two types of routing: token choice routing and expert choice routing. -Mixtral uses *token choice* topK routing, which means each token will select its topK experts. The router is implemented through a learnable gate function, whose outputs will go through softmax and topK. The TokenChoiceMoeLayer class then defines how tokens select experts based on router scores. +Mixtral uses *token choice* topK routing, which means each token will select its topK experts. The router is implemented through a learnable gate function, whose outputs will go through softmax and topK. The router defines how tokens select experts / experts select tokens based on router scores. **Here's the proposed Token Choice Routing and TokenChoiceMoeLayer design in torchtune:** ```python @@ -99,8 +85,8 @@ class TokenChoiceTopKRouter(nn.Module): input: x shape [bs*slen, hidden_dim] outputs: - top_scores shape [bs*slen, experts_per_token] - top_indices shape [bs*slen, experts_per_token] + routed_input shape [bs*slen*experts_per_token, hidden_dim] + num_local_tokens_per_expert shape [num_experts,] ''' # scores shape [bs*slen, num_experts] scores = self.gate(x) @@ -108,48 +94,13 @@ class TokenChoiceTopKRouter(nn.Module): scores = torch.sigmoid(scores.to(sigmoid_dtype)).to(x.dtype) else: scores = F.softmax(scores.to(softmax_dtype), dim=1).to(x.dtype) - top_scores, top_indices = torch.topk(scores, k=self.experts_per_token, dim=1) - top_scores /= top_scores.sum(dim=-1, keep_dim=True).to(x.dtype) - return top_scores, top_indices - - -# Implementation 1: Least efficient approach: looping over experts -class TokenChoiceMoeLayer(nn.Module): - def __init__(self): - self.experts = nn.ModuleList(moe_expert(hidden_dim, model_dim) for _ in range(num_experts)) - self.shared_expert = moe_expert(hidden_dim, model_dim) - self.router = TokenChoiceTopKRouter(hidden_dim, num_experts, experts_per_token) - - def forward(self, x): - # x shape [bs*slen, hidden_dim] - # router scores/indices shape [bs*slen, experts_per_token] - top_scores, selected_experts_indices = self.router(x) - - # expert_mask shape [num_experts, experts_per_token, bs*slen] - expert_mask = torch.nn.functional.one_hot(selected_experts_indices, num_class=num_experts).permute(2,1,0) - out = torch.zeros((batch_size * seq_len, hidden_dim)) - for i in range(num_experts): - expert = self.experts[i] - expert_idx, token_idx = torch.where(expert_mask[i]) - # compute hidden state for the each selected expert and multiply by the routing weights - hidden_states = expert(x[token_idx]) * top_scores[token_idx, expert_idx] - out.index_add_(0, token_idx, hidden_states) - out += self.shared_expert(x) - return out - - -# Implementation 2: More efficient approach: without explicitly looping over experts, use_token_choice=True for expert's forward -class TokenChoiceMoeLayer(nn.Module): - def __init__(self): - self.experts = moe_experts(hidden_dim, model_dim, num_experts=num_experts, expert_choice=False) - self.shared_expert = moe_expert(hidden_dim, model_dim) - self.router = TokenChoiceTopKRouter(hidden_dim, num_experts, experts_per_token) + # TODO: implement load balancing auxiliary loss for token choice routing + # https://github.com/NVIDIA/Megatron-LM/blob/f1f039224584f0bc6ba89c21ef4f491d7136e3ce/megatron/core/transformer/moe/router.py#L162 - def forward(self, x): - # x shape [bs*slen, hidden_dim] # router scores/indices shape [bs*slen, experts_per_token] - top_scores, selected_experts_indices = self.router(x) + top_scores, selected_experts_indices = torch.topk(scores, k=self.experts_per_token, dim=1) + top_scores /= top_scores.sum(dim=-1, keep_dim=True).to(x.dtype) # shape [num_experts,]: how many tokens for each expert num_local_tokens_per_expert = torch.histc(selected_expert_indices.view(-1), bins=num_experts, min=0, max=num_experts) @@ -163,24 +114,8 @@ class TokenChoiceMoeLayer(nn.Module): # routed_input shape [bs*slen*experts_per_token, hidden_dim] routed_input = torch.gather(x, dim=0, index=token_indices_experts_sorted_expanded) routed_input = routed_input * top_scores - # output [bs*slen*experts_per_token, hidden_dim] - routed_output = self.experts(routed_input, num_local_tokens_per_expert=num_local_tokens_per_expert) - # shared expert - if use_shared_expert: - out = self.shared_expert(x) - else: - out = torch.zeros_like(x) - - # add experts output - out.data = scatter_add_( - # [bs*slen, hidden_dim] - out.data, - # [bs*slen*experts_per_token, hidden_dim] - routed_output, - # [bs*slen*experts_per_token, hidden_dim] - token_indices_experts_sorted_expanded, - ) + return routed_input, token_indices_experts_sorted_expanded, num_local_tokens_per_expert ``` However, token choice routing has several pitfalls according to the expert choice [paper](https://arxiv.org/pdf/2002.05202). @@ -200,10 +135,10 @@ class ExpertChoiceTopKRouter(nn.Module): def forward(self, x, use_sigmoid=False): ''' input: - x shape [bs*slen, hidden_dim] + x: shape [bs*slen, hidden_dim] outputs: - top_scores shape [num_experts, tokens_per_expert] - top_indices shape [num_experts, tokens_per_expert] + routed_input: shape [num_experts*tokens_per_expert, hidden_dim] + num_local_tokens_per_expert: None ''' # scores shape [num_experts, bs*slen] scores = self.gate(x).transpose(0,1) @@ -211,54 +146,38 @@ class ExpertChoiceTopKRouter(nn.Module): scores = torch.sigmoid(scores.to(sigmoid_dtype)).to(x.dtype) else: scores = F.softmax(scores.to(softmax_dtype), dim=0).to(x.dtype) - top_scores, top_indices = torch.topk(scores, k=self.tokens_per_expert, dim=1) - return top_scores, top_indices - - -# Implementation 1: Least efficient approach: looping over experts -class ExpertChoiceMoeLayer(nn.Module): - def __init__(self): - self.experts = nn.ModuleList(moe_expert(hidden_dim, model_dim) for _ in range(num_experts)) - self.shared_expert = moe_expert(hidden_dim, model_dim) - self.router = ExpertChoiceTopKRouter(hidden_dim, num_experts, tokens_per_expert) - - def forward(self, x): - # x shape [bs*slen, hidden_dim] # router scores/indices shape [num_experts, tokens_per_expert] - top_scores, selected_token_indices = self.router(x) - - # out shape [bs*slen, hidden_dim] - out = torch.zeros((batch_size * seq_len, hidden_dim)) - for i in range(num_experts): - expert = self.experts[i] - # selected_tokens [tokens_per_expert, hidden_dim] - selected_tokens = x[selected_token_indices[i]] - # compute hidden state for the each selected expert and multiply by the routing weights [tokens_per_expert, hidden_dim] - hidden_states = expert(selected_tokens) * top_scores[i] - out.index_add_(0, selected_token_indices[i], hidden_states) - - out += self.shared_expert(x) - return out - + top_scores, selected_token_indices = torch.topk(scores, k=self.tokens_per_expert, dim=1) -# Implementation 2: More efficient approach: without looping over experts using torch.bmm -class ExpertChoiceMoeLayer(nn.Module): - def __init__(self): - self.experts = moe_experts(hidden_dim, model_dim, num_experts=num_experts, expert_choice=True) - self.shared_expert = moe_expert(hidden_dim, model_dim) - self.router = ExpertChoiceTopKRouter(hidden_dim, num_experts, tokens_per_expert) - - def forward(self, x, infernece=False): - # x shape [bs*slen, hidden_dim] - # router scores/indices shape [num_experts, tokens_per_expert] - top_scores, selected_token_indices = self.router(x) # apply the token preprocess function and then run experts forward selected_token_indices_expanded = selected_token_indices.reshape(-1, 1).expand(-1, D) # routed input shape [num_experts*tokens_per_expert, hidden_dim] routed_input = torch.gather(x, dim=0, index=selected_token_indices_expanded) routed_input = routed_input * top_scores.reshape(-1, 1) - # routed output shape [num_experts*tokens_per_expert, hidden_dim] - routed_output = self.experts(routed_input) + return routed_input, selected_token_indices_expanded, None, +``` + +## Moe Layer +An MOE layer consists of experts and routers. + +**Here's the proposed MoeLayer design in torchtune:** +```python +class MoeLayer(nn.Module): + def __init__(self, router="token_choice"): + self.experts = moe_experts(hidden_dim, model_dim, num_experts=num_experts) + self.shared_expert = moe_expert(hidden_dim, model_dim) + if router == "token_choice": + self.router = TokenChoiceTopKRouter(hidden_dim, num_experts, experts_per_token) + elif router == "expert_choice": + self.router = ExpertChoiceTopKRouter(hidden_dim, num_experts, tokens_per_expert) + else: + raise NotImplementedError("This router is not supported yet!") + + def forward(self, x, infernece=False): + routed_input, token_indices, num_local_tokens_per_expert = self.router(x) + + # routed output shape [num_experts*tokens_per_expert, hidden_dim] for EC, [bs*slen*experts_per_token, hidden_dim] for TC + routed_output = self.experts(routed_input, num_local_tokens_per_expert=num_local_tokens_per_expert) # shared expert if use_shared_expert: @@ -268,12 +187,9 @@ class ExpertChoiceMoeLayer(nn.Module): # add experts output out.data = scatter_add_( - # [bs*slen, hidden_dim] out.data, - # [num_experts*tokens_per_expert, hidden_dim] routed_output, - # [num_experts*tokens_per_expert, hidden_dim] - selected_token_indices_expanded, + selected_indices, ) return out ``` @@ -291,7 +207,7 @@ def moe(...) -> TransformerDecoder: # - Final projection into the token space' token_embeddings = nn.Embedding(vocab_size, embed_dim) self_attn = MultiHeadAttention() - moe_layer = ExpertsChoiceMoeLayer() # or TokenChoiceMoeLayer() + moe_layer = MoeLayer(router="token_choice") # or MoeLayer(router="expert_choice") norm = RMSNorm(dim=embed_dim) layer = TransformerSelfAttentionLayer(attn=self_attn, mlp=moe_layer, sa_norm=norm, mlp_norm=norm) output_proj = nn.Linear(embed_dim, vocab_size) @@ -315,12 +231,9 @@ torchtune/ moe_layers.py TokenChoiceTopKRouter() ExpertChoiceTopKRouter() - TokenChoiceMoeLayer() - ExpertChoiceMoeLayer() + MoeLayer() experts.py Experts() - TokenChoiceExperts - ExpertChoiceExperts models/ moe/ _component_builders.py From 5d855bdac5a6177c44976e2e97a4fd7f35dfb8cf Mon Sep 17 00:00:00 2001 From: JessicaZhong Date: Mon, 28 Oct 2024 17:34:50 -0700 Subject: [PATCH 12/13] unify Experts class and MoeLayer --- RFC.md | 48 +++++++++++++++++++++++++++++------------------- 1 file changed, 29 insertions(+), 19 deletions(-) diff --git a/RFC.md b/RFC.md index 75b2a2d30a..725918e81b 100644 --- a/RFC.md +++ b/RFC.md @@ -19,18 +19,22 @@ class Experts(nn.Module): self.up_proj = None self.act_fn = nonlinearity - def forward(self, x, use_token_choice=False, num_local_tokens_per_expert=None): + def forward(self, x, num_local_tokens_per_expert=None): ''' inputs: - x: input tokens, shape [bs*slen*experts_per_token, hidden_dim] for TC or [num_experts*tokens_per_expert, hidden_dim] for EC - use_token_choice: if we use token choice forward - num_local_tokens_per_expert: number of tokens for each expert, only used for token choice forward + x: input tokens + shape [bs*slen*experts_per_token, hidden_dim] for TC forward + shape [num_experts*tokens_per_expert, hidden_dim] for EC forward + num_local_tokens_per_expert: number of tokens for each expert, only used for TC forward outputs: - out: output tokens, shape [bs*slen*experts_per_token, hidden_dim] + out: output tokens + shape [bs*slen*experts_per_token, hidden_dim] for TC forward + shape [num_experts*tokens_per_expert, hidden_dim] for EC forward ''' - if use_token_choice: + # TC forward + if num_local_tokens_per_expert is not None: # TODO: use cutlass groupGEMM instead of torch.matmul() to optimize performance - assert num_local_tokens_per_expert is not None, "num_local_tokens_per_expert is needed for token choice expert forward" + # x shape [bs*slen*experts_per_token, hidden_dim] # x_expert_splits shape [num_experts, tokens_per_expert(varying), hidden_dim] x_expert_splits = torch.split(x, split_size_or_sections=num_local_tokens_per_expert, dim=0) out_expert_splits = [] @@ -50,6 +54,7 @@ class Experts(nn.Module): out_expert_splits.append(h) # shape [num_experts * tokens_per_expert(varying), hidden_dim] = [bs*slen*experts_per_token, hidden_dim] out = torch.cat(out_expert_splits, dim=0) + # EC forward else: # x shape [num_experts, tokens_per_expert, hidden_dim] x = x.view(num_experts, -1, dim_in) @@ -83,10 +88,14 @@ class TokenChoiceTopKRouter(nn.Module): def forward(self, x, use_sigmoid=False): ''' input: - x shape [bs*slen, hidden_dim] + x: input tokens + shape [bs*slen, hidden_dim] outputs: - routed_input shape [bs*slen*experts_per_token, hidden_dim] - num_local_tokens_per_expert shape [num_experts,] + routed_input: tokens gather by selected experts + shape [bs*slen*experts_per_token, hidden_dim] + token_indices: token indices sorted by selected experts indices + num_local_tokens_per_expert: number of tokens assigned to each expert + shape [num_experts,] ''' # scores shape [bs*slen, num_experts] scores = self.gate(x) @@ -109,13 +118,13 @@ class TokenChoiceTopKRouter(nn.Module): # top_scores shape [bs*slen*experts_per_token,] top_scores = top_scores.view(-1)[token_indices_experts_sorted] - # token_indices_experts_sorted_expanded shape [bs*slen*experts_per_token, hidden_dim] - token_indices_experts_sorted_expanded = token_indices_experts_sorted.reshape(-1, 1).expand(-1, hidden_dim) + # token_indices shape [bs*slen*experts_per_token, hidden_dim] + token_indices = token_indices_experts_sorted.reshape(-1, 1).expand(-1, hidden_dim) # routed_input shape [bs*slen*experts_per_token, hidden_dim] - routed_input = torch.gather(x, dim=0, index=token_indices_experts_sorted_expanded) + routed_input = torch.gather(x, dim=0, index=token_indices) routed_input = routed_input * top_scores - return routed_input, token_indices_experts_sorted_expanded, num_local_tokens_per_expert + return routed_input, token_indices, num_local_tokens_per_expert ``` However, token choice routing has several pitfalls according to the expert choice [paper](https://arxiv.org/pdf/2002.05202). @@ -137,7 +146,9 @@ class ExpertChoiceTopKRouter(nn.Module): input: x: shape [bs*slen, hidden_dim] outputs: - routed_input: shape [num_experts*tokens_per_expert, hidden_dim] + routed_input: selected tokens + shape [num_experts*tokens_per_expert, hidden_dim] + token_indices: selected token indices num_local_tokens_per_expert: None ''' # scores shape [num_experts, bs*slen] @@ -150,11 +161,11 @@ class ExpertChoiceTopKRouter(nn.Module): top_scores, selected_token_indices = torch.topk(scores, k=self.tokens_per_expert, dim=1) # apply the token preprocess function and then run experts forward - selected_token_indices_expanded = selected_token_indices.reshape(-1, 1).expand(-1, D) + token_indices = selected_token_indices.reshape(-1, 1).expand(-1, D) # routed input shape [num_experts*tokens_per_expert, hidden_dim] - routed_input = torch.gather(x, dim=0, index=selected_token_indices_expanded) + routed_input = torch.gather(x, dim=0, index=token_indices) routed_input = routed_input * top_scores.reshape(-1, 1) - return routed_input, selected_token_indices_expanded, None, + return routed_input, token_indices, None, ``` ## Moe Layer @@ -175,7 +186,6 @@ class MoeLayer(nn.Module): def forward(self, x, infernece=False): routed_input, token_indices, num_local_tokens_per_expert = self.router(x) - # routed output shape [num_experts*tokens_per_expert, hidden_dim] for EC, [bs*slen*experts_per_token, hidden_dim] for TC routed_output = self.experts(routed_input, num_local_tokens_per_expert=num_local_tokens_per_expert) From ed424d8bf5a7ccce3052f0fcb7a5cac6661f2fa2 Mon Sep 17 00:00:00 2001 From: JessicaZhong Date: Tue, 29 Oct 2024 10:09:29 -0700 Subject: [PATCH 13/13] unify Experts class and MoeLayer --- RFC.md | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/RFC.md b/RFC.md index 725918e81b..2968ecc944 100644 --- a/RFC.md +++ b/RFC.md @@ -76,9 +76,9 @@ def moe_expert(hidden_dim, model_dim, swiglu=True, nonlinearity=None): ## Router Router is a gating network that calculates router scores and learns token-to-expert affinity. There are two types of routing: token choice routing and expert choice routing. -Mixtral uses *token choice* topK routing, which means each token will select its topK experts. The router is implemented through a learnable gate function, whose outputs will go through softmax and topK. The router defines how tokens select experts / experts select tokens based on router scores. +Mixtral uses *token choice* topK routing, which means each token will select its topK experts. The router is implemented through a learnable gate function, whose outputs will go through softmax and topK. The router then defines how tokens select experts based on router scores. -**Here's the proposed Token Choice Routing and TokenChoiceMoeLayer design in torchtune:** +**Here's the proposed Token Choice Routing design in torchtune:** ```python class TokenChoiceTopKRouter(nn.Module): def __init__(self, hidden_dim, num_experts, experts_per_token): @@ -132,16 +132,16 @@ However, token choice routing has several pitfalls according to the expert choic 2. Experts under specialization. Ideally the gating network will learn token-to-expert affinity such that similar or relevant tokens are routed to the same expert. However, a sub-optimal strategy can produce redundant experts and/or experts that are not sufficiently specialized. 3. Same compute for each token. Token choice will allocate a fixed number of experts to each token regardless of the importance of different tokens. Ideally an MOE model should flexibly allocate compute resources based on the complexity of the input. -Compared to **token choice**, **expert choice** topK routing lets experts select its top-k tokens. The ExpertChoiceMoeLayer class routes input tokens to different experts based on the routing algorithm, processes them through the experts and the shared expert, and then combines the output. +Compared to **token choice**, **expert choice** topK routing lets experts select its top-k tokens. The ExpertChoiceTopKRouter class routes input tokens to different experts based on the router scores. -**Here's the proposed Expert Choice Routing and ExpertChoiceMoeLayer design in torchtune:** +**Here's the proposed Expert Choice Routing design in torchtune:** ```python class ExpertChoiceTopKRouter(nn.Module): - def __init__(self, hidden_dim, num_experts): - self.gate = nn.Linear(hidden_dim, num_experts) - self.tokens_per_expert = tokens_per_expert + def __init__(self, hidden_dim, num_experts): + self.gate = nn.Linear(hidden_dim, num_experts) + self.tokens_per_expert = tokens_per_expert - def forward(self, x, use_sigmoid=False): + def forward(self, x, use_sigmoid=False): ''' input: x: shape [bs*slen, hidden_dim]