-
Notifications
You must be signed in to change notification settings - Fork 615
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow device to configure conversion to numpy and use of pure_callback
#6788
base: master
Are you sure you want to change the base?
Conversation
Hello. You may have forgotten to update the changelog!
|
…I/pennylane into no-interface-boundary
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## master #6788 +/- ##
=======================================
Coverage 99.60% 99.60%
=======================================
Files 476 476
Lines 45232 45242 +10
=======================================
+ Hits 45055 45065 +10
Misses 177 177 ☔ View full report in Codecov by Sentry. |
out = jitted_qnode1(0.123) | ||
print(out) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
out = jitted_qnode1(0.123) | |
print(out) | |
jitted_qnode1(0.123) |
* Devices can now configure whether or not the data is converted to numpy and `jax.pure_callback` | ||
is used by the new `ExecutionConfig.convert_to_numpy` property. Finite shot executions | ||
on `default.qubit` can now be jitted end-to-end, even with parameter shift. | ||
[(#6788)](https://github.com/PennyLaneAI/pennylane/pull/6788) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
* Devices can now configure whether or not the data is converted to numpy and `jax.pure_callback` | |
is used by the new `ExecutionConfig.convert_to_numpy` property. Finite shot executions | |
on `default.qubit` can now be jitted end-to-end, even with parameter shift. | |
[(#6788)](https://github.com/PennyLaneAI/pennylane/pull/6788) | |
* Devices can now configure whether or not the data is converted to numpy enabling `jax.pure_callback` to be used by the new `ExecutionConfig.convert_to_numpy` property. Finite shot executions on `default.qubit` can now be jitted end-to-end leading to performance improvements, even with parameter shift. | |
[(#6788)](https://github.com/PennyLaneAI/pennylane/pull/6788) | |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just some comments on my initial pass through.
@@ -581,6 +591,12 @@ def _setup_execution_config(self, execution_config: ExecutionConfig) -> Executio | |||
""" | |||
updated_values = {} | |||
|
|||
updated_values["convert_to_numpy"] = ( | |||
execution_config.interface.value not in {"jax", "jax-jit"} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we want to continue using the Interface enum instead?
execution_config.interface.value not in {"jax", "jax-jit"} | |
execution_config.interface not in {qml.math.Interface.JAX, qml.math.Interface.JAX_JIT} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
def test_convert_to_numpy_with_adjoint(self): | ||
"""Test that we will convert to numpy with adjoint.""" | ||
config = qml.devices.ExecutionConfig(gradient_method="adjoint", interface="jax-jit") | ||
dev = qml.device("default.qubit") | ||
processed = dev.setup_execution_config(config) | ||
assert processed.convert_to_numpy | ||
|
||
@pytest.mark.parametrize("interface", ("autograd", "torch", "tf")) | ||
def test_convert_to_numpy_non_jax(self, interface): | ||
"""Test that other interfaces are still converted to numpy.""" | ||
config = qml.devices.ExecutionConfig(gradient_method="adjoint", interface=interface) | ||
dev = qml.device("default.qubit") | ||
processed = dev.setup_execution_config(config) | ||
assert processed.convert_to_numpy |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you include "jax" as an interface in the testing?
Also, I'm curious if converting to numpy with adjoint has negative effects on performance.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I do think we could make adjoint jittable and get some nice speed boosts, but i think that would need to be a follow on task.
@@ -417,7 +417,7 @@ def circuit(state): | |||
|
|||
|
|||
@pytest.mark.jax | |||
@pytest.mark.parametrize("shots, atol", [(None, 0.005), (1000000, 0.05)]) | |||
@pytest.mark.parametrize("shots, atol", [(None, 0.005), (1000000, 0.1)]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a lot of shots...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we're using a seed, maybe it's worth reducing the number of shots.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Given the tolerance is still almost to high... the problem is that this test is now doing finite shot finite differences with float32. Im not sure we can get any accuracy out of that.
|
||
result = execute([tape1, tape2], dev, diff_method=param_shift, max_diff=1) | ||
return result[0] + result[1][0] | ||
|
||
res = cost_fn(params) | ||
x, y = params | ||
expected = 0.5 * (3 + jnp.cos(x) ** 2 * jnp.cos(2 * y)) | ||
assert np.allclose(res, expected, atol=tol, rtol=0) | ||
assert np.allclose(res, expected, atol=2e-2, rtol=0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the tol
isn't being used, it can be removed as an argument to this test.
Context:
While we have logic for sampling with jax, it does not really integrate very well into the workflow. While you can technically set
diff_method=None
right now and jit the execution end-to-end, trying to jitdiff_method=None
will cause incomprehensible error messages on non-DQ devices.We want to forbid differentiation
diff_method=None
, but keep a way to jit a finite shot execution.Description of the Change:
In order to allow jitting finite shot executions, we need a way for the device to be able to configure whether or not the data is converted to numpy. To do so, we simply add another property to the
ExecutionConfig
,convert_to_numpy
. IfFalse
, then we will not use apure_callback
to convert the parameters to numpy. IfTrue
, we use apure_callback
and convert the parameters to numpy.Benefits:
Speed ups due to being able to jit the entire execution.
Possible Drawbacks:
ExecutionConfig
gets an addtional property, making it more complicated.Related GitHub Issues:
Fixes #6054 Fixes #3259 Blocks #6770