diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index 6389a876c9e76..6344c3d39eb7e 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -27,6 +27,9 @@ import torch.nn as nn from vllm.model_executor.custom_op import CustomOp +from vllm.platforms import current_platform + +is_hpu = current_platform.is_hpu() def _rotate_neox(x: torch.Tensor) -> torch.Tensor: @@ -664,9 +667,12 @@ def __init__( is_neox_style, dtype) def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor: - pos_freqs = self.base**(torch.arange( - 0, self.rotary_dim, 2, dtype=torch.float, device="cuda") / - self.rotary_dim) + pos_freqs = self.base**( + torch.arange(0, + self.rotary_dim, + 2, + dtype=torch.float, + device="hpu" if is_hpu else "cuda") / self.rotary_dim) inv_freq_extrapolation = 1.0 / pos_freqs inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs) @@ -684,7 +690,7 @@ def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor: def _compute_cos_sin_cache(self) -> torch.Tensor: inv_freq = self._compute_inv_freq(self.scaling_factor) t = torch.arange(self.max_position_embeddings * self.scaling_factor, - device="cuda", + device="hpu" if is_hpu else "cuda", dtype=torch.float32) freqs = torch.einsum("i,j -> ij", t, inv_freq) cos = (freqs.cos() * self.mscale) diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index d2c4ca0bf85e9..029cdc52d37f3 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -47,6 +47,7 @@ ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from .interfaces import SupportsPP @@ -54,6 +55,8 @@ make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) +is_hpu = current_platform.is_hpu() + class DeepseekV2MLP(nn.Module): @@ -111,18 +114,30 @@ def __init__( if config.hidden_act != "silu": raise ValueError(f"Unsupported activation: {config.hidden_act}. " "Only silu is supported for now.") - - self.experts = FusedMoE(num_experts=config.n_routed_experts, - top_k=config.num_experts_per_tok, - hidden_size=config.hidden_size, - intermediate_size=config.moe_intermediate_size, - reduce_results=False, - renormalize=config.norm_topk_prob, - quant_config=quant_config, - use_grouped_topk=True, - num_expert_group=config.n_group, - topk_group=config.topk_group, - prefix=f"{prefix}.experts") + if is_hpu: + self.experts = FusedMoE( + num_experts=config.n_routed_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=config.norm_topk_prob, + quant_config=quant_config, + use_grouped_topk=False, + prefix=f"{prefix}.experts") + else: + self.experts = FusedMoE( + num_experts=config.n_routed_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=config.norm_topk_prob, + quant_config=quant_config, + use_grouped_topk=True, + num_expert_group=config.n_group, + topk_group=config.topk_group, + prefix=f"{prefix}.experts") self.gate = ReplicatedLinear(config.hidden_size, config.n_routed_experts, @@ -277,9 +292,22 @@ def forward( kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, ) -> torch.Tensor: + if is_hpu: + # need reshape from tensor(x0, y0) to tensor(x1) for hpu + _batch_size = positions.shape[0] + positions = positions.reshape(positions.shape[0] * + positions.shape[1]) + hidden_states = hidden_states.reshape( + hidden_states.shape[0] * hidden_states.shape[1], + hidden_states.shape[2]) if self.q_lora_rank is not None: - q = self.q_a_proj(hidden_states)[0] - q = self.q_a_layernorm(q) + if is_hpu: + # w/a of SW-208144 + q = self.q_a_proj(hidden_states)[0].unsqueeze(0) + q = self.q_a_layernorm(q).squeeze(0) + else: + q = self.q_a_proj(hidden_states)[0] + q = self.q_a_layernorm(q) q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim) else: @@ -291,7 +319,11 @@ def forward( kv_a, _ = latent_cache.split( [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) latent_cache = latent_cache.unsqueeze(1) - kv_a = self.kv_a_layernorm(kv_a.contiguous()) + if is_hpu: + kv_a = self.kv_a_layernorm(kv_a.contiguous().unsqueeze(0)).squeeze( + 0) # w/a of SW-208144 + else: + kv_a = self.kv_a_layernorm(kv_a.contiguous()) kv = self.kv_b_proj(kv_a)[0] kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim) @@ -311,11 +343,25 @@ def forward( v = torch.nn.functional.pad(v, [0, 256 - self.v_head_dim], value=0).view(-1, self.num_local_heads * 256) + if is_hpu: + # need restore from tensor(x0, y0) to tensor(x1, y1, z1) for hpu + q = q.reshape(_batch_size, q.shape[0] // _batch_size, q.shape[1]) + k = k.reshape(_batch_size, k.shape[0] // _batch_size, k.shape[1]) + v = v.reshape(_batch_size, v.shape[0] // _batch_size, v.shape[1]) attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + if is_hpu: + # need restore from tensor(x0, y0, z0) to tensor(x1, y1) for hpu + attn_output = attn_output.reshape( + attn_output.shape[0] * attn_output.shape[1], + attn_output.shape[2]) attn_output = attn_output.view( -1, self.num_local_heads, 256)[..., :self.v_head_dim].reshape( -1, self.num_local_heads * self.v_head_dim) output, _ = self.o_proj(attn_output) + if is_hpu: + output = output.reshape(_batch_size, + output.shape[0] // _batch_size, + output.shape[1]) return output @@ -383,6 +429,8 @@ def forward( attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], ) -> torch.Tensor: + if is_hpu: + _batch_size = positions.shape[0] # Self Attention if residual is None: residual = hidden_states @@ -400,7 +448,16 @@ def forward( # Fully Connected hidden_states, residual = self.post_attention_layernorm( hidden_states, residual) + if is_hpu: + # need reshape from tensor(x0, y0) to tensor(x1) for hpu + hidden_states = hidden_states.reshape( + hidden_states.shape[0] * hidden_states.shape[1], + hidden_states.shape[2]) hidden_states = self.mlp(hidden_states) + if is_hpu: + hidden_states = hidden_states.reshape( + _batch_size, hidden_states.shape[0] // _batch_size, + hidden_states.shape[1]) return hidden_states, residual