From 71a4589c970aa1222fd3c513733c7871e9e58fe0 Mon Sep 17 00:00:00 2001 From: Guoqiong Date: Wed, 27 Nov 2024 01:02:08 +0000 Subject: [PATCH] xpu support api --- torchtune/training/precision.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torchtune/training/precision.py b/torchtune/training/precision.py index 0828baa7ea..85a2c07e4f 100644 --- a/torchtune/training/precision.py +++ b/torchtune/training/precision.py @@ -53,6 +53,7 @@ def verify_bf16_support() -> bool: - NCCL is available and version >= 2.10 - MPS is available and torch was built with MPS - NPU is available and supports bf16 + - XPU is available and supports bf16 Returns: bool: True if bf16 is available, False otherwise. @@ -66,7 +67,8 @@ def verify_bf16_support() -> bool: ) mps_support = torch.backends.mps.is_available() and torch.backends.mps.is_built() npu_support = is_npu_available and torch.npu.is_bf16_supported() - return cuda_support or mps_support or npu_support + xpu_support = torch.xpu.is_available() and torch.xpu.is_bf16_supported() + return cuda_support or mps_support or npu_support or xpu_support def get_dtype(