From ade98ca97f7ae66128a821a433dd7deb469f63e7 Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Sun, 5 Nov 2023 09:54:36 -0600 Subject: [PATCH] Fix Pylint issues in the numba_dpex.experimental module. --- numba_dpex/core/codegen.py | 10 +++ numba_dpex/experimental/decorators.py | 10 +-- numba_dpex/experimental/kernel_dispatcher.py | 72 ++++++-------------- 3 files changed, 33 insertions(+), 59 deletions(-) diff --git a/numba_dpex/core/codegen.py b/numba_dpex/core/codegen.py index 11e22418fe..7c97f705a5 100644 --- a/numba_dpex/core/codegen.py +++ b/numba_dpex/core/codegen.py @@ -57,6 +57,12 @@ def _optimize_final_module(self): pmb.populate(pm) pm.run(self._final_module) + def optimize_final_module(self): + """Public member function to optimize the final LLVM module in the + library. The function calls the protected overridden function. + """ + self._optimize_final_module() + def _finalize_specific(self): # Fix global naming for gv in self._final_module.global_variables: @@ -68,6 +74,10 @@ def get_asm_str(self): # generated (in numba_dpex.compiler). return None + @property + def final_module(self): + return self._final_module + class JITSPIRVCodegen(CPUCodegen): """ diff --git a/numba_dpex/experimental/decorators.py b/numba_dpex/experimental/decorators.py index d87ae14771..3559a7ec69 100644 --- a/numba_dpex/experimental/decorators.py +++ b/numba_dpex/experimental/decorators.py @@ -13,7 +13,7 @@ from .kernel_dispatcher import KernelDispatcher -def kernel(func_or_sig=None, debug=False, cache=False, **options): +def kernel(func_or_sig=None, **options): """A decorator to define a kernel function. A kernel function is conceptually equivalent to a SYCL kernel function, and @@ -27,12 +27,9 @@ def kernel(func_or_sig=None, debug=False, cache=False, **options): # FIXME: The options need to be evaluated and checked here like it is # done in numba.core.decorators.jit - def _kernel_dispatcher(pyfunc, sigs=None): + def _kernel_dispatcher(pyfunc): return KernelDispatcher( pyfunc=pyfunc, - debug_flags=debug, - enable_cache=cache, - specialization_sigs=sigs, targetoptions=options, ) @@ -64,9 +61,6 @@ def _kernel_dispatcher(pyfunc, sigs=None): def _specialized_kernel_dispatcher(pyfunc): return KernelDispatcher( pyfunc=pyfunc, - debug_flags=debug, - enable_cache=cache, - specialization_sigs=func_or_sig, ) return _specialized_kernel_dispatcher diff --git a/numba_dpex/experimental/kernel_dispatcher.py b/numba_dpex/experimental/kernel_dispatcher.py index ba8dba48a5..908fe23c8b 100644 --- a/numba_dpex/experimental/kernel_dispatcher.py +++ b/numba_dpex/experimental/kernel_dispatcher.py @@ -5,16 +5,14 @@ """Implements a new numba dispatcher class and a compiler class to compile and call numba_dpex.kernel decorated function. """ -import functools -from collections import Counter, OrderedDict, namedtuple +from collections import namedtuple from contextlib import ExitStack import numba.core.event as ev -from numba.core import errors, sigutils, types, utils -from numba.core.caching import NullCache +from numba.core import errors, sigutils, types from numba.core.compiler import CompileResult from numba.core.compiler_lock import global_compiler_lock -from numba.core.dispatcher import Dispatcher, _DispatcherBase, _FunctionCompiler +from numba.core.dispatcher import Dispatcher, _FunctionCompiler from numba.core.typing.typeof import Purpose, typeof from numba_dpex import config, spirv_generator @@ -85,12 +83,12 @@ def _compile_to_spirv( # makes sure that the spir_func is completely inlined into the # spir_kernel wrapper - kernel_library._optimize_final_module() + kernel_library.optimize_final_module() # Compiled the LLVM IR to SPIR-V kernel_spirv_module = spirv_generator.llvm_to_spirv( kernel_targetctx, - kernel_library._final_module, - kernel_library._final_module.as_bitcode(), + kernel_library.final_module, + kernel_library.final_module.as_bitcode(), ) return _KernelModule( kernel_name=kernel_fn.name, kernel_bitcode=kernel_spirv_module @@ -159,7 +157,7 @@ def _compile_cached( "w", encoding="UTF-8", ) as f: - f.write(kernel_cres.library._final_module) + f.write(kernel_cres.library.final_module) except errors.TypingError as e: self._failed_cache[key] = e @@ -188,57 +186,29 @@ class KernelDispatcher(Dispatcher): def __init__( self, pyfunc, - debug_flags=None, - compile_flags=None, - specialization_sigs=None, - enable_cache=True, - locals={}, - targetoptions={}, - impl_kind="kernel", + local_vars_to_numba_types=None, + targetoptions=None, pipeline_class=kernel_compiler.KernelCompiler, ): + if targetoptions is None: + targetoptions = {} + + if local_vars_to_numba_types is None: + local_vars_to_numba_types = {} + targetoptions["nopython"] = True targetoptions["experimental"] = True self._kernel_name = pyfunc.__name__ - self.typingctx = self.targetdescr.typing_context - self.targetctx = self.targetdescr.target_context - - pysig = utils.pysignature(pyfunc) - arg_count = len(pysig.parameters) - - self.overloads = OrderedDict() - - can_fallback = not targetoptions.get("nopython", False) - _DispatcherBase.__init__( - self, - arg_count, - pyfunc, - pysig, - can_fallback, - exact_match_required=False, + super().__init__( + py_func=pyfunc, + locals=local_vars_to_numba_types, + impl_kind="kernel", + targetoptions=targetoptions, + pipeline_class=pipeline_class, ) - functools.update_wrapper(self, pyfunc) - - self.targetoptions = targetoptions - self.locals = locals - self._cache = NullCache() - compiler_class = self._impl_kinds[impl_kind] - self._impl_kind = impl_kind - self._compiler: _KernelCompiler = compiler_class( - pyfunc, self.targetdescr, targetoptions, locals, pipeline_class - ) - self._cache_hits = Counter() - self._cache_misses = Counter() - - self._type = types.Dispatcher(self) - self.typingctx.insert_global(self, self._type) - - # Remember target restriction - self._required_target_backend = targetoptions.get("target_backend") - def typeof_pyval(self, val): """ Resolve the Numba type of Python value *val*.