Skip to content

Commit

Permalink
i am very much struggling
Browse files Browse the repository at this point in the history
  • Loading branch information
kzawora-intel committed Nov 13, 2024
1 parent fd77180 commit 4dadef5
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 15 deletions.
8 changes: 4 additions & 4 deletions vllm/attention/selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,12 +173,12 @@ def _cached_get_attn_backend(
return FlashInferBackend
elif backend == _Backend.HPU_ATTN:
logger.info("Using HPUAttention backend.")
from vllm.v1.attention.backends.hpu_attn import HPUAttentionBackendV1
return HPUAttentionBackendV1
elif backend == _Backend.HPU_ATTN_V1:
logger.info("Using HPUAttentionV1 backend.")
from vllm.attention.backends.hpu_attn import HPUAttentionBackend
return HPUAttentionBackend
elif backend == _Backend.HPU_ATTN_V1:
logger.info("Using HPUAttentionV1 backend.")
from vllm.v1.attention.backends.hpu_attn import HPUAttentionBackendV1
return HPUAttentionBackendV1
elif backend == _Backend.PALLAS:
logger.info("Using Pallas backend.")
from vllm.attention.backends.pallas import PallasAttentionBackend
Expand Down
88 changes: 77 additions & 11 deletions vllm/v1/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
block_ids=req_data.block_ids,
num_computed_tokens=req_data.num_computed_tokens,
output_token_ids=[],
num_output_tokens=0 # NOTE(kzawora): this assumes that all new requests contain no output tokens out of the box

Check failure on line 373 in vllm/v1/worker/hpu_model_runner.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/v1/worker/hpu_model_runner.py:373:81: E501 Line too long (128 > 80)
)
req_ids_to_add.append(req_id)

Expand Down Expand Up @@ -404,6 +405,11 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
assert total_num_scheduled_tokens > 0
num_reqs = self.input_batch.num_reqs
assert num_reqs > 0
assert not self.input_batch.mixed_batch, "Prefill chunking is not supported on HPU"

Check failure on line 408 in vllm/v1/worker/hpu_model_runner.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/v1/worker/hpu_model_runner.py:408:81: E501 Line too long (91 > 80)
is_prompt = self.input_batch.all_prefill
is_decode = self.input_batch.all_decode
num_prefills = num_reqs if is_prompt else 0
num_decodes = num_reqs if is_decode else 0

Check failure on line 412 in vllm/v1/worker/hpu_model_runner.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (F841)

vllm/v1/worker/hpu_model_runner.py:412:9: F841 Local variable `num_decodes` is assigned to but never used

# OPTIMIZATION: Start copying the block table first.
# This way, we can overlap the copy with the following CPU operations.
Expand All @@ -423,6 +429,11 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
num_scheduled_tokens = np.array(num_scheduled_tokens, dtype=np.int32)
assert max_num_scheduled_tokens > 0

seq_lens = (self.input_batch.num_computed_tokens_cpu[:num_reqs] +
num_scheduled_tokens)
context_lens = self.input_batch.num_output_tokens_cpu[:num_reqs] + self.input_batch.num_prompt_tokens_cpu[:num_reqs]

Check failure on line 434 in vllm/v1/worker/hpu_model_runner.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/v1/worker/hpu_model_runner.py:434:81: E501 Line too long (124 > 80)
max_seq_len = seq_lens.max()

# Get request indices.
# E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
indices = np.arange(num_reqs)
Expand Down Expand Up @@ -482,9 +493,7 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
query_start_loc_np[0] = 0
np.cumsum(num_scheduled_tokens, out=query_start_loc_np[1:])

seq_lens = (self.input_batch.num_computed_tokens_cpu[:num_reqs] +
num_scheduled_tokens)
max_seq_len = seq_lens.max()

