diff --git a/requirements-hpu.txt b/requirements-hpu.txt index f4fb89ef42834..3f4cf33f105d6 100644 --- a/requirements-hpu.txt +++ b/requirements-hpu.txt @@ -8,4 +8,4 @@ pandas tabulate setuptools>=61 setuptools-scm>=8 -vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@4312768 +vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@0766759 diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index aed04361e5fb4..1150412c83b95 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -233,6 +233,8 @@ def __init__( kv_cache_dtype: str = "auto", blocksparse_params: Optional[Dict[str, Any]] = None, logits_soft_cap: Optional[float] = None, + tp_rank: Optional[int] = None, + prev_attn: Optional[torch.nn.Module] = None, ) -> None: raise NotImplementedError diff --git a/vllm/attention/backends/blocksparse_attn.py b/vllm/attention/backends/blocksparse_attn.py index 99cb84346d84e..5b0a37ca7fa77 100644 --- a/vllm/attention/backends/blocksparse_attn.py +++ b/vllm/attention/backends/blocksparse_attn.py @@ -300,6 +300,7 @@ def __init__( kv_cache_dtype: str, blocksparse_params: Optional[Dict[str, Any]] = None, logits_soft_cap: Optional[float] = None, + **kwargs, ) -> None: assert blocksparse_params is not None assert alibi_slopes is None, ValueError( diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index c69e12ad78c44..8f253f9e0ae7c 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -600,6 +600,7 @@ def __init__( kv_cache_dtype: str, blocksparse_params: Optional[Dict[str, Any]] = None, logits_soft_cap: Optional[float] = None, + **kwargs, ) -> None: if blocksparse_params is not None: raise ValueError( diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index e367468d05d26..304500dc601c6 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -748,6 +748,7 @@ def __init__( kv_cache_dtype: str, blocksparse_params: Optional[Dict[str, Any]] = None, logits_soft_cap: Optional[float] = None, + **kwargs, ) -> None: self.num_heads = num_heads self.head_size = head_size diff --git a/vllm/attention/backends/hpu_attn.py b/vllm/attention/backends/hpu_attn.py index 96dafe8c2fcb1..a374237ef1ddd 100644 --- a/vllm/attention/backends/hpu_attn.py +++ b/vllm/attention/backends/hpu_attn.py @@ -125,6 +125,9 @@ def __init__( kv_cache_dtype: str, blocksparse_params: Optional[Dict[str, Any]] = None, max_seq_len: int = 4096, + tp_rank: Optional[int] = None, + prev_attn: Optional[torch.nn.Module] = None, + **kwargs, ) -> None: super(AttentionImpl, self).__init__() self.kv_cache_dtype = kv_cache_dtype @@ -142,11 +145,42 @@ def __init__( else ModuleFusedSDPA(HPUFusedSDPA) self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads self.sliding_window = sliding_window - self.alibi_slopes = alibi_slopes + self.alibi_slopes = None + self.prompt_position_bias = None + # Set upper bound on sequence length + self.max_seq_len = int( + os.getenv( + 'VLLM_PROMPT_ALIBI_MAX_SEQ_LEN', + max_seq_len, + )) + # Set lower bound on sequence length + self.max_seq_len = max([ + self.max_seq_len, + int(os.getenv('VLLM_PROMPT_SEQ_BUCKET_MAX', '0')), + ]) + self.tp_rank = tp_rank + self.prev_attn = None if prev_attn is None else prev_attn.impl if alibi_slopes is not None: - alibi_slopes_tensor = torch.tensor(alibi_slopes, - dtype=torch.bfloat16) - self.alibi_slopes = alibi_slopes_tensor + if (self.prev_attn is not None + and self.prev_attn.tp_rank == self.tp_rank): + self.alibi_slopes = self.prev_attn.alibi_slopes + self.prompt_position_bias = self.prev_attn.prompt_position_bias + else: + slope_tensor_dtype = { + True: torch.float32, + False: torch.bfloat16, + }[os.getenv('VLLM_ALIBI_USE_FLOAT32_BIASES', '1').lower() + in ['1', 'true']] + alibi_slopes_tensor = torch.tensor(alibi_slopes, + dtype=slope_tensor_dtype) + self.alibi_slopes = alibi_slopes_tensor + # Creating the prompt_position_bias once and reusing it + # if seq_len permits. + self.prompt_position_bias = _make_prompt_alibi_bias( + alibi_slopes=self.alibi_slopes, + seq_len=self.max_seq_len, + dtype=self.alibi_slopes.dtype, + ) assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads @@ -157,6 +191,12 @@ def __init__( assert alibi_slopes is None, \ 'Prefill with FusedSDPA not supported with alibi slopes!' + self.use_contiguous_pa = os.environ.get('VLLM_CONTIGUOUS_PA', + 'true').lower() == 'true' + if not self.use_contiguous_pa: + assert alibi_slopes is None, \ + 'Non-contiguous PA not supported with alibi slopes!' + suppored_head_sizes = HPUPagedAttention.get_supported_head_sizes() if head_size not in suppored_head_sizes: raise ValueError( @@ -230,27 +270,58 @@ def forward( query_shape = (batch_size, seq_len, self.num_heads, self.head_size) kv_shape = (batch_size, seq_len_kv, self.num_kv_heads, self.head_size) + if attn_metadata is None or attn_metadata.block_list is None: if not self.prefill_use_fusedsdpa: # TODO: move this outside of model assert attn_metadata.attn_bias is not None, \ 'attn_bias must be set before calling model.forward' + # If we have alibi_slopes, incorporate them with + # position_bias and position_bias_offset. attn_bias = attn_metadata.attn_bias - if self.alibi_slopes is not None: - position_bias = _make_alibi_bias( - self.alibi_slopes, self.num_kv_heads, - attn_bias.dtype, attn_bias.shape[-1]) - attn_bias = attn_bias.tile( - (1, self.num_kv_heads, 1, 1)) - attn_bias.add_(position_bias) + seq_lens_tensor = attn_metadata.seq_lens_tensor + position_bias = None + position_bias_offset = None + if (self.prompt_position_bias is not None + and self.alibi_slopes is not None): + if self.max_seq_len >= max(attn_bias.size(-2), + attn_bias.size(-1)): + # Using pre-computed prompt_position_bias subset. + position_bias = self.prompt_position_bias[:, :, + -attn_bias.size(-2):, + -attn_bias.size(-1):] + else: + # For longer sequences than precomputed, + # recreate the bias. This is memory inefficient. + position_bias = _make_prompt_alibi_bias( + alibi_slopes=self.alibi_slopes, + seq_len=max(attn_bias.size(-2), + attn_bias.size(-1)), + dtype=self.alibi_slopes.dtype, + ) + # If seq_lens_tensor is provided, we create a + # position_bias_offset. This offset helps handle + # sequences of varying lengths in a batch. + if seq_lens_tensor is not None: + position_bias_offset = seq_lens_tensor.unsqueeze( + 1).tile(1, self.num_heads).to( + dtype=self.alibi_slopes.dtype) + position_bias_offset.mul_( + self.alibi_slopes[None, :]) + position_bias_offset = position_bias_offset \ + - position_bias[:, :, -1, 0] else: attn_bias = None + position_bias = None + position_bias_offset = None out = ops.prompt_attention( query.view(query_shape), key.view(kv_shape), value.view(kv_shape), attn_bias=attn_bias, + position_bias=position_bias, + position_bias_offset=position_bias_offset, p=0.0, scale=self.scale, matmul_qk_op=self.matmul_qk, @@ -278,6 +349,20 @@ def forward( output = out.reshape(batch_size, seq_len, hidden_size) else: # Decoding run. + self.position_bias = None + alibi_blocks = attn_metadata.alibi_blocks + if self.alibi_slopes is not None and alibi_blocks is not None: + if (self.prev_attn is not None + and self.prev_attn.tp_rank == self.tp_rank): + self.position_bias = self.prev_attn.position_bias + else: + # For decoding, compute position bias using alibi_blocks. + self.position_bias = _make_decode_alibi_bias( + alibi_blocks=alibi_blocks, + alibi_slopes=self.alibi_slopes, + dtype=self.alibi_slopes.dtype, + ) + output = HPUPagedAttention.forward_decode( query=query, key_cache=key_cache, @@ -288,14 +373,18 @@ def forward( block_scales=attn_metadata.block_scales, block_groups=attn_metadata.block_groups, scale=self.scale, + position_bias=self.position_bias, matmul_qk_op=self.matmul_qk, matmul_av_op=self.matmul_av, batch2block_matmul_op=self.batch2block_matmul, block2batch_matmul_op=self.block2batch_matmul, keys_fetch_func=self.k_cache.fetch_from_cache, - values_fetch_func=self.v_cache.fetch_from_cache) + values_fetch_func=self.v_cache.fetch_from_cache, + ) + # Reshape the output tensor. - return output.view(batch_size, seq_len, hidden_size) + output = output.view(batch_size, seq_len, hidden_size) + return output def forward_encoder_decoder( self, @@ -409,12 +498,25 @@ def forward_encoder_decoder( return output.view(batch_size, -1, hidden_size) -def _make_alibi_bias( +def _make_prompt_alibi_bias( alibi_slopes: torch.Tensor, - num_kv_heads: int, - dtype: torch.dtype, seq_len: int, + dtype: torch.dtype, ) -> torch.Tensor: + """ + Create the ALiBi position bias tensor for prompt stage. + This tensor is reused or tiled as needed for each forward pass. + Does not scale with batch size or number of blocks. + + Args: + alibi_slopes: shape = [num_heads] + seq_len: int + dtype: torch.dtype + + Returns: + A per-head bias tensor of shape [1, num_heads, seq_len, seq_len]. + This bias encodes positional information via ALiBi slopes. + """ bias = torch.arange(seq_len, dtype=dtype) # NOTE(zhuohan): HF uses # `bias = bias[None, :].repeat(seq_len, 1)` @@ -427,15 +529,54 @@ def _make_alibi_bias( padded_len = (seq_len + 7) // 8 * 8 num_heads = alibi_slopes.shape[0] - bias = torch.empty( - 1, # batch size + per_head_bias = torch.empty( + 1, num_heads, seq_len, padded_len, device=alibi_slopes.device, dtype=dtype, - )[:, :, :, :seq_len].copy_(bias) - bias.mul_(alibi_slopes[:, None, None]) - if num_heads != num_kv_heads: - bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads)) - return bias + )[:, :, :, :seq_len] + # NOTE(Tanner): + # .copy_ was not performing broadcasting of bias + # to all 32 heads in Eager mode. + per_head_bias[:, :] = bias + per_head_bias.mul_(alibi_slopes[:, None, None]) + + return per_head_bias + + +def _make_decode_alibi_bias( + alibi_blocks: torch.Tensor, + alibi_slopes: torch.Tensor, + dtype: torch.dtype, +) -> torch.Tensor: + """ + Create the ALiBi position bias tensor for decode stage. + Uses stored alibi_blocks and slopes for final scaling. + Scales with number of blocks, not with batch size. + + Args: + alibi_blocks: shape = [num_blocks, block_size] + alibi_slopes: shape = [num_heads] + dtype: torch.dtype + + Returns: + A per-head bias tensor of shape [num_blocks, num_heads, block_size]. + Each row encodes position-dependent ALiBi slopes for decoding steps. + """ + num_heads = alibi_slopes.shape[0] + per_head_bias = torch.empty( + alibi_blocks.size(0), + num_heads, + alibi_blocks.size(-1), + device=alibi_slopes.device, + dtype=dtype, + ) + # NOTE(Tanner): + # .copy_ was not performing broadcasting of bias + # to all 32 heads in Eager mode. + per_head_bias[:, :] = alibi_blocks.unsqueeze(-2) + per_head_bias.mul_(alibi_slopes[None, :, None]) + + return per_head_bias diff --git a/vllm/attention/backends/ipex_attn.py b/vllm/attention/backends/ipex_attn.py index 21949874bea47..faff256c3de3f 100644 --- a/vllm/attention/backends/ipex_attn.py +++ b/vllm/attention/backends/ipex_attn.py @@ -115,6 +115,7 @@ def __init__( kv_cache_dtype: str, blocksparse_params: Optional[Dict[str, Any]] = None, logits_soft_cap: Optional[float] = None, + **kwargs, ) -> None: if blocksparse_params is not None: raise ValueError( diff --git a/vllm/attention/backends/pallas.py b/vllm/attention/backends/pallas.py index 9809aed0e66f9..cb39e4dca5fee 100644 --- a/vllm/attention/backends/pallas.py +++ b/vllm/attention/backends/pallas.py @@ -100,6 +100,7 @@ def __init__( kv_cache_dtype: str, blocksparse_params: Optional[Dict[str, Any]] = None, logits_soft_cap: Optional[float] = None, + **kwargs, ) -> None: self.num_heads = num_heads self.head_size = head_size diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 19daeb729ee61..9e0540f0ff3db 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -338,6 +338,7 @@ def __init__( kv_cache_dtype: str, blocksparse_params: Optional[Dict[str, Any]] = None, logits_soft_cap: Optional[float] = None, + **kwargs, ) -> None: if blocksparse_params is not None: raise ValueError( diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index 86e952a903f36..1098211e4b5f3 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -390,6 +390,7 @@ def __init__( kv_cache_dtype: str, blocksparse_params: Optional[Dict[str, Any]] = None, logits_soft_cap: Optional[float] = None, + **kwargs, ) -> None: if blocksparse_params is not None: raise ValueError( diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index e2e989efb020c..ba061d09b01e6 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -381,6 +381,7 @@ def __init__( kv_cache_dtype: str, blocksparse_params: Optional[Dict[str, Any]] = None, logits_soft_cap: Optional[float] = None, + **kwargs, ) -> None: if blocksparse_params is not None: raise ValueError( diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 05d997279893b..c90c801766525 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -38,9 +38,11 @@ def __init__( cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, blocksparse_params: Optional[Dict[str, Any]] = None, - logits_soft_cap: Optional[float] = None, + logits_soft_cap: Optional[int] = 4096, per_layer_sliding_window: Optional[int] = None, + tp_rank: Optional[int] = None, prefix: str = "", + prev_attn: Optional[nn.Module] = None, ) -> None: super().__init__() if per_layer_sliding_window is not None: @@ -96,7 +98,8 @@ def __init__( impl_cls = attn_backend.get_impl_cls() self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads, alibi_slopes, sliding_window, kv_cache_dtype, - blocksparse_params, logits_soft_cap) + blocksparse_params, logits_soft_cap, + tp_rank=tp_rank, prev_attn=prev_attn) self.num_heads = num_heads self.head_size = head_size self.num_kv_heads = num_kv_heads diff --git a/vllm/attention/ops/hpu_paged_attn.py b/vllm/attention/ops/hpu_paged_attn.py index e55a4de11fd6c..d1235e6ec7aa7 100644 --- a/vllm/attention/ops/hpu_paged_attn.py +++ b/vllm/attention/ops/hpu_paged_attn.py @@ -22,6 +22,7 @@ class HPUPagedAttentionMetadata: block_offsets: Optional[torch.Tensor] block_scales: Optional[torch.Tensor] block_groups: Optional[torch.Tensor] + alibi_blocks: Optional[torch.Tensor] class HPUPagedAttention: diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py index 5e68b7f165bf4..e1fb3c85ba10b 100644 --- a/vllm/model_executor/models/baichuan.py +++ b/vllm/model_executor/models/baichuan.py @@ -117,6 +117,7 @@ def __init__( cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + prev_attn: Optional[nn.Module] = None, ): super().__init__() self.hidden_size = hidden_size @@ -127,7 +128,7 @@ def __init__( self.num_heads = (self.total_num_heads // tensor_model_parallel_world_size) self.head_dim = hidden_size // self.total_num_heads - self.postion_embedding = position_embedding + self.position_embedding = position_embedding self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings @@ -147,20 +148,26 @@ def __init__( quant_config=quant_config, ) # Create the alibi slopes and slice them. - if self.postion_embedding == "ALIBI": + if self.position_embedding == "ALIBI": tp_rank = get_tensor_model_parallel_rank() head_start = tp_rank * self.num_heads head_end = (tp_rank + 1) * self.num_heads alibi_slopes = _get_alibi_slopes(self.total_num_heads) alibi_slopes = alibi_slopes[head_start:head_end].tolist() + prev_attn = None if prev_attn is None else prev_attn.attn scaling = self.head_dim**-0.5 - self.attn = Attention(self.num_heads, - self.head_dim, - scaling, - alibi_slopes=alibi_slopes, - quant_config=quant_config, - prefix=f"{prefix}.attn") + self.attn = Attention( + self.num_heads, + self.head_dim, + scaling, + alibi_slopes=alibi_slopes, + quant_config=quant_config, + logits_soft_cap=self.max_position_embeddings, + tp_rank=tp_rank, + prefix=f"{prefix}.attn", + prev_attn=prev_attn, + ) else: self.rotary_emb = get_rope( self.head_dim, @@ -169,12 +176,14 @@ def __init__( base=self.rope_theta, ) self.scaling = self.head_dim**-0.5 - self.attn = Attention(self.num_heads, - self.head_dim, - self.scaling, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn") + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) def forward( self, @@ -185,7 +194,7 @@ def forward( ) -> torch.Tensor: qkv, _ = self.W_pack(hidden_states) q, k, v = qkv.chunk(chunks=3, dim=-1) - if self.postion_embedding != "ALIBI": + if self.position_embedding != "ALIBI": q, k = self.rotary_emb(positions, q, k) attn_output = self.attn(q, k, v, kv_cache, attn_metadata) output, _ = self.o_proj(attn_output) @@ -194,17 +203,21 @@ def forward( class BaiChuanDecoderLayer(nn.Module): - def __init__(self, - config: PretrainedConfig, - position_embedding: str, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): + def __init__( + self, + config: PretrainedConfig, + position_embedding: str, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + prev_layer: Optional[nn.Module] = None, + ): super().__init__() self.hidden_size = config.hidden_size rope_theta = getattr(config, "rope_theta", 10000) max_position_embeddings = getattr(config, "max_position_embeddings", 8192) + prev_attn = None if prev_layer is None else prev_layer.self_attn self.self_attn = BaiChuanAttention( hidden_size=self.hidden_size, num_heads=config.num_attention_heads, @@ -214,6 +227,7 @@ def __init__(self, cache_config=cache_config, quant_config=quant_config, prefix=f"{prefix}.self_attn", + prev_attn=prev_attn, ) self.mlp = BaiChuanMLP( hidden_size=self.hidden_size, @@ -280,12 +294,16 @@ def __init__( ) self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: BaiChuanDecoderLayer(config, - position_embedding, - cache_config, - quant_config, - prefix=prefix), + lambda prefix, prev_layer: BaiChuanDecoderLayer( + config, + position_embedding, + cache_config, + quant_config, + prefix=prefix, + prev_layer=prev_layer, + ), prefix=f"{prefix}.layers", + use_layer_sharing=True, ) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.make_empty_intermediate_tensors = ( @@ -372,6 +390,7 @@ def __init__( self.lora_config = lora_config self.quant_config = quant_config + self.use_alibi = position_embedding == "ALIBI" self.model = BaiChuanModel(vllm_config=vllm_config, prefix=prefix, position_embedding=position_embedding) diff --git a/vllm/model_executor/models/bloom.py b/vllm/model_executor/models/bloom.py index fee74f491acc1..e963b61be9d27 100644 --- a/vllm/model_executor/models/bloom.py +++ b/vllm/model_executor/models/bloom.py @@ -79,6 +79,7 @@ def __init__( cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + prev_attn: Optional[nn.Module] = None, ): super().__init__() self.hidden_size = config.hidden_size @@ -110,15 +111,20 @@ def __init__( head_end = (tp_rank + 1) * self.num_heads alibi_slopes = _get_alibi_slopes(self.total_num_heads) alibi_slopes = alibi_slopes[head_start:head_end].tolist() + prev_attn = None if prev_attn is None else prev_attn.attn scaling = self.head_dim**-0.5 - self.attn = Attention(self.num_heads, - self.head_dim, - scaling, - alibi_slopes=alibi_slopes, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn") + self.attn = Attention( + self.num_heads, + self.head_dim, + scaling, + alibi_slopes=alibi_slopes, + cache_config=cache_config, + quant_config=quant_config, + tp_rank=tp_rank, + prefix=f"{prefix}.attn", + prev_attn=prev_attn, + ) def forward( self, @@ -171,16 +177,21 @@ def __init__( cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + prev_layer: Optional[nn.Module] = None, ): super().__init__() hidden_size = config.hidden_size + prev_attn = None if prev_layer is None else prev_layer.self_attention self.input_layernorm = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - self.self_attention = BloomAttention(config, - cache_config, - quant_config, - prefix=f"{prefix}.self_attention") + self.self_attention = BloomAttention( + config, + cache_config, + quant_config, + prefix=f"{prefix}.self_attention", + prev_attn=prev_attn, + ) self.post_attention_layernorm = nn.LayerNorm( hidden_size, eps=config.layer_norm_epsilon) self.mlp = BloomMLP(config, quant_config) @@ -247,9 +258,14 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): # Transformer blocks self.start_layer, self.end_layer, self.h = make_layers( config.num_hidden_layers, - lambda prefix: BloomBlock( - config, cache_config, quant_config, prefix=prefix), - prefix=f"{prefix}.h") + lambda prefix, prev_layer: BloomBlock(config, + cache_config, + quant_config, + prefix=prefix, + prev_layer=prev_layer), + prefix=f"{prefix}.h", + use_layer_sharing=True, + ) # Final Layer Norm self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) @@ -299,6 +315,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): quant_config = vllm_config.quant_config self.config = config self.quant_config = quant_config + self.use_alibi = True self.transformer = BloomModel(vllm_config=vllm_config, prefix=maybe_prefix( prefix, "transformer")) diff --git a/vllm/model_executor/models/falcon.py b/vllm/model_executor/models/falcon.py index 8660cf79b9cdb..16964016a9633 100644 --- a/vllm/model_executor/models/falcon.py +++ b/vllm/model_executor/models/falcon.py @@ -85,6 +85,7 @@ def __init__( cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + prev_attn: Optional[nn.Module] = None, ): super().__init__() @@ -155,12 +156,15 @@ def __init__( max_position=max_position_embeddings, base=rope_theta, ) - self.attn = Attention(self.num_heads, - self.head_dim, - self.inv_norm_factor, - num_kv_heads=self.num_kv_heads, - quant_config=quant_config, - prefix=f"{prefix}.attn") + self.attn = Attention( + self.num_heads, + self.head_dim, + self.inv_norm_factor, + num_kv_heads=self.num_kv_heads, + quant_config=quant_config, + logits_soft_cap=max_position_embeddings, + prefix=f"{prefix}.attn", + ) elif self.use_alibi: tp_rank = get_tensor_model_parallel_rank() head_start = tp_rank * self.num_heads @@ -168,21 +172,28 @@ def __init__( alibi_slopes = (_get_alibi_slopes(self.total_num_heads) * self.inv_norm_factor) alibi_slopes = alibi_slopes[head_start:head_end].tolist() - self.attn = Attention(self.num_heads, - self.head_dim, - self.inv_norm_factor, - num_kv_heads=self.num_kv_heads, - alibi_slopes=alibi_slopes, - quant_config=quant_config, - prefix=f"{prefix}.attn") + prev_attn = None if prev_attn is None else prev_attn.attn + self.attn = Attention( + self.num_heads, + self.head_dim, + self.inv_norm_factor, + num_kv_heads=self.num_kv_heads, + alibi_slopes=alibi_slopes, + quant_config=quant_config, + tp_rank=tp_rank, + prefix=f"{prefix}.attn", + prev_attn=prev_attn, + ) else: - self.attn = Attention(self.num_heads, - self.head_dim, - scale=self.inv_norm_factor, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn") + self.attn = Attention( + self.num_heads, + self.head_dim, + scale=self.inv_norm_factor, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) def forward( self, @@ -246,15 +257,19 @@ def __init__( cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + prev_layer: Optional[nn.Module] = None, ): super().__init__() hidden_size = config.hidden_size self.num_heads = config.num_attention_heads + prev_attn = None if prev_layer is None else prev_layer.self_attention self.self_attention = FalconAttention( config, cache_config, quant_config, - prefix=f"{prefix}.self_attention") + prefix=f"{prefix}.self_attention", + prev_attn=prev_attn, + ) self.mlp = FalconMLP(config, quant_config) self.config = config @@ -354,7 +369,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.config = config self.embed_dim = config.hidden_size self.num_heads = config.num_attention_heads - self.use_alibi = config.alibi # Embedding + LN Embedding self.word_embeddings = VocabParallelEmbedding( @@ -365,9 +379,16 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): # Transformer blocks self.start_layer, self.end_layer, self.h = make_layers( config.num_hidden_layers, - lambda prefix: FalconDecoderLayer( - config, cache_config, quant_config, prefix=prefix), - prefix=f"{prefix}.h") + lambda prefix, prev_layer: FalconDecoderLayer( + config, + cache_config, + quant_config, + prefix=prefix, + prev_layer=prev_layer, + ), + prefix=f"{prefix}.h", + use_layer_sharing=True, + ) # Final Layer Norm self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) @@ -419,6 +440,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): quant_config = vllm_config.quant_config self.config = config self.quant_config = quant_config + self.use_alibi = config.alibi self.transformer = FalconModel(vllm_config=vllm_config, prefix=maybe_prefix( prefix, "transformer")) diff --git a/vllm/model_executor/models/jais.py b/vllm/model_executor/models/jais.py index 8c81dff6b5768..6e86cb3c9ba52 100644 --- a/vllm/model_executor/models/jais.py +++ b/vllm/model_executor/models/jais.py @@ -77,6 +77,7 @@ def __init__( cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + prev_attn: Optional[nn.Module] = None, ): super().__init__() self.hidden_size = config.hidden_size @@ -110,13 +111,19 @@ def __init__( head_end = (tp_rank + 1) * self.num_heads alibi_slopes = _get_alibi_slopes(total_num_heads) alibi_slopes = alibi_slopes[head_start:head_end] - self.attn = Attention(self.num_heads, - self.head_dim, - scale=self.scale, - alibi_slopes=alibi_slopes, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn") + prev_attn = None if prev_attn is None else prev_attn.attn + self.attn = Attention( + self.num_heads, + self.head_dim, + scale=self.scale, + alibi_slopes=alibi_slopes, + cache_config=cache_config, + quant_config=quant_config, + logits_soft_cap=config.max_position_embeddings, + tp_rank=tp_rank, + prefix=f"{prefix}.attn", + prev_attn=prev_attn, + ) def forward( self, @@ -181,6 +188,7 @@ def __init__( cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + prev_layer: Optional[nn.Module] = None, ): super().__init__() hidden_size = config.hidden_size @@ -188,10 +196,14 @@ def __init__( hidden_size) self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - self.attn = JAISAttention(config, - cache_config, - quant_config, - prefix=f"{prefix}.attn") + prev_attn = None if prev_layer is None else prev_layer.attn + self.attn = JAISAttention( + config, + cache_config, + quant_config, + prefix=f"{prefix}.attn", + prev_attn=prev_attn, + ) self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.mlp = JAISMLP(inner_dim, config, quant_config) @@ -245,11 +257,15 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.start_layer, self.end_layer, self.h = make_layers( config.num_hidden_layers, - lambda prefix: JAISBlock(config=config, - cache_config=cache_config, - quant_config=quant_config, - prefix=prefix), + lambda prefix, prev_layer: JAISBlock( + config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix, + prev_layer=prev_layer, + ), prefix=f"{prefix}.h", + use_layer_sharing=True, ) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) @@ -304,6 +320,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): quant_config = vllm_config.quant_config self.config = config self.quant_config = quant_config + self.use_alibi = config.position_embedding_type == "alibi" self.transformer = JAISModel(vllm_config=vllm_config, prefix=maybe_prefix( prefix, "transformer")) diff --git a/vllm/model_executor/models/mpt.py b/vllm/model_executor/models/mpt.py index 1235816413a44..88ee6c0d9124d 100644 --- a/vllm/model_executor/models/mpt.py +++ b/vllm/model_executor/models/mpt.py @@ -51,6 +51,7 @@ def __init__( cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + prev_attn: Optional[nn.Module] = None, ): super().__init__() self.d_model = config.d_model @@ -59,6 +60,7 @@ def __init__( self.clip_qkv = config.attn_config["clip_qkv"] self.qk_ln = config.attn_config["qk_ln"] self.alibi_bias_max = config.attn_config["alibi_bias_max"] + self.max_seq_len = config.max_seq_len if "kv_n_heads" in config.attn_config: self.total_num_kv_heads = config.attn_config['kv_n_heads'] else: @@ -107,17 +109,23 @@ def __init__( alibi_slopes = _get_alibi_slopes(self.total_num_heads, self.alibi_bias_max) alibi_slopes = alibi_slopes[head_start:head_end].tolist() + prev_attn = None if prev_attn is None else prev_attn.attn self.head_dim = self.d_model // self.total_num_heads scaling = self.head_dim**-0.5 - self.attn = Attention(self.num_heads, - self.head_dim, - scaling, - alibi_slopes=alibi_slopes, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn") + self.attn = Attention( + self.num_heads, + self.head_dim, + scaling, + alibi_slopes=alibi_slopes, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + logits_soft_cap=self.max_seq_len, + tp_rank=tp_rank, + prefix=f"{prefix}.attn", + prev_attn=prev_attn, + ) def forward( self, @@ -173,20 +181,23 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class MPTBlock(nn.Module): - def __init__( - self, - config: MPTConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ): + def __init__(self, + config: MPTConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + prev_layer: Optional[nn.Module] = None): super().__init__() hidden_size = config.d_model self.norm_1 = nn.LayerNorm(hidden_size) - self.attn = MPTAttention(config, - cache_config, - quant_config, - prefix=f"{prefix}.attn") + prev_attn = None if prev_layer is None else prev_layer.attn + self.attn = MPTAttention( + config, + cache_config, + quant_config, + prefix=f"{prefix}.attn", + prev_attn=prev_attn, + ) self.norm_2 = nn.LayerNorm(hidden_size) self.ffn = MPTMLP(config, quant_config) @@ -230,9 +241,14 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): ) self.start_layer, self.end_layer, self.blocks = make_layers( config.n_layers, - lambda prefix: MPTBlock( - config, cache_config, quant_config, prefix=prefix), - prefix=f"{prefix}.blocks") + lambda prefix, prev_layer: MPTBlock(config, + cache_config, + quant_config, + prefix=prefix, + prev_layer=prev_layer), + prefix=f"{prefix}.blocks", + use_layer_sharing=True, + ) self.norm_f = nn.LayerNorm(config.d_model) if config.no_bias: for module in self.modules(): @@ -288,7 +304,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.config = config assert config.tie_word_embeddings self.quant_config = quant_config - + self.use_alibi = config.attn_config['alibi'] self.transformer = MPTModel(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "transformer")) self.lm_head = self.transformer.wte diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 269b66806adf4..19d2368ba931f 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -1,7 +1,7 @@ import itertools from dataclasses import dataclass, field -from typing import (Callable, Dict, Iterable, List, Literal, Mapping, Optional, - Protocol, Set, Tuple, Union, overload) +from typing import (Callable, Dict, Iterable, List, Literal, Mapping, + Optional, Protocol, Set, Tuple, Union, overload) import torch import torch.nn as nn @@ -330,7 +330,7 @@ def merge_multimodal_embeddings_from_map( inputs_embeds: torch.Tensor, multimodal_embeddings: NestedTensors, placeholder_map: MultiModalPlaceholderMap.IndexMap) -> torch.Tensor: """ - Merge ``multimodal_embeddings`` into ``inputs_embeds`` using the provided + Merge ``multimodal_embeddings`` into ``inputs_embeds`` using the provided placeholder map . Note: @@ -415,11 +415,11 @@ def merge_multimodal_embeddings( Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the positions in ``inputs_embeds`` corresponding to placeholder tokens in ``input_ids``. - - ``placeholder_token_id`` can be a list of token ids (e.g, token ids - of img_start, img_break, and img_end tokens) when needed: This means - the order of these tokens in the ``input_ids`` MUST MATCH the order of - their embeddings in ``multimodal_embeddings`` since we need to + + ``placeholder_token_id`` can be a list of token ids (e.g, token ids + of img_start, img_break, and img_end tokens) when needed: This means + the order of these tokens in the ``input_ids`` MUST MATCH the order of + their embeddings in ``multimodal_embeddings`` since we need to slice-merge instead of individually scattering. For example, if input_ids is "TTTTTSIIIBIIIBIIIETTT", where @@ -428,9 +428,9 @@ def merge_multimodal_embeddings( - I is image embedding token - B is image break token - E is image end token. - - Then the image embeddings (that correspond to I's) from vision encoder - must be padded with embeddings of S, B, and E in the same order of + + Then the image embeddings (that correspond to I's) from vision encoder + must be padded with embeddings of S, B, and E in the same order of input_ids for a correct embedding merge. Note: @@ -454,7 +454,10 @@ def merge_multimodal_embeddings( class LayerFn(Protocol): - def __call__(self, prefix: str) -> torch.nn.Module: + def __call__( + self, + prefix: str, + prev_layer: Optional[torch.nn.Module] = None) -> torch.nn.Module: ... @@ -537,6 +540,7 @@ def make_layers( num_hidden_layers: int, layer_fn: LayerFn, prefix: str, + use_layer_sharing: bool = False, ) -> Tuple[int, int, torch.nn.ModuleList]: """Make a list of layers with the given layer function, taking pipeline parallelism into account. @@ -546,11 +550,26 @@ def make_layers( start_layer, end_layer = get_pp_indices(num_hidden_layers, get_pp_group().rank_in_group, get_pp_group().world_size) - modules = torch.nn.ModuleList( - [PPMissingLayer() for _ in range(start_layer)] + [ - maybe_offload_to_cpu(layer_fn(prefix=f"{prefix}.{idx}")) - for idx in range(start_layer, end_layer) - ] + [PPMissingLayer() for _ in range(end_layer, num_hidden_layers)]) + layers = [] + for _ in range(start_layer): + curr_layer = PPMissingLayer() + layers.append(curr_layer) + + curr_layer = None + for idx in range(start_layer, end_layer): + if use_layer_sharing: + curr_layer = layer_fn(prefix=f"{prefix}.{idx}", + prev_layer=curr_layer) + else: + curr_layer = layer_fn(prefix=f"{prefix}.{idx}") + layers.append(maybe_offload_to_cpu(curr_layer)) + + for _ in range(end_layer, num_hidden_layers): + curr_layer = PPMissingLayer() + layers.append(curr_layer) + + modules = nn.ModuleList(layers) + return start_layer, end_layer, modules diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 251a103e60f06..f127551a0d9af 100644 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -71,6 +71,7 @@ def __init__( kv_cache_dtype: str, blocksparse_params: Optional[Dict[str, Any]] = None, logits_soft_cap: Optional[float] = None, + **kwargs, ) -> None: if blocksparse_params is not None: raise ValueError( diff --git a/vllm/worker/hpu_enc_dec_model_runner.py b/vllm/worker/hpu_enc_dec_model_runner.py index 2b8acb502822d..15e5ece47c7b8 100644 --- a/vllm/worker/hpu_enc_dec_model_runner.py +++ b/vllm/worker/hpu_enc_dec_model_runner.py @@ -411,7 +411,8 @@ def create_dummy_seq_group_metadata(self, seq_len, is_prompt, lora_request=None, - temperature=0): + temperature=0, + last_block_assigned=0): sampling_params = SamplingParams(temperature=0) num_blocks = math.ceil(seq_len / self.block_size) cross_block_table: Optional[List[int]] = None diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index 7c3679d40546d..9d6c60713e42b 100755 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -736,6 +736,8 @@ def load_model(self) -> None: self.model = self.model.to("hpu") htcore.mark_step() + self.use_alibi = hasattr(self.model, + "use_alibi") and self.model.use_alibi hidden_layer_markstep_interval = int( os.getenv('VLLM_CONFIG_HIDDEN_LAYERS', '1')) model_config = getattr(self.model, "config", None) @@ -978,10 +980,14 @@ def _prepare_prompt( block_list=prefix_block_list_tensor, block_mapping=None, block_usage=None, + # Set by later "precompute_indices_and_offsets" function call block_indices=None, + # Set by later "precompute_indices_and_offsets" function call block_offsets=None, + # Set by later "_set_block_scales" function call block_scales=None, block_groups=None, + # Set by later "_set_attn_bias" function call attn_bias=None, seq_lens=seq_lens, seq_lens_tensor=seq_lens_tensor, @@ -990,8 +996,9 @@ def _prepare_prompt( num_prefill_tokens=num_prefill_tokens, num_decode_tokens=0, slot_mapping=slot_mapping, - multi_modal_placeholder_index_maps= - None # FIXME(kzawora): mutli-modality will not work here + alibi_blocks=None, + # FIXME(kzawora): mutli-modality will not work here + multi_modal_placeholder_index_maps=None, ) multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list) for t in multi_modal_kwargs: @@ -1130,6 +1137,13 @@ def _prepare_decode( block_groups = padding_fn(block_groups, -1) block_usage = padding_fn(block_usage, 1) + alibi_blocks = None + if self.use_alibi: + alibi_blocks = self._compute_alibi_block(block_tables, seq_lens, + len(block_groups)) + alibi_blocks = alibi_blocks.to( # type: ignore + self.device, non_blocking=True) + block_list = torch.tensor(block_list, dtype=torch.int, device='cpu') block_groups = torch.tensor(block_groups, dtype=torch.int, @@ -1157,12 +1171,17 @@ def _prepare_decode( attn_metadata = self.attn_backend.make_metadata( is_prompt=False, block_list=block_list, + # Set by later "_set_block_mapping" function call block_mapping=None, block_usage=block_usage, + # Set by later "precompute_indices_and_offsets" function call block_indices=None, + # Set by later "precompute_indices_and_offsets" function call block_offsets=None, + # Set by later "_set_block_scales" function call block_scales=None, block_groups=block_groups, + # Set by later "_set_block_mapping" function call attn_bias=None, seq_lens_tensor=None, context_lens_tensor=None, @@ -1170,7 +1189,9 @@ def _prepare_decode( num_prefill_tokens=0, num_decode_tokens=num_decode_tokens, slot_mapping=slot_mapping, - multi_modal_placeholder_index_maps=None) + alibi_blocks=alibi_blocks, + multi_modal_placeholder_index_maps=None, + ) return PrepareDecodeMetadata(input_tokens=input_tokens, input_positions=input_positions, attn_metadata=attn_metadata, @@ -1180,6 +1201,63 @@ def _prepare_decode( slot_mapping=slot_mapping, lora_ids=lora_ids) + def _compute_alibi_block(self, block_tables, seq_lens, num_blocks): + """ + Compute the ALiBi offsets for each block during decoding. + + For each block in each sequence, this function assigns position-based + offsets according to ALiBi logic. It returns a tensor that captures + these offsets for all sequences and blocks, which is then used for + decode-time ALiBi bias creation. + + Args: + block_tables: + A list of lists, where each inner list contains block indices + assigned to a particular sequence. + seq_lens: + A list of sequence lengths corresponding to each sequence. + num_blocks: + The total number of blocks across all sequences for which + ALiBi offsets need to be computed. + + Returns: + A torch.Tensor of shape [num_blocks, block_size], containing ALiBi + offsets for each block. + """ + # Create intermediary and output structures + max_block_table_len = max( + len(block_table) for block_table in block_tables) + alibi_offsets = torch.arange(-max_block_table_len * self.block_size + + 1, + 1, + dtype=torch.long, + device='cpu') + alibi_blocks = torch.zeros((num_blocks, self.block_size), + dtype=torch.long, + device='cpu') + + # Assign biases per token + for batch_idx in range(len(block_tables)): + seq_len = seq_lens[batch_idx] + for seq_idx in range(len(block_tables[batch_idx])): + block_idx = block_tables[batch_idx][seq_idx] + + # Calculate the number of valid positions in the current block + valid_length = seq_len - seq_idx * self.block_size + if valid_length > 0: + current_block_length = min(valid_length, self.block_size) + offset_end = current_block_length - valid_length + if offset_end == 0: + alibi_blocks[ + block_idx][:current_block_length] = alibi_offsets[ + -valid_length:] + else: + alibi_blocks[ + block_idx][:current_block_length] = alibi_offsets[ + -valid_length:offset_end] + + return alibi_blocks + def prepare_input_tensors( self, seq_group_metadata_list: List[SequenceGroupMetadata], @@ -1384,6 +1462,7 @@ def trim_attn_metadata(self, metadata: AttentionMetadata) -> object: 'block_offsets', 'block_scales', 'block_groups', + 'alibi_blocks', ]) return attention_metadata @@ -1392,18 +1471,31 @@ def create_dummy_seq_group_metadata(self, seq_len, is_prompt, lora_request=None, - temperature=0): + temperature=0, + last_block_assigned=0): sampling_params = SamplingParams(temperature=temperature) num_blocks = math.ceil(seq_len / self.block_size) - seq_len = max(seq_len, 1) + # FIXME(Tanner): + # When num_scheduler_steps>1 an additional + # token gets appended to dummy groups at some point + # This causes an RTE during warmup. Hence, subtracting 1 from seq_len. + seq_len = max(seq_len - 1, 1) + block_tables: Optional[dict[Any, Any]] = None if is_prompt: input_len = seq_len output_len = 0 - block_tables = None else: input_len = seq_len - 1 output_len = 1 - block_tables = {group_id: [_PAD_BLOCK_ID] * num_blocks} + # NOTE(Tanner): + # ALiBI biases fail if block_tables for + # dummy sequences are all zeros. + # By default "_PAD_BLOCK_ID" is "0" and this + # is not a realistic value for block tables. + block_tables = {group_id: []} + for block_idx in range(num_blocks): + last_block_assigned += 1 + block_tables[group_id] += [last_block_assigned] prompt_token_ids = [0] * input_len output_token_ids = [1] * output_len prompt_token_ids_array = array('l', prompt_token_ids) # noqa: F821 @@ -1477,18 +1569,31 @@ def warmup_scenario(self, temperature=temperature) for i in range(batch_size) ] else: - # FIXME: seq_len is actually number of blocks - blocks = [seq_len // batch_size for _ in range(batch_size)] - blocks[0] += seq_len % batch_size - seqs = [ - self.create_dummy_seq_group_metadata( - i, - b * self.block_size - 1, - is_prompt, - lora_request=dummy_lora_requests_per_seq[i] - if dummy_lora_requests_per_seq else None, - temperature=temperature) for i, b in enumerate(blocks) + # NOTE(Tanner): + # seq_len is num blocks + # Here we assign as many blocks to each sequence as we can + blocks_per_seq = (seq_len - 1) // batch_size + extra_blocks = (seq_len - 1) % batch_size + blocks = [ + blocks_per_seq + (1 if i < extra_blocks else 0) + for i in range(batch_size) ] + seqs = [] + last_block_assigned = 0 + for i, b in enumerate(blocks): + seqs += [ + self.create_dummy_seq_group_metadata( + i, + b * self.block_size, + is_prompt, + lora_request=dummy_lora_requests_per_seq[i] + if dummy_lora_requests_per_seq else None, + temperature=temperature, + last_block_assigned=last_block_assigned, + ) + ] + if len(seqs[-1].block_tables[i]) > 0: + last_block_assigned = seqs[-1].block_tables[i][-1] torch.hpu.synchronize() profiler = None if is_pt_profiler_run and self.is_driver_worker: @@ -1955,7 +2060,7 @@ def create_lora_mask(self, input_tokens: torch.Tensor, lora_ids: List[int], This is a helper function to create the mask for lora computations. Lora Mask is needed to ensure we match the correct lora weights for the for the request. - For Prompt phase we have + For Prompt phase we have lora_mask with shape (batch_size * seq_len, max_loras * max_rank) lora_logits_mask with shape (batch_size, max_loras * max_rank) For Decode phase we have both