Skip to content

Commit

Permalink
Fix Pylint issues in the numba_dpex.experimental module.
Browse files Browse the repository at this point in the history
  • Loading branch information
Diptorup Deb committed Nov 6, 2023
1 parent de4f7de commit ade98ca
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 59 deletions.
10 changes: 10 additions & 0 deletions numba_dpex/core/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,12 @@ def _optimize_final_module(self):
pmb.populate(pm)
pm.run(self._final_module)

def optimize_final_module(self):
"""Public member function to optimize the final LLVM module in the
library. The function calls the protected overridden function.
"""
self._optimize_final_module()

def _finalize_specific(self):
# Fix global naming
for gv in self._final_module.global_variables:
Expand All @@ -68,6 +74,10 @@ def get_asm_str(self):
# generated (in numba_dpex.compiler).
return None

@property
def final_module(self):
return self._final_module


class JITSPIRVCodegen(CPUCodegen):
"""
Expand Down
10 changes: 2 additions & 8 deletions numba_dpex/experimental/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from .kernel_dispatcher import KernelDispatcher


def kernel(func_or_sig=None, debug=False, cache=False, **options):
def kernel(func_or_sig=None, **options):
"""A decorator to define a kernel function.
A kernel function is conceptually equivalent to a SYCL kernel function, and
Expand All @@ -27,12 +27,9 @@ def kernel(func_or_sig=None, debug=False, cache=False, **options):
# FIXME: The options need to be evaluated and checked here like it is
# done in numba.core.decorators.jit

def _kernel_dispatcher(pyfunc, sigs=None):
def _kernel_dispatcher(pyfunc):
return KernelDispatcher(
pyfunc=pyfunc,
debug_flags=debug,
enable_cache=cache,
specialization_sigs=sigs,
targetoptions=options,
)

Expand Down Expand Up @@ -64,9 +61,6 @@ def _kernel_dispatcher(pyfunc, sigs=None):
def _specialized_kernel_dispatcher(pyfunc):
return KernelDispatcher(
pyfunc=pyfunc,
debug_flags=debug,
enable_cache=cache,
specialization_sigs=func_or_sig,
)

return _specialized_kernel_dispatcher
Expand Down
72 changes: 21 additions & 51 deletions numba_dpex/experimental/kernel_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,14 @@
"""Implements a new numba dispatcher class and a compiler class to compile and
call numba_dpex.kernel decorated function.
"""
import functools
from collections import Counter, OrderedDict, namedtuple
from collections import namedtuple
from contextlib import ExitStack

import numba.core.event as ev
from numba.core import errors, sigutils, types, utils
from numba.core.caching import NullCache
from numba.core import errors, sigutils, types
from numba.core.compiler import CompileResult
from numba.core.compiler_lock import global_compiler_lock
from numba.core.dispatcher import Dispatcher, _DispatcherBase, _FunctionCompiler
from numba.core.dispatcher import Dispatcher, _FunctionCompiler
from numba.core.typing.typeof import Purpose, typeof

from numba_dpex import config, spirv_generator
Expand Down Expand Up @@ -85,12 +83,12 @@ def _compile_to_spirv(

# makes sure that the spir_func is completely inlined into the
# spir_kernel wrapper
kernel_library._optimize_final_module()
kernel_library.optimize_final_module()
# Compiled the LLVM IR to SPIR-V
kernel_spirv_module = spirv_generator.llvm_to_spirv(
kernel_targetctx,
kernel_library._final_module,
kernel_library._final_module.as_bitcode(),
kernel_library.final_module,
kernel_library.final_module.as_bitcode(),
)
return _KernelModule(
kernel_name=kernel_fn.name, kernel_bitcode=kernel_spirv_module
Expand Down Expand Up @@ -159,7 +157,7 @@ def _compile_cached(
"w",
encoding="UTF-8",
) as f:
f.write(kernel_cres.library._final_module)
f.write(kernel_cres.library.final_module)

except errors.TypingError as e:
self._failed_cache[key] = e
Expand Down Expand Up @@ -188,57 +186,29 @@ class KernelDispatcher(Dispatcher):
def __init__(
self,
pyfunc,
debug_flags=None,
compile_flags=None,
specialization_sigs=None,
enable_cache=True,
locals={},
targetoptions={},
impl_kind="kernel",
local_vars_to_numba_types=None,
targetoptions=None,
pipeline_class=kernel_compiler.KernelCompiler,
):
if targetoptions is None:
targetoptions = {}

if local_vars_to_numba_types is None:
local_vars_to_numba_types = {}

targetoptions["nopython"] = True
targetoptions["experimental"] = True

self._kernel_name = pyfunc.__name__
self.typingctx = self.targetdescr.typing_context
self.targetctx = self.targetdescr.target_context

pysig = utils.pysignature(pyfunc)
arg_count = len(pysig.parameters)

self.overloads = OrderedDict()

can_fallback = not targetoptions.get("nopython", False)

_DispatcherBase.__init__(
self,
arg_count,
pyfunc,
pysig,
can_fallback,
exact_match_required=False,
super().__init__(
py_func=pyfunc,
locals=local_vars_to_numba_types,
impl_kind="kernel",
targetoptions=targetoptions,
pipeline_class=pipeline_class,
)

functools.update_wrapper(self, pyfunc)

self.targetoptions = targetoptions
self.locals = locals
self._cache = NullCache()
compiler_class = self._impl_kinds[impl_kind]
self._impl_kind = impl_kind
self._compiler: _KernelCompiler = compiler_class(
pyfunc, self.targetdescr, targetoptions, locals, pipeline_class
)
self._cache_hits = Counter()
self._cache_misses = Counter()

self._type = types.Dispatcher(self)
self.typingctx.insert_global(self, self._type)

# Remember target restriction
self._required_target_backend = targetoptions.get("target_backend")

def typeof_pyval(self, val):
"""
Resolve the Numba type of Python value *val*.
Expand Down

0 comments on commit ade98ca

Please sign in to comment.