From 22a900f2707dd08fd8165babca71091863866399 Mon Sep 17 00:00:00 2001 From: David Ittah Date: Wed, 23 Oct 2024 10:04:49 -0400 Subject: [PATCH] Fix extrapolation in ZNE function (#1213) Extrapolation was done against the folding numbers instead of the scale factors. Since the folding numbers start at 0, extrapolation would always yield a result very close to the first data point. --------- Co-authored-by: Romain Moyard --- doc/releases/changelog-dev.md | 4 +++ .../api_extensions/error_mitigation.py | 28 +++++++++---------- frontend/catalyst/jax_primitives.py | 3 +- frontend/test/lit/test_mitigation.py | 2 +- 4 files changed, 20 insertions(+), 17 deletions(-) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index d817f0d661..3728c64971 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -298,6 +298,10 @@

Bug fixes

+* Fix a bug in `catalyst.mitigate_with_zne` that would lead + to incorrectly extrapolated results. + [(#1213)](https://github.com/PennyLaneAI/catalyst/pull/1213) + * Fix a bug preventing the target of `qml.adjoint` and `qml.ctrl` calls from being transformed by AutoGraph. [(#1212)](https://github.com/PennyLaneAI/catalyst/pull/1212) diff --git a/frontend/catalyst/api_extensions/error_mitigation.py b/frontend/catalyst/api_extensions/error_mitigation.py index 2517c2c82b..0665c931e1 100644 --- a/frontend/catalyst/api_extensions/error_mitigation.py +++ b/frontend/catalyst/api_extensions/error_mitigation.py @@ -143,9 +143,7 @@ def workflow(weights, s): if not _is_odd_positive(scale_factors): raise ValueError("The scale factors must be positive odd integers: {scale_factors}") - num_folds = jnp.array([jnp.floor((s - 1) / 2) for s in scale_factors], dtype=int) - - return ZNECallable(fn, num_folds, extrapolate, folding) + return ZNECallable(fn, scale_factors, extrapolate, folding) ## IMPL ## @@ -164,14 +162,14 @@ class ZNECallable(CatalystCallable): def __init__( self, fn: Callable, - num_folds: jnp.ndarray, + scale_factors: Sequence[int], extrapolate: Callable[[Sequence[float], Sequence[float]], float], folding: str, ): functools.update_wrapper(self, fn) self.fn = fn self.__name__ = f"zne.{getattr(fn, '__name__', 'unknown')}" - self.num_folds = num_folds + self.scale_factors = scale_factors self.extrapolate = extrapolate self.folding = folding @@ -209,16 +207,18 @@ def __call__(self, *args, **kwargs): callable_fn ), "expected callable set as param on the first operation in zne target" - results = zne_p.bind( - *args_data, self.num_folds, folding=folding, jaxpr=jaxpr, fn=callable_fn + fold_numbers = (jnp.asarray(self.scale_factors, dtype=int) - 1) // 2 + fold_results = zne_p.bind( + *args_data, fold_numbers, folding=folding, jaxpr=jaxpr, fn=callable_fn ) - float_num_folds = jnp.array(self.num_folds, dtype=float) - results = self.extrapolate(float_num_folds, results[0]) - # Single measurement - if results.shape == (): - return results - # Multiple measurements - return tuple(res for res in results) + + scale_factors = jnp.asarray(self.scale_factors, dtype=float) + zne_results = self.extrapolate(scale_factors, fold_results) + + # if multiple measurement processes, split array back into tuple + if len(zne_results.shape): + zne_results = tuple(zne_results) + return zne_results def polynomial_extrapolation(degree): diff --git a/frontend/catalyst/jax_primitives.py b/frontend/catalyst/jax_primitives.py index 25d8c84106..ad901df934 100644 --- a/frontend/catalyst/jax_primitives.py +++ b/frontend/catalyst/jax_primitives.py @@ -238,7 +238,6 @@ class Folding(Enum): ############## zne_p = core.Primitive("zne") -zne_p.multiple_results = True qdevice_p = core.Primitive("qdevice") qdevice_p.multiple_results = True qalloc_p = core.Primitive("qalloc") @@ -1053,7 +1052,7 @@ def _zne_abstract_eval(*args, folding, jaxpr, fn): # pylint: disable=unused-arg shape = list(args[-1].shape) if len(jaxpr.out_avals) > 1: shape.append(len(jaxpr.out_avals)) - return [core.ShapedArray(shape, jaxpr.out_avals[0].dtype)] + return core.ShapedArray(shape, jaxpr.out_avals[0].dtype) def _folding_attribute(ctx, folding): diff --git a/frontend/test/lit/test_mitigation.py b/frontend/test/lit/test_mitigation.py index f5a575c005..bbe6829bfc 100644 --- a/frontend/test/lit/test_mitigation.py +++ b/frontend/test/lit/test_mitigation.py @@ -37,7 +37,7 @@ def circuit(): # CHECK: func.func public @jit_mcm_method_with_zne() -> tensor -# CHECK: mitigation.zne @one_shot_wrapper(%c_0) folding( global) numFolds(%6 : tensor<2xi64>) : (tensor<5xi1>) -> tensor<2xf64> +# CHECK: mitigation.zne @one_shot_wrapper(%c) folding( global) numFolds(%2 : tensor<2xi64>) : (tensor<5xi1>) -> tensor<2xf64> # CHECK: func.func private @one_shot_wrapper(%arg0: tensor<5xi1>) -> tensor # CHECK: catalyst.launch_kernel @module_circuit::@circuit() : () -> tensor