Skip to content

Commit

Permalink
default.qutrit now returns integer samples (#6385)
Browse files Browse the repository at this point in the history
**Context:**

`default.qutrit` was returning float samples instead of integer samples.
This causes errors when jitting, as the jit boundary is extremely
sensitive to types and shapes.

**Description of the Change:**

Only casts samples to float if the `obs` is not a measurement process.

I also add a test to the device test suite for `qml.sample`. Though the
device test suite does not apply to `default.qutrit`, this test will
still help validate that sample returns are working as expected on
devices.

**Benefits:**

**Possible Drawbacks:**

**Related GitHub Issues:**

Fixes #6384 [sc-75867]
  • Loading branch information
albi3ro authored Oct 11, 2024
1 parent 38a5140 commit 0c87b9a
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 1 deletion.
3 changes: 3 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,9 @@

<h3>Bug fixes 🐛</h3>

* `default.qutrit` now returns integer samples.
[(#6385)](https://github.com/PennyLaneAI/pennylane/pull/6385)

* `adjoint_metric_tensor` now works with circuits containing state preparation operations.
[(#6358)](https://github.com/PennyLaneAI/pennylane/pull/6358)

Expand Down
3 changes: 2 additions & 1 deletion pennylane/devices/_qubit_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,7 +663,8 @@ def statistics(

elif isinstance(m, SampleMP):
samples = self.sample(obs, shot_range=shot_range, bin_size=bin_size, counts=False)
result = self._asarray(qml.math.squeeze(samples))
dtype = int if isinstance(obs, SampleMP) else None
result = self._asarray(qml.math.squeeze(samples), dtype=dtype)

elif isinstance(m, CountsMP):
result = self.sample(m, shot_range=shot_range, bin_size=bin_size, counts=True)
Expand Down
19 changes: 19 additions & 0 deletions pennylane/devices/tests/test_measurements.py
Original file line number Diff line number Diff line change
Expand Up @@ -694,6 +694,25 @@ def result():
class TestSample:
"""Tests for the sample return type."""

def test_sample_wires(self, device):
"""Test that a device can return samples."""

n_wires = 1
dev = device(n_wires)

if not dev.shots:
pytest.skip("Device is in analytic mode, cannot test sampling.")

@qml.qnode(dev)
def circuit():
qml.X(0)
return qml.sample(wires=0)

res = circuit()
assert qml.math.allclose(res, 1) # note, might be violated with a noisy device?
assert qml.math.shape(res) == (dev.shots.total_shots,)
assert qml.math.get_dtype_name(res)[0:3] == "int" # either 32 or 64 precision.

def test_sample_values(self, device, tol):
"""Tests if the samples returned by sample have
the correct values
Expand Down
10 changes: 10 additions & 0 deletions tests/devices/test_default_qutrit.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,6 +701,16 @@ def circuit():
class TestSample:
"""Tests that samples are properly calculated."""

def test_sample_dtype(self):
"""Test that if the raw samples are requested, they are of dtype int."""

dev = qml.device("default.qutrit", wires=1, shots=10)

tape = qml.tape.QuantumScript([], [qml.sample(wires=0)], shots=10)
res = dev.execute(tape)
assert qml.math.get_dtype_name(res)[0:3] == "int"
assert res.shape == (10,)

def test_sample_dimensions(self):
"""Tests if the samples returned by the sample function have
the correct dimensions
Expand Down

0 comments on commit 0c87b9a

Please sign in to comment.