Skip to content

Commit

Permalink
unify xpu and cpu backend and use paged attention (#1009)
Browse files Browse the repository at this point in the history
* add page attention implementation remove jit logic

Signed-off-by: Wang, Yi A <[email protected]>

* add support in transformers 4.45

Signed-off-by: Wang, Yi A <[email protected]>

* fix congif (#935)

* move patch model to init

Signed-off-by: Wang, Yi A <[email protected]>

* refine class IPEXPagedCache's update method (#945)

* refine class IPEXPagedCache's update method

Signed-off-by: Liu, Kaixuan <[email protected]>

* replace tensor on xpu to List to avoid memory copy

Signed-off-by: Liu, Kaixuan <[email protected]>

* split IPEXPagedCache's update function into `update_for_prefill` and `update_for_decode`

Signed-off-by: Liu, Kaixuan <[email protected]>

---------

Signed-off-by: Liu, Kaixuan <[email protected]>

* fix bug when doing beam search (#954)

Signed-off-by: Liu, Kaixuan <[email protected]>

* enable qkv concat layer (#958)

* enable qkv

* split key value into 2 lists

* add xpu cache optimiztion

Signed-off-by: Wang, Yi A <[email protected]>

* xpu mlp optimization

Signed-off-by: Wang, Yi A <[email protected]>

* optimize cache ops in xpu, improve for beam search

Signed-off-by: Wang, Yi A <[email protected]>

* enable gpt2, falcon has core dump error in PagedAttention.single_quer… (#979)

* enable gpt2, falcon has core dump error in PagedAttention.single_query_cached_kv_attention

* enable new_decoder_arch falcon

* only keep 1 config

* rm autocast

* fix unit test case, CPU part is OK; Enable Falcon7b for XPU (#992)

* fix bug when run IPEXCausalModel forward directly; fix bug when using `save_pretrain`

Signed-off-by: Liu, Kaixuan <[email protected]>

* add LinearGelu Op support for XPU

Signed-off-by: Liu, Kaixuan <[email protected]>

* fix unit test error

Signed-off-by: Liu, Kaixuan <[email protected]>

* adjust unit test case

Signed-off-by: Liu, Kaixuan <[email protected]>

* fix bug

Signed-off-by: Liu, Kaixuan <[email protected]>

---------

Signed-off-by: Liu, Kaixuan <[email protected]>

* skip assited decoding unit test for models using paged attention (#998)

* skip assited decoding unit test for models using paged attention

Signed-off-by: Liu, Kaixuan <[email protected]>

* XPU CI tests get almost all passed

Signed-off-by: Liu, Kaixuan <[email protected]>

---------

Signed-off-by: Liu, Kaixuan <[email protected]>

* fix ci config (#1010)

Signed-off-by: jiqing-feng <[email protected]>

* Fix tests versions (#1011)

* fix ci config

* fix test versions

* fix ipex version

Signed-off-by: jiqing-feng <[email protected]>

* fix torch test version (#1012)

Signed-off-by: jiqing-feng <[email protected]>

* use python3.9 test (#1013)

* use python3.9 test

Signed-off-by: jiqing-feng <[email protected]>

* change ipex transformers limited verison in setup (#1015)

* change ipex transformers limited verison in setup
* fix inc tests

Signed-off-by: jiqing-feng <[email protected]>

* add XPU LinearAddAdd op (#1017)

Signed-off-by: Liu, Kaixuan <[email protected]>

* fix bert and vit patch (#1022)

* fix bert and vit patch
* fix vit and bert save


Signed-off-by: jiqing-feng <[email protected]>

* Paged attn (#1024)

* fix reorder cache for non-patch models

Signed-off-by: jiqing-feng <[email protected]>

* disable torch < 2.3 tests, we won't use torch < 2.4

Signed-off-by: jiqing-feng <[email protected]>

* fix test beam serach

Signed-off-by: jiqing-feng <[email protected]>

* fix cache selection

Signed-off-by: jiqing-feng <[email protected]>

* upgrad to transformers4.46

Signed-off-by: jiqing-feng <[email protected]>

* change ipex test yaml transformers version to 4.46

Signed-off-by: jiqing-feng <[email protected]>

---------

Signed-off-by: jiqing-feng <[email protected]>

* set device as the same as origin model (#1031)

* set device as the same as origin model
* fix device

Signed-off-by: jiqing-feng <[email protected]>

---------

Signed-off-by: jiqing-feng <[email protected]>

* Simplify IPEXModel (#1032)

* simplify forward and save pretrained since no jit support

* fix format

* rm warmup because no jit mode anymore

* simplify forward for causal lm model

* fix paged pkv  forward

* disable use_cache when just run forward

---------

Signed-off-by: jiqing-feng <[email protected]>

* nice code (#1035)

Signed-off-by: Liu, Kaixuan <[email protected]>

* Paged attn (#1036)

* nice code
* device type adjustment

Signed-off-by: Liu, Kaixuan <[email protected]>

* Enable torch.compile for non-generation tasks in CPU (#1037)

* enable compile for non-generation tasks
* add no_grad in forward
* warmup compiled model
* disable compile not ready models
* set system level optimize for torch.compile
* fix typo
* add comments
* set torch minimum version for compiling

Signed-off-by: jiqing-feng <[email protected]>

* Fix ipex upload and update readme. (#1045)

* fix readme and push to hub support

Signed-off-by: jiqing-feng <[email protected]>

* rm export in tests

Signed-off-by: jiqing-feng <[email protected]>

* test with torch 2.5.*

Signed-off-by: jiqing-feng <[email protected]>

---------

Signed-off-by: jiqing-feng <[email protected]>

* Fix tests (#1047)

* fix tests
* fix typo
* add patched tests

* change forward to generate

* fix tests

* fix test model name


---------

Signed-off-by: jiqing-feng <[email protected]>

* Patch gpt2 block forward for passing input_lens. (#1050)

* fix forward without pkv
* patch gpt2 block forward
* fix typo
* revert causal lm tests

Signed-off-by: jiqing-feng <[email protected]>

---------

Signed-off-by: Wang, Yi A <[email protected]>
Signed-off-by: Liu, Kaixuan <[email protected]>
Signed-off-by: jiqing-feng <[email protected]>
Co-authored-by: jiqing-feng <[email protected]>
Co-authored-by: kaixuanliu <[email protected]>
Co-authored-by: Ilyas Moutawwakil <[email protected]>
  • Loading branch information
4 people authored Dec 5, 2024
1 parent c94b3f5 commit 41f0a46
Show file tree
Hide file tree
Showing 13 changed files with 1,035 additions and 860 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test_inc.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ jobs:
strategy:
fail-fast: false
matrix:
torch-version: ["2.4.*", "2.5.0"]
torch-version: ["2.4.0", "2.5.*"]

runs-on: ubuntu-22.04

Expand Down
8 changes: 2 additions & 6 deletions .github/workflows/test_ipex.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ jobs:
strategy:
fail-fast: false
matrix:
torch-version: ["2.2.0", "2.3.*"]
transformers-version: ["4.39.0", "4.44.*"]
transformers-version: ["4.46.0", "4.46.3"]
torch-version: ["2.4.0", "2.5.*"]

runs-on: ubuntu-22.04

Expand All @@ -38,10 +38,6 @@ jobs:
pip install torch==${{ matrix.torch-version }} torchaudio torchvision --extra-index-url https://download.pytorch.org/whl/cpu
pip install .[ipex,tests] transformers[testing]==${{ matrix.transformers-version }} intel_extension_for_pytorch==${{ matrix.torch-version }}
- if: ${{ matrix.torch-version == '2.2.0' }}
name: Downgrade Numpy
run: pip install numpy==1.*

- name: Assert versions
run: |
python -c "import torch; print(torch.__version__); assert torch.__version__.startswith('${{ matrix.torch-version }}'.replace('.*', ''))"
Expand Down
6 changes: 3 additions & 3 deletions docs/source/ipex/inference.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ Optimum Intel can be used to load models from the [Hub](https://huggingface.co/m

## Loading

You can load your model and apply IPEX optimizations (including weight prepacking and graph mode). For supported architectures like LLaMA, BERT and ViT, further optimizations will be applied by patching the model to use custom operators.
For now, support is only enabled for CPUs and the original model will be exported via TorchScript. In the future `torch.compile` will be used and model exported via TorchScript will get deprecated.
You can load your model and apply IPEX optimizations (apply torch.compile for non-generation tasks). For supported architectures like LLaMA, BERT and ViT, further optimizations will be applied by patching the model to use custom operators.
For now, support is enabled for Intel CPU/GPU. Previous models converted to TorchScript will be deprecated in v1.22.

```diff
import torch
Expand All @@ -25,7 +25,7 @@ For now, support is only enabled for CPUs and the original model will be exporte

model_id = "gpt2"
- model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16)
+ model = IPEXModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, export=True)
+ model = IPEXModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(model_id)
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
results = pipe("He's a dreadful magician and")
Expand Down
238 changes: 238 additions & 0 deletions optimum/exporters/ipex/cache_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,238 @@
from typing import List, Optional, Tuple

import torch
from intel_extension_for_pytorch.llm.modules import PagedAttention
from transformers import Cache, PretrainedConfig


class IPEXPagedCache(Cache):
"""
A PagedCache that grows dynamically as more tokens are generated. everytime it grows block-size memory, vendor could set the pageCache memory layout.
ipex-xpu:
ipex-cpu:
Example:
```python
>>> from transformers import AutoTokenizer
>>> from optimum.intel import IPEXModelForCausalLM
>>> from optimum.exporters.ipex.cache_utils import IPEXPagedCache
>>> model = IPEXModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf", export=True)
>>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
>>> inputs = tokenizer(text="My name is GPT2", return_tensors="pt")
>>> # Prepare a cache class and pass it to model's forward
>>> past_key_values = IPEXPagedCache()
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
>>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation
```
"""

def __init__(
self,
config: PretrainedConfig,
batch_size: int,
max_cache_len: int,
device,
dtype=None,
layer_device_map=None,
**kwargs,
) -> None:
super().__init__()
self.batch_size = batch_size
# Used in `generate` to keep tally of how many tokens the cache has seen
self._seen_tokens = torch.zeros([batch_size], dtype=torch.int32, device=device)
self.block_size = 16
self.num_blocks = (max_cache_len // self.block_size + (max_cache_len % self.block_size != 0)) * batch_size
self.block_tables = -1 * torch.ones([self.num_blocks], dtype=torch.int32, device=device).reshape(
batch_size, -1
)
self.free_blocks = torch.arange(self.num_blocks, device=device)
self.max_cache_len = max_cache_len
self.num_kv_heads = config.num_key_value_heads
self.num_hidden_layers = config.num_hidden_layers
if hasattr(config, "head_dim"):
head_size = config.head_dim
else:
head_size = config.hidden_size // config.num_attention_heads
self.head_size = head_size
self.max_seq_len = 0

self.key_cache: List[torch.Tensor] = []
self.value_cache: List[torch.Tensor] = []

if device.type == "cpu":
key_cache_shape = (self.num_blocks, self.num_kv_heads, self.block_size, head_size)
value_cache_shape = (self.num_blocks, self.num_kv_heads, self.block_size, head_size)
elif device.type == "xpu":
key_cache_shape = (self.num_blocks, self.num_kv_heads, head_size, self.block_size, 1)
value_cache_shape = (self.num_blocks, self.num_kv_heads, head_size, self.block_size)
for i in range(config.num_hidden_layers):
new_layer_key_cache = torch.zeros(key_cache_shape, dtype=dtype, device=device)
new_layer_value_cache = torch.zeros(value_cache_shape, dtype=dtype, device=device)
self.key_cache.append(new_layer_key_cache)
self.value_cache.append(new_layer_value_cache)

def update_for_prefill(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
batch_size: int,
input_lens: torch.Tensor,
):
if layer_idx == 0:
all_block_indices = []
all_slot_offsets = []
num_blocks = (input_lens + self.block_size - 1) // self.block_size
for i in range(batch_size):
for b_idx in range(num_blocks[i]):
if self.block_tables[i][b_idx] == -1:
# need a free block
self.block_tables[i][b_idx] = self.free_blocks[0]
self.free_blocks = self.free_blocks[1:]

slots_range = torch.arange(input_lens[i], device=key_states.device)
block_indices = slots_range // self.block_size
slot_offsets = slots_range % self.block_size
all_block_indices.append(self.block_tables[i][block_indices])
all_slot_offsets.append(slot_offsets)

all_block_indices = torch.cat(all_block_indices)
all_slot_offsets = torch.cat(all_slot_offsets)
self.slots = all_block_indices * self.block_size + all_slot_offsets

# Update the cache
PagedAttention.reshape_and_cache(
key_states,
value_states,
self.key_cache[layer_idx],
self.value_cache[layer_idx],
self.slots,
)

# Update the number of seen tokens
if layer_idx == self.num_hidden_layers - 1:
self._seen_tokens = self._seen_tokens + input_lens
self.max_seq_len, _ = self._seen_tokens.max(dim=0)

def update_for_decode(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
batch_size: int,
):
if layer_idx == 0:
start_block_idx = self._seen_tokens // self.block_size
num_blocks = (self._seen_tokens + self.block_size) // self.block_size
slot_offset_in_block = (self._seen_tokens) % self.block_size
self.slots = torch.zeros([batch_size], device=key_states.device, dtype=torch.int32)
for i in range(batch_size):
for b_idx in range(start_block_idx[i], num_blocks[i]):
if self.block_tables[i][b_idx] == -1:
# need a free block
self.block_tables[i][b_idx] = self.free_blocks[0]
self.free_blocks = self.free_blocks[1:]

self.slots[i] = self.block_tables[i][start_block_idx[i]] * self.block_size + slot_offset_in_block[i]
# Update the cache
PagedAttention.reshape_and_cache(
key_states,
value_states,
self.key_cache[layer_idx],
self.value_cache[layer_idx],
self.slots,
)

# Update the number of seen tokens
if layer_idx == self.num_hidden_layers - 1:
self._seen_tokens = self._seen_tokens + 1
self.max_seq_len = self.max_seq_len + 1

def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
attention_mask: torch.Tensor,
input_lens: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
Parameters:
key_states (`torch.Tensor`):
The new key states to cache.
value_states (`torch.Tensor`):
The new value states to cache.
layer_idx (`int`):
The index of the layer to cache the states for.
Return:
A tuple containing the updated key and value states.
"""

batch_size = input_lens.shape[-1]
if self.get_seq_length() == 0:
# prefill
self.update_for_prefill(key_states, value_states, layer_idx, batch_size, input_lens)
else:
# decode
self.update_for_decode(key_states, value_states, layer_idx, batch_size)

return self.key_cache[layer_idx], self.value_cache[layer_idx]

def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states that were seen by the model."""
return self.max_seq_len

def get_max_length(self) -> Optional[int]:
"""Returns the maximum sequence length of the cached states."""
return self.max_cache_len

def reset(self):
"""Resets the cache values while preserving the objects"""
self._seen_tokens = torch.zeros([self.batch_size], dtype=torch.int32, device=self.block_tables.device)
self.block_tables.fill_(-1)
self.free_blocks = torch.arange(self.num_blocks, device=self.block_tables.device)
self.max_seq_len = 0

def reorder_cache(self, beam_idx: torch.LongTensor):
"""Reorders the cache for beam search, given the selected beam indices."""
device = self.block_tables.device
origin_table = self.block_tables.clone()
updated_block_tables = self.block_tables.index_select(0, beam_idx.to(device))
mask = self.block_tables.masked_fill(self.block_tables != -1, 1).masked_fill(self.block_tables == -1, 0)
num_blocks = mask.cumsum(-1)[:, -1]
updated_table = []
for i in range(beam_idx.shape[0]):
self.block_tables[i, 0 : num_blocks[i] - 1] = updated_block_tables[i, 0 : num_blocks[i] - 1]
updated_table.append(self.block_tables[i : i + 1, num_blocks[i] - 1 : num_blocks[i]])
updated_table = torch.cat(tuple(updated_table), dim=0)
for layer_idx in range(self.num_hidden_layers):
self.key_cache[layer_idx][updated_table] = self.key_cache[layer_idx][updated_table[beam_idx]]
self.value_cache[layer_idx][updated_table] = self.value_cache[layer_idx][updated_table[beam_idx]]
free_table = torch.unique((origin_table[origin_table != self.block_tables]).view(-1))
self.free_blocks = torch.cat((self.free_blocks, free_table))

def crop(self, maximum_length: int):
"""Crop the past key values up to a new `maximum_length` in terms of tokens. `maximum_length` can also be
negative to remove `maximum_length` tokens. This is used in assisted decoding and contrastive search."""

max_seq_len = self.get_seq_length()
if maximum_length < 0:
maximum_length = max_seq_len - abs(maximum_length)

if max_seq_len <= maximum_length:
return
origin_table = self.block_tables.clone()
for bs in range(self._seen_tokens.shape[0]):
new_tokens = self._seen_tokens[bs] + maximum_length - max_seq_len
num_blocks = (new_tokens + self.block_size - 1) // self.block_size
self.block_tables[bs, num_blocks:] = -1
self._seen_tokens[bs] = new_tokens
self.max_seq_len, _ = self._seen_tokens.max(dim=0)
free_table = torch.unique((origin_table[origin_table != self.block_tables]).view(-1))
self.free_blocks = torch.cat((self.free_blocks, free_table))
39 changes: 22 additions & 17 deletions optimum/exporters/ipex/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,10 @@
# limitations under the License.

from transformers.models.bert.modeling_bert import BertIntermediate
from transformers.models.falcon.modeling_falcon import FalconDecoderLayer, FalconForCausalLM
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block, GPT2LMHeadModel
from transformers.models.falcon.modeling_falcon import FalconDecoderLayer, FalconModel
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block, GPT2Model
from transformers.models.llama.modeling_llama import (
LlamaDecoderLayer,
LlamaForCausalLM,
LlamaModel,
LlamaRMSNorm,
)
Expand All @@ -28,7 +27,9 @@

from .modeling_utils import (
_IPEX_MINIMUM_VERSION_FOR_PATCHING,
_falcon_model_forward,
_gpt2_block_forward,
_gpt2_model_forward,
_ipex_rms_layer_norm_forward,
_IPEXFalconDecoderLayer,
_IPEXGPT2Attention,
Expand All @@ -39,8 +40,8 @@


# Please also update in the setup.py and .github/workflows/test_ipex.yml if you change the transformers version
_TRANSFORMERS_MIN_VERSION = "4.39.0"
_TRANSFORMERS_MAX_VERSION = "4.44.99"
_TRANSFORMERS_MIN_VERSION = "4.46.0"
_TRANSFORMERS_MAX_VERSION = "4.46.99"

_IPEX_EXPORTED_GENERATION_TASKS = ("text-generation",)

Expand Down Expand Up @@ -75,7 +76,7 @@ def patch_op(m, target_m, new_op_name, new_op):
def _patch_llama_model(model):
"""
Patch llama model:
1. Use IPEX Rope and IAKV cache
1. Use IPEX rope and paged cache
2. Linear fusion with (2 Linears + Silu + Mul) and (Linear + Add)
"""
convert_functions(model, LlamaModel, "forward", _llama_model_forward)
Expand All @@ -87,11 +88,14 @@ def _patch_llama_model(model):
def _patch_falcon_model(model):
"""
Patch falcon model:
1. Disable SDPA so the attention mask will be compatible to ipex attention.
2. Use IPEX Rope and IAKV cache
3. Linear fusion with (Linear + Gelu) and (Linear + Add + Add)
1. Use IPEX rope and paged cache
2. Linear fusion with (Linear + Gelu) and (Linear + Add + Add)
"""
model.transformer._use_sdpa = False
num_key_value_heads = (
model.config.num_kv_heads if (model.config.new_decoder_architecture or not model.config.multi_query) else 1
)
setattr(model.config, "num_key_value_heads", num_key_value_heads)
convert_functions(model, FalconModel, "forward", _falcon_model_forward)
replace_customized_linear_with_linear(model)
convert_class(model, FalconDecoderLayer, _IPEXFalconDecoderLayer, model.config)
return model
Expand All @@ -100,12 +104,13 @@ def _patch_falcon_model(model):
def _patch_gpt2_model(model):
"""
Patch gpt2 model:
1. Disable SDPA so the attention mask will be compatible to ipex attention.
2. Use IAKV cache
1. Use IPEX paged attention
"""
model.transformer._attn_implementation = "eager"
convert_class(model, GPT2Attention, _IPEXGPT2Attention, model.config)
num_key_value_heads = model.config.num_attention_heads
setattr(model.config, "num_key_value_heads", num_key_value_heads)
convert_functions(model, GPT2Model, "forward", _gpt2_model_forward)
convert_functions(model, GPT2Block, "forward", _gpt2_block_forward)
convert_class(model, GPT2Attention, _IPEXGPT2Attention, model.config)
return model


Expand Down Expand Up @@ -136,11 +141,11 @@ def _patch_model(model):
raise ImportError(
f"Only transformers versions {_TRANSFORMERS_MIN_VERSION} ~ {_TRANSFORMERS_MAX_VERSION} are verified."
)
if isinstance(model, LlamaForCausalLM):
if model.config.model_type == "llama":
model = _patch_llama_model(model)
elif isinstance(model, FalconForCausalLM):
elif model.config.model_type == "falcon":
model = _patch_falcon_model(model)
elif isinstance(model, GPT2LMHeadModel):
elif model.config.model_type == "gpt2":
model = _patch_gpt2_model(model)
elif model.config.model_type == "bert":
model = _patch_bert_model(model)
Expand Down
Loading

0 comments on commit 41f0a46

Please sign in to comment.