Skip to content

Commit

Permalink
handling on any arch
Browse files Browse the repository at this point in the history
  • Loading branch information
skirdey-inflection committed Nov 25, 2024
1 parent 9aec429 commit 19285e8
Showing 1 changed file with 9 additions and 6 deletions.
15 changes: 9 additions & 6 deletions vllm/model_executor/models/internlm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,8 @@ def __init__(
)

def split_qkv(self, qkv: torch.Tensor):
batch_size, seq_len, _ = qkv.shape
# Unpack all dimensions except the last one
*batch_dims, last_dim = qkv.shape

if self.tp_size > 1:
qkv_map = [self.q_size, self.kv_size, self.kv_size] * self.tp_size
Expand All @@ -155,19 +156,21 @@ def split_qkv(self, qkv: torch.Tensor):

qkv = qkv.contiguous()

qkv = qkv.reshape(batch_size, seq_len, self.total_num_kv_heads,
self.key_value_groups + 2, self.head_dim)
# Dynamically reshape based on the number of batch dimensions
qkv = qkv.view(*batch_dims, self.total_num_kv_heads,
self.key_value_groups + 2, self.head_dim)
q, k, v = torch.split(qkv, [self.key_value_groups, 1, 1], dim=-2)
q = q.reshape(batch_size, seq_len, self.q_size * self.tp_size)
k = k.reshape(batch_size, seq_len, self.kv_size * self.tp_size)
v = v.reshape(batch_size, seq_len, self.kv_size * self.tp_size)
q = q.view(*batch_dims, self.q_size * self.tp_size)
k = k.view(*batch_dims, self.kv_size * self.tp_size)
v = v.view(*batch_dims, self.kv_size * self.tp_size)

if self.tp_size > 1:
splitter = partial(split_tensor_along_last_dim,
num_partitions=self.tp_size)
q = splitter(q)[self.tp_rank]
k = splitter(k)[self.tp_rank]
v = splitter(v)[self.tp_rank]

return q, k, v

def forward(
Expand Down

0 comments on commit 19285e8

Please sign in to comment.