Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Miscellaneous changes to compiler internals #1173

Merged
merged 2 commits into from
Oct 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 24 additions & 56 deletions numba_dpex/core/parfors/parfor_lowerer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
)

from numba_dpex import config
from numba_dpex.core.datamodel.models import dpex_data_model_manager as dpex_dmm
from numba_dpex.core.parfors.reduction_helper import (
ReductionHelper,
ReductionKernelVariables,
Expand All @@ -26,8 +27,6 @@
create_reduction_remainder_kernel_for_parfor,
)

from numba_dpex.core.datamodel.models import dpex_data_model_manager as dpex_dmm

# A global list of kernels to keep the objects alive indefinitely.
keep_alive_kernels = []

Expand Down Expand Up @@ -89,7 +88,9 @@ def _get_exec_queue(self, kernel_fn, lowerer):
"""Creates a stack variable storing the sycl queue pointer used to
launch the kernel function.
"""
self.kernel_builder = KernelLaunchIRBuilder(lowerer, kernel_fn.kernel)
self.kernel_builder = KernelLaunchIRBuilder(
lowerer.context, lowerer.builder, kernel_fn.kernel.addressof_ref()
)

# Create a local variable storing a pointer to a DPCTLSyclQueueRef
# pointer.
Expand All @@ -109,71 +110,38 @@ def _build_kernel_arglist(self, kernel_fn, lowerer):
AssertionError: If the LLVM IR Value for an argument defined in
Numba IR is not found.
"""
num_flattened_args = 0
self.num_flattened_args = 0

# Compute number of args to be passed to the kernel. Note that the
# actual number of kernel arguments is greater than the count of
# kernel_fn.kernel_args as arrays get flattened.
for arg_type in kernel_fn.kernel_arg_types:
if isinstance(arg_type, DpnpNdArray):
datamodel = dpex_dmm.lookup(arg_type)
num_flattened_args += datamodel.flattened_field_count
self.num_flattened_args += datamodel.flattened_field_count
elif arg_type == types.complex64 or arg_type == types.complex128:
num_flattened_args += 2
self.num_flattened_args += 2
else:
num_flattened_args += 1
self.num_flattened_args += 1

# Create LLVM values for the kernel args list and kernel arg types list
self.args_list = self.kernel_builder.allocate_kernel_arg_array(
num_flattened_args
self.num_flattened_args
)
self.args_ty_list = self.kernel_builder.allocate_kernel_arg_ty_array(
num_flattened_args
self.num_flattened_args
)
callargs_ptrs = []
for arg in kernel_fn.kernel_args:
callargs_ptrs.append(_getvar(lowerer, arg))

self.kernel_builder.populate_kernel_args_and_args_ty_arrays(
kernel_argtys=kernel_fn.kernel_arg_types,
callargs_ptrs=callargs_ptrs,
args_list=self.args_list,
args_ty_list=self.args_ty_list,
datamodel_mgr=dpex_dmm,
)
# Populate the args_list and the args_ty_list LLVM arrays
self.kernel_arg_num = 0
for arg_num, arg in enumerate(kernel_fn.kernel_args):
argtype = kernel_fn.kernel_arg_types[arg_num]
llvm_val = _getvar(lowerer, arg)
if isinstance(argtype, DpnpNdArray):
datamodel = dpex_dmm.lookup(argtype)
self.kernel_builder.build_array_arg(
array_val=llvm_val,
array_data_model=datamodel,
array_rank=argtype.ndim,
arg_list=self.args_list,
args_ty_list=self.args_ty_list,
arg_num=self.kernel_arg_num,
)
self.kernel_arg_num += datamodel.flattened_field_count
else:
if argtype == types.complex64:
self.kernel_builder.build_complex_arg(
llvm_val,
types.float32,
self.args_list,
self.args_ty_list,
self.kernel_arg_num,
)
self.kernel_arg_num += 2
elif argtype == types.complex128:
self.kernel_builder.build_complex_arg(
llvm_val,
types.float64,
self.args_list,
self.args_ty_list,
self.kernel_arg_num,
)
self.kernel_arg_num += 2
else:
self.kernel_builder.build_arg(
llvm_val,
argtype,
self.args_list,
self.args_ty_list,
self.kernel_arg_num,
)
self.kernel_arg_num += 1

def _submit_parfor_kernel(
self,
Expand Down Expand Up @@ -213,7 +181,7 @@ def _submit_parfor_kernel(
# Submit a synchronous kernel
self.kernel_builder.submit_sync_kernel(
self.curr_queue,
self.kernel_arg_num,
self.num_flattened_args,
self.args_list,
self.args_ty_list,
global_range,
Expand Down Expand Up @@ -255,7 +223,7 @@ def _submit_reduction_main_parfor_kernel(
# Submit a synchronous kernel
self.kernel_builder.submit_sync_kernel(
self.curr_queue,
self.kernel_arg_num,
self.num_flattened_args,
self.args_list,
self.args_ty_list,
global_range,
Expand Down Expand Up @@ -290,7 +258,7 @@ def _submit_reduction_remainder_parfor_kernel(
# Submit a synchronous kernel
self.kernel_builder.submit_sync_kernel(
self.curr_queue,
self.kernel_arg_num,
self.num_flattened_args,
self.args_list,
self.args_ty_list,
global_range,
Expand Down
10 changes: 7 additions & 3 deletions numba_dpex/core/parfors/reduction_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,13 +393,17 @@ def lowerer(self):
def work_group_size(self):
return self._work_group_size

def copy_final_sum_to_host(self, psrfor_kernel):
def copy_final_sum_to_host(self, parfor_kernel):
lowerer = self.lowerer
ir_builder = KernelLaunchIRBuilder(lowerer, psrfor_kernel.kernel)
ir_builder = KernelLaunchIRBuilder(
lowerer.context,
lowerer.builder,
parfor_kernel.kernel.addressof_ref(),
)

# Create a local variable storing a pointer to a DPCTLSyclQueueRef
# pointer.
curr_queue = ir_builder.get_queue(exec_queue=psrfor_kernel.queue)
curr_queue = ir_builder.get_queue(exec_queue=parfor_kernel.queue)

builder = lowerer.builder
context = lowerer.context
Expand Down
13 changes: 11 additions & 2 deletions numba_dpex/core/pipelines/kernel_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,6 @@ def define_nopython_lowering_pipeline(state, name="dpex_kernel_lowering"):
)
pm.add_pass(IRLegalization, "ensure IR is legal prior to lowering")

# lower
# NativeLowering has some issue with freevar ambiguity,
# therefore, we are using QualNameDisambiguationLowering instead
# numba-dpex github issue: https://github.com/IntelPython/numba-dpex/issues/898
Expand Down Expand Up @@ -173,12 +172,22 @@ def define_nopython_pipeline(state, name="dpex_kernel_nopython"):


class KernelCompiler(CompilerBase):
"""Dpex's kernel compilation pipeline."""
"""Dpex's kernel compilation pipeline."""

