Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add operations for setting state and basis state and integrate this with skip_initial_state_prep #955

Merged
merged 110 commits into from
Aug 14, 2024
Merged
Show file tree
Hide file tree
Changes from 99 commits
Commits
Show all changes
110 commits
Select commit Hold shift + click to select a range
5f34a3a
Prevent StatePrep from being decomposed for Lightning
rauletorresc Jul 24, 2024
7c76701
Add skip_initial_state_prep_flag to DeviceCapabilities
erick-xanadu Jul 26, 2024
d3304ef
Use skip_initial_state_prep_flag during decomposition
erick-xanadu Jul 26, 2024
718c6be
Skip verification for StatePrep as special case
erick-xanadu Jul 26, 2024
462496a
Tracing
erick-xanadu Jul 29, 2024
b19dae8
Add SetStateOp
erick-xanadu Jul 29, 2024
916fa84
Adds lowering and type promotion
erick-xanadu Jul 29, 2024
e1ff4b1
Add Bufferization
erick-xanadu Jul 29, 2024
535f857
Add runtime interface for SetState
erick-xanadu Jul 30, 2024
87f496b
Initial lowering
erick-xanadu Jul 30, 2024
1c8df53
Formatting
erick-xanadu Jul 30, 2024
831a8e9
Fix the interface for the parameter type
erick-xanadu Jul 30, 2024
f88d48a
Propagate the length of the qreg
erick-xanadu Jul 30, 2024
ab2e2b3
Change to DataView reference
erick-xanadu Jul 30, 2024
245cced
Correct assignment for lightning.qubit
erick-xanadu Jul 30, 2024
277a450
Propagate length
erick-xanadu Jul 30, 2024
126ced0
Name
erick-xanadu Jul 30, 2024
efc9e84
Name
erick-xanadu Jul 30, 2024
aa08885
Fix rebase
erick-xanadu Jul 30, 2024
c858d30
Basis state frontend
erick-xanadu Jul 30, 2024
ecb1f1a
Add SetBasisState
erick-xanadu Jul 31, 2024
dcbc36d
Use new casting
erick-xanadu Jul 31, 2024
02cb5cb
Frontend changes
erick-xanadu Jul 31, 2024
cace061
Add lowering for SetBasisState
erick-xanadu Jul 31, 2024
9d461d4
Initial changes for SetBasisState in the runtime
erick-xanadu Jul 31, 2024
59dd647
Style
erick-xanadu Jul 31, 2024
f4fd1af
Add test for setBasisState
erick-xanadu Jul 31, 2024
300f15b
Add test for SetStateVector
erick-xanadu Jul 31, 2024
f9ae675
Add test for SetState
erick-xanadu Jul 31, 2024
f325fbd
Test for BasisState
erick-xanadu Jul 31, 2024
7ecc769
Add conversion test for SetStateOp
erick-xanadu Jul 31, 2024
215d15b
Add conversion test for SetBasisState
erick-xanadu Jul 31, 2024
6b294dd
Add test for code generation
erick-xanadu Jul 31, 2024
920d0e5
Add codegen test for BasisState
erick-xanadu Jul 31, 2024
a0d4610
Tested that only the first call actually uses this instruction
erick-xanadu Jul 31, 2024
b124cd4
Add comment
erick-xanadu Jul 31, 2024
4917aa0
Fix typo
erick-xanadu Jul 31, 2024
79ddb97
TODO
erick-xanadu Jul 31, 2024
3b4096f
Add test cases for errors
erick-xanadu Jul 31, 2024
413116d
Frontend tests
erick-xanadu Jul 31, 2024
a0f0c46
Style
erick-xanadu Jul 31, 2024
ff964a8
Outline
erick-xanadu Jul 31, 2024
d6cdac2
Global phase
erick-xanadu Jul 31, 2024
8f29dff
clang format
erick-xanadu Jul 31, 2024
d9bcfc2
Add dynamic wire pseudo-test (ask josh)
erick-xanadu Jul 31, 2024
6105cc2
Changelog
erick-xanadu Aug 1, 2024
0569978
I think this is better
erick-xanadu Aug 1, 2024
7118cb0
wip
erick-xanadu Aug 1, 2024
db67e9d
Style
erick-xanadu Aug 1, 2024
9b20114
Runtime basis prep index calculation
erick-xanadu Aug 1, 2024
4e74844
Finally
erick-xanadu Aug 1, 2024
64724be
Fix OQC
erick-xanadu Aug 2, 2024
05297de
F
erick-xanadu Aug 2, 2024
4443166
StatePrep no unknown wires
erick-xanadu Aug 2, 2024
0e27102
code factor
erick-xanadu Aug 2, 2024
003866f
Clang format
erick-xanadu Aug 2, 2024
b22878b
Comments
erick-xanadu Aug 2, 2024
d4e252a
More tests
erick-xanadu Aug 2, 2024
84875ec
Grad means error
erick-xanadu Aug 2, 2024
46496ee
Minor changes
erick-xanadu Aug 2, 2024
3f6a645
isort
erick-xanadu Aug 2, 2024
8abc397
style
erick-xanadu Aug 2, 2024
621b68b
Update dep version
erick-xanadu Aug 2, 2024
3ef3da7
small hack
erick-xanadu Aug 2, 2024
1e6f00f
small hack again
erick-xanadu Aug 2, 2024
4141e05
Skip kokkos since it is old device
erick-xanadu Aug 2, 2024
f6ae6c7
Update lightning git tag.
erick-xanadu Aug 5, 2024
7b68e0a
Remove codegen from frontend
erick-xanadu Aug 7, 2024
e460440
f
erick-xanadu Aug 7, 2024
501b88c
f
erick-xanadu Aug 7, 2024
badd2cc
Bufferization
erick-xanadu Aug 7, 2024
ff02176
Conversion
erick-xanadu Aug 7, 2024
5d7ce24
Casting
erick-xanadu Aug 7, 2024
0c261d1
Initial changes to the runtime
erick-xanadu Aug 7, 2024
4ffeafd
Runtime
erick-xanadu Aug 7, 2024
a2ba350
Default implementations
erick-xanadu Aug 7, 2024
b5ad9ee
Use setBasisState and setStateVector from lightning upstream
erick-xanadu Aug 9, 2024
f42d85a
Fix tests
erick-xanadu Aug 9, 2024
670ffad
Style
erick-xanadu Aug 9, 2024
520f653
Comment
erick-xanadu Aug 9, 2024
175452f
error checking
erick-xanadu Aug 9, 2024
2ed335e
Error messages
erick-xanadu Aug 9, 2024
83ca9d6
gg
erick-xanadu Aug 9, 2024
e9a7c28
Yes, no branches, this is only temporary
erick-xanadu Aug 12, 2024
e7305c6
Apply suggestions from code review
erick-xanadu Aug 12, 2024
9472a0f
Fix type
erick-xanadu Aug 12, 2024
01a4bcf
clang-format-16 -> clang-format-14
erick-xanadu Aug 12, 2024
e3c2b71
Use virtual functions
erick-xanadu Aug 12, 2024
d5878a9
Coverage
erick-xanadu Aug 12, 2024
708fa52
Fix plxpr
erick-xanadu Aug 12, 2024
4caa55f
update lightning
erick-xanadu Aug 13, 2024
15c4874
Merge branch 'main' into raultorres/stateprep_lightning
erick-xanadu Aug 13, 2024
5e1c847
Update .github/workflows/scripts/linux_arm64/rh8/build_catalyst.sh
erick-xanadu Aug 13, 2024
b05569c
Apply suggestions from code review
erick-xanadu Aug 13, 2024
0ddb868
Different branch
erick-xanadu Aug 13, 2024
deeaf15
Verify
erick-xanadu Aug 13, 2024
c76b8fd
Update doc/changelog.md
erick-xanadu Aug 13, 2024
0451092
Update doc/changelog.md
erick-xanadu Aug 13, 2024
3e7625f
comment
erick-xanadu Aug 13, 2024
0923e2b
Update doc/changelog.md
erick-xanadu Aug 13, 2024
041990b
Merge branch 'main' into raultorres/stateprep_lightning
erick-xanadu Aug 13, 2024
3586386
RTD
erick-xanadu Aug 13, 2024
ee9c008
Documentation
erick-xanadu Aug 13, 2024
8baa687
test
erick-xanadu Aug 13, 2024
5a80cae
Update demos/adaptive_circuits_demo.ipynb
erick-xanadu Aug 13, 2024
8de8d9b
Apply suggestions from code review
erick-xanadu Aug 13, 2024
9fce036
Merge branch 'main' into raultorres/stateprep_lightning
erick-xanadu Aug 13, 2024
803822f
fix
erick-xanadu Aug 13, 2024
9c76b5f
Install correct kokkos
erick-xanadu Aug 13, 2024
507302c
Disable kokkos, use correct kokkos version on wheels
erick-xanadu Aug 14, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .dep-versions
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@ pennylane=0.38.0.dev11
# 'runtime/Makefile' and at all GitHub workflows, using the exact
# commit hash corresponding to the merged PR that implements the
# desired feature.
lightning=0.38.0-dev26
lightning=0.38.0-dev32
2 changes: 1 addition & 1 deletion .github/workflows/build-wheel-linux-x86_64.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ jobs:
-DPYTHON_EXECUTABLE=$(which python${{ matrix.python_version }}) \
-Dpybind11_DIR=$(python${{ matrix.python_version }} -c "import pybind11; print(pybind11.get_cmake_dir())") \
-DENABLE_LAPACK=OFF \
-DLIGHTNING_GIT_TAG=8f517d24c71c6c7765f5c1bf29b0264b951de96a \
-DLIGHTNING_GIT_TAG=c6b86a5 \
-DENABLE_WARNINGS=OFF \
-DENABLE_OPENQASM=ON \
-DENABLE_OPENMP=OFF \
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/build-wheel-macos-arm64.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ jobs:
-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=$GITHUB_WORKSPACE/runtime-build/lib \
-DPYTHON_EXECUTABLE=$(which python${{ matrix.python_version }}) \
-Dpybind11_DIR=$(python${{ matrix.python_version }} -c "import pybind11; print(pybind11.get_cmake_dir())") \
-DLIGHTNING_GIT_TAG=8f517d24c71c6c7765f5c1bf29b0264b951de96a \
-DLIGHTNING_GIT_TAG=c6b86a5 \
-DENABLE_LAPACK=OFF \
-DENABLE_WARNINGS=OFF \
-DENABLE_OPENQASM=ON \
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/build-wheel-macos-x86_64.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ jobs:
-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=$GITHUB_WORKSPACE/runtime-build/lib \
-DPYTHON_EXECUTABLE=$(which python${{ matrix.python_version }}) \
-Dpybind11_DIR=$(python${{ matrix.python_version }} -c "import pybind11; print(pybind11.get_cmake_dir())") \
-DLIGHTNING_GIT_TAG=8f517d24c71c6c7765f5c1bf29b0264b951de96a \
-DLIGHTNING_GIT_TAG=c6b86a5 \
-DENABLE_LAPACK=OFF \
-DENABLE_WARNINGS=OFF \
-DENABLE_OPENQASM=ON \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ cmake -S runtime -B runtime-build -G Ninja \
-DPYTHON_INCLUDE_DIR=/opt/_internal/cpython-${PYTHON_VERSION}.${PYTHON_SUBVERSION}/include/python${PYTHON_VERSION} \
-DPYTHON_LIBRARY=/opt/_internal/cpython-${PYTHON_VERSION}.${PYTHON_SUBVERSION}/lib \
-Dpybind11_DIR=/opt/_internal/cpython-${PYTHON_VERSION}.${PYTHON_SUBVERSION}/lib/python${PYTHON_VERSION}/site-packages/pybind11/share/cmake/pybind11 \
-DLIGHTNING_GIT_TAG=8f517d24c71c6c7765f5c1bf29b0264b951de96a \
-DLIGHTNING_GIT_TAG=c6b86a5 \
-DENABLE_LAPACK=OFF \
-DENABLE_WARNINGS=OFF \
-DENABLE_OPENQASM=ON \
Expand Down
5 changes: 5 additions & 0 deletions demos/adaptive_circuits_demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@
"\n",
"@qml.qnode(dev, diff_method=\"adjoint\")\n",
"def cost_func(params):\n",
" qml.RX(0.0, wires=[0])\n",
" qml.BasisState(hf, wires=range(qubits))\n",
erick-xanadu marked this conversation as resolved.
Show resolved Hide resolved
" qml.DoubleExcitation(params[0], wires=[0, 1, 2, 3])\n",
" qml.DoubleExcitation(params[1], wires=[0, 1, 4, 5])\n",
Expand Down Expand Up @@ -239,6 +240,7 @@
"\n",
"@qml.qnode(qml.device(\"lightning.qubit\", wires=qubits))\n",
"def catalyst_cost_func(params):\n",
" qml.RX(0.0, wires=[0])\n",
erick-xanadu marked this conversation as resolved.
Show resolved Hide resolved
" qml.BasisState(hf, wires=range(qubits))\n",
" qml.DoubleExcitation(params[0], wires=[0, 1, 2, 3])\n",
" qml.DoubleExcitation(params[1], wires=[0, 1, 4, 5])\n",
Expand Down Expand Up @@ -594,6 +596,7 @@
"# Create a circuit that applies a selected group of gates\n",
"# to a reference Hartree-Fock state.\n",
"def circuit_1(params, excitations):\n",
" qml.RX(0.0, wires=[0])\n",
" qml.BasisState(hf, wires=range(qubits))\n",
"\n",
" for i, excitation in enumerate(excitations):\n",
Expand Down Expand Up @@ -723,6 +726,7 @@
"# those that have a non-negligible gradient.\n",
"# Repeat steps 1 and 2 for the single excitations.\n",
"def circuit_2(params, excitations, gates_select, params_select):\n",
" qml.RX(0.0, wires=[0])\n",
" qml.BasisState(hf, wires=range(qubits))\n",
"\n",
" for i, gate in enumerate(gates_select):\n",
Expand Down Expand Up @@ -828,6 +832,7 @@
" selected_double_gates, selected_single_gates,\n",
" gates_verifier, selected_verifier,\n",
" apply_selected):\n",
" qml.RX(0.0, wires=[0])\n",
erick-xanadu marked this conversation as resolved.
Show resolved Hide resolved
" qml.BasisState(hf, wires=range(qubits))\n",
"\n",
" # Used in steps 4-5\n",
Expand Down
11 changes: 11 additions & 0 deletions doc/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,10 @@
Catalyst variant jaxpr.
[(#837)](https://github.com/PennyLaneAI/catalyst/pull/837)

* On devices that support it, initial state preparation routines `qml.StatePrep` and `qml.BasisState`
are no longer decomposed when using Catalyst, improving compilation & runtime performance.
[(#955)](https://github.com/PennyLaneAI/catalyst/pull/955)

<h3>Breaking changes</h3>

* Return values of qjit-compiled functions that were previously `numpy.ndarray` are now of type
Expand All @@ -265,6 +269,13 @@
`extrapolate_kwargs` keyword argument in `mitigate_with_zne`.
[(#806)](https://github.com/PennyLaneAI/catalyst/pull/806)

* The QuantumDevice API has now added the functions `SetState` and `SetBasisState`
* The QuantumDevice API has now added the functions `SetState` and `SetBasisState`
erick-xanadu marked this conversation as resolved.
Show resolved Hide resolved
for simulators that may benefit from instructions that directly set the state.
erick-xanadu marked this conversation as resolved.
Show resolved Hide resolved
Implementing these methods is optional, and device support can be indicated via
the `initial_state_prep` flag in the TOML configuration file.
[(#955)](https://github.com/PennyLaneAI/catalyst/pull/955)

<h3>Bug fixes</h3>

* Catalyst no longer generates a `QubitUnitary` operation during decomposition if a device doesn't
Expand Down
12 changes: 8 additions & 4 deletions frontend/catalyst/api_extensions/control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1107,7 +1107,9 @@ def trace_quantum(self, ctx, device, trace, qrp) -> QRegPromise:
op = self
for region in op.regions:
with EvaluationContext.frame_tracing_context(ctx, region.trace):
qreg_in = _input_type_to_tracers(region.trace.new_arg, [AbstractQreg()])[0]
reg_len = qrp.base.length
new_qreg = AbstractQreg(reg_len)
qreg_in = _input_type_to_tracers(region.trace.new_arg, [new_qreg])[0]
qreg_out = trace_quantum_operations(
region.quantum_tape, device, qreg_in, ctx, region.trace
).actualize()
Expand Down Expand Up @@ -1165,7 +1167,9 @@ def trace_quantum(self, ctx, device, trace, qrp) -> QRegPromise:
expansion_strategy = self.expansion_strategy

with EvaluationContext.frame_tracing_context(ctx, inner_trace):
qreg_in = _input_type_to_tracers(inner_trace.new_arg, [AbstractQreg()])[0]
reg_len = qrp.base.length
new_qreg = AbstractQreg(reg_len)
qreg_in = _input_type_to_tracers(inner_trace.new_arg, [new_qreg])[0]
qrp_out = trace_quantum_operations(inner_tape, device, qreg_in, ctx, inner_trace)
qreg_out = qrp_out.actualize()

Expand Down Expand Up @@ -1245,7 +1249,7 @@ def trace_quantum(self, ctx, device, trace, qrp) -> QRegPromise:
res_classical_tracers,
expansion_strategy=expansion_strategy,
)
_input_type_to_tracers(cond_trace.new_arg, [AbstractQreg()])
_input_type_to_tracers(cond_trace.new_arg, [AbstractQreg(qrp.base.length)])
cond_jaxpr, _, cond_consts = trace_to_jaxpr(
cond_trace, arg_expanded_classical_tracers, res_expanded_classical_tracers
)
Expand All @@ -1256,7 +1260,7 @@ def trace_quantum(self, ctx, device, trace, qrp) -> QRegPromise:
with EvaluationContext.frame_tracing_context(ctx, body_trace):
region = self.regions[1]
res_classical_tracers = region.res_classical_tracers
qreg_in = _input_type_to_tracers(body_trace.new_arg, [AbstractQreg()])[0]
qreg_in = _input_type_to_tracers(body_trace.new_arg, [AbstractQreg(qrp.base.length)])[0]
qrp_out = trace_quantum_operations(body_tape, device, qreg_in, ctx, body_trace)
qreg_out = qrp_out.actualize()
arg_expanded_tracers = expand_args(
Expand Down
2 changes: 1 addition & 1 deletion frontend/catalyst/api_extensions/quantum_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,7 @@ def trace_quantum(self, ctx, device, _trace, qrp) -> QRegPromise:
frame_ctx = EvaluationContext.frame_tracing_context(ctx, body_trace)

with frame_ctx as body_trace:
qreg_in = _input_type_to_tracers(body_trace.new_arg, [AbstractQreg()])[0]
qreg_in = _input_type_to_tracers(body_trace.new_arg, [AbstractQreg(qrp.base.length)])[0]
qrp_out = trace_quantum_operations(body_tape, device, qreg_in, ctx, body_trace)
qreg_out = qrp_out.actualize()
body_jaxpr, _, body_consts = ctx.frames[body_trace].to_jaxpr2(
Expand Down
2 changes: 1 addition & 1 deletion frontend/catalyst/device/decomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def catalyst_decompose(
(toplevel_tape,), _ = decompose(
tape,
stopping_condition,
skip_initial_state_prep=False,
skip_initial_state_prep=capabilities.initial_state_prep_flag,
decomposer=partial(catalyst_decomposer, capabilities=capabilities),
max_expansion=max_expansion,
name="catalyst on this device",
Expand Down
10 changes: 8 additions & 2 deletions frontend/catalyst/device/verification.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
VnEntropyMP,
)
from pennylane.measurements.shots import Shots
from pennylane.operation import Operation, Tensor
from pennylane.operation import Operation, StatePrepBase, Tensor
from pennylane.ops import (
Adjoint,
CompositeOp,
Expand Down Expand Up @@ -209,7 +209,13 @@ def _op_checker(op, state):
# is handled in _inv_op_checker and _ctrl_op_checker.
# Specialed control op classes (e.g. CRZ) should be checked directly though, which is why we
# can't use isinstance(op, Controlled).
if type(op) in (Controlled, ControlledOp) or isinstance(op, Adjoint):
if type(op) in (Controlled, ControlledOp) or isinstance(op, (Adjoint)):
pass
# Don't check StatePrep since StatePrep is not in the list of device capabilities.
# It is only valid when the TOML file has the initial_state_prep_flag.
elif (
isinstance(op, StatePrepBase) and qjit_device.qjit_capabilities.initial_state_prep_flag
):
pass
elif not qjit_device.qjit_capabilities.native_ops.get(op.name):
raise CompileError(
Expand Down
2 changes: 1 addition & 1 deletion frontend/catalyst/from_plxpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def setup(self):
resets the wire map.
"""
qdevice_p.bind(**_get_device_kwargs(self._device))
self.qreg = qalloc_p.bind(len(self._device.wires))
self.qreg = qalloc_p.bind(len(self._device.wires), static_size=len(self._device.wires))
self.wire_map = {}

def cleanup(self):
Expand Down
73 changes: 68 additions & 5 deletions frontend/catalyst/jax_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@
ProbsOp,
QubitUnitaryOp,
SampleOp,
SetBasisStateOp,
SetStateOp,
StateOp,
TensorOp,
VarianceOp,
Expand Down Expand Up @@ -143,6 +145,9 @@ class AbstractQreg(AbstractValue):

hash_value = hash("AbstractQreg")

def __init__(self, length):
self.length = length

def __eq__(self, other):
return isinstance(other, AbstractQreg)

Expand Down Expand Up @@ -286,6 +291,10 @@ class Folding(Enum):
apply_registered_pass_p = core.Primitive("apply_registered_pass")
transform_named_sequence_p = core.Primitive("transform_named_sequence")
transform_named_sequence_p.multiple_results = True
set_state_p = jax.core.Primitive("state_prep")
set_state_p.multiple_results = True
set_basis_state_p = jax.core.Primitive("set_basis_state")
set_basis_state_p.multiple_results = True


def _assert_jaxpr_without_constants(jaxpr: ClosedJaxpr):
Expand Down Expand Up @@ -968,17 +977,17 @@ def _qdevice_lowering(jax_ctx: mlir.LoweringRuleContext, rtd_lib, rtd_name, rtd_
# qalloc
#
@qalloc_p.def_impl
def _qalloc_def_impl(ctx, size_value): # pragma: no cover
def _qalloc_def_impl(ctx, size_value, static_size=None): # pragma: no cover
raise NotImplementedError()


@qalloc_p.def_abstract_eval
def _qalloc_abstract_eval(size):
def _qalloc_abstract_eval(size, static_size=None):
"""This function is called with abstract arguments for tracing."""
return AbstractQreg()
return AbstractQreg(static_size)


def _qalloc_lowering(jax_ctx: mlir.LoweringRuleContext, size_value: ir.Value):
def _qalloc_lowering(jax_ctx: mlir.LoweringRuleContext, size_value: ir.Value, static_size=None):
ctx = jax_ctx.module_context.context
ctx.allow_unregistered_dialects = True

Expand Down Expand Up @@ -1068,7 +1077,7 @@ def _qinsert_abstract_eval(qreg_old, qubit_idx, qubit):
"""This function is called with abstract arguments for tracing."""
assert isinstance(qreg_old, AbstractQreg)
assert isinstance(qubit, AbstractQbit)
return AbstractQreg()
return AbstractQreg(qreg_old.length)


def _qinsert_lowering(
Expand Down Expand Up @@ -2053,6 +2062,58 @@ def _assert_lowering(jax_ctx: mlir.LoweringRuleContext, assertion, error):
return ()


#
# state_prep
#
@set_state_p.def_impl
def set_state_impl(ctx, *qubits_or_params): # pragma: no cover
"""Concrete evaluation"""
raise NotImplementedError()


@set_state_p.def_abstract_eval
def set_state_abstract(*qubits_or_params):
"""Abstract evaluation"""
length = len(qubits_or_params)
qubits_length = length - 1
return (AbstractQbit(),) * qubits_length


def _set_state_lowering(jax_ctx: mlir.LoweringRuleContext, *qubits_or_params):
"""Lowering of set state"""
qubits_or_params = list(qubits_or_params)
param = qubits_or_params.pop()
qubits = qubits_or_params
out_qubits = [qubit.type for qubit in qubits]
return SetStateOp(out_qubits, param, qubits).results


#
# set_basis_state
#
@set_basis_state_p.def_impl
def set_basis_state_impl(ctx, *qubits_or_params): # pragma: no cover
"""Concrete evaluation"""
raise NotImplementedError()


@set_basis_state_p.def_abstract_eval
def set_basis_state_abstract(*qubits_or_params):
"""Abstract evaluation"""
length = len(qubits_or_params)
qubits_length = length - 1
return (AbstractQbit(),) * qubits_length


def _set_basis_state_lowering(jax_ctx: mlir.LoweringRuleContext, *qubits_or_params):
"""Lowering of set basis state"""
qubits_or_params = list(qubits_or_params)
param = qubits_or_params.pop()
qubits = qubits_or_params
out_qubits = [qubit.type for qubit in qubits]
return SetBasisStateOp(out_qubits, param, qubits).results


#
# adjoint
#
Expand Down Expand Up @@ -2151,6 +2212,8 @@ def _adjoint_lowering(
mlir.register_lowering(value_and_grad_p, _value_and_grad_lowering)
mlir.register_lowering(apply_registered_pass_p, _apply_registered_pass_lowering)
mlir.register_lowering(transform_named_sequence_p, _transform_named_sequence_lowering)
mlir.register_lowering(set_state_p, _set_state_lowering)
mlir.register_lowering(set_basis_state_p, _set_basis_state_lowering)


def _scalar_abstractify(t):
Expand Down
Loading
Loading