Skip to content

Commit

Permalink
fix custom kernel registration (#12674)
Browse files Browse the repository at this point in the history
  • Loading branch information
MeouSker77 authored Jan 8, 2025
1 parent a22a8c2 commit 5c24276
Showing 1 changed file with 25 additions and 19 deletions.
44 changes: 25 additions & 19 deletions python/llm/src/ipex_llm/transformers/xpu_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
import xe_addons


@torch.library.register_fake("ipex_llm::forward_new")
def _(x, weight, qtype, input_size):
return torch.empty_like(x)
# @torch.library.register_fake("ipex_llm::forward_new")
# def _(x, weight, qtype, input_size):
# return ???


# @torch.library.register_fake("ipex_llm::dequant")
Expand All @@ -32,32 +32,38 @@ def _(x, weight, qtype, input_size):

@torch.library.register_fake("ipex_llm::mlp_forward_xpu")
def _(x, weight1, weight2, batch_size, state_size, output_size, act_type, qtype):
return torch.empty_like(x)
return torch.empty([batch_size, output_size],
dtype=x.dtype, device=x.device)


# @torch.library.register_fake("ipex_llm::rwkv_linear_attention_v4")
# def _(time_decay, time_first, key, value, num_state, den_state, max_state)
# return ???
@torch.library.register_fake("ipex_llm::rwkv_linear_attention_v4")
def _(time_decay, time_first, key, value, num_state, den_state, max_state):
return torch.empty_like(key)


# @torch.library.register_fake("ipex_llm::rwkv_linear_attention_v5")
# def _(time_decay, time_first, receptance, key, value, state)
# return ???
@torch.library.register_fake("ipex_llm::rwkv_linear_attention_v5")
def _(time_decay, time_first, receptance, key, value, state):
bsz, n_heads, seq_len, head_dim = key.shape
return torch.empty([bsz, seq_len, n_heads, head_dim],
dtype=key.dtype, device=key.device)


# @torch.library.register_fake("ipex_llm::rwkv_time_shift")
# def _(hidden, shifted, mix):
# return ???
@torch.library.register_fake("ipex_llm::rwkv_time_shift")
def _(hidden, shifted, mix):
bsz, seq_len, hidden_size = hidden.shape
return torch.empty([mix.size(0), bsz, seq_len, hidden_size],
dtype=hidden.dtype, device=hidden.device)


# @torch.library.register_fake("ipex_llm::dequantize_rows")
# def _(x, weight, qtype, state_size, output_size):
# return ???
@torch.library.register_fake("ipex_llm::dequantize_rows")
def _(x, weight, qtype, state_size, output_size):
return torch.empty([x.size(0), x.size(1), state_size],
dtype=torch.float, device=weight.device)


@torch.library.register_fake("ipex_llm::batch_forward")
def _(x, weight, qtype):
return torch.empty_like(x)
# @torch.library.register_fake("ipex_llm::batch_forward")
# def _(x, weight, qtype):
# return ???


@torch.library.register_fake("ipex_llm::sdp")
Expand Down

0 comments on commit 5c24276

Please sign in to comment.