Skip to content

Commit

Permalink
Implement ExpectationMP.process_counts
Browse files Browse the repository at this point in the history
  • Loading branch information
Tarun-Kumar07 committed Feb 25, 2024
1 parent 2ae851b commit c48e6ce
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 1 deletion.
6 changes: 5 additions & 1 deletion doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,9 @@

<h3>Improvements 🛠</h3>

* Implemented the method `process_counts` in `ExpectationMp`.
[(#5241)](https://github.com/PennyLaneAI/pennylane/issues/5241)

<h4>Faster gradients with VJPs and other performance improvements</h4>

* Adjoint device VJP's are now supported with `jax.jacobian`. `device_vjp=True` is
Expand Down Expand Up @@ -683,4 +686,5 @@ Lee J. O'Riordan,
Mudit Pandey,
Alex Preciado,
Matthew Silverman,
Jay Soni.
Jay Soni,
Tarun Kumar Allamsetty.
4 changes: 4 additions & 0 deletions pennylane/measurements/expval.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,3 +136,7 @@ def process_state(self, state: Sequence[complex], wire_order: Wires):
prob = qml.probs(wires=self.wires).process_state(state=state, wire_order=wire_order)
# In case of broadcasting, `prob` has two axes and this is a matrix-vector product
return qml.math.dot(prob, eigvals)

def process_counts(self, counts: dict, wire_order: Wires):
probs = qml.probs(wires=self.wires).process_counts(counts=counts, wire_order=wire_order)
return qml.math.dot(probs, self.eigvals())
16 changes: 16 additions & 0 deletions tests/measurements/test_expval.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,3 +318,19 @@ def cost_circuit(params):
energy_batched = cost_circuit(params)

assert qml.math.allequal(energy_batched, energy)

@pytest.mark.parametrize(
"wire, expected",
[
(0, 0.0),
(1, 1.0),
],
)
def test_estimate_expectation_with_counts(self, wire, expected):
counts = {"000": 100, "100": 100}

wire_order = qml.wires.Wires((0, 1, 2))

res = qml.expval(qml.Z(wire)).process_counts(counts=counts, wire_order=wire_order)

assert np.allclose(res, expected)

0 comments on commit c48e6ce

Please sign in to comment.