Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow device to configure conversion to numpy and use of pure_callback #6788

Open
wants to merge 8 commits into
base: master
Choose a base branch
from

Conversation

albi3ro
Copy link
Contributor

@albi3ro albi3ro commented Jan 8, 2025

Context:

While we have logic for sampling with jax, it does not really integrate very well into the workflow. While you can technically set diff_method=None right now and jit the execution end-to-end, trying to jit diff_method=None will cause incomprehensible error messages on non-DQ devices.

We want to forbid differentiation diff_method=None, but keep a way to jit a finite shot execution.

Description of the Change:

In order to allow jitting finite shot executions, we need a way for the device to be able to configure whether or not the data is converted to numpy. To do so, we simply add another property to the ExecutionConfig, convert_to_numpy. If False, then we will not use a pure_callback to convert the parameters to numpy. If True, we use a pure_callback and convert the parameters to numpy.

Benefits:

Speed ups due to being able to jit the entire execution.

image

Possible Drawbacks:

ExecutionConfig gets an addtional property, making it more complicated.

Related GitHub Issues:

Fixes #6054 Fixes #3259 Blocks #6770

Copy link
Contributor

github-actions bot commented Jan 8, 2025

Hello. You may have forgotten to update the changelog!
Please edit doc/releases/changelog-dev.md with:

  • A one-to-two sentence description of the change. You may include a small working example for new features.
  • A link back to this PR.
  • Your name (or GitHub username) in the contributors section.

Copy link

codecov bot commented Jan 9, 2025

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 99.60%. Comparing base (5efeffb) to head (2110c9b).
Report is 1 commits behind head on master.

Additional details and impacted files
@@           Coverage Diff           @@
##           master    #6788   +/-   ##
=======================================
  Coverage   99.60%   99.60%           
=======================================
  Files         476      476           
  Lines       45232    45242   +10     
=======================================
+ Hits        45055    45065   +10     
  Misses        177      177           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Comment on lines +710 to +711
out = jitted_qnode1(0.123)
print(out)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
out = jitted_qnode1(0.123)
print(out)
jitted_qnode1(0.123)

Comment on lines +9 to +13
* Devices can now configure whether or not the data is converted to numpy and `jax.pure_callback`
is used by the new `ExecutionConfig.convert_to_numpy` property. Finite shot executions
on `default.qubit` can now be jitted end-to-end, even with parameter shift.
[(#6788)](https://github.com/PennyLaneAI/pennylane/pull/6788)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
* Devices can now configure whether or not the data is converted to numpy and `jax.pure_callback`
is used by the new `ExecutionConfig.convert_to_numpy` property. Finite shot executions
on `default.qubit` can now be jitted end-to-end, even with parameter shift.
[(#6788)](https://github.com/PennyLaneAI/pennylane/pull/6788)
* Devices can now configure whether or not the data is converted to numpy enabling `jax.pure_callback` to be used by the new `ExecutionConfig.convert_to_numpy` property. Finite shot executions on `default.qubit` can now be jitted end-to-end leading to performance improvements, even with parameter shift.
[(#6788)](https://github.com/PennyLaneAI/pennylane/pull/6788)

Copy link
Contributor

@andrijapau andrijapau left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just some comments on my initial pass through.

@@ -581,6 +591,12 @@ def _setup_execution_config(self, execution_config: ExecutionConfig) -> Executio
"""
updated_values = {}

updated_values["convert_to_numpy"] = (
execution_config.interface.value not in {"jax", "jax-jit"}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to continue using the Interface enum instead?

Suggested change
execution_config.interface.value not in {"jax", "jax-jit"}
execution_config.interface not in {qml.math.Interface.JAX, qml.math.Interface.JAX_JIT}

Copy link
Contributor

@astralcai astralcai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Comment on lines +155 to +168
def test_convert_to_numpy_with_adjoint(self):
"""Test that we will convert to numpy with adjoint."""
config = qml.devices.ExecutionConfig(gradient_method="adjoint", interface="jax-jit")
dev = qml.device("default.qubit")
processed = dev.setup_execution_config(config)
assert processed.convert_to_numpy

@pytest.mark.parametrize("interface", ("autograd", "torch", "tf"))
def test_convert_to_numpy_non_jax(self, interface):
"""Test that other interfaces are still converted to numpy."""
config = qml.devices.ExecutionConfig(gradient_method="adjoint", interface=interface)
dev = qml.device("default.qubit")
processed = dev.setup_execution_config(config)
assert processed.convert_to_numpy
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you include "jax" as an interface in the testing?

Also, I'm curious if converting to numpy with adjoint has negative effects on performance.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do think we could make adjoint jittable and get some nice speed boosts, but i think that would need to be a follow on task.

@@ -417,7 +417,7 @@ def circuit(state):


@pytest.mark.jax
@pytest.mark.parametrize("shots, atol", [(None, 0.005), (1000000, 0.05)])
@pytest.mark.parametrize("shots, atol", [(None, 0.005), (1000000, 0.1)])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a lot of shots...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we're using a seed, maybe it's worth reducing the number of shots.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given the tolerance is still almost to high... the problem is that this test is now doing finite shot finite differences with float32. Im not sure we can get any accuracy out of that.


result = execute([tape1, tape2], dev, diff_method=param_shift, max_diff=1)
return result[0] + result[1][0]

res = cost_fn(params)
x, y = params
expected = 0.5 * (3 + jnp.cos(x) ** 2 * jnp.cos(2 * y))
assert np.allclose(res, expected, atol=tol, rtol=0)
assert np.allclose(res, expected, atol=2e-2, rtol=0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the tol isn't being used, it can be removed as an argument to this test.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
4 participants