Skip to content

Commit

Permalink
Enable DeepseekV2 Lite/Chat models (#516)
Browse files Browse the repository at this point in the history
Enable DeepseekV2 Lite/Chat models
  • Loading branch information
hlin99 authored Dec 4, 2024
1 parent 8c76728 commit f6865f4
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 19 deletions.
14 changes: 10 additions & 4 deletions vllm/model_executor/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@
import torch.nn as nn

from vllm.model_executor.custom_op import CustomOp
from vllm.platforms import current_platform

is_hpu = current_platform.is_hpu()


def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -664,9 +667,12 @@ def __init__(
is_neox_style, dtype)

def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
pos_freqs = self.base**(torch.arange(
0, self.rotary_dim, 2, dtype=torch.float, device="cuda") /
self.rotary_dim)
pos_freqs = self.base**(
torch.arange(0,
self.rotary_dim,
2,
dtype=torch.float,
device="hpu" if is_hpu else "cuda") / self.rotary_dim)
inv_freq_extrapolation = 1.0 / pos_freqs
inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs)

Expand All @@ -684,7 +690,7 @@ def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
def _compute_cos_sin_cache(self) -> torch.Tensor:
inv_freq = self._compute_inv_freq(self.scaling_factor)
t = torch.arange(self.max_position_embeddings * self.scaling_factor,
device="cuda",
device="hpu" if is_hpu else "cuda",
dtype=torch.float32)
freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = (freqs.cos() * self.mscale)
Expand Down
87 changes: 72 additions & 15 deletions vllm/model_executor/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,16 @@
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors

from .interfaces import SupportsPP
from .utils import (PPMissingLayer, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)

is_hpu = current_platform.is_hpu()


class DeepseekV2MLP(nn.Module):

Expand Down Expand Up @@ -111,18 +114,30 @@ def __init__(
if config.hidden_act != "silu":
raise ValueError(f"Unsupported activation: {config.hidden_act}. "
"Only silu is supported for now.")

self.experts = FusedMoE(num_experts=config.n_routed_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=config.norm_topk_prob,
quant_config=quant_config,
use_grouped_topk=True,
num_expert_group=config.n_group,
topk_group=config.topk_group,
prefix=f"{prefix}.experts")
if is_hpu:
self.experts = FusedMoE(
num_experts=config.n_routed_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=config.norm_topk_prob,
quant_config=quant_config,
use_grouped_topk=False,
prefix=f"{prefix}.experts")
else:
self.experts = FusedMoE(
num_experts=config.n_routed_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=config.norm_topk_prob,
quant_config=quant_config,
use_grouped_topk=True,
num_expert_group=config.n_group,
topk_group=config.topk_group,
prefix=f"{prefix}.experts")

self.gate = ReplicatedLinear(config.hidden_size,
config.n_routed_experts,
Expand Down Expand Up @@ -277,9 +292,22 @@ def forward(
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
if is_hpu:
# need reshape from tensor(x0, y0) to tensor(x1) for hpu
_batch_size = positions.shape[0]
positions = positions.reshape(positions.shape[0] *
positions.shape[1])
hidden_states = hidden_states.reshape(
hidden_states.shape[0] * hidden_states.shape[1],
hidden_states.shape[2])
if self.q_lora_rank is not None:
q = self.q_a_proj(hidden_states)[0]
q = self.q_a_layernorm(q)
if is_hpu:
# w/a of SW-208144
q = self.q_a_proj(hidden_states)[0].unsqueeze(0)
q = self.q_a_layernorm(q).squeeze(0)
else:
q = self.q_a_proj(hidden_states)[0]
q = self.q_a_layernorm(q)
q = self.q_b_proj(q)[0].view(-1, self.num_local_heads,
self.qk_head_dim)
else:
Expand All @@ -291,7 +319,11 @@ def forward(
kv_a, _ = latent_cache.split(
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
latent_cache = latent_cache.unsqueeze(1)
kv_a = self.kv_a_layernorm(kv_a.contiguous())
if is_hpu:
kv_a = self.kv_a_layernorm(kv_a.contiguous().unsqueeze(0)).squeeze(
0) # w/a of SW-208144
else:
kv_a = self.kv_a_layernorm(kv_a.contiguous())
kv = self.kv_b_proj(kv_a)[0]
kv = kv.view(-1, self.num_local_heads,
self.qk_nope_head_dim + self.v_head_dim)
Expand All @@ -311,11 +343,25 @@ def forward(
v = torch.nn.functional.pad(v, [0, 256 - self.v_head_dim],
value=0).view(-1,
self.num_local_heads * 256)
if is_hpu:
# need restore from tensor(x0, y0) to tensor(x1, y1, z1) for hpu
q = q.reshape(_batch_size, q.shape[0] // _batch_size, q.shape[1])
k = k.reshape(_batch_size, k.shape[0] // _batch_size, k.shape[1])
v = v.reshape(_batch_size, v.shape[0] // _batch_size, v.shape[1])
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
if is_hpu:
# need restore from tensor(x0, y0, z0) to tensor(x1, y1) for hpu
attn_output = attn_output.reshape(
attn_output.shape[0] * attn_output.shape[1],
attn_output.shape[2])
attn_output = attn_output.view(
-1, self.num_local_heads, 256)[..., :self.v_head_dim].reshape(
-1, self.num_local_heads * self.v_head_dim)
output, _ = self.o_proj(attn_output)
if is_hpu:
output = output.reshape(_batch_size,
output.shape[0] // _batch_size,
output.shape[1])
return output


Expand Down Expand Up @@ -383,6 +429,8 @@ def forward(
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
) -> torch.Tensor:
if is_hpu:
_batch_size = positions.shape[0]
# Self Attention
if residual is None:
residual = hidden_states
Expand All @@ -400,7 +448,16 @@ def forward(
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
if is_hpu:
# need reshape from tensor(x0, y0) to tensor(x1) for hpu
hidden_states = hidden_states.reshape(
hidden_states.shape[0] * hidden_states.shape[1],
hidden_states.shape[2])
hidden_states = self.mlp(hidden_states)
if is_hpu:
hidden_states = hidden_states.reshape(
_batch_size, hidden_states.shape[0] // _batch_size,
hidden_states.shape[1])
return hidden_states, residual


Expand Down

0 comments on commit f6865f4

Please sign in to comment.