From d254ee4efd3d297d2810175a75a05c8c4d051ef0 Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Thu, 12 Oct 2023 17:45:12 -0500 Subject: [PATCH] Refactoring the kernel_launcher.KernelLaunchIRBuilder API. - Changes to the class constructor to make it easier to use from places other than the parfor_lowerer. Removed the need to pass a lowerer object and instead pass a context and builder. - Adds a new helper function populate_kernel_args_and_args_ty_arrays that populates arrays storing kernel args and kernel arg types. --- numba_dpex/core/parfors/parfor_lowerer.py | 80 +++++++-------------- numba_dpex/core/parfors/reduction_helper.py | 10 ++- numba_dpex/core/utils/kernel_launcher.py | 68 +++++++++++++++--- 3 files changed, 90 insertions(+), 68 deletions(-) diff --git a/numba_dpex/core/parfors/parfor_lowerer.py b/numba_dpex/core/parfors/parfor_lowerer.py index 07fbe50ba3..a88b3f2ab5 100644 --- a/numba_dpex/core/parfors/parfor_lowerer.py +++ b/numba_dpex/core/parfors/parfor_lowerer.py @@ -12,6 +12,7 @@ ) from numba_dpex import config +from numba_dpex.core.datamodel.models import dpex_data_model_manager as dpex_dmm from numba_dpex.core.parfors.reduction_helper import ( ReductionHelper, ReductionKernelVariables, @@ -26,8 +27,6 @@ create_reduction_remainder_kernel_for_parfor, ) -from numba_dpex.core.datamodel.models import dpex_data_model_manager as dpex_dmm - # A global list of kernels to keep the objects alive indefinitely. keep_alive_kernels = [] @@ -89,7 +88,9 @@ def _get_exec_queue(self, kernel_fn, lowerer): """Creates a stack variable storing the sycl queue pointer used to launch the kernel function. """ - self.kernel_builder = KernelLaunchIRBuilder(lowerer, kernel_fn.kernel) + self.kernel_builder = KernelLaunchIRBuilder( + lowerer.context, lowerer.builder, kernel_fn.kernel.addressof_ref() + ) # Create a local variable storing a pointer to a DPCTLSyclQueueRef # pointer. @@ -109,7 +110,7 @@ def _build_kernel_arglist(self, kernel_fn, lowerer): AssertionError: If the LLVM IR Value for an argument defined in Numba IR is not found. """ - num_flattened_args = 0 + self.num_flattened_args = 0 # Compute number of args to be passed to the kernel. Note that the # actual number of kernel arguments is greater than the count of @@ -117,63 +118,30 @@ def _build_kernel_arglist(self, kernel_fn, lowerer): for arg_type in kernel_fn.kernel_arg_types: if isinstance(arg_type, DpnpNdArray): datamodel = dpex_dmm.lookup(arg_type) - num_flattened_args += datamodel.flattened_field_count + self.num_flattened_args += datamodel.flattened_field_count elif arg_type == types.complex64 or arg_type == types.complex128: - num_flattened_args += 2 + self.num_flattened_args += 2 else: - num_flattened_args += 1 + self.num_flattened_args += 1 # Create LLVM values for the kernel args list and kernel arg types list self.args_list = self.kernel_builder.allocate_kernel_arg_array( - num_flattened_args + self.num_flattened_args ) self.args_ty_list = self.kernel_builder.allocate_kernel_arg_ty_array( - num_flattened_args + self.num_flattened_args + ) + callargs_ptrs = [] + for arg in kernel_fn.kernel_args: + callargs_ptrs.append(_getvar(lowerer, arg)) + + self.kernel_builder.populate_kernel_args_and_args_ty_arrays( + kernel_argtys=kernel_fn.kernel_arg_types, + callargs_ptrs=callargs_ptrs, + args_list=self.args_list, + args_ty_list=self.args_ty_list, + datamodel_mgr=dpex_dmm, ) - # Populate the args_list and the args_ty_list LLVM arrays - self.kernel_arg_num = 0 - for arg_num, arg in enumerate(kernel_fn.kernel_args): - argtype = kernel_fn.kernel_arg_types[arg_num] - llvm_val = _getvar(lowerer, arg) - if isinstance(argtype, DpnpNdArray): - datamodel = dpex_dmm.lookup(argtype) - self.kernel_builder.build_array_arg( - array_val=llvm_val, - array_data_model=datamodel, - array_rank=argtype.ndim, - arg_list=self.args_list, - args_ty_list=self.args_ty_list, - arg_num=self.kernel_arg_num, - ) - self.kernel_arg_num += datamodel.flattened_field_count - else: - if argtype == types.complex64: - self.kernel_builder.build_complex_arg( - llvm_val, - types.float32, - self.args_list, - self.args_ty_list, - self.kernel_arg_num, - ) - self.kernel_arg_num += 2 - elif argtype == types.complex128: - self.kernel_builder.build_complex_arg( - llvm_val, - types.float64, - self.args_list, - self.args_ty_list, - self.kernel_arg_num, - ) - self.kernel_arg_num += 2 - else: - self.kernel_builder.build_arg( - llvm_val, - argtype, - self.args_list, - self.args_ty_list, - self.kernel_arg_num, - ) - self.kernel_arg_num += 1 def _submit_parfor_kernel( self, @@ -213,7 +181,7 @@ def _submit_parfor_kernel( # Submit a synchronous kernel self.kernel_builder.submit_sync_kernel( self.curr_queue, - self.kernel_arg_num, + self.num_flattened_args, self.args_list, self.args_ty_list, global_range, @@ -255,7 +223,7 @@ def _submit_reduction_main_parfor_kernel( # Submit a synchronous kernel self.kernel_builder.submit_sync_kernel( self.curr_queue, - self.kernel_arg_num, + self.num_flattened_args, self.args_list, self.args_ty_list, global_range, @@ -290,7 +258,7 @@ def _submit_reduction_remainder_parfor_kernel( # Submit a synchronous kernel self.kernel_builder.submit_sync_kernel( self.curr_queue, - self.kernel_arg_num, + self.num_flattened_args, self.args_list, self.args_ty_list, global_range, diff --git a/numba_dpex/core/parfors/reduction_helper.py b/numba_dpex/core/parfors/reduction_helper.py index cb8471ee19..890e2c0a4a 100644 --- a/numba_dpex/core/parfors/reduction_helper.py +++ b/numba_dpex/core/parfors/reduction_helper.py @@ -393,13 +393,17 @@ def lowerer(self): def work_group_size(self): return self._work_group_size - def copy_final_sum_to_host(self, psrfor_kernel): + def copy_final_sum_to_host(self, parfor_kernel): lowerer = self.lowerer - ir_builder = KernelLaunchIRBuilder(lowerer, psrfor_kernel.kernel) + ir_builder = KernelLaunchIRBuilder( + lowerer.context, + lowerer.builder, + parfor_kernel.kernel.addressof_ref(), + ) # Create a local variable storing a pointer to a DPCTLSyclQueueRef # pointer. - curr_queue = ir_builder.get_queue(exec_queue=psrfor_kernel.queue) + curr_queue = ir_builder.get_queue(exec_queue=parfor_kernel.queue) builder = lowerer.builder context = lowerer.context diff --git a/numba_dpex/core/utils/kernel_launcher.py b/numba_dpex/core/utils/kernel_launcher.py index 5865160c6f..df3fd4f56c 100644 --- a/numba_dpex/core/utils/kernel_launcher.py +++ b/numba_dpex/core/utils/kernel_launcher.py @@ -6,6 +6,7 @@ from numba_dpex import utils from numba_dpex.core.runtime.context import DpexRTContext +from numba_dpex.core.types import DpnpNdArray from numba_dpex.dpctl_iface import DpctlCAPIFnBuilder from numba_dpex.dpctl_iface._helpers import numba_type_to_dpctl_typenum @@ -19,19 +20,17 @@ class KernelLaunchIRBuilder: for submitting kernels. The LLVM Values that """ - def __init__(self, lowerer, kernel): + def __init__(self, context, builder, kernel_addr): """Create a KernelLauncher for the specified kernel. Args: - lowerer: The Numba Lowerer that will be used to generate the code. - kernel: The SYCL kernel for which we are generating the code. - num_inputs: The number of arguments to the kernels. + context: A Numba target context that will be used to generate the code. + builder: An llvmlite IRBuilder instance used to generate LLVM IR. + kernel_addr: The address of a SYCL kernel. """ - self.lowerer = lowerer - self.context = self.lowerer.context - self.builder = self.lowerer.builder - self.kernel = kernel - self.kernel_addr = self.kernel.addressof_ref() + self.context = context + self.builder = builder + self.kernel_addr = kernel_addr self.rtctx = DpexRTContext(self.context) def _build_nullptr(self): @@ -402,3 +401,54 @@ def submit_sync_kernel( lr = self._create_sycl_range(local_range) args = args1 + [lr] + args2 self.rtctx.submit_ndrange(self.builder, *args) + + def populate_kernel_args_and_args_ty_arrays( + self, + kernel_argtys, + callargs_ptrs, + args_list, + args_ty_list, + datamodel_mgr, + ): + kernel_arg_num = 0 + for arg_num, argtype in enumerate(kernel_argtys): + llvm_val = callargs_ptrs[arg_num] + if isinstance(argtype, DpnpNdArray): + datamodel = datamodel_mgr.lookup(argtype) + self.build_array_arg( + array_val=llvm_val, + array_data_model=datamodel, + array_rank=argtype.ndim, + arg_list=args_list, + args_ty_list=args_ty_list, + arg_num=kernel_arg_num, + ) + kernel_arg_num += datamodel.flattened_field_count + else: + if argtype == types.complex64: + self.build_complex_arg( + llvm_val, + types.float32, + args_list, + args_ty_list, + kernel_arg_num, + ) + kernel_arg_num += 2 + elif argtype == types.complex128: + self.build_complex_arg( + llvm_val, + types.float64, + args_list, + args_ty_list, + kernel_arg_num, + ) + kernel_arg_num += 2 + else: + self.build_arg( + llvm_val, + argtype, + args_list, + args_ty_list, + kernel_arg_num, + ) + kernel_arg_num += 1