From 8bab5ae47fa9825bde6025ca0ef32eb5ab26fa00 Mon Sep 17 00:00:00 2001 From: Yishuo Wang Date: Wed, 25 Dec 2024 10:06:50 +0800 Subject: [PATCH 1/3] add compresskv back for mistral --- .../ipex_llm/transformers/models/mistral.py | 27 +++++++++++++++---- .../ipex_llm/transformers/models/mixtral.py | 4 +-- 2 files changed, 24 insertions(+), 7 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/models/mistral.py b/python/llm/src/ipex_llm/transformers/models/mistral.py index 983f2508e45..983cb0cd304 100644 --- a/python/llm/src/ipex_llm/transformers/models/mistral.py +++ b/python/llm/src/ipex_llm/transformers/models/mistral.py @@ -37,6 +37,7 @@ from typing import Optional, Tuple, Union, List +import os import torch from transformers.cache_utils import Cache from transformers.modeling_outputs import BaseModelOutputWithPast @@ -45,8 +46,10 @@ from ipex_llm.transformers.models.common import merge_qkv_base from ipex_llm.transformers.models.common import scaled_dot_product_attention from ipex_llm.transformers.models.utils import should_use_fuse_rope, apply_rotary_pos_emb +from ipex_llm.transformers.models.utils import should_use_compresskv, is_enough_kv_cache_room_4_36 from ipex_llm.transformers.models.utils import use_quantize_kv_cache -from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicNormalCache +from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicCompressCache, DynamicNormalCache +KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256)) def mistral_model_forward( @@ -69,11 +72,17 @@ def mistral_model_forward( use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, inputs, self.config.num_attention_heads // self.config.num_key_value_heads) + use_compress_kv = should_use_compresskv(inputs, inputs.size(-2)) or \ + isinstance(past_key_values, DynamicCompressCache) if use_cache: - if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache): + if use_compress_kv and not isinstance(past_key_values, DynamicCompressCache): + past_key_values = DynamicCompressCache.from_legacy_cache(past_key_values) + elif use_quantize_kv and not use_compress_kv and not isinstance(past_key_values, + DynamicFp8Cache): past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values) - elif not use_quantize_kv and not isinstance(past_key_values, DynamicNormalCache): + elif not use_quantize_kv and not use_compress_kv and not isinstance(past_key_values, + DynamicNormalCache): past_key_values = DynamicNormalCache.from_legacy_cache(past_key_values) # ipex-llm changes end @@ -127,8 +136,16 @@ def mistral_attention_forward( query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids, "mistral") - key_states, value_states = past_key_value.update(key_states, value_states, - self.layer_idx, None) + if isinstance(past_key_value, DynamicCompressCache): + enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx, q_len) + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, + query_states, attention_mask, self.num_key_value_groups, + self.config, enough_kv_room, KV_CACHE_ALLOC_BLOCK_LENGTH + ) + else: + key_states, value_states = past_key_value.update(key_states, value_states, + self.layer_idx, None) # IPEX-LLM OPT: sdpa attn_weights = None diff --git a/python/llm/src/ipex_llm/transformers/models/mixtral.py b/python/llm/src/ipex_llm/transformers/models/mixtral.py index b63772a8e41..25083827ce1 100644 --- a/python/llm/src/ipex_llm/transformers/models/mixtral.py +++ b/python/llm/src/ipex_llm/transformers/models/mixtral.py @@ -52,7 +52,7 @@ from ipex_llm.utils.common import invalidInputError from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache from ipex_llm.transformers.models.utils import apply_rotary_pos_emb, is_enough_kv_cache_room_4_36 -from ipex_llm.transformers.models.mistral import should_use_fuse_rope +from ipex_llm.transformers.models.utils import should_use_fuse_rope from ipex_llm.transformers.models.utils import use_decoding_fast_path from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp from ipex_llm.transformers.models.utils import mlp_fusion_check, SILU @@ -171,7 +171,7 @@ def mixtral_attention_forward( # for flash attention original_dtype = hidden_states.dtype - use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids) + use_fuse_rope = should_use_fuse_rope(hidden_states, position_ids, self.training) enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx) decoding_fast_path = use_decoding_fast_path(self.q_proj, use_fuse_rope, From 6c490f25017337b27ea1870b37304c2d1d7e6461 Mon Sep 17 00:00:00 2001 From: Yishuo Wang Date: Wed, 25 Dec 2024 10:52:45 +0800 Subject: [PATCH 2/3] fix --- python/llm/src/ipex_llm/transformers/models/mistral.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/llm/src/ipex_llm/transformers/models/mistral.py b/python/llm/src/ipex_llm/transformers/models/mistral.py index 983cb0cd304..1a844931123 100644 --- a/python/llm/src/ipex_llm/transformers/models/mistral.py +++ b/python/llm/src/ipex_llm/transformers/models/mistral.py @@ -72,7 +72,7 @@ def mistral_model_forward( use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, inputs, self.config.num_attention_heads // self.config.num_key_value_heads) - use_compress_kv = should_use_compresskv(inputs, inputs.size(-2)) or \ + use_compress_kv = should_use_compresskv(inputs, inputs.size(1)) or \ isinstance(past_key_values, DynamicCompressCache) if use_cache: From 0483bfad7d8a6aa4abaade87d89839c95a4f2af7 Mon Sep 17 00:00:00 2001 From: Yishuo Wang Date: Wed, 25 Dec 2024 11:01:46 +0800 Subject: [PATCH 3/3] fix --- .../ipex_llm/transformers/models/mistral.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/models/mistral.py b/python/llm/src/ipex_llm/transformers/models/mistral.py index 1a844931123..4534f735aaa 100644 --- a/python/llm/src/ipex_llm/transformers/models/mistral.py +++ b/python/llm/src/ipex_llm/transformers/models/mistral.py @@ -48,7 +48,8 @@ from ipex_llm.transformers.models.utils import should_use_fuse_rope, apply_rotary_pos_emb from ipex_llm.transformers.models.utils import should_use_compresskv, is_enough_kv_cache_room_4_36 from ipex_llm.transformers.models.utils import use_quantize_kv_cache -from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicCompressCache, DynamicNormalCache +from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicNormalCache +from ipex_llm.transformers.kv import DynamicCompressCache, DynamicCompressFp8Cache KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256)) @@ -77,12 +78,17 @@ def mistral_model_forward( if use_cache: if use_compress_kv and not isinstance(past_key_values, DynamicCompressCache): - past_key_values = DynamicCompressCache.from_legacy_cache(past_key_values) - elif use_quantize_kv and not use_compress_kv and not isinstance(past_key_values, - DynamicFp8Cache): + if use_quantize_kv: + past_key_values = DynamicCompressFp8Cache.from_legacy_cache(past_key_values) + else: + past_key_values = DynamicCompressCache.from_legacy_cache(past_key_values) + elif use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache): past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values) - elif not use_quantize_kv and not use_compress_kv and not isinstance(past_key_values, - DynamicNormalCache): + elif ( + not use_quantize_kv + and not use_compress_kv + and not isinstance(past_key_values, DynamicNormalCache) + ): past_key_values = DynamicNormalCache.from_legacy_cache(past_key_values) # ipex-llm changes end