Skip to content

Commit

Permalink
Adding bf16 training for XPU (#1953)
Browse files Browse the repository at this point in the history
Co-authored-by: salman <[email protected]>
  • Loading branch information
songhappy and SalmanMohammadi authored Dec 3, 2024
1 parent 32e265d commit efa91bf
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 8 deletions.
4 changes: 3 additions & 1 deletion torchtune/training/precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def verify_bf16_support() -> bool:
- NCCL is available and version >= 2.10
- MPS is available and torch was built with MPS
- NPU is available and supports bf16
- XPU is available and supports bf16
Returns:
bool: True if bf16 is available, False otherwise.
Expand All @@ -66,7 +67,8 @@ def verify_bf16_support() -> bool:
)
mps_support = torch.backends.mps.is_available() and torch.backends.mps.is_built()
npu_support = is_npu_available and torch.npu.is_bf16_supported()
return cuda_support or mps_support or npu_support
xpu_support = torch.xpu.is_available() and torch.xpu.is_bf16_supported()
return cuda_support or mps_support or npu_support or xpu_support


def get_dtype(
Expand Down
18 changes: 11 additions & 7 deletions torchtune/utils/_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ def _get_device_type_from_env() -> str:
device = "cuda"
elif is_npu_available:
device = "npu"
elif torch.xpu.is_available():
device = "xpu"
else:
device = "cpu"
return device
Expand Down Expand Up @@ -136,7 +138,7 @@ def get_device(device: Optional[str] = None) -> torch.device:
If CUDA-like is available and being used, this function also sets the CUDA-like device.
Args:
device (Optional[str]): The name of the device to use, e.g. "cuda" or "cpu" or "npu".
device (Optional[str]): The name of the device to use, e.g. "cuda" or "cpu" or "npu" or "xpu".
Example:
>>> device = get_device("cuda")
Expand All @@ -149,7 +151,7 @@ def get_device(device: Optional[str] = None) -> torch.device:
if device is None:
device = _get_device_type_from_env()
device = torch.device(device)
if device.type in ["cuda", "npu"]:
if device.type in ["cuda", "npu", "xpu"]:
device = _setup_device(device)
_validate_device_from_env(device)
return device
Expand Down Expand Up @@ -184,16 +186,18 @@ def batch_to_device(batch: dict, device: torch.device) -> None:
class DeviceSupport(Enum):
"""
This is a simple enum for compute devices,
This currently only supports CPU, CUDA, NPU.
This currently only supports CPU, CUDA, NPU, and XPU.
The following enumeration defines various device configurations with attributes:
1. `device_type` (str): The type of device (e.g., "cpu", "cuda", "npu").
2. `device_name` (str): A user-friendly name for the device (e.g., "CPU", "GPU", "NPU").
3. `communication_backend` (str): Specifies the backend used for communication on this device (e.g., "gloo", "nccl", "hccl").
1. `device_type` (str): The type of device (e.g., "cpu", "cuda", "npu", "xpu").
2. `device_name` (str): A user-friendly name for the device (e.g., "CPU", "GPU", "NPU", "XPU").
3. `communication_backend` (str): Specifies the backend used for communication on this device
(e.g., "gloo", "nccl", "hccl", "ccl").
"""

CPU = ("cpu", "CPU", "gloo")
CUDA = ("cuda", "GPU", "nccl")
NPU = ("npu", "NPU", "hccl")
XPU = ("xpu", "XPU", "ccl")

def __init__(
self,
Expand All @@ -216,7 +220,7 @@ def from_type(device_type: str):
def get_device_support() -> DeviceSupport:
"""function that gets the DeviceSupport with compute devices based on the current machine.
This currently only supports CPU, CUDA, NPU.
This currently only supports CPU, CUDA, NPU, XPU.
Returns:
device_support: DeviceSupport
Expand Down

0 comments on commit efa91bf

Please sign in to comment.