Skip to content

Commit

Permalink
add a fast path for noops, otherwise ak.transform does a non-negligib…
Browse files Browse the repository at this point in the history
…le overhead...
  • Loading branch information
pfackeldey committed Jan 13, 2025
1 parent c0177d9 commit 6377820
Show file tree
Hide file tree
Showing 11 changed files with 78 additions and 4 deletions.
18 changes: 18 additions & 0 deletions src/vector/_compute/lorentz/t.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ def xy_z_t(lib, x, y, z, t):
return t


xy_z_t.__awkward_transform_allowed__ = False # type:ignore[attr-defined]


def xy_z_tau(lib, x, y, z, tau):
return lib.sqrt(t2.xy_z_tau(lib, x, y, z, tau))

Expand All @@ -45,6 +48,9 @@ def xy_theta_t(lib, x, y, theta, t):
return t


xy_theta_t.__awkward_transform_allowed__ = False # type:ignore[attr-defined]


def xy_theta_tau(lib, x, y, theta, tau):
return lib.sqrt(t2.xy_theta_tau(lib, x, y, theta, tau))

Expand All @@ -53,6 +59,9 @@ def xy_eta_t(lib, x, y, eta, t):
return t


xy_eta_t.__awkward_transform_allowed__ = False # type:ignore[attr-defined]


def xy_eta_tau(lib, x, y, eta, tau):
return lib.sqrt(t2.xy_eta_tau(lib, x, y, eta, tau))

Expand All @@ -61,6 +70,9 @@ def rhophi_z_t(lib, rho, phi, z, t):
return t


rhophi_z_t.__awkward_transform_allowed__ = False # type:ignore[attr-defined]


def rhophi_z_tau(lib, rho, phi, z, tau):
return lib.sqrt(t2.rhophi_z_tau(lib, rho, phi, z, tau))

Expand All @@ -69,6 +81,9 @@ def rhophi_theta_t(lib, rho, phi, theta, t):
return t


rhophi_theta_t.__awkward_transform_allowed__ = False # type:ignore[attr-defined]


def rhophi_theta_tau(lib, rho, phi, theta, tau):
return lib.sqrt(t2.rhophi_theta_tau(lib, rho, phi, theta, tau))

Expand All @@ -77,6 +92,9 @@ def rhophi_eta_t(lib, rho, phi, eta, t):
return t


rhophi_eta_t.__awkward_transform_allowed__ = False # type:ignore[attr-defined]


def rhophi_eta_tau(lib, rho, phi, eta, tau):
return lib.sqrt(t2.rhophi_eta_tau(lib, rho, phi, eta, tau))

Expand Down
18 changes: 18 additions & 0 deletions src/vector/_compute/lorentz/tau.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ def xy_z_tau(lib, x, y, z, tau):
return tau


xy_z_tau.__awkward_transform_allowed__ = False # type:ignore[attr-defined]


def xy_theta_t(lib, x, y, theta, t):
squared = tau2.xy_theta_t(lib, x, y, theta, t)
return lib.copysign(lib.sqrt(lib.absolute(squared)), squared)
Expand All @@ -51,6 +54,9 @@ def xy_theta_tau(lib, x, y, theta, tau):
return tau


xy_theta_tau.__awkward_transform_allowed__ = False # type:ignore[attr-defined]


def xy_eta_t(lib, x, y, eta, t):
squared = tau2.xy_eta_t(lib, x, y, eta, t)
return lib.copysign(lib.sqrt(lib.absolute(squared)), squared)
Expand All @@ -60,6 +66,9 @@ def xy_eta_tau(lib, x, y, eta, tau):
return tau


xy_eta_tau.__awkward_transform_allowed__ = False # type:ignore[attr-defined]


def rhophi_z_t(lib, rho, phi, z, t):
squared = tau2.rhophi_z_t(lib, rho, phi, z, t)
return lib.copysign(lib.sqrt(lib.absolute(squared)), squared)
Expand All @@ -69,6 +78,9 @@ def rhophi_z_tau(lib, rho, phi, z, tau):
return tau


rhophi_z_tau.__awkward_transform_allowed__ = False # type:ignore[attr-defined]


def rhophi_theta_t(lib, rho, phi, theta, t):
squared = tau2.rhophi_theta_t(lib, rho, phi, theta, t)
return lib.copysign(lib.sqrt(lib.absolute(squared)), squared)
Expand All @@ -78,6 +90,9 @@ def rhophi_theta_tau(lib, rho, phi, theta, tau):
return tau


rhophi_theta_tau.__awkward_transform_allowed__ = False # type:ignore[attr-defined]


def rhophi_eta_t(lib, rho, phi, eta, t):
squared = tau2.rhophi_eta_t(lib, rho, phi, eta, t)
return lib.copysign(lib.sqrt(lib.absolute(squared)), squared)
Expand All @@ -87,6 +102,9 @@ def rhophi_eta_tau(lib, rho, phi, eta, tau):
return tau


rhophi_eta_tau.__awkward_transform_allowed__ = False # type:ignore[attr-defined]


