diff --git a/torchtune/training/precision.py b/torchtune/training/precision.py index 0828baa7ea..85a2c07e4f 100644 --- a/torchtune/training/precision.py +++ b/torchtune/training/precision.py @@ -53,6 +53,7 @@ def verify_bf16_support() -> bool: - NCCL is available and version >= 2.10 - MPS is available and torch was built with MPS - NPU is available and supports bf16 + - XPU is available and supports bf16 Returns: bool: True if bf16 is available, False otherwise. @@ -66,7 +67,8 @@ def verify_bf16_support() -> bool: ) mps_support = torch.backends.mps.is_available() and torch.backends.mps.is_built() npu_support = is_npu_available and torch.npu.is_bf16_supported() - return cuda_support or mps_support or npu_support + xpu_support = torch.xpu.is_available() and torch.xpu.is_bf16_supported() + return cuda_support or mps_support or npu_support or xpu_support def get_dtype( diff --git a/torchtune/utils/_device.py b/torchtune/utils/_device.py index 36ca14a358..d4f84cd63e 100644 --- a/torchtune/utils/_device.py +++ b/torchtune/utils/_device.py @@ -89,6 +89,8 @@ def _get_device_type_from_env() -> str: device = "cuda" elif is_npu_available: device = "npu" + elif torch.xpu.is_available(): + device = "xpu" else: device = "cpu" return device @@ -136,7 +138,7 @@ def get_device(device: Optional[str] = None) -> torch.device: If CUDA-like is available and being used, this function also sets the CUDA-like device. Args: - device (Optional[str]): The name of the device to use, e.g. "cuda" or "cpu" or "npu". + device (Optional[str]): The name of the device to use, e.g. "cuda" or "cpu" or "npu" or "xpu". Example: >>> device = get_device("cuda") @@ -149,7 +151,7 @@ def get_device(device: Optional[str] = None) -> torch.device: if device is None: device = _get_device_type_from_env() device = torch.device(device) - if device.type in ["cuda", "npu"]: + if device.type in ["cuda", "npu", "xpu"]: device = _setup_device(device) _validate_device_from_env(device) return device @@ -184,16 +186,18 @@ def batch_to_device(batch: dict, device: torch.device) -> None: class DeviceSupport(Enum): """ This is a simple enum for compute devices, - This currently only supports CPU, CUDA, NPU. + This currently only supports CPU, CUDA, NPU, and XPU. The following enumeration defines various device configurations with attributes: - 1. `device_type` (str): The type of device (e.g., "cpu", "cuda", "npu"). - 2. `device_name` (str): A user-friendly name for the device (e.g., "CPU", "GPU", "NPU"). - 3. `communication_backend` (str): Specifies the backend used for communication on this device (e.g., "gloo", "nccl", "hccl"). + 1. `device_type` (str): The type of device (e.g., "cpu", "cuda", "npu", "xpu"). + 2. `device_name` (str): A user-friendly name for the device (e.g., "CPU", "GPU", "NPU", "XPU"). + 3. `communication_backend` (str): Specifies the backend used for communication on this device + (e.g., "gloo", "nccl", "hccl", "ccl"). """ CPU = ("cpu", "CPU", "gloo") CUDA = ("cuda", "GPU", "nccl") NPU = ("npu", "NPU", "hccl") + XPU = ("xpu", "XPU", "ccl") def __init__( self, @@ -216,7 +220,7 @@ def from_type(device_type: str): def get_device_support() -> DeviceSupport: """function that gets the DeviceSupport with compute devices based on the current machine. - This currently only supports CPU, CUDA, NPU. + This currently only supports CPU, CUDA, NPU, XPU. Returns: device_support: DeviceSupport