seq_start_loc = torch.empty((num_reqs + 1, ),
dtype=torch.int32,
device="cpu",
Expand All @@ -493,23 +502,63 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
seq_start_loc_np[0] = 0
np.cumsum(seq_lens, out=seq_start_loc_np[1:])

self.input_ids[:total_num_scheduled_tokens].copy_(input_ids,
import pdb; pdb.set_trace()
# NOTE(kzawora): this is probably dumb
input_ids = self.input_batch.token_ids_cpu[:num_reqs,:max_seq_len]
positions = [list(range(context_len, seq_len)) for context_len, seq_len in zip(context_lens,seq_lens)] # idk what to do here
self.input_ids[:num_reqs,:max_seq_len].copy_(torch.from_numpy(input_ids),
non_blocking=True)
self.positions[:num_reqs,:max_seq_len].copy_(torch.from_numpy(positions),
non_blocking=True)
seq_lens_tensor = torch.empty((num_reqs, ),
dtype=torch.int32,
device="cpu",
pin_memory=self.pin_memory)
context_lens_tensor = torch.empty((num_reqs, ),
dtype=torch.int32,
device="cpu",
pin_memory=self.pin_memory)
seq_lens_tensor[:num_reqs].copy_(torch.from_numpy(seq_lens),
non_blocking=True)
self.positions[:total_num_scheduled_tokens].copy_(positions,
context_lens_tensor[:num_reqs].copy_(torch.from_numpy(context_lens),
non_blocking=True)

query_start_loc = query_start_loc.to(self.device, non_blocking=True)
seq_start_loc = seq_start_loc.to(self.device, non_blocking=True)
slot_mapping = slot_mapping.to(self.device, non_blocking=True).long()
seq_lens = self.seq_lens.to(self.device, non_blocking=True)

import pdb; pdb.set_trace()
attn_metadata = None

prefix_block_list_tensor = [] # FIXME(kzawora)
block_indices, block_offsets = precompute_indices_and_offsets(
self.block_size, slot_mapping, True)
attn_metadata = HPUAttentionMetadata(
max_query_len=max_num_scheduled_tokens,
query_start_loc=query_start_loc,
max_seq_len=max_seq_len,
seq_start_loc=seq_start_loc,
block_table=self.input_batch.block_table[:num_reqs],
is_prompt=is_prompt,
block_list=prefix_block_list_tensor,
block_mapping=None,
block_usage=None,
block_indices=block_indices,
block_offsets=block_offsets,
block_scales=None,
block_groups=None,
attn_bias=None,
seq_lens_tensor=seq_lens_tensor, # FIXME(kzawora)
context_lens_tensor=context_lens_tensor, # FIXME(kzawora)
num_prefills=num_prefills,
num_prefill_tokens=total_num_scheduled_tokens if is_prompt else 0,
num_decode_tokens=total_num_scheduled_tokens if is_decode else 0,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None # FIXME(kzawora): mutli-modality will not work here
)
#HPUAttentionMetadata(
# max_query_len=max_num_scheduled_tokens,
# query_start_loc=query_start_loc,
# max_seq_len=max_seq_len,
# seq_start_loc=seq_start_loc,
# block_table=self.input_batch.block_table[:num_reqs],
# slot_mapping=slot_mapping,
#)
# NOTE(woosuk): Due to chunked prefills, there can be at most 1 partial
# request in the batch. While we should not sample any token from this
# partial request, we do so for simplicity. We will ignore the sampled
Expand Down Expand Up @@ -792,6 +841,7 @@ class CachedRequestState:

block_ids: List[int]
num_computed_tokens: int
num_output_tokens: int
output_token_ids: List[int]

@property
Expand Down Expand Up @@ -821,6 +871,8 @@ def __init__(
self.token_ids_cpu = np.empty((max_num_reqs, max_model_len),
dtype=np.int32)
self.num_computed_tokens_cpu = np.empty(max_num_reqs, dtype=np.int32)
self.num_output_tokens_cpu = np.empty(max_num_reqs, dtype=np.int32)
self.num_prompt_tokens_cpu = np.empty(max_num_reqs, dtype=np.int32)

# Attention-related.
self.block_table = torch.zeros((max_num_reqs, max_num_blocks_per_req),
Expand Down Expand Up @@ -895,6 +947,8 @@ def add_request(
start_idx:end_idx] = request.output_token_ids

self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens
self.num_output_tokens_cpu[req_index] = request.num_output_tokens
self.num_prompt_tokens_cpu[req_index] = len(request.prompt_token_ids)
num_blocks = len(request.block_ids)
self.block_table_cpu[req_index, :num_blocks] = request.block_ids

Expand Down Expand Up @@ -1024,6 +1078,18 @@ def all_greedy(self) -> bool:
def all_random(self) -> bool:
return len(self.greedy_reqs) == 0

@property
def all_prefill(self) -> bool:
return all(output_tokens == 0 for output_tokens in self.num_output_tokens_cpu[:self.num_reqs])

@property
def all_decode(self) -> bool:
return all(output_tokens > 0 for output_tokens in self.num_output_tokens_cpu[:self.num_reqs])

@property
def mixed_batch(self) -> bool:
return not self.all_prefill and not self.all_decode

@property
def no_top_p(self) -> bool:
return len(self.top_p_reqs) == 0
Expand Down

0 comments on commit 4dadef5

Please sign in to comment.