Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add compresskv back for mistral #12607

Merged
merged 3 commits into from
Dec 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 27 additions & 4 deletions python/llm/src/ipex_llm/transformers/models/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@

from typing import Optional, Tuple, Union, List

import os
import torch
from transformers.cache_utils import Cache
from transformers.modeling_outputs import BaseModelOutputWithPast
Expand All @@ -45,8 +46,11 @@
from ipex_llm.transformers.models.common import merge_qkv_base
from ipex_llm.transformers.models.common import scaled_dot_product_attention
from ipex_llm.transformers.models.utils import should_use_fuse_rope, apply_rotary_pos_emb
from ipex_llm.transformers.models.utils import should_use_compresskv, is_enough_kv_cache_room_4_36
from ipex_llm.transformers.models.utils import use_quantize_kv_cache
from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicNormalCache
from ipex_llm.transformers.kv import DynamicCompressCache, DynamicCompressFp8Cache
KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256))


def mistral_model_forward(
Expand All @@ -69,11 +73,22 @@ def mistral_model_forward(
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, inputs,
self.config.num_attention_heads //
self.config.num_key_value_heads)
use_compress_kv = should_use_compresskv(inputs, inputs.size(1)) or \
isinstance(past_key_values, DynamicCompressCache)

if use_cache:
if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
if use_compress_kv and not isinstance(past_key_values, DynamicCompressCache):
if use_quantize_kv:
past_key_values = DynamicCompressFp8Cache.from_legacy_cache(past_key_values)
else:
past_key_values = DynamicCompressCache.from_legacy_cache(past_key_values)
elif use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
elif not use_quantize_kv and not isinstance(past_key_values, DynamicNormalCache):
elif (
not use_quantize_kv
and not use_compress_kv
and not isinstance(past_key_values, DynamicNormalCache)
):
past_key_values = DynamicNormalCache.from_legacy_cache(past_key_values)
# ipex-llm changes end

Expand Down Expand Up @@ -127,8 +142,16 @@ def mistral_attention_forward(
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
cos, sin, position_ids, "mistral")

key_states, value_states = past_key_value.update(key_states, value_states,
self.layer_idx, None)
if isinstance(past_key_value, DynamicCompressCache):
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx, q_len)
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx,
query_states, attention_mask, self.num_key_value_groups,
self.config, enough_kv_room, KV_CACHE_ALLOC_BLOCK_LENGTH
)
else:
key_states, value_states = past_key_value.update(key_states, value_states,
self.layer_idx, None)

# IPEX-LLM OPT: sdpa
attn_weights = None
Expand Down
4 changes: 2 additions & 2 deletions python/llm/src/ipex_llm/transformers/models/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
from ipex_llm.utils.common import invalidInputError
from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb, is_enough_kv_cache_room_4_36
from ipex_llm.transformers.models.mistral import should_use_fuse_rope
from ipex_llm.transformers.models.utils import should_use_fuse_rope
from ipex_llm.transformers.models.utils import use_decoding_fast_path
from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp
from ipex_llm.transformers.models.utils import mlp_fusion_check, SILU
Expand Down Expand Up @@ -171,7 +171,7 @@ def mixtral_attention_forward(
# for flash attention
original_dtype = hidden_states.dtype

use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
use_fuse_rope = should_use_fuse_rope(hidden_states, position_ids, self.training)
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx)
decoding_fast_path = use_decoding_fast_path(self.q_proj,
use_fuse_rope,
Expand Down
Loading