Skip to content

Commit

Permalink
Add sycl event wait overload
Browse files Browse the repository at this point in the history
  • Loading branch information
ZzEeKkAa committed Oct 18, 2023
1 parent 9e6c224 commit f904a77
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 0 deletions.
1 change: 1 addition & 0 deletions numba_dpex/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def parse_sem_version(version_string: str) -> Tuple[int, int, int]:

# Re-export all type names
from numba_dpex.core.types import * # noqa E402
from numba_dpex.dpctl_iface import _intrinsic # noqa E402
from numba_dpex.dpnp_iface import dpnpimpl # noqa E402

if config.HAS_NON_HOST_DEVICE:
Expand Down
1 change: 1 addition & 0 deletions numba_dpex/core/datamodel/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,7 @@ def _init_data_model_manager() -> datamodel.DataModelManager:

# Register the DpctlSyclEvent type
register_model(DpctlSyclEvent)(SyclEventModel)

# Register the RangeType type
register_model(RangeType)(RangeModel)

Expand Down
49 changes: 49 additions & 0 deletions numba_dpex/dpctl_iface/_intrinsic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# SPDX-FileCopyrightText: 2020 - 2023 Intel Corporation
#
# SPDX-License-Identifier: Apache-2.0

from numba import types
from numba.core import cgutils
from numba.core.datamodel import default_manager
from numba.extending import intrinsic, overload_method

import numba_dpex.dpctl_iface.libsyclinterface_bindings as sycl
from numba_dpex.core import types as dpex_types


@intrinsic
def sycl_event_wait(typingctx, ty_event):
# check for accepted types
if not isinstance(ty_event, dpex_types.DpctlSyclEvent):
return

result_type = types.void
sig = result_type(ty_event)

# defines the custom code generation
def codegen(context, builder, signature, args):
event_struct_proxy = cgutils.create_struct_proxy(ty_event)(
context, builder
)
event_struct_ptr = event_struct_proxy._getpointer()

event_struct = builder.load(event_struct_ptr)
sycl_event_dm = default_manager.lookup(ty_event)
event_ref = builder.extract_value(
event_struct,
sycl_event_dm.get_field_position("event_ref"),
)

sycl.dpctl_event_wait(builder, event_ref)

return sig, codegen


@overload_method(dpex_types.DpctlSyclEvent, "wait")
def ol_dpctl_sycl_event_wait(
event,
):
"""Implementation of an overload to support dpctl.SyclEvent() inside
a dpjit function.
"""
return lambda event: sycl_event_wait(event)
16 changes: 16 additions & 0 deletions numba_dpex/tests/core/types/DpctlSyclEvent/test_overloads.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import dpctl

from numba_dpex import dpjit


@dpjit
def wait_call(a):
a.wait()
return None


def test_wait_DpctlSyclEvent():
"""Test the dpctl.SyclEvent.wait() call overload."""

e = dpctl.SyclEvent()
wait_call(e)

0 comments on commit f904a77

Please sign in to comment.