Skip to content

Commit

Permalink
implementation of state-vector interface.
Browse files Browse the repository at this point in the history
  • Loading branch information
weinbe58 committed Oct 25, 2023
1 parent e89e439 commit b068a5a
Show file tree
Hide file tree
Showing 3 changed files with 152 additions and 22 deletions.
29 changes: 22 additions & 7 deletions pdm.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ dev = [
"vcrpy==4.4.0",
"pyinstrument>=4.5.3",
"scikit-optimize>=0.9.0",
"icecream>=2.1.3",
]

[tool.pdm.scripts]
Expand Down
144 changes: 129 additions & 15 deletions src/bloqade/ir/routine/bloqade.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
from collections import OrderedDict
from decimal import Decimal

from bloqade.ir.routine.base import RoutineBase, __pydantic_dataclass_config__
from bloqade.builder.typing import LiteralType
from bloqade.task.batch import LocalBatch
from beartype import beartype
from beartype.typing import Optional, Tuple
from beartype.typing import Optional, Tuple, Callable, Dict, Any, VarArg
from pydantic.dataclasses import dataclass
import numpy as np

from bloqade.emulate.codegen.hamiltonian import CompileCache, RydbergHamiltonianCodeGen
from bloqade.emulate.ir.state_vector import AnalogGate


@dataclass(frozen=True, config=__pydantic_dataclass_config__)
Expand All @@ -15,6 +21,48 @@ def python(self):

@dataclass(frozen=True, config=__pydantic_dataclass_config__)
class BloqadePythonRoutine(RoutineBase):
@staticmethod
def process_tasks(tasks, results, runner):
while not tasks.empty():
task_id, (emulator_ir, metadata) = tasks.get()
result = runner.run_task(emulator_ir, metadata)
results.put((task_id, result))

@dataclass(config=__pydantic_dataclass_config__)
class EmuRunner:
compile_cache: CompileCache
solver_args: Dict
callback: Callable
callback_args: Tuple

def run_task(self, emulator_ir, metadata):
hamiltonian = RydbergHamiltonianCodeGen(
compile_cache=self.compile_cache
).emit(emulator_ir)

zero_state = hamiltonian.space.zero_state(np.complex128)
result = AnalogGate(hamiltonian).apply(zero_state, **self.solver_args)

return self.callback(result, metadata, *self.callback_args)

def _generate_ir(self, args, blockade_radius):
from bloqade.ir.analysis.assignment_scan import AssignmentScan
from bloqade.codegen.common.assign_variables import AssignAnalogCircuit
from bloqade.codegen.emulator_ir import EmulatorProgramCodeGen

circuit, params = self.circuit, self.params

circuit = AssignAnalogCircuit(params.static_params).visit(circuit)

for task_number, batch_param in enumerate(params.batch_assignments(*args)):
record_params = AssignmentScan(batch_param).emit(circuit)
final_circuit = AssignAnalogCircuit(record_params).visit(circuit)
metadata = {**params.static_params, **record_params}
emulator_ir = EmulatorProgramCodeGen(blockade_radius=blockade_radius).emit(
final_circuit
)
yield task_number, emulator_ir, metadata

def _compile(
self,
shots: int,
Expand All @@ -23,29 +71,17 @@ def _compile(
blockade_radius: LiteralType = 0.0,
cache_matrices: bool = False,
) -> LocalBatch:
from bloqade.ir.analysis.assignment_scan import AssignmentScan
from bloqade.codegen.common.assign_variables import AssignAnalogCircuit
from bloqade.codegen.emulator_ir import EmulatorProgramCodeGen
from bloqade.emulate.codegen.hamiltonian import CompileCache
from bloqade.task.bloqade import BloqadeTask

circuit, params = self.circuit, self.params

circuit = AssignAnalogCircuit(params.static_params).visit(circuit)

if cache_matrices:
matrix_cache = CompileCache()
else:
matrix_cache = None

tasks = OrderedDict()
for task_number, batch_param in enumerate(params.batch_assignments(*args)):
record_params = AssignmentScan(batch_param).emit(circuit)
final_circuit = AssignAnalogCircuit(record_params).visit(circuit)
emulator_ir = EmulatorProgramCodeGen(blockade_radius=blockade_radius).emit(
final_circuit
)
metadata = {**params.static_params, **record_params}
it_iter = self._generate_ir(args, blockade_radius)
for task_number, metadata, emulator_ir in it_iter:
tasks[task_number] = BloqadeTask(shots, emulator_ir, metadata, matrix_cache)

return LocalBatch(self.source, tasks, name)
Expand Down Expand Up @@ -157,3 +193,81 @@ def __call__(
interaction_picture=interaction_picture,
)
return self.run(**options)

@beartype
def run_callback(
self,
callback: Callable[[np.ndarray, Dict[str, Decimal], VarArg], Any],
program_args: Tuple[LiteralType, ...] = (),
callback_args: Tuple = (),
blockade_radius: float = 0.0,
interaction_picture: bool = False,
cache_matrices: bool = False,
multiprocessing: bool = False,
num_workers: Optional[int] = None,
solver_name: str = "dop853",
atol: float = 1e-14,
rtol: float = 1e-7,
nsteps: int = 2_147_483_647,
):
from bloqade.emulate.codegen.hamiltonian import CompileCache
from multiprocessing import Process, Queue, cpu_count

if cache_matrices:
compile_cache = CompileCache()
else:
compile_cache = None

solver_args = dict(
solver_name=solver_name,
atol=atol,
rtol=rtol,
nsteps=nsteps,
interaction_picture=interaction_picture,
)

runner = self.EmuRunner(
compile_cache=compile_cache,
solver_args=solver_args,
callback=callback,
callback_args=callback_args,
)

tasks = Queue()
results = Queue()

it_iter = self._generate_ir(program_args, blockade_radius)
for task_number, metadata, emulator_ir in it_iter:
tasks.put((task_number, (emulator_ir, metadata)))

if multiprocessing:
num_workers = int(num_workers or cpu_count())

workers = [
Process(
target=BloqadePythonRoutine.process_tasks,
args=(tasks, results, runner),
)
for i in range(num_workers)
]

for worker in workers:
worker.start()

for worker in workers:
worker.join()

else:
while not tasks.empty():
task_id, (emulator_ir, metadata) = tasks.get()
result = runner.run_task(emulator_ir, metadata)
results.put((task_id, result))

id_results = []
while not results.empty():
id_results.append(results.get())

id_results.sort(key=lambda x: x[0])
results = [result for _, result in id_results]

return results

0 comments on commit b068a5a

Please sign in to comment.