dispatch_map = {
(AzimuthalXY, LongitudinalZ, TemporalT): (xy_z_t, float),
(AzimuthalXY, LongitudinalZ, TemporalTau): (xy_z_tau, float),
Expand Down
3 changes: 3 additions & 0 deletions src/vector/_compute/planar/phi.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ def rhophi(lib, rho, phi):
return phi


rhophi.__awkward_transform_allowed__ = False # type:ignore[attr-defined]


dispatch_map = {
(AzimuthalXY,): (xy, float),
(AzimuthalRhoPhi,): (rhophi, float),
Expand Down
3 changes: 3 additions & 0 deletions src/vector/_compute/planar/rho.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ def rhophi(lib, rho, phi):
return rho


rhophi.__awkward_transform_allowed__ = False # type:ignore[attr-defined]


dispatch_map = {
(AzimuthalXY,): (xy, float),
(AzimuthalRhoPhi,): (rhophi, float),
Expand Down
3 changes: 3 additions & 0 deletions src/vector/_compute/planar/unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ def rhophi(lib, rho, phi):
return (1, phi)


rhophi.__awkward_transform_allowed__ = False # type:ignore[attr-defined]


dispatch_map = {
(AzimuthalXY,): (xy, AzimuthalXY),
(AzimuthalRhoPhi,): (rhophi, AzimuthalRhoPhi),
Expand Down
3 changes: 3 additions & 0 deletions src/vector/_compute/planar/x.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ def xy(lib, x, y):
return x


xy.__awkward_transform_allowed__ = False # type:ignore[attr-defined]


def rhophi(lib, rho, phi):
return rho * lib.cos(phi)

Expand Down
3 changes: 3 additions & 0 deletions src/vector/_compute/planar/y.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ def xy(lib, x, y):
return y


xy.__awkward_transform_allowed__ = False # type:ignore[attr-defined]


def rhophi(lib, rho, phi):
return rho * lib.sin(phi)

Expand Down
6 changes: 6 additions & 0 deletions src/vector/_compute/spatial/eta.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ def xy_eta(lib, x, y, eta):
return eta


xy_eta.__awkward_transform_allowed__ = False # type:ignore[attr-defined]


def rhophi_z(lib, rho, phi, z):
return lib.nan_to_num(
lib.arcsinh(z / rho),
Expand All @@ -66,6 +69,9 @@ def rhophi_eta(lib, rho, phi, eta):
return eta


rhophi_eta.__awkward_transform_allowed__ = False # type:ignore[attr-defined]


dispatch_map = {
(AzimuthalXY, LongitudinalZ): (xy_z, float),
(AzimuthalXY, LongitudinalTheta): (xy_theta, float),
Expand Down
6 changes: 6 additions & 0 deletions src/vector/_compute/spatial/theta.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ def xy_theta(lib, x, y, theta):
return theta


xy_theta.__awkward_transform_allowed__ = False # type:ignore[attr-defined]


def xy_eta(lib, x, y, eta):
return 2.0 * lib.arctan(lib.exp(-eta))

Expand All @@ -50,6 +53,9 @@ def rhophi_theta(lib, rho, phi, theta):
return theta


rhophi_theta.__awkward_transform_allowed__ = False # type:ignore[attr-defined]


def rhophi_eta(lib, rho, phi, eta):
return 2.0 * lib.arctan(lib.exp(-eta))

Expand Down
6 changes: 6 additions & 0 deletions src/vector/_compute/spatial/z.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ def xy_z(lib, x, y, z):
return z


xy_z.__awkward_transform_allowed__ = False # type:ignore[attr-defined]


def xy_theta(lib, x, y, theta):
return lib.nan_to_num(
rho.xy(lib, x, y) / lib.tan(theta), nan=0.0, posinf=inf, neginf=-inf
Expand All @@ -49,6 +52,9 @@ def rhophi_z(lib, rho, phi, z):
return z


rhophi_z.__awkward_transform_allowed__ = False # type:ignore[attr-defined]


def rhophi_theta(lib, rho, phi, theta):
return lib.nan_to_num(rho / lib.tan(theta), nan=0.0, posinf=inf, neginf=-inf)

Expand Down
13 changes: 9 additions & 4 deletions src/vector/backends/awkward.py
Original file line number Diff line number Diff line change
Expand Up @@ -994,10 +994,14 @@ def __init__(self, func: typing.Callable) -> None: # type: ignore[type-arg]
self.func = func

def __call__(self, *args: typing.Any, **kwargs: typing.Any) -> typing.Callable: # type: ignore[type-arg]
# '__awkward_transform_allowed__' is a flag that can be set to False to disable this wrapping, e.g. for no-ops
if not getattr(self.func, "__awkward_transform_allowed__", True):
return self.func(*args, **kwargs)

# only pos args
assert not kwargs

# prepare the function and its args; we;re currently assuming non-nested input args
# prepare the function and its args; we're currently assuming non-nested input args
args2bind, awkward_arrays = [], []
n_orig_akarrays = 0
for arg in args:
Expand All @@ -1011,7 +1015,7 @@ def __call__(self, *args: typing.Any, **kwargs: typing.Any) -> typing.Callable:
else:
args2bind.append(arg)

# this means we're working with awkward-arrays
# this means we're working with awkward-arrays and we should group operations with ak.transform
if n_orig_akarrays > 0:

def transformer(
Expand All @@ -1034,7 +1038,7 @@ def transformer(
f"but this routine received `{rule}`"
) from None

# apply the function to the numpy arrays, first we need to 'partial it out' all non-awkward array arguments
# apply the function to the numpy arrays, first we need to 'partial out' all non-awkward array arguments
out_numpys = bind(self.func, *args2bind)(
*(map(operator.attrgetter("data"), layouts))
)
Expand All @@ -1043,7 +1047,8 @@ def transformer(
out_numpys = (out_numpys,)
# propagate parameters
out_params = parameters_factory(
(layout.parameters for layout in layouts), len(out_numpys)
tuple(map(operator.attrgetter("parameters"), layouts)),
len(out_numpys),
)
# wrap the numpy arrays in awkward arrays
out_arrays = tuple(
Expand Down

0 comments on commit 6377820

Please sign in to comment.