From ac4cb5dd8cec2de5cb54441a851c46037be128cb Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Wed, 13 Sep 2023 15:46:55 +0200 Subject: [PATCH] Numba-mlir integration --- .github/workflows/conda-package.yml | 11 +++++- numba_dpex/config.py | 2 + numba_dpex/core/descriptor.py | 2 + numba_dpex/core/pipelines/dpjit_compiler.py | 42 +++++++++++++++++---- numba_dpex/decorators.py | 9 ++++- numba_dpex/tests/_helper.py | 25 +++++++++++- numba_dpex/tests/test_prange.py | 39 +++++++++++++++++-- 7 files changed, 115 insertions(+), 15 deletions(-) diff --git a/.github/workflows/conda-package.yml b/.github/workflows/conda-package.yml index a26aced695..946e33796e 100644 --- a/.github/workflows/conda-package.yml +++ b/.github/workflows/conda-package.yml @@ -104,8 +104,9 @@ jobs: python: ['3.9', '3.10', '3.11'] os: [ubuntu-20.04, ubuntu-latest, windows-latest] experimental: [false] + use_mlir: [false] - continue-on-error: ${{ matrix.experimental }} + continue-on-error: ${{ matrix.experimental || matrix.use_mlir }} steps: - name: Setup miniconda @@ -169,6 +170,10 @@ jobs: - name: Install builded package run: mamba install ${{ env.PACKAGE_NAME }}=${{ env.PACKAGE_VERSION }} intel::intel-opencl-rt pytest -c ${{ env.CHANNEL_PATH }} + - name: Install numba-mlir + if: matrix.use_mlir + run: mamba install numba-mlir -c dppy/label/dev -c conda-forge -c intel + - name: Setup OpenCL CPU device if: runner.os == 'Windows' shell: pwsh @@ -184,9 +189,13 @@ jobs: python -c "import dpcpp_llvm_spirv as p; print(p.get_llvm_spirv_path())" - name: Smoke test + env: + NUMBA_DPEX_USE_MLIR: ${{ matrix.use_mlir && '1' || '0' }} run: python -c "import dpnp, dpctl, numba_dpex; dpctl.lsplatform()" - name: Run tests + env: + NUMBA_DPEX_USE_MLIR: ${{ matrix.use_mlir && '1' || '0' }} run: | pytest -q -ra --disable-warnings --pyargs ${{ env.MODULE_NAME }} -vv diff --git a/numba_dpex/config.py b/numba_dpex/config.py index ece1ba8ecd..43dabcd66c 100644 --- a/numba_dpex/config.py +++ b/numba_dpex/config.py @@ -96,3 +96,5 @@ def __getattr__(name): DPEX_OPT = _readenv("NUMBA_DPEX_OPT", int, 2) INLINE_THRESHOLD = _readenv("NUMBA_DPEX_INLINE_THRESHOLD", int, None) + +USE_MLIR = _readenv("NUMBA_DPEX_USE_MLIR", int, 0) diff --git a/numba_dpex/core/descriptor.py b/numba_dpex/core/descriptor.py index 406f7115f8..42d5ee26a3 100644 --- a/numba_dpex/core/descriptor.py +++ b/numba_dpex/core/descriptor.py @@ -39,12 +39,14 @@ class DpexTargetOptions(CPUTargetOptions): experimental = _option_mapping("experimental") release_gil = _option_mapping("release_gil") no_compile = _option_mapping("no_compile") + use_mlir = _option_mapping("use_mlir") def finalize(self, flags, options): super().finalize(flags, options) _inherit_if_not_set(flags, options, "experimental", False) _inherit_if_not_set(flags, options, "release_gil", False) _inherit_if_not_set(flags, options, "no_compile", True) + _inherit_if_not_set(flags, options, "use_mlir", False) class DpexKernelTarget(TargetDescriptor): diff --git a/numba_dpex/core/pipelines/dpjit_compiler.py b/numba_dpex/core/pipelines/dpjit_compiler.py index f84469d528..2b2351d644 100644 --- a/numba_dpex/core/pipelines/dpjit_compiler.py +++ b/numba_dpex/core/pipelines/dpjit_compiler.py @@ -37,6 +37,8 @@ class _DpjitPassBuilder(object): execution. """ + _use_mlir = False + @staticmethod def define_typed_pipeline(state, name="dpex_dpjit_typed"): """Returns the typed part of the nopython pipeline""" @@ -55,7 +57,8 @@ def define_typed_pipeline(state, name="dpex_dpjit_typed"): pm.add_pass(NopythonRewrites, "nopython rewrites") pm.add_pass(ParforPass, "convert to parfors") pm.add_pass( - ParforLegalizeCFDPass, "Legalize parfors for compute follows data" + ParforLegalizeCFDPass, + "Legalize parfors for compute follows data", ) pm.add_pass(ParforFusionPass, "fuse parfors") pm.add_pass(ParforPreLoweringPass, "parfor prelowering") @@ -63,11 +66,22 @@ def define_typed_pipeline(state, name="dpex_dpjit_typed"): pm.finalize() return pm - @staticmethod - def define_nopython_lowering_pipeline(state, name="dpex_dpjit_lowering"): + @classmethod + def define_nopython_lowering_pipeline( + cls, state, name="dpex_dpjit_lowering" + ): """Returns an nopython mode pipeline based PassManager""" pm = PassManager(name) + flags = state.flags + if cls._use_mlir or hasattr(flags, "use_mlir") and flags.use_mlir: + from numba_mlir.mlir.passes import MlirReplaceParfors + + pm.add_pass( + MlirReplaceParfors, + "Lower parfor using MLIR pipeline", + ) + # legalize pm.add_pass( NoPythonSupportedFeatureValidation, @@ -85,11 +99,11 @@ def define_nopython_lowering_pipeline(state, name="dpex_dpjit_lowering"): pm.finalize() return pm - @staticmethod - def define_nopython_pipeline(state, name="dpex_dpjit_nopython"): + @classmethod + def define_nopython_pipeline(cls, state, name="dpex_dpjit_nopython"): """Returns an nopython mode pipeline based PassManager""" # compose pipeline from untyped, typed and lowering parts - dpb = _DpjitPassBuilder + dpb = cls pm = PassManager(name) untyped_passes = DefaultPassBuilder.define_untyped_pipeline(state) pm.passes.extend(untyped_passes.passes) @@ -104,9 +118,15 @@ def define_nopython_pipeline(state, name="dpex_dpjit_nopython"): return pm +class _DpjitPassBuilderMlir(_DpjitPassBuilder): + _use_mlir = True + + class DpjitCompiler(CompilerBase): """Dpex's compiler pipeline to offload parfor nodes into SYCL kernels.""" + _pass_builder = _DpjitPassBuilder + def define_pipelines(self): pms = [] self.state.parfor_diagnostics = ExtendedParforDiagnostics() @@ -114,7 +134,15 @@ def define_pipelines(self): "parfor_diagnostics" ] = self.state.parfor_diagnostics if not self.state.flags.force_pyobject: - pms.append(_DpjitPassBuilder.define_nopython_pipeline(self.state)) + pms.append(self._pass_builder.define_nopython_pipeline(self.state)) if self.state.status.can_fallback or self.state.flags.force_pyobject: raise UnsupportedCompilationModeError() return pms + + +class DpjitCompilerMlir(DpjitCompiler): + _pass_builder = _DpjitPassBuilderMlir + + +def get_compiler(use_mlir): + return DpjitCompilerMlir if use_mlir else DpjitCompiler diff --git a/numba_dpex/decorators.py b/numba_dpex/decorators.py index b80449603f..3a69c2df8e 100644 --- a/numba_dpex/decorators.py +++ b/numba_dpex/decorators.py @@ -13,7 +13,9 @@ compile_func, compile_func_template, ) -from numba_dpex.core.pipelines.dpjit_compiler import DpjitCompiler +from numba_dpex.core.pipelines.dpjit_compiler import get_compiler + +from .config import USE_MLIR def kernel( @@ -152,9 +154,12 @@ def dpjit(*args, **kws): "pipeline class is set for dpjit and is ignored", RuntimeWarning ) del kws["forceobj"] + + use_mlir = kws.pop("use_mlir", bool(USE_MLIR)) + kws.update({"nopython": True}) kws.update({"parallel": True}) - kws.update({"pipeline_class": DpjitCompiler}) + kws.update({"pipeline_class": get_compiler(use_mlir)}) # FIXME: When trying to use dpex's target context, overloads do not work # properly. We will turn on dpex target once the issue is fixed. diff --git a/numba_dpex/tests/_helper.py b/numba_dpex/tests/_helper.py index cb9a166f9b..aea4fb8972 100644 --- a/numba_dpex/tests/_helper.py +++ b/numba_dpex/tests/_helper.py @@ -6,12 +6,23 @@ import contextlib import shutil +from functools import cache import dpctl import dpnp import pytest -from numba_dpex import config, numba_sem_version +from numba_dpex import config, dpjit, numba_sem_version + + +@cache +def has_numba_mlir(): + try: + import numba_mlir + except ImportError: + return False + + return True def has_opencl_gpu(): @@ -89,6 +100,10 @@ def is_windows(): not has_level_zero(), reason="No level-zero GPU platforms available", ) +skip_no_numba_mlir = pytest.mark.skipif( + not has_numba_mlir(), + reason="numba-mlir package is not availabe", +) filter_strings = [ pytest.param("level_zero:gpu:0", marks=skip_no_level_zero_gpu), @@ -123,6 +138,14 @@ def is_windows(): ) +decorators = [ + pytest.param(dpjit, id="dpjit"), + pytest.param( + dpjit(use_mlir=True), id="dpjit_mlir", marks=skip_no_numba_mlir + ), +] + + @contextlib.contextmanager def override_config(name, value, config=config): """ diff --git a/numba_dpex/tests/test_prange.py b/numba_dpex/tests/test_prange.py index 4e42db8cf6..fbf4897b40 100644 --- a/numba_dpex/tests/test_prange.py +++ b/numba_dpex/tests/test_prange.py @@ -12,9 +12,12 @@ from numba_dpex import dpjit, prange +from ._helper import decorators -def test_one_prange_mul(): - @dpjit + +@pytest.mark.parametrize("jit", decorators) +def test_one_prange_mul(jit): + @jit def f(a, b): for i in prange(4): b[i, 0] = a[i, 0] * 10 @@ -35,6 +38,33 @@ def f(a, b): assert nb[i, 0] == na[i, 0] * 10 +@pytest.mark.parametrize("jit", decorators) +def test_one_prange_mul_nested(jit): + @jit + def f_inner(a, b): + for i in prange(4): + b[i, 0] = a[i, 0] * 10 + return + + @jit + def f(a, b): + return f_inner(a, b) + + device = dpctl.select_default_device() + + m = 8 + n = 8 + a = dpnp.ones((m, n), device=device) + b = dpnp.ones((m, n), device=device) + + f(a, b) + na = dpnp.asnumpy(a) + nb = dpnp.asnumpy(b) + + for i in range(4): + assert nb[i, 0] == na[i, 0] * 10 + + @pytest.mark.skip(reason="dpnp.add() doesn't support variable + scalar.") def test_one_prange_add_scalar(): @dpjit @@ -155,8 +185,9 @@ def f(a, b): assert np.all(b.asnumpy() == 12) -def test_two_consecutive_prange(): - @dpjit +@pytest.mark.parametrize("jit", decorators) +def test_two_consecutive_prange(jit): + @jit def prange_example(a, b, c, d): for i in prange(n): c[i] = a[i] + b[i]