Skip to content

Commit

Permalink
Refactoring the kernel_launcher.KernelLaunchIRBuilder API.
Browse files Browse the repository at this point in the history
    - Changes to the class constructor to make it easier to use from
      places other than the parfor_lowerer. Removed the need to pass
      a lowerer object and instead pass a context and builder.
    - Adds a new helper function populate_kernel_args_and_args_ty_arrays
      that populates arrays storing kernel args and kernel arg types.
  • Loading branch information
Diptorup Deb committed Oct 12, 2023
1 parent 3594cac commit d254ee4
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 68 deletions.
80 changes: 24 additions & 56 deletions numba_dpex/core/parfors/parfor_lowerer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
)

from numba_dpex import config
from numba_dpex.core.datamodel.models import dpex_data_model_manager as dpex_dmm
from numba_dpex.core.parfors.reduction_helper import (
ReductionHelper,
ReductionKernelVariables,
Expand All @@ -26,8 +27,6 @@
create_reduction_remainder_kernel_for_parfor,
)

from numba_dpex.core.datamodel.models import dpex_data_model_manager as dpex_dmm

# A global list of kernels to keep the objects alive indefinitely.
keep_alive_kernels = []

Expand Down Expand Up @@ -89,7 +88,9 @@ def _get_exec_queue(self, kernel_fn, lowerer):
"""Creates a stack variable storing the sycl queue pointer used to
launch the kernel function.
"""
self.kernel_builder = KernelLaunchIRBuilder(lowerer, kernel_fn.kernel)
self.kernel_builder = KernelLaunchIRBuilder(
lowerer.context, lowerer.builder, kernel_fn.kernel.addressof_ref()
)

# Create a local variable storing a pointer to a DPCTLSyclQueueRef
# pointer.
Expand All @@ -109,71 +110,38 @@ def _build_kernel_arglist(self, kernel_fn, lowerer):
AssertionError: If the LLVM IR Value for an argument defined in
Numba IR is not found.
"""
num_flattened_args = 0
self.num_flattened_args = 0

# Compute number of args to be passed to the kernel. Note that the
# actual number of kernel arguments is greater than the count of
# kernel_fn.kernel_args as arrays get flattened.
for arg_type in kernel_fn.kernel_arg_types:
if isinstance(arg_type, DpnpNdArray):
datamodel = dpex_dmm.lookup(arg_type)
num_flattened_args += datamodel.flattened_field_count
self.num_flattened_args += datamodel.flattened_field_count
elif arg_type == types.complex64 or arg_type == types.complex128:
num_flattened_args += 2
self.num_flattened_args += 2
else:
num_flattened_args += 1
self.num_flattened_args += 1

# Create LLVM values for the kernel args list and kernel arg types list
self.args_list = self.kernel_builder.allocate_kernel_arg_array(
num_flattened_args
self.num_flattened_args
)
self.args_ty_list = self.kernel_builder.allocate_kernel_arg_ty_array(
num_flattened_args
self.num_flattened_args
)
callargs_ptrs = []
for arg in kernel_fn.kernel_args:
callargs_ptrs.append(_getvar(lowerer, arg))

self.kernel_builder.populate_kernel_args_and_args_ty_arrays(
kernel_argtys=kernel_fn.kernel_arg_types,
callargs_ptrs=callargs_ptrs,
args_list=self.args_list,
args_ty_list=self.args_ty_list,
datamodel_mgr=dpex_dmm,
)
# Populate the args_list and the args_ty_list LLVM arrays
self.kernel_arg_num = 0
for arg_num, arg in enumerate(kernel_fn.kernel_args):
argtype = kernel_fn.kernel_arg_types[arg_num]
llvm_val = _getvar(lowerer, arg)
if isinstance(argtype, DpnpNdArray):
datamodel = dpex_dmm.lookup(argtype)
self.kernel_builder.build_array_arg(
array_val=llvm_val,
array_data_model=datamodel,
array_rank=argtype.ndim,
arg_list=self.args_list,
args_ty_list=self.args_ty_list,
arg_num=self.kernel_arg_num,
)
self.kernel_arg_num += datamodel.flattened_field_count
else:
if argtype == types.complex64:
self.kernel_builder.build_complex_arg(
llvm_val,
types.float32,
self.args_list,
self.args_ty_list,
self.kernel_arg_num,
)
self.kernel_arg_num += 2
elif argtype == types.complex128:
self.kernel_builder.build_complex_arg(
llvm_val,
types.float64,
self.args_list,
self.args_ty_list,
self.kernel_arg_num,
)
self.kernel_arg_num += 2
else:
self.kernel_builder.build_arg(
llvm_val,
argtype,
self.args_list,
self.args_ty_list,
self.kernel_arg_num,
)
self.kernel_arg_num += 1

def _submit_parfor_kernel(
self,
Expand Down Expand Up @@ -213,7 +181,7 @@ def _submit_parfor_kernel(
# Submit a synchronous kernel
self.kernel_builder.submit_sync_kernel(
self.curr_queue,
self.kernel_arg_num,
self.num_flattened_args,
self.args_list,
self.args_ty_list,
global_range,
Expand Down Expand Up @@ -255,7 +223,7 @@ def _submit_reduction_main_parfor_kernel(
# Submit a synchronous kernel
self.kernel_builder.submit_sync_kernel(
self.curr_queue,
self.kernel_arg_num,
self.num_flattened_args,
self.args_list,
self.args_ty_list,
global_range,
Expand Down Expand Up @@ -290,7 +258,7 @@ def _submit_reduction_remainder_parfor_kernel(
# Submit a synchronous kernel
self.kernel_builder.submit_sync_kernel(
self.curr_queue,
self.kernel_arg_num,
self.num_flattened_args,
self.args_list,
self.args_ty_list,
global_range,
Expand Down
10 changes: 7 additions & 3 deletions numba_dpex/core/parfors/reduction_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,13 +393,17 @@ def lowerer(self):
def work_group_size(self):
return self._work_group_size

def copy_final_sum_to_host(self, psrfor_kernel):
def copy_final_sum_to_host(self, parfor_kernel):
lowerer = self.lowerer
ir_builder = KernelLaunchIRBuilder(lowerer, psrfor_kernel.kernel)
ir_builder = KernelLaunchIRBuilder(
lowerer.context,
lowerer.builder,
parfor_kernel.kernel.addressof_ref(),
)

# Create a local variable storing a pointer to a DPCTLSyclQueueRef
# pointer.
curr_queue = ir_builder.get_queue(exec_queue=psrfor_kernel.queue)
curr_queue = ir_builder.get_queue(exec_queue=parfor_kernel.queue)

builder = lowerer.builder
context = lowerer.context
Expand Down
68 changes: 59 additions & 9 deletions numba_dpex/core/utils/kernel_launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from numba_dpex import utils
from numba_dpex.core.runtime.context import DpexRTContext
from numba_dpex.core.types import DpnpNdArray
from numba_dpex.dpctl_iface import DpctlCAPIFnBuilder
from numba_dpex.dpctl_iface._helpers import numba_type_to_dpctl_typenum

Expand All @@ -19,19 +20,17 @@ class KernelLaunchIRBuilder:
for submitting kernels. The LLVM Values that
"""

