Skip to content

Commit

Permalink
fix lnl perf (#12700)
Browse files Browse the repository at this point in the history
  • Loading branch information
MeouSker77 authored Jan 10, 2025
1 parent 4bf93c6 commit db9db51
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions python/llm/src/ipex_llm/transformers/low_bit_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
import torch
import torch.distributed
import torch.nn.functional as F
from torch import Tensor, device, dtype, nn
from torch import Tensor, dtype, nn
from operator import mul
from functools import reduce
from ipex_llm.transformers.xpu_customize_fwd import custom_fwd, custom_bwd
Expand Down Expand Up @@ -294,10 +294,10 @@ def use_batch_forward(x: torch.Tensor, qtype: int, output_len: int):
if hard_condition:
return (
batch_size > 1
or (device in ["arc"] and qtype in [SYM_INT8, FP4])
or (device in ["arc", "mtl"] and qtype in [FP8E4])
or (device in ["lnl"] and qtype in [SYM_INT4] and x.shape[1] % 512 == 0)
or (device in ["bmg"] and qtype in [SYM_INT4, FP8E5])
or (device_name in ["arc"] and qtype in [SYM_INT8, FP4])
or (device_name in ["arc", "mtl"] and qtype in [FP8E4])
or (device_name in ["lnl"] and qtype in [SYM_INT4] and x.shape[1] % 512 == 0)
or (device_name in ["bmg"] and qtype in [SYM_INT4, FP8E5])
)
return False

Expand Down

0 comments on commit db9db51

Please sign in to comment.