From 4dadef583818e9e2b0d80aebf984573cbe85f472 Mon Sep 17 00:00:00 2001 From: Konrad Zawora Date: Wed, 13 Nov 2024 16:21:26 +0200 Subject: [PATCH] i am very much struggling --- vllm/attention/selector.py | 8 +-- vllm/v1/worker/hpu_model_runner.py | 88 ++++++++++++++++++++++++++---- 2 files changed, 81 insertions(+), 15 deletions(-) diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index b0ccdaae9bf31..7b02b6247b652 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -173,12 +173,12 @@ def _cached_get_attn_backend( return FlashInferBackend elif backend == _Backend.HPU_ATTN: logger.info("Using HPUAttention backend.") - from vllm.v1.attention.backends.hpu_attn import HPUAttentionBackendV1 - return HPUAttentionBackendV1 - elif backend == _Backend.HPU_ATTN_V1: - logger.info("Using HPUAttentionV1 backend.") from vllm.attention.backends.hpu_attn import HPUAttentionBackend return HPUAttentionBackend + elif backend == _Backend.HPU_ATTN_V1: + logger.info("Using HPUAttentionV1 backend.") + from vllm.v1.attention.backends.hpu_attn import HPUAttentionBackendV1 + return HPUAttentionBackendV1 elif backend == _Backend.PALLAS: logger.info("Using Pallas backend.") from vllm.attention.backends.pallas import PallasAttentionBackend diff --git a/vllm/v1/worker/hpu_model_runner.py b/vllm/v1/worker/hpu_model_runner.py index 4716dc7c1318d..3fbd5b2fe70b2 100755 --- a/vllm/v1/worker/hpu_model_runner.py +++ b/vllm/v1/worker/hpu_model_runner.py @@ -370,6 +370,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: block_ids=req_data.block_ids, num_computed_tokens=req_data.num_computed_tokens, output_token_ids=[], + num_output_tokens=0 # NOTE(kzawora): this assumes that all new requests contain no output tokens out of the box ) req_ids_to_add.append(req_id) @@ -404,6 +405,11 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): assert total_num_scheduled_tokens > 0 num_reqs = self.input_batch.num_reqs assert num_reqs > 0 + assert not self.input_batch.mixed_batch, "Prefill chunking is not supported on HPU" + is_prompt = self.input_batch.all_prefill + is_decode = self.input_batch.all_decode + num_prefills = num_reqs if is_prompt else 0 + num_decodes = num_reqs if is_decode else 0 # OPTIMIZATION: Start copying the block table first. # This way, we can overlap the copy with the following CPU operations. @@ -423,6 +429,11 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): num_scheduled_tokens = np.array(num_scheduled_tokens, dtype=np.int32) assert max_num_scheduled_tokens > 0 + seq_lens = (self.input_batch.num_computed_tokens_cpu[:num_reqs] + + num_scheduled_tokens) + context_lens = self.input_batch.num_output_tokens_cpu[:num_reqs] + self.input_batch.num_prompt_tokens_cpu[:num_reqs] + max_seq_len = seq_lens.max() + # Get request indices. # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2] indices = np.arange(num_reqs) @@ -482,9 +493,7 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): query_start_loc_np[0] = 0 np.cumsum(num_scheduled_tokens, out=query_start_loc_np[1:]) - seq_lens = (self.input_batch.num_computed_tokens_cpu[:num_reqs] + - num_scheduled_tokens) - max_seq_len = seq_lens.max() + seq_start_loc = torch.empty((num_reqs + 1, ), dtype=torch.int32, device="cpu", @@ -493,23 +502,63 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): seq_start_loc_np[0] = 0 np.cumsum(seq_lens, out=seq_start_loc_np[1:]) - self.input_ids[:total_num_scheduled_tokens].copy_(input_ids, + import pdb; pdb.set_trace() + # NOTE(kzawora): this is probably dumb + input_ids = self.input_batch.token_ids_cpu[:num_reqs,:max_seq_len] + positions = [list(range(context_len, seq_len)) for context_len, seq_len in zip(context_lens,seq_lens)] # idk what to do here + self.input_ids[:num_reqs,:max_seq_len].copy_(torch.from_numpy(input_ids), + non_blocking=True) + self.positions[:num_reqs,:max_seq_len].copy_(torch.from_numpy(positions), + non_blocking=True) + seq_lens_tensor = torch.empty((num_reqs, ), + dtype=torch.int32, + device="cpu", + pin_memory=self.pin_memory) + context_lens_tensor = torch.empty((num_reqs, ), + dtype=torch.int32, + device="cpu", + pin_memory=self.pin_memory) + seq_lens_tensor[:num_reqs].copy_(torch.from_numpy(seq_lens), non_blocking=True) - self.positions[:total_num_scheduled_tokens].copy_(positions, + context_lens_tensor[:num_reqs].copy_(torch.from_numpy(context_lens), non_blocking=True) - query_start_loc = query_start_loc.to(self.device, non_blocking=True) seq_start_loc = seq_start_loc.to(self.device, non_blocking=True) slot_mapping = slot_mapping.to(self.device, non_blocking=True).long() + seq_lens = self.seq_lens.to(self.device, non_blocking=True) + import pdb; pdb.set_trace() + attn_metadata = None + + prefix_block_list_tensor = [] # FIXME(kzawora) + block_indices, block_offsets = precompute_indices_and_offsets( + self.block_size, slot_mapping, True) attn_metadata = HPUAttentionMetadata( - max_query_len=max_num_scheduled_tokens, - query_start_loc=query_start_loc, - max_seq_len=max_seq_len, - seq_start_loc=seq_start_loc, - block_table=self.input_batch.block_table[:num_reqs], + is_prompt=is_prompt, + block_list=prefix_block_list_tensor, + block_mapping=None, + block_usage=None, + block_indices=block_indices, + block_offsets=block_offsets, + block_scales=None, + block_groups=None, + attn_bias=None, + seq_lens_tensor=seq_lens_tensor, # FIXME(kzawora) + context_lens_tensor=context_lens_tensor, # FIXME(kzawora) + num_prefills=num_prefills, + num_prefill_tokens=total_num_scheduled_tokens if is_prompt else 0, + num_decode_tokens=total_num_scheduled_tokens if is_decode else 0, slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=None # FIXME(kzawora): mutli-modality will not work here ) + #HPUAttentionMetadata( + # max_query_len=max_num_scheduled_tokens, + # query_start_loc=query_start_loc, + # max_seq_len=max_seq_len, + # seq_start_loc=seq_start_loc, + # block_table=self.input_batch.block_table[:num_reqs], + # slot_mapping=slot_mapping, + #) # NOTE(woosuk): Due to chunked prefills, there can be at most 1 partial # request in the batch. While we should not sample any token from this # partial request, we do so for simplicity. We will ignore the sampled @@ -792,6 +841,7 @@ class CachedRequestState: block_ids: List[int] num_computed_tokens: int + num_output_tokens: int output_token_ids: List[int] @property @@ -821,6 +871,8 @@ def __init__( self.token_ids_cpu = np.empty((max_num_reqs, max_model_len), dtype=np.int32) self.num_computed_tokens_cpu = np.empty(max_num_reqs, dtype=np.int32) + self.num_output_tokens_cpu = np.empty(max_num_reqs, dtype=np.int32) + self.num_prompt_tokens_cpu = np.empty(max_num_reqs, dtype=np.int32) # Attention-related. self.block_table = torch.zeros((max_num_reqs, max_num_blocks_per_req), @@ -895,6 +947,8 @@ def add_request( start_idx:end_idx] = request.output_token_ids self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens + self.num_output_tokens_cpu[req_index] = request.num_output_tokens + self.num_prompt_tokens_cpu[req_index] = len(request.prompt_token_ids) num_blocks = len(request.block_ids) self.block_table_cpu[req_index, :num_blocks] = request.block_ids @@ -1024,6 +1078,18 @@ def all_greedy(self) -> bool: def all_random(self) -> bool: return len(self.greedy_reqs) == 0 + @property + def all_prefill(self) -> bool: + return all(output_tokens == 0 for output_tokens in self.num_output_tokens_cpu[:self.num_reqs]) + + @property + def all_decode(self) -> bool: + return all(output_tokens > 0 for output_tokens in self.num_output_tokens_cpu[:self.num_reqs]) + + @property + def mixed_batch(self) -> bool: + return not self.all_prefill and not self.all_decode + @property def no_top_p(self) -> bool: return len(self.top_p_reqs) == 0