Skip to content

Commit

Permalink
Support Mixtral quantization using INC (HabanaAI#267)
Browse files Browse the repository at this point in the history
  • Loading branch information
dudilester authored and zhouyu5 committed Sep 20, 2024
1 parent 029658d commit a5904b6
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 42 deletions.
88 changes: 58 additions & 30 deletions vllm/hpu/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,36 +86,6 @@ def silu_and_mul(x: torch.Tensor) -> torch.Tensor:
return F.silu(x[..., :d]) * x[..., d:]


def static_fused_moe(hidden_states, w1, w2, score, topk):
B, D = hidden_states.shape
num_experts = w1.shape[0]
routing_weights = F.softmax(score, dim=1, dtype=torch.float32)
routing_weights, selected_experts = torch.topk(routing_weights,
topk,
dim=-1)
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
routing_weights = routing_weights.to(hidden_states.dtype)
final_hidden_states = torch.zeros((1, B, D),
dtype=hidden_states.dtype,
device=hidden_states.device)
padded_weights = torch.zeros((B, num_experts),
dtype=hidden_states.dtype,
device=hidden_states.device)
padded_weights.scatter_(-1, selected_experts, routing_weights)
padded_weights = padded_weights.reshape(-1, B, w1.shape[0])
padded_weights = padded_weights.permute(2, 0, 1).unsqueeze(-1)

htorch.core.mark_step()

for expert_idx in range(num_experts):
w_output = torch.matmul(hidden_states, w1[expert_idx].transpose(0, 1))
w_output = silu_and_mul(w_output)
w_output = torch.matmul(w_output, w2[expert_idx].transpose(0, 1))
final_hidden_states += w_output * padded_weights[expert_idx]

return final_hidden_states.view(-1, D)


#TODO: remove after fusedsdpa fix for query_head != kv_head
def repeat_kv(kv: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
Expand Down Expand Up @@ -252,3 +222,61 @@ def dispatch_bgmv_embedding(
wb = wb.reshape(wb.shape[0] * wb.shape[1], wb.shape[2])
out = x @ wb
y += out * scale


class MoeMatmul(torch.nn.Module):

def __init__(self):
super().__init__()

def set_weight(self, w):
self.weight = w

def calc(self, state, expert_id, w):
self.weight = w[expert_id].transpose(0, 1)
return self.forward(state)

def forward(self, state):
return torch.matmul(state, self.weight)


class StaticFusedMOE(torch.nn.Module):

def __init__(self, num_total_experts):
super().__init__()
self.w13_list = torch.nn.ModuleList(
[MoeMatmul() for _ in range(num_total_experts)])
self.w2_list = torch.nn.ModuleList(
[MoeMatmul() for _ in range(num_total_experts)])
self.num_total_experts = num_total_experts

def forward(self, hidden_states, w1, w2, score, topk):
B, D = hidden_states.shape
routing_weights = F.softmax(score, dim=1, dtype=torch.float32)
routing_weights, selected_experts = torch.topk(routing_weights,
topk,
dim=-1)
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
routing_weights = routing_weights.to(hidden_states.dtype)
final_hidden_states = torch.zeros((1, B, D),
dtype=hidden_states.dtype,
device=hidden_states.device)
padded_weights = torch.zeros((B, self.num_total_experts),
dtype=hidden_states.dtype,
device=hidden_states.device)
padded_weights.scatter_(-1, selected_experts, routing_weights)
padded_weights = padded_weights.reshape(-1, B, self.num_total_experts)
padded_weights = padded_weights.permute(2, 0, 1).unsqueeze(-1)
htorch.core.mark_step()

for expert_idx in range(self.num_total_experts):
padded_weight = padded_weights[expert_idx]
current_state_static = hidden_states.reshape(-1, D)
w_output = self.w13_list[expert_idx].calc(current_state_static,
expert_idx, w1)
w_output = silu_and_mul(w_output)
w_output = self.w2_list[expert_idx].calc(w_output, expert_idx, w2)
current_hidden_states_static = w_output * padded_weight
final_hidden_states += current_hidden_states_static

return final_hidden_states.view(-1, D)
42 changes: 32 additions & 10 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,6 @@
from vllm.model_executor.utils import set_weight_attrs
from vllm.utils import is_hpu

if is_hpu():
from vllm.hpu.ops import static_fused_moe

logger = init_logger(__name__)


Expand Down Expand Up @@ -78,7 +75,8 @@ def apply(
) -> torch.Tensor:
return self.forward(x, layer.w13_weight, layer.w2_weight,
router_logits, top_k, renormalize,
use_grouped_topk, num_expert_group, topk_group)
use_grouped_topk, num_expert_group, topk_group,
layer)

def forward_cuda(
self,
Expand All @@ -91,6 +89,7 @@ def forward_cuda(
use_grouped_topk: bool,
num_expert_group: Optional[int],
topk_group: Optional[int],
layer: Optional[torch.nn.Module],
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe.fused_moe import fused_moe
return fused_moe(x,
Expand All @@ -104,15 +103,25 @@ def forward_cuda(
num_expert_group=num_expert_group,
topk_group=topk_group)

def forward_hpu(self, x: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor,
router_logits: torch.Tensor, top_k: int, renormalize: bool,
use_grouped_topk: bool, num_expert_group: Optional[int],
topk_group: Optional[int]):
def forward_hpu(
self,
x: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool,
num_expert_group: Optional[int],
topk_group: Optional[int],
layer: Optional[torch.nn.Module],
):
assert not use_grouped_topk, 'use_grouped_topk must be False on HPU'
assert num_expert_group is None, ('num_expert_group is '
'not supported on HPU')
assert topk_group is None, 'topk_group is not supported on HPU'
return static_fused_moe(x, w1, w2, router_logits, top_k)
if layer is not None:
return layer.hpu_static_fused_moe(x, w1, w2, router_logits, top_k)

def forward_cpu(self, *args, **kwargs):
raise NotImplementedError(
Expand All @@ -129,6 +138,7 @@ def forward_tpu(
use_grouped_topk: bool,
num_expert_group: Optional[int],
topk_group: Optional[int],
layer: Optional[torch.nn.Module],
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe.moe_pallas import fused_moe
assert not use_grouped_topk
Expand All @@ -140,7 +150,7 @@ def forward_tpu(
class FusedMoE(torch.nn.Module):
"""FusedMoE layer for MoE models.
This layer contains both MergedColumnParallel weights (gate_up_proj /
This layer contains both MergedColumnParallel weights (gate_up_proj /
w13) and RowParallelLinear weights (down_proj/ w2).
Note: Mixtral uses w1, w2, and w3 for gate, up, and down_proj. We
Expand Down Expand Up @@ -191,6 +201,9 @@ def __init__(
assert num_expert_group is not None and topk_group is not None
self.num_expert_group = num_expert_group
self.topk_group = topk_group
if is_hpu():
from vllm.hpu.ops import StaticFusedMOE
self.hpu_static_fused_moe = StaticFusedMOE(self.num_experts)

if quant_config is None:
self.quant_method: Optional[QuantizeMethodBase] = (
Expand Down Expand Up @@ -245,13 +258,22 @@ def weight_loader(self, param: torch.nn.Parameter,
if shard_id == 0:
param_data[expert_id,
0:shard_size, :] = loaded_weight[shard, :]
if is_hpu():
self.hpu_static_fused_moe.w13_list[expert_id].set_weight(
param_data[expert_id])
# w3, up_proj case: Load into second shard of w13.
elif shard_id == 2:
param_data[expert_id, shard_size:2 *
shard_size, :] = loaded_weight[shard, :]
if is_hpu():
self.hpu_static_fused_moe.w13_list[expert_id].set_weight(
param_data[expert_id])
# w2, down_proj case: Load into only shard of w2.
elif shard_id == 1:
param_data[expert_id, :, :] = loaded_weight[:, shard]
if is_hpu():
self.hpu_static_fused_moe.w2_list[expert_id].set_weight(
param_data[expert_id])
else:
raise ValueError(
f"Shard id must be in [0,1,2] but got {shard_id}")
Expand Down
6 changes: 5 additions & 1 deletion vllm/model_executor/layers/quantization/inc.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from torch.nn.parameter import Parameter

from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, UnquantizedFusedMoEMethod)
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
Expand Down Expand Up @@ -52,6 +54,8 @@ def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["INCLinearMethod"]:
if isinstance(layer, LinearBase):
return INCLinearMethod(self)
elif isinstance(layer, FusedMoE):
return UnquantizedFusedMoEMethod()
return None

def get_scaled_act_names(self) -> List[str]:
Expand All @@ -78,7 +82,7 @@ class INCLinearMethod(LinearMethodBase):
1. Only support per-tensor quantization due to torch._scaled_mm support.
2. Only support float8_e4m3fn data type due to the limitation of
torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856)
Args:
quant_config: The quantization config.
"""
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/model_loader/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def get_model_architecture(
# Special handling for quantized Mixtral.
# FIXME(woosuk): This is a temporary hack.
if (model_config.quantization is not None
and model_config.quantization != "fp8"
and model_config.quantization not in ["fp8", "inc"]
and "MixtralForCausalLM" in architectures):
architectures = ["QuantMixtralForCausalLM"]

Expand Down

0 comments on commit a5904b6

Please sign in to comment.