From def1e399537660429b9ed0117b37060afeb745f2 Mon Sep 17 00:00:00 2001 From: Matthew Silverman Date: Tue, 20 Feb 2024 15:18:43 -0500 Subject: [PATCH 1/2] fix QubitDensityMatrix for subsystems with jit on default.mixed --- doc/releases/changelog-dev.md | 1 + pennylane/devices/default_mixed.py | 2 +- tests/devices/test_default_mixed_jax.py | 11 ++++++++--- 3 files changed, 10 insertions(+), 4 deletions(-) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 940d879fefa..ad5fce81d7c 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -601,6 +601,7 @@ * `QubitDensityMatrix` now works with jax-jit on the `default.mixed` device. [(#5203)](https://github.com/PennyLaneAI/pennylane/pull/5203) + [()](https://github.com/PennyLaneAI/pennylane/pull/)

Contributors ✍️

diff --git a/pennylane/devices/default_mixed.py b/pennylane/devices/default_mixed.py index 88facf5be74..44b6e03a749 100644 --- a/pennylane/devices/default_mixed.py +++ b/pennylane/devices/default_mixed.py @@ -573,7 +573,7 @@ def _apply_density_matrix(self, state, device_wires): right_axes.append(index + self.num_wires) transpose_axes = left_axes + right_axes rho = qnp.transpose(rho, axes=transpose_axes) - assert qnp.allclose( + assert qml.math.is_abstract(rho) or qnp.allclose( qnp.trace(qnp.reshape(rho, (2**self.num_wires, 2**self.num_wires))), 1.0, atol=tolerance, diff --git a/tests/devices/test_default_mixed_jax.py b/tests/devices/test_default_mixed_jax.py index 4bb178cc690..c9fedbb4a6c 100644 --- a/tests/devices/test_default_mixed_jax.py +++ b/tests/devices/test_default_mixed_jax.py @@ -89,10 +89,11 @@ def circuit(a): assert np.allclose(state, expected, atol=tol, rtol=0) - def test_qubit_density_matrix_jit_compatible(self, mocker): + @pytest.mark.parametrize("n_qubits", [1, 2]) + def test_qubit_density_matrix_jit_compatible(self, n_qubits, mocker): """Test that _apply_density_matrix works with jax-jit""" - dev = qml.device("default.mixed", wires=1) + dev = qml.device("default.mixed", wires=n_qubits) spy = mocker.spy(dev, "_apply_density_matrix") @jax.jit @@ -106,7 +107,11 @@ def circuit(state_ini): rho_out = circuit(rho_ini) spy.assert_called_once() assert qml.math.get_interface(rho_out) == "jax" - assert np.array_equal(rho_out, [[1, 0], [0, 0]]) + + dim = 2**n_qubits + expected = np.zeros((dim, dim)) + expected[0, 0] = 1.0 + assert np.array_equal(rho_out, expected) class TestDtypePreserved: From 8b45b04a8e88fc5fd50051ab75fc610f1e9d13b2 Mon Sep 17 00:00:00 2001 From: Matthew Silverman Date: Tue, 20 Feb 2024 15:19:52 -0500 Subject: [PATCH 2/2] changelog --- doc/releases/changelog-dev.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index ad5fce81d7c..69b02dab897 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -601,7 +601,7 @@ * `QubitDensityMatrix` now works with jax-jit on the `default.mixed` device. [(#5203)](https://github.com/PennyLaneAI/pennylane/pull/5203) - [()](https://github.com/PennyLaneAI/pennylane/pull/) + [(#5236)](https://github.com/PennyLaneAI/pennylane/pull/5236)

Contributors ✍️