def define_pipelines(self):
pms = []
if not self.state.flags.force_pyobject:
pms.append(_KernelPassBuilder.define_nopython_pipeline(self.state))
if self.state.status.can_fallback or self.state.flags.force_pyobject:
raise UnsupportedCompilationModeError()

# Compile the kernel without generating a cpython or a cfunc wrapper
self.state.flags.no_cpython_wrapper = True
self.state.flags.no_cfunc_wrapper = True
# The pass pipeline does not generate an executable when compiling a
# kernel function. Instead, the
# kernel_dispatcher._KernelCompiler.compile generates the executable in
# the form of a host callable launcher function
self.state.flags.no_compile = True

return pms
68 changes: 59 additions & 9 deletions numba_dpex/core/utils/kernel_launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from numba_dpex import utils
from numba_dpex.core.runtime.context import DpexRTContext
from numba_dpex.core.types import DpnpNdArray
from numba_dpex.dpctl_iface import DpctlCAPIFnBuilder
from numba_dpex.dpctl_iface._helpers import numba_type_to_dpctl_typenum

Expand All @@ -19,19 +20,17 @@ class KernelLaunchIRBuilder:
for submitting kernels. The LLVM Values that
"""

def __init__(self, lowerer, kernel):
def __init__(self, context, builder, kernel_addr):
"""Create a KernelLauncher for the specified kernel.

Args:
lowerer: The Numba Lowerer that will be used to generate the code.
kernel: The SYCL kernel for which we are generating the code.
num_inputs: The number of arguments to the kernels.
context: A Numba target context that will be used to generate the code.
builder: An llvmlite IRBuilder instance used to generate LLVM IR.
kernel_addr: The address of a SYCL kernel.
"""
self.lowerer = lowerer
self.context = self.lowerer.context
self.builder = self.lowerer.builder
self.kernel = kernel
self.kernel_addr = self.kernel.addressof_ref()
self.context = context
self.builder = builder
self.kernel_addr = kernel_addr
self.rtctx = DpexRTContext(self.context)

def _build_nullptr(self):
Expand Down Expand Up @@ -402,3 +401,54 @@ def submit_sync_kernel(
lr = self._create_sycl_range(local_range)
args = args1 + [lr] + args2
self.rtctx.submit_ndrange(self.builder, *args)

def populate_kernel_args_and_args_ty_arrays(
self,
kernel_argtys,
callargs_ptrs,
args_list,
args_ty_list,
datamodel_mgr,
):
kernel_arg_num = 0
for arg_num, argtype in enumerate(kernel_argtys):
llvm_val = callargs_ptrs[arg_num]
if isinstance(argtype, DpnpNdArray):
datamodel = datamodel_mgr.lookup(argtype)
self.build_array_arg(
array_val=llvm_val,
array_data_model=datamodel,
array_rank=argtype.ndim,
arg_list=args_list,
args_ty_list=args_ty_list,
arg_num=kernel_arg_num,
)
kernel_arg_num += datamodel.flattened_field_count
else:
if argtype == types.complex64:
self.build_complex_arg(
llvm_val,
types.float32,
args_list,
args_ty_list,
kernel_arg_num,
)
kernel_arg_num += 2
elif argtype == types.complex128:
self.build_complex_arg(
llvm_val,
types.float64,
args_list,
args_ty_list,
kernel_arg_num,
)
kernel_arg_num += 2
else:
self.build_arg(
llvm_val,
argtype,
args_list,
args_ty_list,
kernel_arg_num,
)
kernel_arg_num += 1
Loading