Skip to content

Commit

Permalink
allow passing a callable to de/serialization funcs (#6855)
Browse files Browse the repository at this point in the history
* allow passing func to de/serialization funcs

* coverage

* simplify

* typecheck

* nit

* mypy

* comments

* comments
  • Loading branch information
senecameeks authored Dec 19, 2024
1 parent 2448121 commit 0d9a6ee
Show file tree
Hide file tree
Showing 2 changed files with 128 additions and 18 deletions.
59 changes: 41 additions & 18 deletions cirq-google/cirq_google/api/v2/sweeps.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, cast, Dict, List, Optional
from typing import Any, cast, Callable, Dict, List, Optional

import sympy
import tunits

import cirq
from cirq.study import sweeps
from cirq_google.api.v2 import run_context_pb2
from cirq_google.study.device_parameter import DeviceParameter

Expand Down Expand Up @@ -55,14 +56,18 @@ def _recover_sweep_const(const_pb: run_context_pb2.ConstValue) -> Any:


def sweep_to_proto(
sweep: cirq.Sweep, *, out: Optional[run_context_pb2.Sweep] = None
sweep: cirq.Sweep,
*,
out: Optional[run_context_pb2.Sweep] = None,
sweep_transformer: Callable[[sweeps.SingleSweep], sweeps.SingleSweep] = lambda x: x,
) -> run_context_pb2.Sweep:
"""Converts a Sweep to v2 protobuf message.
Args:
sweep: The sweep to convert.
out: Optional message to be populated. If not given, a new message will
be created.
sweep_transformer: A function called on Linspace, Points.
Returns:
Populated sweep protobuf message.
Expand Down Expand Up @@ -91,6 +96,7 @@ def sweep_to_proto(
for s in sweep.sweeps:
sweep_to_proto(s, out=out.sweep_function.sweeps.add())
elif isinstance(sweep, cirq.Linspace) and not isinstance(sweep.key, sympy.Expr):
sweep = cast(cirq.Linspace, sweep_transformer(sweep))
out.single_sweep.parameter_key = sweep.key
if isinstance(sweep.start, tunits.Value):
unit = sweep.start.unit
Expand All @@ -110,6 +116,7 @@ def sweep_to_proto(
if sweep.metadata and getattr(sweep.metadata, 'units', None):
out.single_sweep.parameter.units = sweep.metadata.units
elif isinstance(sweep, cirq.Points) and not isinstance(sweep.key, sympy.Expr):
sweep = cast(cirq.Points, sweep_transformer(sweep))
out.single_sweep.parameter_key = sweep.key
if len(sweep.points) == 1:
out.single_sweep.const_value.MergeFrom(_build_sweep_const(sweep.points[0]))
Expand Down Expand Up @@ -142,8 +149,17 @@ def sweep_to_proto(
return out


def sweep_from_proto(msg: run_context_pb2.Sweep) -> cirq.Sweep:
"""Creates a Sweep from a v2 protobuf message."""
def sweep_from_proto(
msg: run_context_pb2.Sweep,
sweep_transformer: Callable[[sweeps.SingleSweep], sweeps.SingleSweep] = lambda x: x,
) -> cirq.Sweep:
"""Creates a Sweep from a v2 protobuf message.
Args:
msg: Serialized sweep message.
sweep_transformer: A function called on Linspace, Point, and ConstValue.
"""
which = msg.WhichOneof('sweep')
if which is None:
return cirq.UnitSweep
Expand Down Expand Up @@ -178,31 +194,38 @@ def sweep_from_proto(msg: run_context_pb2.Sweep) -> cirq.Sweep:
)
else:
metadata = None

if msg.single_sweep.WhichOneof('sweep') == 'linspace':
unit: float | tunits.Value = 1.0
if msg.single_sweep.linspace.HasField('unit'):
unit = tunits.Value.from_proto(msg.single_sweep.linspace.unit)
return cirq.Linspace(
key=key,
start=msg.single_sweep.linspace.first_point * unit, # type: ignore[arg-type]
stop=msg.single_sweep.linspace.last_point * unit, # type: ignore[arg-type]
length=msg.single_sweep.linspace.num_points,
metadata=metadata,
return sweep_transformer(
cirq.Linspace(
key=key,
start=msg.single_sweep.linspace.first_point * unit, # type: ignore[arg-type]
stop=msg.single_sweep.linspace.last_point * unit, # type: ignore[arg-type]
length=msg.single_sweep.linspace.num_points,
metadata=metadata,
)
)
if msg.single_sweep.WhichOneof('sweep') == 'points':
unit = 1.0
if msg.single_sweep.points.HasField('unit'):
unit = tunits.Value.from_proto(msg.single_sweep.points.unit)
return cirq.Points(
key=key,
points=[p * unit for p in msg.single_sweep.points.points],
metadata=metadata,
return sweep_transformer(
cirq.Points(
key=key,
points=[p * unit for p in msg.single_sweep.points.points],
metadata=metadata,
)
)
if msg.single_sweep.WhichOneof('sweep') == 'const_value':
return cirq.Points(
key=key,
points=[_recover_sweep_const(msg.single_sweep.const_value)],
metadata=metadata,
return sweep_transformer(
cirq.Points(
key=key,
points=[_recover_sweep_const(msg.single_sweep.const_value)],
metadata=metadata,
)
)

raise ValueError(f'single sweep type not set: {msg}')
Expand Down
87 changes: 87 additions & 0 deletions cirq-google/cirq_google/api/v2/sweeps_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,58 @@ def test_sweep_to_proto_points():
assert list(proto.single_sweep.points.points) == [-1, 0, 1, 1.5]


def test_sweep_to_proto_with_simple_func_succeeds():
def func(sweep: sweeps.SingleSweep):
if isinstance(sweep, cirq.Points):
sweep.points = [point + 3 for point in sweep.points]

return sweep

sweep = cirq.Points('foo', [1, 2, 3])
proto = v2.sweep_to_proto(sweep, sweep_transformer=func)

assert list(proto.single_sweep.points.points) == [4.0, 5.0, 6.0]


def test_sweep_to_proto_with_func_linspace():
def func(sweep: sweeps.SingleSweep):
return cirq.Linspace('foo', 3 * tunits.ns, 6 * tunits.ns, 3) # type: ignore[arg-type]

sweep = cirq.Linspace('foo', start=1, stop=3, length=3)
proto = v2.sweep_to_proto(sweep, sweep_transformer=func)

assert proto.single_sweep.linspace.first_point == 3.0
assert proto.single_sweep.linspace.last_point == 6.0
assert tunits.Value.from_proto(proto.single_sweep.linspace.unit) == tunits.ns


def test_sweep_to_proto_with_func_const_value():
def func(sweep: sweeps.SingleSweep):
if isinstance(sweep, cirq.Points):
sweep.points = [point + 3 for point in sweep.points]

return sweep

sweep = cirq.Points('foo', points=[1])
proto = v2.sweep_to_proto(sweep, sweep_transformer=func)

assert proto.single_sweep.const_value.int_value == 4


@pytest.mark.parametrize('sweep', [(cirq.Points('foo', [1, 2, 3])), (cirq.Points('foo', [1]))])
def test_sweep_to_proto_with_func_round_trip(sweep):
def add_tunit_func(sweep: sweeps.SingleSweep):
if isinstance(sweep, cirq.Points):
sweep.points = [point * tunits.ns for point in sweep.points] # type: ignore[misc]

return sweep

proto = v2.sweep_to_proto(sweep, sweep_transformer=add_tunit_func)
recovered = v2.sweep_from_proto(proto)

assert list(recovered.points)[0] == 1 * tunits.ns


def test_sweep_to_proto_unit():
proto = v2.sweep_to_proto(cirq.UnitSweep)
assert isinstance(proto, v2.run_context_pb2.Sweep)
Expand Down Expand Up @@ -188,6 +240,41 @@ def test_sweep_from_proto_single_sweep_type_not_set():
v2.sweep_from_proto(proto)


@pytest.mark.parametrize('sweep', [cirq.Points('foo', [1, 2, 3]), cirq.Points('foo', [1])])
def test_sweep_from_proto_with_func_succeeds(sweep):
def add_tunit_func(sweep: sweeps.SingleSweep):
if isinstance(sweep, cirq.Points):
sweep.points = [point * tunits.ns for point in sweep.points] # type: ignore[misc]

return sweep

msg = v2.sweep_to_proto(sweep)
sweep = v2.sweep_from_proto(msg, sweep_transformer=add_tunit_func)

assert list(sweep.points)[0] == [1.0 * tunits.ns]


@pytest.mark.parametrize('sweep', [cirq.Points('foo', [1, 2, 3]), cirq.Points('foo', [1])])
def test_sweep_from_proto_with_func_round_trip(sweep):
def add_tunit_func(sweep: sweeps.SingleSweep):
if isinstance(sweep, cirq.Points):
sweep.points = [point * tunits.ns for point in sweep.points] # type: ignore[misc]

return sweep

def strip_tunit_func(sweep: sweeps.SingleSweep):
if isinstance(sweep, cirq.Points):
if isinstance(sweep.points[0], tunits.Value):
sweep.points = [point[point.unit] for point in sweep.points]

return sweep

msg = v2.sweep_to_proto(sweep, sweep_transformer=add_tunit_func)
sweep = v2.sweep_from_proto(msg, sweep_transformer=strip_tunit_func)

assert list(sweep.points)[0] == 1.0


def test_sweep_with_list_sweep():
ls = cirq.study.to_sweep([{'a': 1, 'b': 2}, {'a': 3, 'b': 4}])
proto = v2.sweep_to_proto(ls)
Expand Down

0 comments on commit 0d9a6ee

Please sign in to comment.