From eb56f47ba51f8a4ba41295bfb6abf79741428c28 Mon Sep 17 00:00:00 2001 From: Aniket Dalvi Date: Thu, 7 Nov 2024 16:33:08 -0500 Subject: [PATCH 1/5] potential fix using max of function arguments and function signature arguments to compare validity of `static_argnums` --- frontend/catalyst/jit.py | 3 ++- frontend/catalyst/tracing/type_signatures.py | 9 +++++---- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/frontend/catalyst/jit.py b/frontend/catalyst/jit.py index 86722edd56..82d402d733 100644 --- a/frontend/catalyst/jit.py +++ b/frontend/catalyst/jit.py @@ -642,7 +642,8 @@ def capture(self, args, **kwargs): PyTreeDef: PyTree metadata of the function output Tuple[Any]: the dynamic argument signature """ - verify_static_argnums(args, self.compile_options.static_argnums) + sig_args = inspect.signature(self.original_function).parameters + verify_static_argnums(args, sig_args, self.compile_options.static_argnums) static_argnums = self.compile_options.static_argnums abstracted_axes = self.compile_options.abstracted_axes diff --git a/frontend/catalyst/tracing/type_signatures.py b/frontend/catalyst/tracing/type_signatures.py index 84a8098b8e..0dec5b2ee1 100644 --- a/frontend/catalyst/tracing/type_signatures.py +++ b/frontend/catalyst/tracing/type_signatures.py @@ -104,21 +104,22 @@ def verify_static_argnums_type(static_argnums): return None -def verify_static_argnums(args, static_argnums): +def verify_static_argnums(args, sig_args, static_argnums): """Verify that static_argnums have correct type and range. Args: args (Iterable): arguments to a compiled function + sig_args (Iterable): arguments from the signature of a compiled function static_argnums (Iterable[int]): indices to verify Returns: None """ verify_static_argnums_type(static_argnums) - + arg_limit = max(len(args), len(sig_args)) for argnum in static_argnums: - if argnum < 0 or argnum >= len(args): - msg = f"argnum {argnum} is beyond the valid range of [0, {len(args)})." + if argnum < 0 or argnum >= arg_limit: + msg = f"argnum {argnum} is beyond the valid range of [0, {arg_limit})." raise CompileError(msg) return None From 73824ec2541fd77afd49f922113142f678d3a975 Mon Sep 17 00:00:00 2001 From: Aniket Dalvi Date: Thu, 7 Nov 2024 17:18:27 -0500 Subject: [PATCH 2/5] added test cases and some comments --- frontend/catalyst/tracing/type_signatures.py | 2 +- frontend/test/pytest/test_static_arguments.py | 12 ++++++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/frontend/catalyst/tracing/type_signatures.py b/frontend/catalyst/tracing/type_signatures.py index 0dec5b2ee1..911061056b 100644 --- a/frontend/catalyst/tracing/type_signatures.py +++ b/frontend/catalyst/tracing/type_signatures.py @@ -116,7 +116,7 @@ def verify_static_argnums(args, sig_args, static_argnums): None """ verify_static_argnums_type(static_argnums) - arg_limit = max(len(args), len(sig_args)) + arg_limit = max(len(args), len(sig_args)) # accommodates variable args for argnum in static_argnums: if argnum < 0 or argnum >= arg_limit: msg = f"argnum {argnum} is beyond the valid range of [0, {arg_limit})." diff --git a/frontend/test/pytest/test_static_arguments.py b/frontend/test/pytest/test_static_arguments.py index 4b0b05517e..81693477a1 100644 --- a/frontend/test/pytest/test_static_arguments.py +++ b/frontend/test/pytest/test_static_arguments.py @@ -123,6 +123,18 @@ def f(x: MyClass, y: int, z: MyClass): assert f(MyClass(5), 2, MyClass(5)) == 12 assert function == f.compiled_function + def test_default_static_arguments(self): + + @qjit(static_argnums=[1]) + def f(y, x=9): + if x < 10: + return x + y + return 42000 + + assert f(20) == 29 + assert f(20, 3) == 23 + assert f(20, 300000) == 42000 + def test_mutable_static_arguments(self): """Test QJIT with mutable static arguments.""" From 7953afca02e9f0effb222e44afd48a8a3c077845 Mon Sep 17 00:00:00 2001 From: Aniket Dalvi Date: Thu, 7 Nov 2024 19:34:43 -0500 Subject: [PATCH 3/5] added comments --- frontend/catalyst/jit.py | 2 ++ frontend/catalyst/tracing/type_signatures.py | 4 +++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/frontend/catalyst/jit.py b/frontend/catalyst/jit.py index 82d402d733..a2acafe474 100644 --- a/frontend/catalyst/jit.py +++ b/frontend/catalyst/jit.py @@ -642,6 +642,8 @@ def capture(self, args, **kwargs): PyTreeDef: PyTree metadata of the function output Tuple[Any]: the dynamic argument signature """ + + # use inspect to get parameters defined in the function declaration sig_args = inspect.signature(self.original_function).parameters verify_static_argnums(args, sig_args, self.compile_options.static_argnums) static_argnums = self.compile_options.static_argnums diff --git a/frontend/catalyst/tracing/type_signatures.py b/frontend/catalyst/tracing/type_signatures.py index 911061056b..16636881a1 100644 --- a/frontend/catalyst/tracing/type_signatures.py +++ b/frontend/catalyst/tracing/type_signatures.py @@ -116,7 +116,9 @@ def verify_static_argnums(args, sig_args, static_argnums): None """ verify_static_argnums_type(static_argnums) - arg_limit = max(len(args), len(sig_args)) # accommodates variable args + + # `static_argnums` should be compared against the maximum args that can legally be passed to a function + arg_limit = max(len(args), len(sig_args)) for argnum in static_argnums: if argnum < 0 or argnum >= arg_limit: msg = f"argnum {argnum} is beyond the valid range of [0, {arg_limit})." From dc8476ad6220c34de0c960e2480616ca400cbbfb Mon Sep 17 00:00:00 2001 From: Aniket Dalvi Date: Fri, 8 Nov 2024 11:03:25 -0500 Subject: [PATCH 4/5] formatting changes --- frontend/catalyst/tracing/type_signatures.py | 2 +- frontend/test/pytest/test_static_arguments.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/frontend/catalyst/tracing/type_signatures.py b/frontend/catalyst/tracing/type_signatures.py index 16636881a1..82c5a289f7 100644 --- a/frontend/catalyst/tracing/type_signatures.py +++ b/frontend/catalyst/tracing/type_signatures.py @@ -117,7 +117,7 @@ def verify_static_argnums(args, sig_args, static_argnums): """ verify_static_argnums_type(static_argnums) - # `static_argnums` should be compared against the maximum args that can legally be passed to a function + # `static_argnums` should be compared to the maximum args that can be passed to a function arg_limit = max(len(args), len(sig_args)) for argnum in static_argnums: if argnum < 0 or argnum >= arg_limit: diff --git a/frontend/test/pytest/test_static_arguments.py b/frontend/test/pytest/test_static_arguments.py index 81693477a1..34b553f479 100644 --- a/frontend/test/pytest/test_static_arguments.py +++ b/frontend/test/pytest/test_static_arguments.py @@ -124,6 +124,7 @@ def f(x: MyClass, y: int, z: MyClass): assert function == f.compiled_function def test_default_static_arguments(self): + """Test QJIT with static arguments that have a default value.""" @qjit(static_argnums=[1]) def f(y, x=9): From 9fac10e49373729a84349f32d176e764f598e661 Mon Sep 17 00:00:00 2001 From: Aniket Dalvi Date: Fri, 8 Nov 2024 19:41:39 -0500 Subject: [PATCH 5/5] PR comment of passing the fn instead of the inspected signature parameters --- frontend/catalyst/jit.py | 4 +--- frontend/catalyst/tracing/type_signatures.py | 7 +++++-- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/frontend/catalyst/jit.py b/frontend/catalyst/jit.py index a2acafe474..dfb654f5d6 100644 --- a/frontend/catalyst/jit.py +++ b/frontend/catalyst/jit.py @@ -643,9 +643,7 @@ def capture(self, args, **kwargs): Tuple[Any]: the dynamic argument signature """ - # use inspect to get parameters defined in the function declaration - sig_args = inspect.signature(self.original_function).parameters - verify_static_argnums(args, sig_args, self.compile_options.static_argnums) + verify_static_argnums(args, self.original_function, self.compile_options.static_argnums) static_argnums = self.compile_options.static_argnums abstracted_axes = self.compile_options.abstracted_axes diff --git a/frontend/catalyst/tracing/type_signatures.py b/frontend/catalyst/tracing/type_signatures.py index 82c5a289f7..7eb4c5df4a 100644 --- a/frontend/catalyst/tracing/type_signatures.py +++ b/frontend/catalyst/tracing/type_signatures.py @@ -104,12 +104,12 @@ def verify_static_argnums_type(static_argnums): return None -def verify_static_argnums(args, sig_args, static_argnums): +def verify_static_argnums(args, fn, static_argnums): """Verify that static_argnums have correct type and range. Args: args (Iterable): arguments to a compiled function - sig_args (Iterable): arguments from the signature of a compiled function + fn (Callable): the quantum or classical function in question static_argnums (Iterable[int]): indices to verify Returns: @@ -117,6 +117,9 @@ def verify_static_argnums(args, sig_args, static_argnums): """ verify_static_argnums_type(static_argnums) + # use inspect to get parameters defined in the function declaration + sig_args = inspect.signature(fn).parameters + # `static_argnums` should be compared to the maximum args that can be passed to a function arg_limit = max(len(args), len(sig_args)) for argnum in static_argnums: