Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Numba-mlir integration #1194

Merged
merged 1 commit into from
Nov 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion .github/workflows/conda-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,9 @@ jobs:
python: ['3.9', '3.10', '3.11']
os: [ubuntu-20.04, ubuntu-latest, windows-latest]
experimental: [false]
use_mlir: [false]

continue-on-error: ${{ matrix.experimental }}
continue-on-error: ${{ matrix.experimental || matrix.use_mlir }}

steps:
- name: Setup miniconda
Expand Down Expand Up @@ -169,6 +170,10 @@ jobs:
- name: Install builded package
run: mamba install ${{ env.PACKAGE_NAME }}=${{ env.PACKAGE_VERSION }} intel::intel-opencl-rt pytest -c ${{ env.CHANNEL_PATH }}

- name: Install numba-mlir
if: matrix.use_mlir
run: mamba install numba-mlir -c dppy/label/dev -c conda-forge -c intel

- name: Setup OpenCL CPU device
if: runner.os == 'Windows'
shell: pwsh
Expand All @@ -184,9 +189,13 @@ jobs:
python -c "import dpcpp_llvm_spirv as p; print(p.get_llvm_spirv_path())"

- name: Smoke test
env:
NUMBA_DPEX_USE_MLIR: ${{ matrix.use_mlir && '1' || '0' }}
run: python -c "import dpnp, dpctl, numba_dpex; dpctl.lsplatform()"

- name: Run tests
env:
NUMBA_DPEX_USE_MLIR: ${{ matrix.use_mlir && '1' || '0' }}
run: |
pytest -q -ra --disable-warnings --pyargs ${{ env.MODULE_NAME }} -vv

Expand Down
2 changes: 2 additions & 0 deletions numba_dpex/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,3 +96,5 @@ def __getattr__(name):
DPEX_OPT = _readenv("NUMBA_DPEX_OPT", int, 2)

INLINE_THRESHOLD = _readenv("NUMBA_DPEX_INLINE_THRESHOLD", int, None)

USE_MLIR = _readenv("NUMBA_DPEX_USE_MLIR", int, 0)
2 changes: 2 additions & 0 deletions numba_dpex/core/descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,14 @@ class DpexTargetOptions(CPUTargetOptions):
experimental = _option_mapping("experimental")
release_gil = _option_mapping("release_gil")
no_compile = _option_mapping("no_compile")
use_mlir = _option_mapping("use_mlir")

def finalize(self, flags, options):
super().finalize(flags, options)
_inherit_if_not_set(flags, options, "experimental", False)
_inherit_if_not_set(flags, options, "release_gil", False)
_inherit_if_not_set(flags, options, "no_compile", True)
_inherit_if_not_set(flags, options, "use_mlir", False)


class DpexKernelTarget(TargetDescriptor):
Expand Down
42 changes: 35 additions & 7 deletions numba_dpex/core/pipelines/dpjit_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ class _DpjitPassBuilder(object):
execution.
"""

_use_mlir = False

@staticmethod
def define_typed_pipeline(state, name="dpex_dpjit_typed"):
"""Returns the typed part of the nopython pipeline"""
Expand All @@ -55,19 +57,31 @@ def define_typed_pipeline(state, name="dpex_dpjit_typed"):
pm.add_pass(NopythonRewrites, "nopython rewrites")
pm.add_pass(ParforPass, "convert to parfors")
pm.add_pass(
ParforLegalizeCFDPass, "Legalize parfors for compute follows data"
ParforLegalizeCFDPass,
"Legalize parfors for compute follows data",
)
pm.add_pass(ParforFusionPass, "fuse parfors")
pm.add_pass(ParforPreLoweringPass, "parfor prelowering")

pm.finalize()
return pm

@staticmethod
def define_nopython_lowering_pipeline(state, name="dpex_dpjit_lowering"):
@classmethod
def define_nopython_lowering_pipeline(
cls, state, name="dpex_dpjit_lowering"
):
"""Returns an nopython mode pipeline based PassManager"""
pm = PassManager(name)

flags = state.flags
if cls._use_mlir or hasattr(flags, "use_mlir") and flags.use_mlir:
from numba_mlir.mlir.passes import MlirReplaceParfors

pm.add_pass(
MlirReplaceParfors,
"Lower parfor using MLIR pipeline",
)

# legalize
pm.add_pass(
NoPythonSupportedFeatureValidation,
Expand All @@ -85,11 +99,11 @@ def define_nopython_lowering_pipeline(state, name="dpex_dpjit_lowering"):
pm.finalize()
return pm

@staticmethod
def define_nopython_pipeline(state, name="dpex_dpjit_nopython"):
@classmethod
def define_nopython_pipeline(cls, state, name="dpex_dpjit_nopython"):
"""Returns an nopython mode pipeline based PassManager"""
# compose pipeline from untyped, typed and lowering parts
dpb = _DpjitPassBuilder
dpb = cls
pm = PassManager(name)
untyped_passes = DefaultPassBuilder.define_untyped_pipeline(state)
pm.passes.extend(untyped_passes.passes)
Expand All @@ -104,17 +118,31 @@ def define_nopython_pipeline(state, name="dpex_dpjit_nopython"):
return pm


class _DpjitPassBuilderMlir(_DpjitPassBuilder):
_use_mlir = True


class DpjitCompiler(CompilerBase):
"""Dpex's compiler pipeline to offload parfor nodes into SYCL kernels."""

