Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resolve issue of qjit(static_argnums=...) fails when the marked static argument has a default value #1295

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
5 changes: 4 additions & 1 deletion frontend/catalyst/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,7 +642,10 @@ def capture(self, args, **kwargs):
PyTreeDef: PyTree metadata of the function output
Tuple[Any]: the dynamic argument signature
"""
verify_static_argnums(args, self.compile_options.static_argnums)

# use inspect to get parameters defined in the function declaration
sig_args = inspect.signature(self.original_function).parameters
verify_static_argnums(args, sig_args, self.compile_options.static_argnums)
AniketDalvi marked this conversation as resolved.
Show resolved Hide resolved
static_argnums = self.compile_options.static_argnums
abstracted_axes = self.compile_options.abstracted_axes

Expand Down
9 changes: 6 additions & 3 deletions frontend/catalyst/tracing/type_signatures.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,21 +104,24 @@ def verify_static_argnums_type(static_argnums):
return None


def verify_static_argnums(args, static_argnums):
def verify_static_argnums(args, sig_args, static_argnums):
"""Verify that static_argnums have correct type and range.

Args:
args (Iterable): arguments to a compiled function
sig_args (Iterable): arguments from the signature of a compiled function
static_argnums (Iterable[int]): indices to verify

Returns:
None
"""
verify_static_argnums_type(static_argnums)

# `static_argnums` should be compared to the maximum args that can be passed to a function
arg_limit = max(len(args), len(sig_args))
for argnum in static_argnums:
if argnum < 0 or argnum >= len(args):
msg = f"argnum {argnum} is beyond the valid range of [0, {len(args)})."
if argnum < 0 or argnum >= arg_limit:
msg = f"argnum {argnum} is beyond the valid range of [0, {arg_limit})."
raise CompileError(msg)
return None

Expand Down
13 changes: 13 additions & 0 deletions frontend/test/pytest/test_static_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,19 @@ def f(x: MyClass, y: int, z: MyClass):
assert f(MyClass(5), 2, MyClass(5)) == 12
assert function == f.compiled_function

def test_default_static_arguments(self):
"""Test QJIT with static arguments that have a default value."""

@qjit(static_argnums=[1])
def f(y, x=9):
if x < 10:
return x + y
return 42000

assert f(20) == 29
assert f(20, 3) == 23
assert f(20, 300000) == 42000

erick-xanadu marked this conversation as resolved.
Show resolved Hide resolved
def test_mutable_static_arguments(self):
"""Test QJIT with mutable static arguments."""

Expand Down