diff --git a/modules/intel/ipex/hijacks.py b/modules/intel/ipex/hijacks.py index b1c9a1182..5440b7b68 100644 --- a/modules/intel/ipex/hijacks.py +++ b/modules/intel/ipex/hijacks.py @@ -75,7 +75,7 @@ def as_tensor(data, dtype=None, device=None): return original_as_tensor(data, dtype=dtype, device=device) -if device_supports_fp64 and os.environ.get('IPEX_FORCE_ATTENTION_SLICE', None) is None: +if os.environ.get('IPEX_FORCE_ATTENTION_SLICE', '0') == '-1' or (device_supports_fp64 and os.environ.get('IPEX_FORCE_ATTENTION_SLICE', '0') == '0'): original_torch_bmm = torch.bmm original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention else: