Skip to content

Commit

Permalink
Style
Browse files Browse the repository at this point in the history
  • Loading branch information
erick-xanadu committed Jul 31, 2024
1 parent ef47ae1 commit 47ed503
Show file tree
Hide file tree
Showing 7 changed files with 20 additions and 10 deletions.
8 changes: 6 additions & 2 deletions frontend/catalyst/api_extensions/control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1107,7 +1107,9 @@ def trace_quantum(self, ctx, device, trace, qrp) -> QRegPromise:
op = self
for region in op.regions:
with EvaluationContext.frame_tracing_context(ctx, region.trace):
qreg_in = _input_type_to_tracers(region.trace.new_arg, [AbstractQreg(qrp.base.length)])[0]
qreg_in = _input_type_to_tracers(
region.trace.new_arg, [AbstractQreg(qrp.base.length)]
)[0]
qreg_out = trace_quantum_operations(
region.quantum_tape, device, qreg_in, ctx, region.trace
).actualize()
Expand Down Expand Up @@ -1165,7 +1167,9 @@ def trace_quantum(self, ctx, device, trace, qrp) -> QRegPromise:
expansion_strategy = self.expansion_strategy

with EvaluationContext.frame_tracing_context(ctx, inner_trace):
qreg_in = _input_type_to_tracers(inner_trace.new_arg, [AbstractQreg(qrp.base.length)])[0]
qreg_in = _input_type_to_tracers(inner_trace.new_arg, [AbstractQreg(qrp.base.length)])[
0
]
qrp_out = trace_quantum_operations(inner_tape, device, qreg_in, ctx, inner_trace)
qreg_out = qrp_out.actualize()

Expand Down
1 change: 1 addition & 0 deletions frontend/catalyst/jax_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -1922,6 +1922,7 @@ def _set_state_lowering(jax_ctx: mlir.LoweringRuleContext, *qubits_or_params):
out_qubits = [qubit.type for qubit in qubits]
return SetStateOp(out_qubits, param, qubits).results


#
# set_basis_state
#
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Quantum/Transforms/ConversionPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -392,8 +392,8 @@ struct SetBasisStateOpPattern : public OpConversionPattern<SetBasisStateOp> {
auto indexVal = op.getIndex();
auto indexTy = indexVal.getType();
ModuleOp moduleOp = op->getParentOfType<ModuleOp>();
auto func = mlir::LLVM::lookupOrCreateFn(moduleOp, "__catalyst__qis__SetBasisState", {indexTy},
voidTy, isVarArg);
auto func = mlir::LLVM::lookupOrCreateFn(moduleOp, "__catalyst__qis__SetBasisState",
{indexTy}, voidTy, isVarArg);

Location loc = op.getLoc();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,8 @@ void LightningSimulator::PrintState()
cout << state[idx] << "]" << endl;
}

void LightningSimulator::SetState(DataView<std::complex<double>, 1> &data) {
void LightningSimulator::SetState(DataView<std::complex<double>, 1> &data)
{
const size_t num_qubits = this->device_sv->getNumQubits();
const size_t size = Pennylane::Util::exp2(num_qubits);
auto &&state = this->device_sv->getDataVector();
Expand All @@ -116,7 +117,8 @@ void LightningSimulator::SetState(DataView<std::complex<double>, 1> &data) {
}
}

void LightningSimulator::SetBasisState(const std::size_t index) {
void LightningSimulator::SetBasisState(const std::size_t index)
{
this->device_sv->setBasisState(index);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,8 @@ class StateVectorLQubitDynamic : public StateVectorLQubit<fp_t, StateVectorLQubi
*
* @param index Index of the target element.
*/
void setBasisState(const std::size_t index) {
void setBasisState(const std::size_t index)
{
std::fill(data_.begin(), data_.end(), 0.0);
data_[index] = {1.0, 0.0};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,8 @@ void LightningKokkosSimulator::PrintState()

void LightningKokkosSimulator::SetState(DataView<std::complex<double>, 1> &) {}

void LightningKokkosSimulator::SetBasisState(const std::size_t index) {
void LightningKokkosSimulator::SetBasisState(const std::size_t index)
{
this->device_sv->setBasisState(index);
}

Expand Down
5 changes: 3 additions & 2 deletions runtime/lib/capi/RuntimeCAPI.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -462,11 +462,12 @@ void __catalyst__qis__SetState(MemRefT_CplxT_double_1d *data)
// But what is not guaranteed is the strided.
MemRefT<std::complex<double>, 1> *data_p = (MemRefT<std::complex<double>, 1> *)data;
DataView<std::complex<double>, 1> data_vector(data_p->data_aligned, data_p->offset,
data_p->sizes, data_p->strides);
data_p->sizes, data_p->strides);
getQuantumDevicePtr()->SetState(data_vector);
}

void __catalyst__qis__SetBasisState(uint64_t index) {
void __catalyst__qis__SetBasisState(uint64_t index)
{
std::size_t index_cast = static_cast<std::size_t>(index);
getQuantumDevicePtr()->SetBasisState(index_cast);
}
Expand Down

0 comments on commit 47ed503

Please sign in to comment.