_pass_builder = _DpjitPassBuilder

def define_pipelines(self):
pms = []
self.state.parfor_diagnostics = ExtendedParforDiagnostics()
self.state.metadata[
"parfor_diagnostics"
] = self.state.parfor_diagnostics
if not self.state.flags.force_pyobject:
pms.append(_DpjitPassBuilder.define_nopython_pipeline(self.state))
pms.append(self._pass_builder.define_nopython_pipeline(self.state))
if self.state.status.can_fallback or self.state.flags.force_pyobject:
raise UnsupportedCompilationModeError()
return pms


class DpjitCompilerMlir(DpjitCompiler):
_pass_builder = _DpjitPassBuilderMlir


def get_compiler(use_mlir):
return DpjitCompilerMlir if use_mlir else DpjitCompiler
9 changes: 7 additions & 2 deletions numba_dpex/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
compile_func,
compile_func_template,
)
from numba_dpex.core.pipelines.dpjit_compiler import DpjitCompiler
from numba_dpex.core.pipelines.dpjit_compiler import get_compiler

from .config import USE_MLIR


def kernel(
Expand Down Expand Up @@ -152,9 +154,12 @@ def dpjit(*args, **kws):
"pipeline class is set for dpjit and is ignored", RuntimeWarning
)
del kws["forceobj"]

use_mlir = kws.pop("use_mlir", bool(USE_MLIR))

kws.update({"nopython": True})
kws.update({"parallel": True})
kws.update({"pipeline_class": DpjitCompiler})
kws.update({"pipeline_class": get_compiler(use_mlir)})

# FIXME: When trying to use dpex's target context, overloads do not work
# properly. We will turn on dpex target once the issue is fixed.
Expand Down
25 changes: 24 additions & 1 deletion numba_dpex/tests/_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,23 @@

import contextlib
import shutil
from functools import cache

import dpctl
import dpnp
import pytest

from numba_dpex import config, numba_sem_version
from numba_dpex import config, dpjit, numba_sem_version


@cache
def has_numba_mlir():
try:
import numba_mlir
except ImportError:
return False

return True


def has_opencl_gpu():
Expand Down Expand Up @@ -89,6 +100,10 @@ def is_windows():
not has_level_zero(),
reason="No level-zero GPU platforms available",
)
skip_no_numba_mlir = pytest.mark.skipif(
not has_numba_mlir(),
reason="numba-mlir package is not availabe",
)

filter_strings = [
pytest.param("level_zero:gpu:0", marks=skip_no_level_zero_gpu),
Expand Down Expand Up @@ -123,6 +138,14 @@ def is_windows():
)


decorators = [
pytest.param(dpjit, id="dpjit"),
pytest.param(
dpjit(use_mlir=True), id="dpjit_mlir", marks=skip_no_numba_mlir
),
]


@contextlib.contextmanager
def override_config(name, value, config=config):
"""
Expand Down
39 changes: 35 additions & 4 deletions numba_dpex/tests/test_prange.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,12 @@

from numba_dpex import dpjit, prange

from ._helper import decorators

def test_one_prange_mul():
@dpjit

@pytest.mark.parametrize("jit", decorators)
def test_one_prange_mul(jit):
@jit
def f(a, b):
for i in prange(4):
b[i, 0] = a[i, 0] * 10
Expand All @@ -35,6 +38,33 @@ def f(a, b):
assert nb[i, 0] == na[i, 0] * 10


@pytest.mark.parametrize("jit", decorators)
def test_one_prange_mul_nested(jit):
@jit
def f_inner(a, b):
for i in prange(4):
b[i, 0] = a[i, 0] * 10
return

@jit
def f(a, b):
return f_inner(a, b)

device = dpctl.select_default_device()

m = 8
n = 8
a = dpnp.ones((m, n), device=device)
b = dpnp.ones((m, n), device=device)

f(a, b)
na = dpnp.asnumpy(a)
nb = dpnp.asnumpy(b)

for i in range(4):
assert nb[i, 0] == na[i, 0] * 10


@pytest.mark.skip(reason="dpnp.add() doesn't support variable + scalar.")
def test_one_prange_add_scalar():
@dpjit
Expand Down Expand Up @@ -155,8 +185,9 @@ def f(a, b):
assert np.all(b.asnumpy() == 12)


def test_two_consecutive_prange():
@dpjit
@pytest.mark.parametrize("jit", decorators)
def test_two_consecutive_prange(jit):
@jit
def prange_example(a, b, c, d):
for i in prange(n):
c[i] = a[i] + b[i]
Expand Down
Loading