Skip to content

Commit

Permalink
xpu support api
Browse files Browse the repository at this point in the history
  • Loading branch information
songhappy committed Nov 27, 2024
1 parent b5d2e63 commit 71a4589
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion torchtune/training/precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def verify_bf16_support() -> bool:
- NCCL is available and version >= 2.10
- MPS is available and torch was built with MPS
- NPU is available and supports bf16
- XPU is available and supports bf16
Returns:
bool: True if bf16 is available, False otherwise.
Expand All @@ -66,7 +67,8 @@ def verify_bf16_support() -> bool:
)
mps_support = torch.backends.mps.is_available() and torch.backends.mps.is_built()
npu_support = is_npu_available and torch.npu.is_bf16_supported()
return cuda_support or mps_support or npu_support
xpu_support = torch.xpu.is_available() and torch.xpu.is_bf16_supported()
return cuda_support or mps_support or npu_support or xpu_support


def get_dtype(
Expand Down

0 comments on commit 71a4589

Please sign in to comment.