def __init__(self, lowerer, kernel):
def __init__(self, context, builder, kernel_addr):
"""Create a KernelLauncher for the specified kernel.
Args:
lowerer: The Numba Lowerer that will be used to generate the code.
kernel: The SYCL kernel for which we are generating the code.
num_inputs: The number of arguments to the kernels.
context: A Numba target context that will be used to generate the code.
builder: An llvmlite IRBuilder instance used to generate LLVM IR.
kernel_addr: The address of a SYCL kernel.
"""
self.lowerer = lowerer
self.context = self.lowerer.context
self.builder = self.lowerer.builder
self.kernel = kernel
self.kernel_addr = self.kernel.addressof_ref()
self.context = context
self.builder = builder
self.kernel_addr = kernel_addr
self.rtctx = DpexRTContext(self.context)

def _build_nullptr(self):
Expand Down Expand Up @@ -402,3 +401,54 @@ def submit_sync_kernel(
lr = self._create_sycl_range(local_range)
args = args1 + [lr] + args2
self.rtctx.submit_ndrange(self.builder, *args)

def populate_kernel_args_and_args_ty_arrays(
self,
kernel_argtys,
callargs_ptrs,
args_list,
args_ty_list,
datamodel_mgr,
):
kernel_arg_num = 0
for arg_num, argtype in enumerate(kernel_argtys):
llvm_val = callargs_ptrs[arg_num]
if isinstance(argtype, DpnpNdArray):
datamodel = datamodel_mgr.lookup(argtype)
self.build_array_arg(
array_val=llvm_val,
array_data_model=datamodel,
array_rank=argtype.ndim,
arg_list=args_list,
args_ty_list=args_ty_list,
arg_num=kernel_arg_num,
)
kernel_arg_num += datamodel.flattened_field_count
else:
if argtype == types.complex64:
self.build_complex_arg(
llvm_val,
types.float32,
args_list,
args_ty_list,
kernel_arg_num,
)
kernel_arg_num += 2
elif argtype == types.complex128:
self.build_complex_arg(
llvm_val,
types.float64,
args_list,
args_ty_list,
kernel_arg_num,
)
kernel_arg_num += 2
else:
self.build_arg(
llvm_val,
argtype,
args_list,
args_ty_list,
kernel_arg_num,
)
kernel_arg_num += 1

0 comments on commit d254ee4

Please sign in to comment.