diff --git a/.gitignore b/.gitignore index 841fe6c9d..cc40feb02 100644 --- a/.gitignore +++ b/.gitignore @@ -4,6 +4,7 @@ out.txt .DS_Store tests/data/jobs/ main.html +main*.py # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/src/bloqade/codegen/hardware/piecewise_constant.py b/src/bloqade/codegen/hardware/piecewise_constant.py new file mode 100644 index 000000000..b049c4107 --- /dev/null +++ b/src/bloqade/codegen/hardware/piecewise_constant.py @@ -0,0 +1,222 @@ +import bloqade.ir.control.waveform as waveform +from bloqade.ir.visitor.waveform import WaveformVisitor + +from typing import Dict, Tuple, List, Union +from pydantic.dataclasses import dataclass +from bisect import bisect_left +import numbers +from decimal import Decimal + + +@dataclass(frozen=True) +class PiecewiseConstant: + times: List[Decimal] + values: List[Decimal] + + def eval(self, time): + if time >= self.times[-1]: + return self.values[-1] + elif time <= self.times[0]: + return self.values[0] + else: + index = bisect_left(self.times, time) + + if self.times[index] == time: + index += 1 + + return self.values[index] + + def slice(self, start_time, stop_time) -> "PiecewiseConstant": + start_time = Decimal(str(start_time)) + stop_time = Decimal(str(stop_time)) + + if start_time == stop_time: + return PiecewiseConstant( + [Decimal(0.0), Decimal(0.0)], [Decimal(0.0), Decimal(0.0)] + ) + + start_index = bisect_left(self.times, start_time) + stop_index = bisect_left(self.times, stop_time) + + if start_time == self.times[start_index]: + if stop_time == self.times[stop_index]: + absolute_times = list(self.times[start_index : stop_index + 1]) + values = list(self.values[start_index : stop_index + 1]) + else: + absolute_times = self.times[start_index:stop_index] + [stop_time] + values = self.values[start_index:stop_index] + [self.values[stop_index]] + else: + if stop_time == self.times[stop_index]: + absolute_times = [start_time] + self.times[start_index : stop_index + 1] + values = [self.values[start_index - 1]] + self.values[ + start_index : stop_index + 1 + ] + else: + absolute_times = ( + [start_time] + self.times[start_index:stop_index] + [stop_time] + ) + values = ( + [self.values[start_index - 1]] + + self.values[start_index:stop_index] + + [self.values[stop_index]] + ) + + values[-1] = values[-2] + + return PiecewiseConstant([time - start_time for time in absolute_times], values) + + def append(self, other: "PiecewiseConstant"): + return PiecewiseConstant( + times=self.times + [time + self.times[-1] for time in other.times[1:]], + values=self.values[:-1] + other.values, + ) + + +class PiecewiseConstantCodeGen(WaveformVisitor): + def __init__(self, assignments: Dict[str, Union[numbers.Real, List[numbers.Real]]]): + self.assignments = assignments + self.times = [] + self.values = [] + + def append_timeseries(self, value, duration): + if len(self.times) > 0: + self.times.append(duration + self.times[-1]) + self.values[-1] = value + self.values.append(value) + else: + self.times = [Decimal(0), duration] + self.values = [value, value] + + def visit_linear(self, ast: waveform.Linear) -> Tuple[List[Decimal], List[Decimal]]: + duration = ast.duration(**self.assignments) + start = ast.start(**self.assignments) + stop = ast.stop(**self.assignments) + + if start != stop: + raise ValueError( + "Failed to compile Waveform to piecewise constant, " + "found non-constant Linear piece." + ) + + self.append_timeseries(start, duration) + + def visit_constant( + self, ast: waveform.Constant + ) -> Tuple[List[Decimal], List[Decimal]]: + duration = ast.duration(**self.assignments) + value = ast.value(**self.assignments) + + self.append_timeseries(value, duration) + + def visit_poly(self, ast: waveform.Poly) -> Tuple[List[Decimal], List[Decimal]]: + order = len(ast.coeffs) - 1 + duration = ast.duration(**self.assignments) + + if len(ast.coeffs) == 0: + value = Decimal(0) + + elif len(ast.coeffs) == 1: + value = ast.coeffs[0](**self.assignments) + + elif len(ast.coeffs) == 2: + start = ast.coeffs[0](**self.assignments) + stop = start + ast.coeffs[1](**self.assignments) * duration + + if start != stop: + raise ValueError( + "Failed to compile Waveform to piecewise constant, " + "found non-constant Polynomial piece." + ) + + value = start + + else: + raise ValueError( + "Failed to compile Waveform to piecewise constant," + f"found Polynomial of order {order}." + ) + + self.append_timeseries(value, duration) + + def visit_negative( + self, ast: waveform.Negative + ) -> Tuple[List[Decimal], List[Decimal]]: + self.visit(ast.waveform) + + self.values = [-value for value in self.values] + + def visit_scale(self, ast: waveform.Scale) -> Tuple[List[Decimal], List[Decimal]]: + self.visit(ast.waveform) + scale = ast.scalar(**self.assignments) + self.values = [scale * value for value in self.values] + + def visit_slice(self, ast: waveform.Slice) -> Tuple[List[Decimal], List[Decimal]]: + duration = ast.waveform.duration(**self.assignments) + + if ast.interval.start is None: + start_time = Decimal(0) + else: + start_time = ast.interval.start(**self.assignments) + + if ast.interval.stop is None: + stop_time = duration + else: + stop_time = ast.interval.stop(**self.assignments) + + if start_time < 0: + raise ValueError((f"start time for slice {start_time} is smaller than 0.")) + + if stop_time > duration: + raise ValueError( + (f"end time for slice {stop_time} is larger than duration {duration}.") + ) + + if stop_time < start_time: + raise ValueError( + ( + f"end time for slice {stop_time} is smaller than " + f"starting value for slice {start_time}." + ) + ) + + new_pwc = ( + PiecewiseConstantCodeGen(self.assignments) + .emit(ast.waveform) + .slice(start_time, stop_time) + ) + + self.times = new_pwc.times + self.values = new_pwc.values + + def visit_append(self, ast: waveform.Append) -> Tuple[List[Decimal], List[Decimal]]: + pwc = PiecewiseConstantCodeGen(self.assignments).emit(ast.waveforms[0]) + + for sub_expr in ast.waveforms[1:]: + new_pwc = PiecewiseConstantCodeGen(self.assignments).emit(sub_expr) + + # skip instructions with duration=0 + if new_pwc.times[-1] == Decimal(0): + continue + + pwc = pwc.append(new_pwc) + + self.times = pwc.times + self.values = pwc.values + + def visit_sample(self, ast: waveform.Sample) -> Tuple[List[Decimal], List[Decimal]]: + if ast.interpolation is not waveform.Interpolation.Constant: + raise ValueError( + "Failed to compile waveform to piecewise constant, " + f"found piecewise {ast.interpolation.value} interpolation." + ) + self.times, values = ast.samples(**self.assignments) + values[-1] = values[-2] + self.values = values + + def visit_record(self, ast: waveform.Record) -> Tuple[List[Decimal], List[Decimal]]: + self.visit(ast.waveform) + + def emit(self, ast: waveform.Waveform) -> PiecewiseConstant: + self.visit(ast) + + return PiecewiseConstant(times=self.times, values=self.values) diff --git a/src/bloqade/codegen/hardware/piecewise_linear.py b/src/bloqade/codegen/hardware/piecewise_linear.py new file mode 100644 index 000000000..a80587929 --- /dev/null +++ b/src/bloqade/codegen/hardware/piecewise_linear.py @@ -0,0 +1,218 @@ +import bloqade.ir.control.waveform as waveform +from bloqade.ir.control.waveform import Record +from bloqade.ir.visitor.waveform import WaveformVisitor + +from typing import Dict, Tuple, List, Union +from pydantic.dataclasses import dataclass +from bisect import bisect_left, bisect_right +import numbers +from decimal import Decimal + + +@dataclass(frozen=True) +class PiecewiseLinear: + times: List[Decimal] + values: List[Decimal] + + def eval(self, time: Decimal) -> Decimal: + if time >= self.times[-1]: + return self.values[-1] + + elif time <= self.times[0]: + return self.values[0] + else: + index = bisect_right(self.times, time) - 1 + + m = (self.values[index + 1] - self.values[index]) / ( + self.times[index + 1] - self.times[index] + ) + t = time - self.times[index] + b = self.values[index] + + return m * t + b + + def slice(self, start_time: Decimal, stop_time: Decimal) -> "PiecewiseLinear": + start_time = ( + Decimal(str(start_time)) + if not isinstance(start_time, Decimal) + else start_time + ) + stop_time = ( + Decimal(str(stop_time)) if not isinstance(stop_time, Decimal) else stop_time + ) + + if start_time == stop_time: + return PiecewiseLinear( + [Decimal(0.0), Decimal(0.0)], [Decimal(0.0), Decimal(0.0)] + ) + + start_index = bisect_left(self.times, start_time) + stop_index = bisect_left(self.times, stop_time) + start_value = self.eval(start_time) + stop_value = self.eval(stop_time) + + if start_time == self.times[start_index]: + if stop_time == self.times[start_index]: + absolute_times = list(self.times[start_index : stop_index + 1]) + values = list(self.values[start_index : stop_index + 1]) + else: + absolute_times = self.times[start_index:stop_index] + [stop_time] + values = self.values[start_index:stop_index] + [stop_value] + else: + if stop_time == self.times[stop_index]: + absolute_times = [start_time] + self.times[start_index : stop_index + 1] + values = [start_value] + self.values[start_index : stop_index + 1] + else: + absolute_times = ( + [start_time] + self.times[start_index:stop_index] + [stop_time] + ) + values = ( + [start_value] + self.values[start_index:stop_index] + [stop_value] + ) + + return PiecewiseLinear([time - start_time for time in absolute_times], values) + + def append(self, other: "PiecewiseLinear"): + assert self.values[-1] == other.values[0] + + return PiecewiseLinear( + times=self.times + [time + self.times[-1] for time in other.times[1:]], + values=self.values + other.values[1:], + ) + + +class PiecewiseLinearCodeGen(WaveformVisitor): + def __init__(self, assignments: Dict[str, Union[numbers.Real, List[numbers.Real]]]): + self.assignments = assignments + self.times = [] + self.values = [] + + @staticmethod + def check_continiuity(left, right): + if left != right: + diff = abs(left - right) + raise ValueError( + f"discontinuity with a jump of {diff} found when compiling to " + "piecewise linear." + ) + + def append_timeseries(self, start, stop, duration): + if len(self.times) == 0: + self.times = [Decimal(0), duration] + self.values = [start, stop] + else: + self.check_continiuity(self.values[-1], start) + + self.times.append(duration + self.times[-1]) + self.values.append(stop) + + def visit_linear(self, ast: waveform.Linear) -> Tuple[List[Decimal], List[Decimal]]: + duration = ast.duration(**self.assignments) + start = ast.start(**self.assignments) + stop = ast.stop(**self.assignments) + self.append_timeseries(start, stop, duration) + + def visit_constant( + self, ast: waveform.Constant + ) -> Tuple[List[Decimal], List[Decimal]]: + duration = ast.duration(**self.assignments) + value = ast.value(**self.assignments) + self.append_timeseries(value, value, duration) + + def visit_poly(self, ast: waveform.Poly) -> Tuple[List[Decimal], List[Decimal]]: + order = len(ast.coeffs) - 1 + duration = ast.duration(**self.assignments) + + if len(ast.coeffs) == 0: + start = Decimal(0) + stop = Decimal(0) + elif len(ast.coeffs) == 1: + start = ast.coeffs[0](**self.assignments) + stop = start + elif len(ast.coeffs) == 2: + start = ast.coeffs[0](**self.assignments) + stop = start + ast.coeffs[1](**self.assignments) * duration + else: + raise ValueError( + "Failed to compile Waveform to piecewise linear," + f"found Polynomial of order {order}." + ) + + self.append_timeseries(start, stop, duration) + + def visit_negative( + self, ast: waveform.Negative + ) -> Tuple[List[Decimal], List[Decimal]]: + self.visit(ast.waveform) + self.values = [-value for value in self.values] + + def visit_scale(self, ast: waveform.Scale) -> Tuple[List[Decimal], List[Decimal]]: + self.visit(ast.waveform) + scaler = ast.scalar(**self.assignments) + self.values = [scaler * value for value in self.values] + + def visit_slice(self, ast: waveform.Slice) -> Tuple[List[Decimal], List[Decimal]]: + duration = ast.waveform.duration(**self.assignments) + if ast.interval.start is None: + start_time = Decimal(0) + else: + start_time = ast.interval.start(**self.assignments) + + if ast.interval.stop is None: + stop_time = duration + else: + stop_time = ast.interval.stop(**self.assignments) + + if start_time < 0: + raise ValueError((f"start time for slice {start_time} is smaller than 0.")) + + if stop_time > duration: + raise ValueError( + (f"end time for slice {stop_time} is larger than duration {duration}.") + ) + + if stop_time < start_time: + raise ValueError( + ( + f"end time for slice {stop_time} is smaller than " + f"starting value for slice {start_time}." + ) + ) + + pwl = PiecewiseLinearCodeGen(self.assignments).emit(ast.waveform) + new_pwl = pwl.slice(start_time, stop_time) + + self.times = new_pwl.times + self.values = new_pwl.values + + def visit_append(self, ast: waveform.Append) -> Tuple[List[Decimal], List[Decimal]]: + pwl = PiecewiseLinearCodeGen(self.assignments).emit(ast.waveforms[0]) + + for sub_expr in ast.waveforms[1:]: + new_pwl = PiecewiseLinearCodeGen(self.assignments).emit(sub_expr) + + # skip instructions with duration=0 + if new_pwl.times[-1] == Decimal(0): + continue + + self.check_continiuity(pwl.values[-1], new_pwl.values[0]) + pwl = pwl.append(new_pwl) + + self.times = pwl.times + self.values = pwl.values + + def visit_sample(self, ast: waveform.Sample) -> Tuple[List[Decimal], List[Decimal]]: + if ast.interpolation is not waveform.Interpolation.Linear: + raise ValueError( + "Failed to compile waveform to piecewise linear, " + f"found piecewise {ast.interpolation.value} interpolation." + ) + + self.times, self.values = ast.samples(**self.assignments) + + def visit_record(self, ast: Record) -> Tuple[List[Decimal], List[Decimal]]: + self.visit(ast.waveform) + + def emit(self, ast: waveform.Waveform) -> PiecewiseLinear: + self.visit(ast) + return PiecewiseLinear(self.times, self.values) diff --git a/src/bloqade/codegen/hardware/quera.py b/src/bloqade/codegen/hardware/quera.py index 655da6cee..6f51f98b9 100644 --- a/src/bloqade/codegen/hardware/quera.py +++ b/src/bloqade/codegen/hardware/quera.py @@ -1,368 +1,322 @@ +from functools import cached_property from bloqade.ir.analog_circuit import AnalogCircuit from bloqade.ir.scalar import Literal -import bloqade.ir.control.waveform as waveform import bloqade.ir.control.field as field import bloqade.ir.control.pulse as pulse import bloqade.ir.control.sequence as sequence from bloqade.ir.location.base import AtomArrangement, ParallelRegister -from bloqade.ir.control.waveform import Record - from bloqade.ir.visitor.analog_circuit import AnalogCircuitVisitor -from bloqade.ir.visitor.waveform import WaveformVisitor import bloqade.submission.ir.task_specification as task_spec +from bloqade.submission.ir.braket import BraketTaskSpecification +from bloqade.submission.ir.task_specification import QuEraTaskSpecification from bloqade.submission.ir.parallel import ParallelDecoder, ClusterLocationInfo from bloqade.submission.ir.capabilities import QuEraCapabilities +from bloqade.codegen.hardware.piecewise_linear import ( + PiecewiseLinearCodeGen, + PiecewiseLinear, +) +from bloqade.codegen.hardware.piecewise_constant import ( + PiecewiseConstantCodeGen, + PiecewiseConstant, +) from typing import Any, Dict, Tuple, List, Union, Optional -from bisect import bisect_left +from pydantic.dataclasses import dataclass import numbers from decimal import Decimal import numpy as np -class PiecewiseLinearCodeGen(WaveformVisitor): - def __init__(self, assignments: Dict[str, Union[numbers.Real, List[numbers.Real]]]): - self.assignments = assignments - - def visit_negative( - self, ast: waveform.Negative - ) -> Tuple[List[Decimal], List[Decimal]]: - times, values = self.visit(ast.waveform) - return times, [-value for value in values] - - def visit_scale(self, ast: waveform.Scale) -> Tuple[List[Decimal], List[Decimal]]: - times, values = self.visit(ast.waveform) - scaler = ast.scalar(**self.assignments) - return times, [scaler * value for value in values] - - def visit_linear(self, ast: waveform.Linear) -> Tuple[List[Decimal], List[Decimal]]: - duration = ast.duration(**self.assignments) - start = ast.start(**self.assignments) - stop = ast.stop(**self.assignments) - - return [Decimal(0), duration], [start, stop] - - def visit_constant( - self, ast: waveform.Constant - ) -> Tuple[List[Decimal], List[Decimal]]: - duration = ast.duration(**self.assignments) - value = ast.value(**self.assignments) - - return [Decimal(0), duration], [value, value] - - def visit_poly(self, ast: waveform.Poly) -> Tuple[List[Decimal], List[Decimal]]: - if len(ast.coeffs) == 1: - duration = ast.duration(**self.assignments) - value = ast.coeffs[0](**self.assignments) - return [Decimal(0), duration], [value, value] - - if len(ast.coeffs) == 2: - duration = ast.duration(**self.assignments) - start = ast.coeffs[0](**self.assignments) - stop = start + ast.coeffs[1](**self.assignments) * duration - - return [Decimal(0), duration], [start, stop] - - order = len(ast.coeffs) - 1 - - raise ValueError( - "Failed to compile Waveform to piecewise linear," - f"found Polynomial of order {order}." +@dataclass(frozen=True) +class AHSCodegenResult: + nshots: int + parallel_decoder: Optional[ParallelDecoder] + sites: List[Tuple[Decimal, Decimal]] + filling: List[bool] + global_detuning: PiecewiseLinear + global_rabi_amplitude: PiecewiseLinear + global_rabi_phase: PiecewiseConstant + lattice_site_coefficients: Optional[List[Decimal]] + local_detuning: Optional[PiecewiseLinear] + + def slice(self, start_time: Decimal, stop_time: Decimal) -> "AHSCodegenResult": + return AHSCodegenResult( + self.nshots, + self.parallel_decoder, + self.sites, + self.filling, + self.global_detuning.slice(start_time, stop_time), + self.global_rabi_amplitude.slice(start_time, stop_time), + self.global_rabi_phase.slice(start_time, stop_time), + self.lattice_site_coefficients, + self.local_detuning.slice(start_time, stop_time), ) - def visit_slice(self, ast: waveform.Slice) -> Tuple[List[Decimal], List[Decimal]]: - duration = ast.waveform.duration(**self.assignments) - if ast.interval.start is None: - start_time = Decimal(0) - else: - start_time = ast.interval.start(**self.assignments) - - if ast.interval.stop is None: - stop_time = duration - else: - stop_time = ast.interval.stop(**self.assignments) - - if start_time < 0: - raise ValueError((f"start time for slice {start_time} is smaller than 0.")) - - if stop_time > duration: - raise ValueError( - (f"end time for slice {stop_time} is larger than duration {duration}.") - ) - - if stop_time < start_time: - raise ValueError( - ( - f"end time for slice {stop_time} is smaller than " - f"starting value for slice {start_time}." - ) - ) - - if start_time == stop_time: - return [Decimal(0.0), Decimal(0.0)], [Decimal(0.0), Decimal(0.0)] - - times, values = self.visit(ast.waveform) - - start_index = bisect_left(times, start_time) - stop_index = bisect_left(times, stop_time) - start_value = ast.waveform.eval_decimal(start_time, **self.assignments) - stop_value = ast.waveform.eval_decimal(stop_time, **self.assignments) - - if start_index == 0: - if stop_time == duration: - absolute_times = times - values = values - else: - absolute_times = times[:stop_index] + [stop_time] - values = values[:stop_index] + [stop_value] - elif start_time == times[start_index]: - if stop_time == duration: - absolute_times = times[start_index:] - values = values[start_index:] - else: - absolute_times = times[start_index:stop_index] + [stop_time] - values = values[start_index:stop_index] + [stop_value] - else: - if stop_time == duration: - absolute_times = [start_time] + times[start_index:] - values = [start_value] + values[start_index:] - else: - absolute_times = ( - [start_time] + times[start_index:stop_index] + [stop_time] - ) - values = [start_value] + values[start_index:stop_index] + [stop_value] - - times = [time - start_time for time in absolute_times] - - return times, values - - def visit_append(self, ast: waveform.Append) -> Tuple[List[Decimal], List[Decimal]]: - times, values = self.visit(ast.waveforms[0]) - - for sub_expr in ast.waveforms[1:]: - new_times, new_values = self.visit(sub_expr) - - # skip instructions with duration=0 - if new_times[-1] == Decimal(0): - continue - if values[-1] != new_values[0]: - diff = abs(new_values[0] - values[-1]) - raise ValueError( - f"discontinuity with a jump of {diff} found when compiling to " - "piecewise linear." - ) - - shifted_times = [time + times[-1] for time in new_times[1:]] - times.extend(shifted_times) - values.extend(new_values[1:]) - - return times, values - - def visit_sample(self, ast: waveform.Sample) -> Tuple[List[Decimal], List[Decimal]]: - if ast.interpolation is not waveform.Interpolation.Linear: - raise ValueError( - "Failed to compile waveform to piecewise linear, " - f"found piecewise {ast.interpolation.value} interpolation." - ) - return ast.samples(**self.assignments) - - def visit_record(self, ast: Record) -> Tuple[List[Decimal], List[Decimal]]: - return self.visit(ast.waveform) - - -class PiecewiseConstantCodeGen(WaveformVisitor): - def __init__(self, assignments: Dict[str, Union[numbers.Real, List[numbers.Real]]]): - self.assignments = assignments - - def visit_negative( - self, ast: waveform.Negative - ) -> Tuple[List[Decimal], List[Decimal]]: - times, values = self.visit(ast.waveform) - return times, [-value for value in values] - - def visit_scale(self, ast: waveform.Scale) -> Tuple[List[Decimal], List[Decimal]]: - times, values = self.visit(ast.waveform) - scaler = ast.scalar(**self.assignments) - return times, [scaler * value for value in values] - - def visit_linear(self, ast: waveform.Linear) -> Tuple[List[Decimal], List[Decimal]]: - duration = ast.duration(**self.assignments) - start = ast.start(**self.assignments) - stop = ast.stop(**self.assignments) - - if start != stop: - raise ValueError( - "Failed to compile Waveform to piecewise constant, " - "found non-constant Linear piecce." - ) - - return [0, duration], [start, stop] - - def visit_constant( - self, ast: waveform.Constant - ) -> Tuple[List[Decimal], List[Decimal]]: - duration = ast.duration(**self.assignments) - value = ast.value(**self.assignments) - - return [Decimal(0), duration], [value, value] - - def visit_poly(self, ast: waveform.Poly) -> Tuple[List[Decimal], List[Decimal]]: - if len(ast.coeffs) == 1: - duration = ast.duration(**self.assignments) - value = ast.coeffs[0](**self.assignments) - return [Decimal(0), duration], [value, value] - - if len(ast.coeffs) == 2: - duration = ast.duration(**self.assignments) - start = ast.coeffs[0](**self.assignments) - stop = start + ast.coeffs[1](**self.assignments) * duration - - if start != stop: - raise ValueError( - "Failed to compile Waveform to piecewise constant, " - "found non-constant Polynomial piece." - ) - - return [Decimal(0), duration], [start, stop] - - order = len(ast.coeffs) - 1 - - raise ValueError( - "Failed to compile Waveform to piecewise constant," - f"found Polynomial of order {order}." + def append(self, other: "AHSCodegenResult") -> "AHSCodegenResult": + assert self.nshots == other.nshots + assert self.parallel_decoder == other.parallel_decoder + assert self.sites == other.sites + assert self.filling == other.filling + assert self.lattice_site_coefficients == other.lattice_site_coefficients + + return AHSCodegenResult( + self.nshots, + self.parallel_decoder, + self.sites, + self.filling, + self.global_detuning.append(other.global_detuning), + self.global_rabi_amplitude.append(other.global_rabi_amplitude), + self.global_rabi_phase.append(other.global_rabi_phase), + self.lattice_site_coefficients, + self.local_detuning.append(other.local_detuning), ) - def visit_slice(self, ast: waveform.Slice) -> Tuple[List[Decimal], List[Decimal]]: - duration = ast.waveform.duration(**self.assignments) - if ast.interval.start is None: - start_time = Decimal(0) - else: - start_time = ast.interval.start(**self.assignments) - - if ast.interval.stop is None: - stop_time = duration - else: - stop_time = ast.interval.stop(**self.assignments) - - if start_time < 0: - raise ValueError((f"start time for slice {start_time} is smaller than 0.")) + @staticmethod + def convert_position_units(position): + return tuple(coordinate * Decimal("1e-6") for coordinate in position) - if stop_time > duration: - raise ValueError( - (f"end time for slice {stop_time} is larger than duration {duration}.") - ) + @staticmethod + def convert_time_units(time): + return Decimal("1e-6") * time - if stop_time < start_time: - raise ValueError( - ( - f"end time for slice {stop_time} is smaller than " - f"starting value for slice {start_time}." - ) - ) + @staticmethod + def convert_energy_units(energy): + return Decimal("1e6") * energy + + @cached_property + def braket_task_ir(self) -> BraketTaskSpecification: + import braket.ir.ahs as ir + + return BraketTaskSpecification( + nshots=self.nshots, + program=ir.Program( + setup=ir.Setup( + ahs_register=ir.AtomArrangement( + sites=list(map(self.convert_position_units, self.sites)), + filling=self.filling, + ) + ), + hamiltonian=ir.Hamiltonian( + drivingFields=[ + ir.DrivingField( + amplitude=ir.PhysicalField( + time_series=ir.TimeSeries( + times=list( + map( + self.convert_time_units, + self.global_rabi_amplitude.times, + ) + ), + values=list( + map( + self.convert_energy_units, + self.global_rabi_amplitude.values, + ) + ), + ), + pattern="uniform", + ), + phase=ir.PhysicalField( + time_series=ir.TimeSeries( + times=list( + map( + self.convert_time_units, + self.global_rabi_phase.times, + ) + ), + values=self.global_rabi_phase.values, + ), + pattern="uniform", + ), + detuning=ir.PhysicalField( + time_series=ir.TimeSeries( + times=list( + map( + self.convert_time_units, + self.global_detuning.times, + ) + ), + values=list( + map( + self.convert_energy_units, + self.global_detuning.values, + ) + ), + ), + pattern="uniform", + ), + ) + ], + shiftingFields=( + [] + if self.lattice_site_coefficients is None + else [ + ir.ShiftingField( + amplitude=ir.PhysicalField( + time_series=ir.TimeSeries( + times=list( + map( + self.convert_time_units, + self.local_detuning.times, + ) + ), + values=list( + map( + self.convert_energy_units, + self.local_detuning.values, + ) + ), + ), + pattern=self.lattice_site_coefficients, + ) + ) + ] + ), + ), + ), + ) - if start_time == stop_time: - return [Decimal(0.0), Decimal(0.0)], [Decimal(0.0), Decimal(0.0)] - - times, values = self.visit(ast.waveform) - - start_index = bisect_left(times, start_time) - stop_index = bisect_left(times, stop_time) - - if start_index == 0: - if stop_time == duration: - absolute_times = times - values = values - else: - absolute_times = times[:stop_index] + [stop_time] - values = values[:stop_index] + [values[stop_index - 1]] - elif start_time == times[start_index]: - if stop_time == duration: - absolute_times = times[start_index:] - values = values[start_index:] - else: - absolute_times = times[start_index:stop_index] + [stop_time] - values = values[start_index:stop_index] + [values[stop_index - 1]] - else: - if stop_time == duration: - absolute_times = [start_time] + times[start_index:] - values = [values[start_index - 1]] + values[start_index:] - else: - absolute_times = ( - [start_time] + times[start_index:stop_index] + [stop_time] - ) - values = ( - [values[start_index - 1]] - + values[start_index:stop_index] - + [values[stop_index - 1]] + @cached_property + def quera_task_ir(self) -> QuEraTaskSpecification: + import bloqade.submission.ir.task_specification as task_spec + + return task_spec.QuEraTaskSpecification( + nshots=self.nshots, + lattice=task_spec.Lattice( + sites=list(map(self.convert_position_units, self.sites)), + filling=self.filling, + ), + effective_hamiltonian=task_spec.EffectiveHamiltonian( + rydberg=task_spec.RydbergHamiltonian( + rabi_frequency_amplitude=task_spec.RabiFrequencyAmplitude( + global_=task_spec.GlobalField( + times=list( + map( + self.convert_time_units, + self.global_rabi_amplitude.times, + ) + ), + values=list( + map( + self.convert_energy_units, + self.global_rabi_amplitude.values, + ) + ), + ) + ), + rabi_frequency_phase=task_spec.RabiFrequencyPhase( + global_=task_spec.GlobalField( + times=list( + map( + self.convert_time_units, + self.global_rabi_phase.times, + ) + ), + values=self.global_rabi_phase.values, + ) + ), + detuning=task_spec.Detuning( + global_=task_spec.GlobalField( + times=list( + map(self.convert_time_units, self.global_detuning.times) + ), + values=list( + map( + self.convert_energy_units, + self.global_detuning.values, + ) + ), + ), + local=( + None + if self.lattice_site_coefficients is None + else task_spec.LocalField( + times=list( + map( + self.convert_time_units, + self.local_detuning.times, + ) + ), + values=list( + map( + self.convert_energy_units, + self.local_detuning.values, + ) + ), + lattice_site_coefficients=self.lattice_site_coefficients, + ) + ), + ), ) - - times = [time - start_time for time in absolute_times] - - return times, values - - def visit_append(self, ast: waveform.Append) -> Tuple[List[Decimal], List[Decimal]]: - times, values = self.visit(ast.waveforms[0]) - - for sub_expr in ast.waveforms[1:]: - new_times, new_values = self.visit(sub_expr) - - # skip instructions with duration=0 - if new_times[-1] == Decimal(0): - continue - - shifted_times = [time + times[-1] for time in new_times[1:]] - times.extend(shifted_times) - values[-1] = new_values[0] - values.extend(new_values[1:]) - - return times, values - - def visit_sample(self, ast: waveform.Sample) -> Tuple[List[Decimal], List[Decimal]]: - if ast.interpolation is not waveform.Interpolation.Constant: - raise ValueError( - "Failed to compile waveform to piecewise constant, " - f"found piecewise {ast.interpolation.value} interpolation." - ) - times, values = ast.samples(**self.assignments) - - values[-1] = values[-2] - return times, values - - def visit_record(self, ast: Record) -> Tuple[List[Decimal], List[Decimal]]: - return self.visit(ast.waveform) + ), + ) -class QuEraCodeGen(AnalogCircuitVisitor): +class AHSCodegen(AnalogCircuitVisitor): def __init__( self, + shots: int, assignments: Dict[str, Union[numbers.Real, List[numbers.Real]]] = {}, capabilities: Optional[QuEraCapabilities] = None, ): + self.nshots = shots self.capabilities = capabilities self.assignments = assignments self.parallel_decoder = None - self.lattice = None - self.effective_hamiltonian = None - self.rydberg = None - self.field_name = None - self.rabi_frequency_amplitude = None - self.rabi_frequency_phase = None - self.detuning = None + self.sites = [] + self.filling = [] + self.global_detuning = None + self.local_detuning = None self.lattice_site_coefficients = None + self.global_rabi_amplitude = None + self.global_rabi_phase = None + + def extract_fields(self, ahs_result: AHSCodegenResult) -> None: + self.nshots = ahs_result.nshots + self.sites = ahs_result.sites + self.filling = ahs_result.filling + self.global_detuning = ahs_result.global_detuning + self.global_rabi_amplitude = ahs_result.global_rabi_amplitude + self.global_rabi_phase = ahs_result.global_rabi_phase + self.lattice_site_coefficients = ahs_result.lattice_site_coefficients + self.local_detuning = ahs_result.local_detuning + + def fix_up_missing_fields(self) -> None: + # fix-up any missing fields + duration = 0.0 - @staticmethod - def convert_time_to_SI_units(times: List[Decimal]): - return [time * Decimal("1e-6") for time in times] + if self.global_rabi_amplitude: + duration = max(duration, self.global_rabi_amplitude.times[-1]) - @staticmethod - def convert_energy_to_SI_units(values: List[Decimal]): - return [value * Decimal("1e6") for value in values] + if self.global_rabi_phase: + duration = max(duration, self.global_rabi_phase.times[-1]) - @staticmethod - def convert_position_to_SI_units(position: Tuple[Decimal]): - return tuple(coordinate * Decimal("1e-6") for coordinate in position) + if self.global_detuning: + duration = max(duration, self.global_detuning.times[-1]) + + if self.local_detuning: + duration = max(duration, self.local_detuning.times[-1]) + + if duration > 0: + if self.global_rabi_amplitude is None: + self.global_rabi_amplitude = PiecewiseLinear( + [Decimal(0), duration], [Decimal(0), Decimal(0)] + ) + + if self.global_rabi_phase is None: + self.global_rabi_phase = PiecewiseConstant( + [Decimal(0), duration], [Decimal(0), Decimal(0)] + ) + + if self.global_detuning is None: + self.global_detuning = PiecewiseLinear( + [Decimal(0), duration], [Decimal(0), Decimal(0)] + ) + + if self.local_detuning is None: + pass def post_visit_spatial_modulation(self, lattice_site_coefficients): self.lattice_site_coefficients = [] @@ -417,64 +371,40 @@ def visit_assigned_run_time_vector(self, ast: field.AssignedRunTimeVector) -> An lattice_site_coefficients = ast.value self.post_visit_spatial_modulation(lattice_site_coefficients) - def visit_detuning(self, ast: field.Field) -> Any: + def calculate_detuning(self, ast: field.Field) -> Any: if len(ast.drives) == 1 and field.Uniform in ast.drives: - times, values = PiecewiseLinearCodeGen(self.assignments).visit( + self.global_detuning = PiecewiseLinearCodeGen(self.assignments).emit( ast.drives[field.Uniform] ) - times = QuEraCodeGen.convert_time_to_SI_units(times) - values = QuEraCodeGen.convert_energy_to_SI_units(values) - - self.detuning = task_spec.Detuning( - global_=task_spec.GlobalField(times=times, values=values) - ) elif len(ast.drives) == 1: ((spatial_modulation, waveform),) = ast.drives.items() - times, values = PiecewiseLinearCodeGen(self.assignments).visit(waveform) + self.local_detuning = PiecewiseLinearCodeGen(self.assignments).emit( + waveform + ) - times = QuEraCodeGen.convert_time_to_SI_units(times) - values = QuEraCodeGen.convert_energy_to_SI_units(values) + self.global_detuning = PiecewiseLinear( + [Decimal(0), self.local_detuning.times[-1]], [Decimal(0), Decimal(0)] + ) self.visit(spatial_modulation) - self.detuning = task_spec.Detuning( - global_=task_spec.GlobalField(times=[0, times[-1]], values=[0.0, 0.0]), - local=task_spec.LocalField( - times=times, - values=values, - lattice_site_coefficients=self.lattice_site_coefficients, - ), - ) + elif len(ast.drives) == 2 and field.Uniform in ast.drives: # will only be two keys for k in ast.drives.keys(): if k == field.Uniform: - global_times, global_values = PiecewiseLinearCodeGen( + self.global_detuning = PiecewiseLinearCodeGen( self.assignments - ).visit(ast.drives[field.Uniform]) + ).emit(ast.drives[field.Uniform]) else: # can be field.RunTimeVector or field.ScaledLocations spatial_modulation = k - local_times, local_values = PiecewiseLinearCodeGen( - self.assignments - ).visit(ast.drives[k]) + self.local_detuning = PiecewiseLinearCodeGen(self.assignments).emit( + ast.drives[k] + ) self.visit(spatial_modulation) # just visit the non-uniform locations - global_times = QuEraCodeGen.convert_time_to_SI_units(global_times) - local_times = QuEraCodeGen.convert_time_to_SI_units(local_times) - - global_values = QuEraCodeGen.convert_energy_to_SI_units(global_values) - local_values = QuEraCodeGen.convert_energy_to_SI_units(local_values) - - self.detuning = task_spec.Detuning( - local=task_spec.LocalField( - times=local_times, - values=local_values, - lattice_site_coefficients=self.lattice_site_coefficients, - ), - global_=task_spec.GlobalField(times=global_times, values=global_values), - ) else: raise ValueError( "Failed to compile Detuning to QuEra task, " @@ -482,18 +412,12 @@ def visit_detuning(self, ast: field.Field) -> Any: f"{repr(ast)}." ) - def visit_rabi_amplitude(self, ast: field.Field) -> Any: + def calculate_rabi_amplitude(self, ast: field.Field) -> Any: if len(ast.drives) == 1 and field.Uniform in ast.drives: - times, values = PiecewiseLinearCodeGen(self.assignments).visit( + self.global_rabi_amplitude = PiecewiseLinearCodeGen(self.assignments).emit( ast.drives[field.Uniform] ) - times = QuEraCodeGen.convert_time_to_SI_units(times) - values = QuEraCodeGen.convert_energy_to_SI_units(values) - - self.rabi_frequency_amplitude = task_spec.RabiFrequencyAmplitude( - global_=task_spec.GlobalField(times=times, values=values) - ) else: raise ValueError( "Failed to compile Rabi Amplitude to QuEra task, " @@ -501,17 +425,12 @@ def visit_rabi_amplitude(self, ast: field.Field) -> Any: f"{repr(ast)}." ) - def visit_rabi_phase(self, ast: field.Field) -> Any: + def calculate_rabi_phase(self, ast: field.Field) -> Any: if len(ast.drives) == 1 and field.Uniform in ast.drives: # has to be global - times, values = PiecewiseConstantCodeGen(self.assignments).visit( + self.global_rabi_phase = PiecewiseConstantCodeGen(self.assignments).emit( ast.drives[field.Uniform] ) - times = QuEraCodeGen.convert_time_to_SI_units(times) - - self.rabi_frequency_phase = task_spec.RabiFrequencyAmplitude( - global_=task_spec.GlobalField(times=times, values=values) - ) else: raise ValueError( "Failed to compile Rabi Phase to QuEra task, " @@ -521,108 +440,76 @@ def visit_rabi_phase(self, ast: field.Field) -> Any: def visit_field(self, ast: field.Field): if self.field_name == pulse.detuning: - self.visit_detuning(ast) + self.calculate_detuning(ast) elif self.field_name == pulse.rabi.amplitude: - self.visit_rabi_amplitude(ast) + self.calculate_rabi_amplitude(ast) elif self.field_name == pulse.rabi.phase: - self.visit_rabi_phase(ast) + self.calculate_rabi_phase(ast) def visit_pulse(self, ast: pulse.Pulse): for field_name in ast.fields.keys(): self.field_name = field_name self.visit(ast.fields[field_name]) - # fix-up any missing fields - duration = 0.0 - - if self.rabi_frequency_amplitude is not None: - duration = max(duration, self.rabi_frequency_amplitude.global_.times[-1]) - - if self.rabi_frequency_phase is not None: - duration = max(duration, self.rabi_frequency_phase.global_.times[-1]) - - if self.detuning is not None: - duration = max(duration, self.detuning.global_.times[-1]) - - if duration == 0.0: - raise ValueError("No Fields found in pulse.") - - if self.rabi_frequency_amplitude is None: - self.rabi_frequency_amplitude = task_spec.RabiFrequencyAmplitude( - global_=task_spec.GlobalField(times=[0, duration], values=[0.0, 0.0]) - ) + def visit_named_pulse(self, ast: pulse.NamedPulse): + self.visit(ast.pulse) - if self.rabi_frequency_phase is None: - self.rabi_frequency_phase = task_spec.RabiFrequencyPhase( - global_=task_spec.GlobalField(times=[0, duration], values=[0.0, 0.0]) - ) + def visit_append_pulse(self, ast: pulse.Append): + subexpr_compiler = AHSCodegen(self.nshots, self.assignments) + ahs_result = subexpr_compiler.emit(ast.pulses) - if self.detuning is None: - self.detuning = task_spec.Detuning( - global_=task_spec.GlobalField(times=[0, duration], values=[0.0, 0.0]) - ) + for seq in ast.sequences: + new_ahs_result = subexpr_compiler.emit(seq) + ahs_result = ahs_result.append(new_ahs_result) - self.rydberg = task_spec.RydbergHamiltonian( - rabi_frequency_amplitude=self.rabi_frequency_amplitude, - rabi_frequency_phase=self.rabi_frequency_phase, - detuning=self.detuning, - ) + self.extract_fields(ahs_result) - def visit_named_pulse(self, ast: pulse.NamedPulse): - self.visit(ast.pulse) + def visit_slice_pulse(self, ast: pulse.Slice): + start_time = ast.interval.start(**self.assignments) + stop_time = ast.interval.stop(**self.assignments) - def visit_append_pulse(self, ast: pulse.Append): - raise NotImplementedError( - "Failed to compile Append to QuEra task, " - "found non-atomic pulse expression: " - f"{repr(ast)}." - ) + ahs_result = AHSCodegen(self.nshots, self.assignments).emit(ast.pulse) + ahs_result = ahs_result.slice(start_time, stop_time) - def visit_slice_pulse(self, ast: pulse.Append): - raise NotImplementedError( - "Failed to compile Append to QuEra task, " - "found non-atomic pulse expression: " - f"{repr(ast)}." - ) + self.extract_fields(ahs_result) def visit_sequence(self, ast: sequence.Sequence): if sequence.HyperfineLevelCoupling() in ast.pulses: raise ValueError("QuEra tasks does not support Hyperfine coupling.") self.visit(ast.pulses.get(sequence.RydbergLevelCoupling(), pulse.Pulse({}))) - self.effective_hamiltonian = task_spec.EffectiveHamiltonian( - rydberg=self.rydberg - ) def visit_named_sequence(self, ast: sequence.NamedSequence): self.visit(ast.sequence) def visit_append_sequence(self, ast: sequence.Append): - raise NotImplementedError( - "Failed to compile Append to QuEra task, " - "found non-atomic sequence expression: " - f"{repr(ast)}." - ) + subexpr_compiler = AHSCodegen(self.nshots, self.assignments) + ahs_result = subexpr_compiler.emit(ast.sequences[0]) + + for sub_sequence in ast.sequences[1:]: + new_ahs_result = subexpr_compiler.emit(sub_sequence) + ahs_result = ahs_result.append(new_ahs_result) + + self.extract_fields(ahs_result) def visit_slice_sequence(self, ast: sequence.Slice): - raise NotImplementedError( - "Failed to compile Slice to QuEra task, " - "found non-atomic sequence expression: " - f"{repr(ast)}." - ) + start_time = ast.interval.start(**self.assignments) + stop_time = ast.interval.stop(**self.assignments) + + ahs_result = AHSCodegen(self.nshots, self.assignments).emit(ast.sequence) + ahs_result = ahs_result.slice(start_time, stop_time) + self.extract_fields(ahs_result) def visit_register(self, ast: AtomArrangement): - sites = [] - filling = [] + self.sites = [] + self.filling = [] for location_info in ast.enumerate(): site = tuple(ele(**self.assignments) for ele in location_info.position) - sites.append(QuEraCodeGen.convert_position_to_SI_units(site)) - filling.append(location_info.filling.value) + self.sites.append(site) + self.filling.append(location_info.filling.value) - self.n_atoms = len(sites) - - self.lattice = task_spec.Lattice(sites=sites, filling=filling) + self.n_atoms = len(self.sites) def visit_parallel_register(self, ast: ParallelRegister) -> Any: if self.capabilities is None: @@ -697,7 +584,7 @@ def visit_parallel_register(self, ast: ParallelRegister) -> Any: for cluster_location_index, (location, filled) in enumerate( zip(new_register_locations[:], register_filling) ): - site = QuEraCodeGen.convert_position_to_SI_units(tuple(location)) + site = tuple(location) sites.append(site) filling.append(filled) @@ -719,15 +606,28 @@ def visit_analog_circuit(self, ast: AnalogCircuit) -> Any: self.visit(ast.register) self.visit(ast.sequence) - def emit( - self, nshots: int, analog_circuit: AnalogCircuit - ) -> Tuple[task_spec.QuEraTaskSpecification, Optional[ParallelDecoder]]: - self.visit(analog_circuit) + def emit(self, ast) -> AHSCodegenResult: + self.visit(ast) + self.fix_up_missing_fields() - task_ir = task_spec.QuEraTaskSpecification( - nshots=nshots, - lattice=self.lattice, - effective_hamiltonian=self.effective_hamiltonian, + if all( # TODO: move this into analysis portion. + [ + self.global_detuning is None, + self.global_rabi_amplitude is None, + self.global_rabi_phase is None, + self.local_detuning is None, + ] + ): + raise ValueError("No fields were specified.") + + return AHSCodegenResult( + nshots=self.nshots, + parallel_decoder=self.parallel_decoder, + sites=self.sites, + filling=self.filling, + global_detuning=self.global_detuning, + global_rabi_amplitude=self.global_rabi_amplitude, + global_rabi_phase=self.global_rabi_phase, + lattice_site_coefficients=self.lattice_site_coefficients, + local_detuning=self.local_detuning, ) - - return task_ir, self.parallel_decoder diff --git a/src/bloqade/ir/analysis/scan_variables.py b/src/bloqade/ir/analysis/scan_variables.py index 7a290b514..53ed204e5 100644 --- a/src/bloqade/ir/analysis/scan_variables.py +++ b/src/bloqade/ir/analysis/scan_variables.py @@ -46,8 +46,11 @@ def visit_div(self, ast: scalar.Div) -> Any: self.visit(ast.rhs) def visit_interval(self, ast: scalar.Interval) -> Any: - self.visit(ast.start) - self.visit(ast.stop) + if ast.start is not None: + self.visit(ast.start) + + if ast.stop is not None: + self.visit(ast.stop) def visit_slice(self, ast: scalar.Slice) -> Any: self.visit(ast.expr) diff --git a/src/bloqade/ir/routine/braket.py b/src/bloqade/ir/routine/braket.py index 7aaa2c81e..906c01da6 100644 --- a/src/bloqade/ir/routine/braket.py +++ b/src/bloqade/ir/routine/braket.py @@ -36,7 +36,7 @@ def _compile( ## fall passes here ### from bloqade.codegen.common.assign_variables import AssignAnalogCircuit from bloqade.ir.analysis.assignment_scan import AssignmentScan - from bloqade.codegen.hardware.quera import QuEraCodeGen + from bloqade.codegen.hardware.quera import AHSCodegen capabilities = self.backend.get_capabilities() @@ -50,17 +50,15 @@ def _compile( final_circuit = AssignAnalogCircuit(record_params).visit(circuit) # TODO: Replace these two steps with: # task_ir, parallel_decoder = BraketCodeGen().emit(shots, final_circuit) - task_ir, parallel_decoder = QuEraCodeGen(capabilities=capabilities).emit( - shots, final_circuit - ) + result = AHSCodegen(shots, capabilities=capabilities).emit(final_circuit) metadata = {**params.static_params, **record_params} - task_ir = task_ir.discretize(capabilities) + task_ir = result.quera_task_ir.discretize(capabilities) tasks[task_number] = BraketTask( None, self.backend, task_ir, metadata, - parallel_decoder, + result.parallel_decoder, None, ) @@ -171,9 +169,8 @@ def _compile( ## fall passes here ### from bloqade.ir import ParallelRegister from bloqade.codegen.common.assign_variables import AssignAnalogCircuit - from bloqade.codegen.hardware.quera import QuEraCodeGen + from bloqade.codegen.hardware.quera import AHSCodegen from bloqade.ir.analysis.assignment_scan import AssignmentScan - from bloqade.submission.ir.braket import to_braket_task_ir circuit, params = self.circuit, self.params circuit = AssignAnalogCircuit(params.static_params).visit(circuit) @@ -191,12 +188,11 @@ def _compile( final_circuit = AssignAnalogCircuit(record_params).visit(circuit) # TODO: Replace these two steps with: # task_ir, _ = BraketCodeGen().emit(shots, final_circuit) - quera_task_ir, _ = QuEraCodeGen().emit(shots, final_circuit) + result = AHSCodegen(shots).emit(final_circuit) - task_ir = to_braket_task_ir(quera_task_ir) metadata = {**params.static_params, **record_params} tasks[task_number] = BraketEmulatorTask( - task_ir, + result.braket_task_ir, metadata, None, ) diff --git a/src/bloqade/ir/routine/quera.py b/src/bloqade/ir/routine/quera.py index 802f825ee..25b457215 100644 --- a/src/bloqade/ir/routine/quera.py +++ b/src/bloqade/ir/routine/quera.py @@ -55,7 +55,7 @@ def _compile( from bloqade.codegen.common.assign_variables import AssignAnalogCircuit from bloqade.ir.analysis.assignment_scan import AssignmentScan - from bloqade.codegen.hardware.quera import QuEraCodeGen + from bloqade.codegen.hardware.quera import AHSCodegen circuit, params = self.circuit, self.params capabilities = self.backend.get_capabilities() @@ -66,14 +66,11 @@ def _compile( for task_number, batch_params in enumerate(params.batch_assignments(*args)): record_params = AssignmentScan(batch_params).emit(circuit) final_circuit = AssignAnalogCircuit(record_params).visit(circuit) - task_ir, parallel_decoder = QuEraCodeGen(capabilities=capabilities).emit( - shots, final_circuit - ) - - task_ir = task_ir.discretize(capabilities) + result = AHSCodegen(shots, capabilities=capabilities).emit(final_circuit) + task_ir = result.quera_task_ir.discretize(capabilities) metadata = {**params.static_params, **record_params} tasks[task_number] = QuEraTask( - None, self.backend, task_ir, metadata, parallel_decoder + None, self.backend, task_ir, metadata, result.parallel_decoder ) batch = RemoteBatch(source=self.source, tasks=tasks, name=name) diff --git a/tests/test_codegen_quera.py b/tests/test_codegen_quera.py index 2bd224e78..0e652d4af 100644 --- a/tests/test_codegen_quera.py +++ b/tests/test_codegen_quera.py @@ -111,40 +111,40 @@ def test_plin_codegen_slice(): scanner = quer.PiecewiseLinearCodeGen(asgn) wf = wv[:0.5] - times, values = scanner.visit_slice(wf) - assert (cf(times), cf(values)) == ([0, 0.5], [0, 0.5]) + pwl = scanner.emit(wf) + assert (cf(pwl.times), cf(pwl.values)) == ([0, 0.5], [0, 0.5]) wf2 = wv[0:1] - times, values = scanner.visit_slice(wf2) - assert (cf(times), cf(values)) == ([0, 1], [0, 1]) + pwl = scanner.emit(wf2) + assert (cf(pwl.times), cf(pwl.values)) == ([0, 1], [0, 1]) wf3 = wv[0.7:1] - times, values = scanner.visit_slice(wf3) - assert (cf(times), cf(values)) == ([0, 0.3], [0.7, 1]) + pwl = scanner.emit(wf3) + assert (cf(pwl.times), cf(pwl.values)) == ([0, 0.3], [0.7, 1]) wf4 = wv[0:0.6] - times, values = scanner.visit_slice(wf4) - assert (cf(times), cf(values)) == ([0, 0.6], [0.0, 0.6]) + pwl = scanner.emit(wf4) + assert (cf(pwl.times), cf(pwl.values)) == ([0, 0.6], [0.0, 0.6]) wf5 = wv[0.2:0.6] - times, values = scanner.visit_slice(wf5) - assert (cf(times), cf(values)) == ([0, 0.4], [0.2, 0.6]) + pwl = scanner.emit(wf5) + assert (cf(pwl.times), cf(pwl.values)) == ([0, 0.4], [0.2, 0.6]) wf6 = wv[0.6:1.5] - times, values = scanner.visit_slice(wf6) - assert (cf(times), cf(values)) == ([0, 0.4, 0.9], [0.6, 1.0, 1.25]) + pwl = scanner.emit(wf6) + assert (cf(pwl.times), cf(pwl.values)) == ([0, 0.4, 0.9], [0.6, 1.0, 1.25]) wf7 = wv[1.0:1.5] - times, values = scanner.visit_slice(wf7) - assert (cf(times), cf(values)) == ([0, 0.5], [1.0, 1.25]) + pwl = scanner.emit(wf7) + assert (cf(pwl.times), cf(pwl.values)) == ([0, 0.5], [1.0, 1.25]) wf8 = wv[1.0:3.0] - times, values = scanner.visit_slice(wf8) - assert (cf(times), cf(values)) == ([0, 2.0], [1.0, 2.0]) + pwl = scanner.emit(wf8) + assert (cf(pwl.times), cf(pwl.values)) == ([0, 2.0], [1.0, 2.0]) wf9 = wv[1.5:3.0] - times, values = scanner.visit_slice(wf9) - assert (cf(times), cf(values)) == ([0, 1.5], [1.25, 2.0]) + pwl = scanner.emit(wf9) + assert (cf(pwl.times), cf(pwl.values)) == ([0, 1.5], [1.25, 2.0]) def test_pconst_codegen_slice(): @@ -155,49 +155,49 @@ def test_pconst_codegen_slice(): scanner = quer.PiecewiseConstantCodeGen(asgn) wf = wv[:1.3] - times, values = scanner.visit_slice(wf) - assert (cf(times), cf(values)) == ([0, 1.0, 1.3], [1.0, 2.0, 2.0]) + pwc = scanner.emit(wf) + assert (cf(pwc.times), cf(pwc.values)) == ([0, 1.0, 1.3], [1.0, 2.0, 2.0]) wf2 = wv[0:1] - times, values = scanner.visit_slice(wf2) - assert (cf(times), cf(values)) == ([0, 1.0], [1.0, 1.0]) + pwc = scanner.emit(wf2) + assert (cf(pwc.times), cf(pwc.values)) == ([0, 1.0], [1.0, 1.0]) wf3 = wv[0.7:1.2] - times, values = scanner.visit_slice(wf3) - assert (cf(times), cf(values)) == ([0, 0.3, 0.5], [1.0, 2.0, 2.0]) + pwc = scanner.emit(wf3) + assert (cf(pwc.times), cf(pwc.values)) == ([0, 0.3, 0.5], [1.0, 2.0, 2.0]) wf4 = wv[0.7:1] - times, values = scanner.visit_slice(wf4) - assert (cf(times), cf(values)) == ([0, 0.3], [1.0, 1]) + pwc = scanner.emit(wf4) + assert (cf(pwc.times), cf(pwc.values)) == ([0, 0.3], [1.0, 1]) wf4 = wv[0:0.6] - times, values = scanner.visit_slice(wf4) - assert (cf(times), cf(values)) == ([0, 0.6], [1.0, 1.0]) + pwc = scanner.emit(wf4) + assert (cf(pwc.times), cf(pwc.values)) == ([0, 0.6], [1.0, 1.0]) wf5 = wv[0.2:0.6] - times, values = scanner.visit_slice(wf5) - assert (cf(times), cf(values)) == ([0, 0.4], [1.0, 1.0]) + pwc = scanner.emit(wf5) + assert (cf(pwc.times), cf(pwc.values)) == ([0, 0.4], [1.0, 1.0]) wf6 = wv[0:0] - times, values = scanner.visit_slice(wf6) - assert (cf(times), cf(values)) == ([0, 0], [0, 0]) + pwc = scanner.emit(wf6) + assert (cf(pwc.times), cf(pwc.values)) == ([0, 0], [0, 0]) wf8 = wv[1:1] - times, values = scanner.visit_slice(wf8) - assert (cf(times), cf(values)) == ([0, 0], [0, 0]) + pwc = scanner.emit(wf8) + assert (cf(pwc.times), cf(pwc.values)) == ([0, 0], [0, 0]) wf7 = wv[1.3:1.3] - times, values = scanner.visit_slice(wf7) - assert (cf(times), cf(values)) == ([0, 0], [0, 0]) + pwc = scanner.emit(wf7) + assert (cf(pwc.times), cf(pwc.values)) == ([0, 0], [0, 0]) wf9 = wv[1:2.5] - times, values = scanner.visit_slice(wf9) - assert (cf(times), cf(values)) == ([0, 1.5], [2.0, 2.0]) + pwc = scanner.emit(wf9) + assert (cf(pwc.times), cf(pwc.values)) == ([0, 1.5], [2.0, 2.0]) wf10 = wv[1.5:2.5] - times, values = scanner.visit_slice(wf10) - assert (cf(times), cf(values)) == ([0, 1.0], [2.0, 2.0]) + pwc = scanner.emit(wf10) + assert (cf(pwc.times), cf(pwc.values)) == ([0, 1.0], [2.0, 2.0]) wf11 = wv[1:1.5] - times, values = scanner.visit_slice(wf11) - assert (cf(times), cf(values)) == ([0, 0.5], [2.0, 2.0]) + pwc = scanner.emit(wf11) + assert (cf(pwc.times), cf(pwc.values)) == ([0, 0.5], [2.0, 2.0])