Skip to content

Commit

Permalink
Merge pull request #1168 from IntelPython/cleanups_to_kernel_target
Browse files Browse the repository at this point in the history
Cleanups to kernel target
  • Loading branch information
Diptorup Deb authored Oct 12, 2023
2 parents 1ece135 + 262ed52 commit 86e1be7
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 54 deletions.
33 changes: 31 additions & 2 deletions numba_dpex/core/descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from functools import cached_property

from numba.core import typing
from numba.core import options, targetconfig, typing
from numba.core.cpu import CPUTargetOptions
from numba.core.descriptors import TargetDescriptor

Expand All @@ -15,13 +15,42 @@
DpexKernelTypingContext,
)

_option_mapping = options._mapping


def _inherit_if_not_set(flags, options, name, default=targetconfig._NotSet):
if name in options:
setattr(flags, name, options[name])
return

cstk = targetconfig.ConfigStack()
if cstk:
# inherit
top = cstk.top()
if hasattr(top, name):
setattr(flags, name, getattr(top, name))
return

if default is not targetconfig._NotSet:
setattr(flags, name, default)


class DpexTargetOptions(CPUTargetOptions):
experimental = _option_mapping("experimental")
release_gil = _option_mapping("release_gil")

def finalize(self, flags, options):
super().finalize(flags, options)
_inherit_if_not_set(flags, options, "experimental", False)
_inherit_if_not_set(flags, options, "release_gil", False)


class DpexKernelTarget(TargetDescriptor):
"""
Implements a target descriptor for numba_dpex.kernel decorated functions.
"""

options = CPUTargetOptions
options = DpexTargetOptions

@cached_property
def _toplevel_target_context(self):
Expand Down
6 changes: 2 additions & 4 deletions numba_dpex/core/kernel_interface/func.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
#
# SPDX-License-Identifier: Apache-2.0

"""_summary_
"""

from numba.core import sigutils, types
from numba.core.typing.templates import AbstractTemplate, ConcreteTemplate
Expand Down Expand Up @@ -67,7 +65,7 @@ def compile(self, arg_types, return_types):
debug=self._debug,
)
func = cres.library.get_function(cres.fndesc.llvm_func_name)
cres.target_context.mark_ocl_device(func)
cres.target_context.set_spir_func_calling_conv(func)

return cres

Expand Down Expand Up @@ -159,7 +157,7 @@ def compile(self, args):
debug=self._debug,
)
func = cres.library.get_function(cres.fndesc.llvm_func_name)
cres.target_context.mark_ocl_device(func)
cres.target_context.set_spir_func_calling_conv(func)
libs = [cres.library]

cres.target_context.insert_user_function(self, cres.fndesc, libs)
Expand Down
2 changes: 1 addition & 1 deletion numba_dpex/core/kernel_interface/spirv_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def compile(
)

