diff --git a/intel_extension_for_pytorch/_inductor/xpu/triton_heuristics.py b/intel_extension_for_pytorch/_inductor/xpu/triton_heuristics.py index 8c7a7ee13..a9141dc84 100644 --- a/intel_extension_for_pytorch/_inductor/xpu/triton_heuristics.py +++ b/intel_extension_for_pytorch/_inductor/xpu/triton_heuristics.py @@ -31,6 +31,7 @@ try: from triton.compiler.compiler import ASTSource + from triton.backends.compiler import GPUTarget except ImportError: warnings.warn( "XPU: Import error on ASTSource, if this is not the case, \ @@ -128,7 +129,8 @@ def _precompile_config(self, cfg: Config, warm_cache_only_with_cc: Dict): ), ) - target = (compile_meta["device_type"], cc) + warp_size = 32 + target = GPUTarget(compile_meta["device_type"], cc, warp_size) options = { "num_warps": compile_meta["num_warps"], "num_stages": compile_meta["num_stages"],