diff --git a/optimum/exporters/onnx/convert.py b/optimum/exporters/onnx/convert.py index 0cca6d129fe..833e3e2461c 100644 --- a/optimum/exporters/onnx/convert.py +++ b/optimum/exporters/onnx/convert.py @@ -16,6 +16,7 @@ import copy import gc +import importlib import multiprocessing as mp import os import traceback @@ -532,7 +533,11 @@ def export_pytorch( # Check that inputs match, and order them properly dummy_inputs = config.generate_dummy_inputs(framework="pt", **input_shapes) - device = torch.device(device) + if device == "dml" and importlib.util.find_spec("torch_directml"): + torch_directml = importlib.import_module("torch_directml") + device = torch_directml.device() + else: + device = torch.device(device) def remap(value): if isinstance(value, torch.Tensor): @@ -540,7 +545,7 @@ def remap(value): return value - if device.type == "cuda" and torch.cuda.is_available(): + if device.type == "cuda" and torch.cuda.is_available() or device.type == "privateuseone": model.to(device) dummy_inputs = tree_map(remap, dummy_inputs) diff --git a/optimum/exporters/tasks.py b/optimum/exporters/tasks.py index 3793a56068a..c09b07223e0 100644 --- a/optimum/exporters/tasks.py +++ b/optimum/exporters/tasks.py @@ -2270,12 +2270,16 @@ def get_model_from_task( kwargs["torch_dtype"] = torch_dtype if isinstance(device, str): - device = torch.device(device) + if device == "dml" and importlib.util.find_spec("torch_directml"): + torch_directml = importlib.import_module("torch_directml") + device = torch_directml.device() + else: + device = torch.device(device) elif device is None: device = torch.device("cpu") # TODO : fix EulerDiscreteScheduler loading to enable for SD models - if version.parse(torch.__version__) >= version.parse("2.0") and library_name != "diffusers": + if version.parse(torch.__version__) >= version.parse("2.0") and library_name != "diffusers" and device.type != "privateuseone": with device: # Initialize directly in the requested device, to save allocation time. Especially useful for large # models to initialize on cuda device. diff --git a/optimum/onnxruntime/io_binding/io_binding_helper.py b/optimum/onnxruntime/io_binding/io_binding_helper.py index f32ecc56e6e..30fe264dbc3 100644 --- a/optimum/onnxruntime/io_binding/io_binding_helper.py +++ b/optimum/onnxruntime/io_binding/io_binding_helper.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import importlib import logging import traceback from typing import TYPE_CHECKING @@ -145,8 +146,12 @@ def to_pytorch_via_dlpack(ort_value: OrtValue) -> torch.Tensor: @staticmethod def get_device_index(device): if isinstance(device, str): - # could be 'cuda:0', 'cuda:1', or 'cpu'. with cpu, set index=0 - device = torch.device(device) + if device == "dml" and importlib.util.find_spec("torch_directml"): + torch_directml = importlib.import_module("torch_directml") + device = torch_directml.device() + else: + # could be 'cuda:0', 'cuda:1', or 'cpu'. with cpu, set index=0 + device = torch.device(device) elif isinstance(device, int): return device return 0 if device.index is None else device.index diff --git a/setup.py b/setup.py index 197d632b434..2078d2c7f10 100644 --- a/setup.py +++ b/setup.py @@ -62,6 +62,15 @@ "accelerate", # ORTTrainer requires it. "transformers>=4.36,<4.48.0", ], + "onnxruntime-directml": [ + "onnx", + "onnxruntime-directml>=1.11.0", + "datasets>=1.2.1", + "evaluate", + "protobuf>=3.20.1", + "accelerate", # ORTTrainer requires it. + "transformers>=4.36,<4.48.0", + ], "exporters": [ "onnx", "onnxruntime", @@ -74,6 +83,12 @@ "timm", "transformers>=4.36,<4.48.0", ], + "exporters-directml": [ + "onnx", + "onnxruntime-directml", + "timm", + "transformers>=4.36,<4.48.0", + ], "exporters-tf": [ "tensorflow>=2.4,<=2.12.1", "tf2onnx",