Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Experimental DirectML support via torch-directml #1702

Draft
wants to merge 11 commits into
base: main
Choose a base branch
from

Conversation

kazssym
Copy link

@kazssym kazssym commented Feb 19, 2024

What does this PR do?

This PR adds experimental DirectML support via torch-directml, which is still in preview and lacks several PyTorch functions such as microsoft/DirectML#449.

If you are interested in this PR, please leave comments below.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

This commit introduces two improvements:

1. DirectML acceleration:

    - Added support for running optimum commands on DirectML hardware (Windows only) using the --device dml flag.
    - Automatically sets the device to torch_directml.device() when the flag is specified.

2. Improved device handling:

    - Ensures the model is directly initialized in the device only when applicable.
This commit refines the device handling in optimum/exporters/tasks.py for the following improvements:

  - More precise device check: Instead of checking for not device.type, the condition is updated to device.type != "privateuseone". This ensures the initialization happens on the requested device only if it's not a private use device (e.g., DirectML).
  - Improved clarity: The code comments are updated to better explain the purpose of the device initialization and its benefits for large models.
  - Extends device compatibility to "privateuseone" in export_pytorch for exporting models usable on specific hardware.

This commit allows exporting PyTorch models compatible with the "privateuseone" device, potentially enabling inference on specialized hardware platforms.
This commit adds support for running PyTorch models on the DML device within the Optimum framework.

  - Dynamic DML device handling: Introduces dynamic import of torch_directml for improved maintainability.
  - Consistent device selection: Ensures consistent device selection across optimum/exporters/onnx/convert.py, optimum/exporters/tasks.py, and optimum/onnxruntime/io_binding/io_binding_helper.py.

This change allows users to leverage DML capabilities for efficient PyTorch model inference with Optimum.
This commit removes unnecessary code for handling the DML device in optimum/commands/optimum_cli.py.

  - Redundant import: The code previously imported torch_directml conditionally, which is no longer needed as DML device support is handled in other parts of the codebase.

This change simplifies the code and avoids potential conflicts.
This commit updates `setup.py` to include the following changes:

  - Introduces a new conditional section "exporters-directml" with dependencies required for exporting models for DML inference.
  - This section mirrors the existing "exporters" and "exporters-gpu" sections, adding `onnxruntime-directml` as a dependency.

This update ensures users have the necessary libraries for working with DML devices when installing Optimum with DML support.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant