diff --git a/vllm/hpu/ops.py b/vllm/hpu/ops.py index b2705429906c4..3d76c36f2648b 100644 --- a/vllm/hpu/ops.py +++ b/vllm/hpu/ops.py @@ -86,36 +86,6 @@ def silu_and_mul(x: torch.Tensor) -> torch.Tensor: return F.silu(x[..., :d]) * x[..., d:] -def static_fused_moe(hidden_states, w1, w2, score, topk): - B, D = hidden_states.shape - num_experts = w1.shape[0] - routing_weights = F.softmax(score, dim=1, dtype=torch.float32) - routing_weights, selected_experts = torch.topk(routing_weights, - topk, - dim=-1) - routing_weights /= routing_weights.sum(dim=-1, keepdim=True) - routing_weights = routing_weights.to(hidden_states.dtype) - final_hidden_states = torch.zeros((1, B, D), - dtype=hidden_states.dtype, - device=hidden_states.device) - padded_weights = torch.zeros((B, num_experts), - dtype=hidden_states.dtype, - device=hidden_states.device) - padded_weights.scatter_(-1, selected_experts, routing_weights) - padded_weights = padded_weights.reshape(-1, B, w1.shape[0]) - padded_weights = padded_weights.permute(2, 0, 1).unsqueeze(-1) - - htorch.core.mark_step() - - for expert_idx in range(num_experts): - w_output = torch.matmul(hidden_states, w1[expert_idx].transpose(0, 1)) - w_output = silu_and_mul(w_output) - w_output = torch.matmul(w_output, w2[expert_idx].transpose(0, 1)) - final_hidden_states += w_output * padded_weights[expert_idx] - - return final_hidden_states.view(-1, D) - - #TODO: remove after fusedsdpa fix for query_head != kv_head def repeat_kv(kv: torch.Tensor, n_rep: int) -> torch.Tensor: """ @@ -252,3 +222,61 @@ def dispatch_bgmv_embedding( wb = wb.reshape(wb.shape[0] * wb.shape[1], wb.shape[2]) out = x @ wb y += out * scale + + +class MoeMatmul(torch.nn.Module): + + def __init__(self): + super().__init__() + + def set_weight(self, w): + self.weight = w + + def calc(self, state, expert_id, w): + self.weight = w[expert_id].transpose(0, 1) + return self.forward(state) + + def forward(self, state): + return torch.matmul(state, self.weight) + + +class StaticFusedMOE(torch.nn.Module): + + def __init__(self, num_total_experts): + super().__init__() + self.w13_list = torch.nn.ModuleList( + [MoeMatmul() for _ in range(num_total_experts)]) + self.w2_list = torch.nn.ModuleList( + [MoeMatmul() for _ in range(num_total_experts)]) + self.num_total_experts = num_total_experts + + def forward(self, hidden_states, w1, w2, score, topk): + B, D = hidden_states.shape + routing_weights = F.softmax(score, dim=1, dtype=torch.float32) + routing_weights, selected_experts = torch.topk(routing_weights, + topk, + dim=-1) + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + routing_weights = routing_weights.to(hidden_states.dtype) + final_hidden_states = torch.zeros((1, B, D), + dtype=hidden_states.dtype, + device=hidden_states.device) + padded_weights = torch.zeros((B, self.num_total_experts), + dtype=hidden_states.dtype, + device=hidden_states.device) + padded_weights.scatter_(-1, selected_experts, routing_weights) + padded_weights = padded_weights.reshape(-1, B, self.num_total_experts) + padded_weights = padded_weights.permute(2, 0, 1).unsqueeze(-1) + htorch.core.mark_step() + + for expert_idx in range(self.num_total_experts): + padded_weight = padded_weights[expert_idx] + current_state_static = hidden_states.reshape(-1, D) + w_output = self.w13_list[expert_idx].calc(current_state_static, + expert_idx, w1) + w_output = silu_and_mul(w_output) + w_output = self.w2_list[expert_idx].calc(w_output, expert_idx, w2) + current_hidden_states_static = w_output * padded_weight + final_hidden_states += current_hidden_states_static + + return final_hidden_states.view(-1, D) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index b49bf40d4746e..cf0d5f98f1b01 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -13,9 +13,6 @@ from vllm.model_executor.utils import set_weight_attrs from vllm.utils import is_hpu -if is_hpu(): - from vllm.hpu.ops import static_fused_moe - logger = init_logger(__name__) @@ -78,7 +75,8 @@ def apply( ) -> torch.Tensor: return self.forward(x, layer.w13_weight, layer.w2_weight, router_logits, top_k, renormalize, - use_grouped_topk, num_expert_group, topk_group) + use_grouped_topk, num_expert_group, topk_group, + layer) def forward_cuda( self, @@ -91,6 +89,7 @@ def forward_cuda( use_grouped_topk: bool, num_expert_group: Optional[int], topk_group: Optional[int], + layer: Optional[torch.nn.Module], ) -> torch.Tensor: from vllm.model_executor.layers.fused_moe.fused_moe import fused_moe return fused_moe(x, @@ -104,15 +103,25 @@ def forward_cuda( num_expert_group=num_expert_group, topk_group=topk_group) - def forward_hpu(self, x: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, - router_logits: torch.Tensor, top_k: int, renormalize: bool, - use_grouped_topk: bool, num_expert_group: Optional[int], - topk_group: Optional[int]): + def forward_hpu( + self, + x: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool, + num_expert_group: Optional[int], + topk_group: Optional[int], + layer: Optional[torch.nn.Module], + ): assert not use_grouped_topk, 'use_grouped_topk must be False on HPU' assert num_expert_group is None, ('num_expert_group is ' 'not supported on HPU') assert topk_group is None, 'topk_group is not supported on HPU' - return static_fused_moe(x, w1, w2, router_logits, top_k) + if layer is not None: + return layer.hpu_static_fused_moe(x, w1, w2, router_logits, top_k) def forward_cpu(self, *args, **kwargs): raise NotImplementedError( @@ -129,6 +138,7 @@ def forward_tpu( use_grouped_topk: bool, num_expert_group: Optional[int], topk_group: Optional[int], + layer: Optional[torch.nn.Module], ) -> torch.Tensor: from vllm.model_executor.layers.fused_moe.moe_pallas import fused_moe assert not use_grouped_topk @@ -140,7 +150,7 @@ def forward_tpu( class FusedMoE(torch.nn.Module): """FusedMoE layer for MoE models. - This layer contains both MergedColumnParallel weights (gate_up_proj / + This layer contains both MergedColumnParallel weights (gate_up_proj / w13) and RowParallelLinear weights (down_proj/ w2). Note: Mixtral uses w1, w2, and w3 for gate, up, and down_proj. We @@ -191,6 +201,9 @@ def __init__( assert num_expert_group is not None and topk_group is not None self.num_expert_group = num_expert_group self.topk_group = topk_group + if is_hpu(): + from vllm.hpu.ops import StaticFusedMOE + self.hpu_static_fused_moe = StaticFusedMOE(self.num_experts) if quant_config is None: self.quant_method: Optional[QuantizeMethodBase] = ( @@ -245,13 +258,22 @@ def weight_loader(self, param: torch.nn.Parameter, if shard_id == 0: param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :] + if is_hpu(): + self.hpu_static_fused_moe.w13_list[expert_id].set_weight( + param_data[expert_id]) # w3, up_proj case: Load into second shard of w13. elif shard_id == 2: param_data[expert_id, shard_size:2 * shard_size, :] = loaded_weight[shard, :] + if is_hpu(): + self.hpu_static_fused_moe.w13_list[expert_id].set_weight( + param_data[expert_id]) # w2, down_proj case: Load into only shard of w2. elif shard_id == 1: param_data[expert_id, :, :] = loaded_weight[:, shard] + if is_hpu(): + self.hpu_static_fused_moe.w2_list[expert_id].set_weight( + param_data[expert_id]) else: raise ValueError( f"Shard id must be in [0,1,2] but got {shard_id}") diff --git a/vllm/model_executor/layers/quantization/inc.py b/vllm/model_executor/layers/quantization/inc.py index f6718ec2ac9e7..ec0141b61f58f 100644 --- a/vllm/model_executor/layers/quantization/inc.py +++ b/vllm/model_executor/layers/quantization/inc.py @@ -5,6 +5,8 @@ from torch.nn.parameter import Parameter from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe.layer import ( + FusedMoE, UnquantizedFusedMoEMethod) from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) @@ -52,6 +54,8 @@ def get_quant_method(self, layer: torch.nn.Module, prefix: str) -> Optional["INCLinearMethod"]: if isinstance(layer, LinearBase): return INCLinearMethod(self) + elif isinstance(layer, FusedMoE): + return UnquantizedFusedMoEMethod() return None def get_scaled_act_names(self) -> List[str]: @@ -78,7 +82,7 @@ class INCLinearMethod(LinearMethodBase): 1. Only support per-tensor quantization due to torch._scaled_mm support. 2. Only support float8_e4m3fn data type due to the limitation of torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856) - + Args: quant_config: The quantization config. """ diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index f7e0f56c1a46e..a8b0a7b07ed8e 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -24,7 +24,7 @@ def get_model_architecture( # Special handling for quantized Mixtral. # FIXME(woosuk): This is a temporary hack. if (model_config.quantization is not None - and model_config.quantization != "fp8" + and model_config.quantization not in ["fp8", "inc"] and "MixtralForCausalLM" in architectures): architectures = ["QuantMixtralForCausalLM"]