func = cres.library.get_function(cres.fndesc.llvm_func_name)
kernel = cres.target_context.prepare_ocl_kernel(
kernel = cres.target_context.prepare_spir_kernel(
func, cres.signature.args
)
cres.library._optimize_final_module()
Expand Down
63 changes: 17 additions & 46 deletions numba_dpex/core/targets/kernel_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#
# SPDX-License-Identifier: Apache-2.0

import re

from functools import cached_property

import dpnp
Expand All @@ -19,16 +19,14 @@
from numba_dpex.core.datamodel.models import _init_data_model_manager
from numba_dpex.core.exceptions import UnsupportedKernelArgumentError
from numba_dpex.core.typeconv import to_usm_ndarray
from numba_dpex.core.types import DpnpNdArray, USMNdArray
from numba_dpex.core.types import USMNdArray
from numba_dpex.core.utils import get_info_from_suai
from numba_dpex.utils import address_space, calling_conv

from .. import codegen

CC_SPIR_KERNEL = "spir_kernel"
CC_SPIR_FUNC = "spir_func"
VALID_CHARS = re.compile(r"[^a-z0-9]", re.I)
LINK_ATOMIC = 111
LLVM_SPIRV_ARGS = 112


Expand Down Expand Up @@ -89,14 +87,15 @@ def resolve_argument_type(self, val):

def load_additional_registries(self):
"""Register the OpenCL API and math and other functions."""
from numba.core.typing import cmathdecl, npydecl
from numba.core.typing import cmathdecl, enumdecl, npydecl

from ...ocl import mathdecl, ocldecl

self.install_registry(ocldecl.registry)
self.install_registry(mathdecl.registry)
self.install_registry(cmathdecl.registry)
self.install_registry(npydecl.registry)
self.install_registry(enumdecl.registry)


class SyclDevice(GPU):
Expand All @@ -105,7 +104,7 @@ class SyclDevice(GPU):
pass


DPEX_KERNEL_TARGET_NAME = "SyclDevice"
DPEX_KERNEL_TARGET_NAME = "dpex_kernel"

target_registry[DPEX_KERNEL_TARGET_NAME] = SyclDevice

Expand Down Expand Up @@ -165,7 +164,7 @@ def _gen_arg_base_type(self, fn):
name = llvmir.MetaDataString(mod, "kernel_arg_base_type")
return mod.add_metadata([name] + consts)

def _finalize_wrapper_module(self, fn):
def _finalize_kernel_wrapper_module(self, fn):
"""Add metadata and calling convention to the wrapper function.
The helper function adds function metadata to the wrapper function and
Expand All @@ -177,41 +176,12 @@ def _finalize_wrapper_module(self, fn):
fn: LLVM function representing the "kernel" wrapper function.
"""
mod = fn.module
# Set norecurse
fn.attributes.add("norecurse")
# Set SPIR kernel calling convention
fn.calling_convention = CC_SPIR_KERNEL

# Mark kernels
ocl_kernels = cgutils.get_or_insert_named_metadata(
mod, "opencl.kernels"
)
ocl_kernels.add(
mod.add_metadata(
[
fn,
self._gen_arg_addrspace_md(fn),
self._gen_arg_type(fn),
self._gen_arg_type_qual(fn),
self._gen_arg_base_type(fn),
],
)
)

# Other metadata
others = [
"opencl.used.extensions",
"opencl.used.optional.core.features",
"opencl.compiler.options",
]

for name in others:
nmd = cgutils.get_or_insert_named_metadata(mod, name)
if not nmd.operands:
mod.add_metadata([])

def _generate_kernel_wrapper(self, func, argtypes):
def _generate_spir_kernel_wrapper(self, func, argtypes):
module = func.module
arginfo = self.get_arg_packer(argtypes)
wrapperfnty = llvmir.FunctionType(
Expand All @@ -227,7 +197,7 @@ def _generate_kernel_wrapper(self, func, argtypes):
func = llvmir.Function(wrapper_module, fnty, name=func.name)
func.calling_convention = CC_SPIR_FUNC
wrapper = llvmir.Function(wrapper_module, wrapperfnty, name=wrappername)
builder = llvmir.IRBuilder(wrapper.append_basic_block(""))
builder = llvmir.IRBuilder(wrapper.append_basic_block("entry"))

callargs = arginfo.from_arguments(builder, wrapper.args)

Expand All @@ -237,7 +207,7 @@ def _generate_kernel_wrapper(self, func, argtypes):
)
builder.ret_void()

self._finalize_wrapper_module(wrapper)
self._finalize_kernel_wrapper_module(wrapper)

# Link the spir_func module to the wrapper module
module.link_in(ll.parse_assembly(str(wrapper_module)))
Expand All @@ -251,7 +221,10 @@ def __init__(self, typingctx, target=DPEX_KERNEL_TARGET_NAME):
super().__init__(typingctx, target)

def init(self):
self._internal_codegen = codegen.JITSPIRVCodegen("numba_dpex.jit")
"""Called by the super().__init__ constructor to initalize the child
class.
"""
self._internal_codegen = codegen.JITSPIRVCodegen("numba_dpex.kernel")
self._target_data = ll.create_target_data(
codegen.SPIR_DATA_LAYOUT[utils.MACHINE_BITS]
)
Expand All @@ -271,7 +244,6 @@ def init(self):
self.ufunc_db = copy.deepcopy(ufunc_db)
self.cpu_context = cpu_target.target_context

# Overrides
def create_module(self, name):
return self._internal_codegen._create_empty_module(name)

Expand Down Expand Up @@ -355,14 +327,14 @@ def mangler(self, name, argtypes, abi_tags=(), uid=None):
name + "dpex_fn", argtypes, abi_tags=abi_tags, uid=uid
)

def prepare_ocl_kernel(self, func, argtypes):
def prepare_spir_kernel(self, func, argtypes):
module = func.module
func.linkage = "linkonce_odr"
module.data_layout = codegen.SPIR_DATA_LAYOUT[self.address_size]
wrapper = self._generate_kernel_wrapper(func, argtypes)
wrapper = self._generate_spir_kernel_wrapper(func, argtypes)
return wrapper

def mark_ocl_device(self, func):
def set_spir_func_calling_conv(self, func):
# Adapt to SPIR
func.calling_convention = CC_SPIR_FUNC
func.linkage = "linkonce_odr"
Expand Down Expand Up @@ -436,7 +408,6 @@ def addrspacecast(self, builder, src, addrspace):
ptras = llvmir.PointerType(src.type.pointee, addrspace=addrspace)
return builder.addrspacecast(src, ptras)

# Overrides
def get_ufunc_info(self, ufunc_key):
return self.ufunc_db[ufunc_key]

Expand All @@ -446,7 +417,7 @@ class DpexCallConv(MinimalCallConv):
numba_dpex's calling convention derives from
:class:`numba.core.callconv import MinimalCallConv`. The
:class:`DpexCallConv` overriddes :func:`call_function`.
:class:`DpexCallConv` overrides :func:`call_function`.
"""

Expand Down
2 changes: 1 addition & 1 deletion numba_dpex/spirv_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from subprocess import CalledProcessError, check_call

from numba_dpex import config
from numba_dpex.core.targets.kernel_target import LINK_ATOMIC, LLVM_SPIRV_ARGS
from numba_dpex.core.targets.kernel_target import LLVM_SPIRV_ARGS


def _raise_bad_env_path(msg, path, extra=None):
Expand Down

0 comments on commit 86e1be7

Please sign in to comment.