Skip to content

Commit

Permalink
[MLIR] Remove scheduling at FunctionOpInterfaces (#1407)
Browse files Browse the repository at this point in the history
**Context:** The `apply_pass` function from the frontend is capable of
scheduling any MLIR pass from the frontend for a specific qnode.
However, in order for the pass to succeed in its compilation, it must
allow itself to be scheduled to transform a module.

**Description of the Change:** There is another branch that I worked on
where I moved the bulk of the transformation to Patterns; however, the
pattern applicators that are upstream in MLIR do not support running a
single iteration of the worklist. I would like at some point to add a
custom pattern applicator that doesn't fold and only does a single pass
through the worklist (or take a variable). The passes as written
represent a single iteration through the worklist. Running a single
iteration through the worklist may not yield all full optimizations, but
that's how it is currently coded. I think this is acceptable as we don't
yet do a cost benefit analysis for any transformation.

**Benefits:** Can write the following:

```python
import jax
import pennylane as qml

import catalyst


@qml.qjit(keep_intermediate=True)
@catalyst.passes.apply_pass("disentangle-CNOT")
@qml.qnode(qml.device("lightning.qubit", wires=2))
def foo():
    qml.Hadamard(0)
    qml.Hadamard(0)
    qml.Hadamard(1)
    qml.Hadamard(1)
    qml.CNOT(wires=[0, 1])
    return qml.state()


foo()
print(foo.mlir)
```
  • Loading branch information
erick-xanadu authored Jan 9, 2025
1 parent 00a9f8a commit d08502e
Show file tree
Hide file tree
Showing 5 changed files with 175 additions and 165 deletions.
170 changes: 88 additions & 82 deletions mlir/lib/Quantum/Transforms/DisentangleCNOT.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,108 +30,114 @@
using namespace mlir;
using namespace catalyst;

namespace catalyst {
#define GEN_PASS_DEF_DISENTANGLECNOTPASS
#define GEN_PASS_DECL_DISENTANGLECNOTPASS
#include "Quantum/Transforms/Passes.h.inc"
namespace {
void disentangleCNOTs(FunctionOpInterface &func, bool verbose)
{
mlir::IRRewriter builder(func->getContext());
Location loc = func->getLoc();

struct DisentangleCNOTPass : public impl::DisentangleCNOTPassBase<DisentangleCNOTPass> {
using impl::DisentangleCNOTPassBase<DisentangleCNOTPass>::DisentangleCNOTPassBase;
PropagateSimpleStatesAnalysis pssa(func);
llvm::DenseMap<Value, QubitState> qubitValues = pssa.getQubitValues();

bool canScheduleOn(RegisteredOperationName opInfo) const override
{
return opInfo.hasInterface<FunctionOpInterface>();
if (verbose) {
for (auto it = qubitValues.begin(); it != qubitValues.end(); ++it) {
it->first.getDefiningOp()->emitRemark(pssa.QubitState2String(it->second));
}
}

void runOnOperation() override
{
FunctionOpInterface func = cast<FunctionOpInterface>(getOperation());
mlir::IRRewriter builder(func->getContext());
Location loc = func->getLoc();
func->walk([&](quantum::CustomOp op) {
StringRef gate = op.getGateName();
if (gate != "CNOT") {
return;
}

PropagateSimpleStatesAnalysis &pssa = getAnalysis<PropagateSimpleStatesAnalysis>();
llvm::DenseMap<Value, QubitState> qubitValues = pssa.getQubitValues();
Value controlIn = op->getOperand(0);
Value targetIn = op->getOperand(1);
Value controlOut = op->getResult(0);
Value targetOut = op->getResult(1);

if (EmitFSMStateRemark) {
for (auto it = qubitValues.begin(); it != qubitValues.end(); ++it) {
it->first.getDefiningOp()->emitRemark(pssa.QubitState2String(it->second));
}
// Do nothing if the inputs states are not tracked
if (!qubitValues.contains(controlIn) || !qubitValues.contains(targetIn)) {
return;
}

func->walk([&](quantum::CustomOp op) {
StringRef gate = op.getGateName();
if (gate != "CNOT") {
return;
}
// |0> control, always do nothing
if (pssa.isZero(qubitValues[controlIn])) {
builder.replaceAllUsesWith(controlOut, controlIn);
builder.replaceAllUsesWith(targetOut, targetIn);
builder.eraseOp(op);
return;
}

Value controlIn = op->getOperand(0);
Value targetIn = op->getOperand(1);
Value controlOut = op->getResult(0);
Value targetOut = op->getResult(1);
// |1> control, insert PauliX gate on target
if (pssa.isOne(qubitValues[controlIn])) {
builder.replaceAllUsesWith(controlOut, controlIn);

// Do nothing if the inputs states are not tracked
if (!qubitValues.contains(controlIn) || !qubitValues.contains(targetIn)) {
// PauliX on |+-> is unnecessary: they are eigenstates!
if ((pssa.isPlus(qubitValues[targetIn])) || (pssa.isMinus(qubitValues[targetIn]))) {
builder.replaceAllUsesWith(targetOut, targetIn);
builder.eraseOp(op);
return;
}

// |0> control, always do nothing
if (pssa.isZero(qubitValues[controlIn])) {
controlOut.replaceAllUsesWith(controlIn);
targetOut.replaceAllUsesWith(targetIn);
op->erase();
else {
builder.setInsertionPoint(op);
quantum::CustomOp xgate =
builder.create<quantum::CustomOp>(loc, /*gate_name=*/"PauliX",
/*in_qubits=*/mlir::ValueRange({targetIn}));
builder.replaceAllUsesWith(targetOut, xgate->getResult(0));
builder.eraseOp(op);
return;
}
}

// |1> control, insert PauliX gate on target
if (pssa.isOne(qubitValues[controlIn])) {
controlOut.replaceAllUsesWith(controlIn);

// PauliX on |+-> is unnecessary: they are eigenstates!
if ((pssa.isPlus(qubitValues[targetIn])) || (pssa.isMinus(qubitValues[targetIn]))) {
targetOut.replaceAllUsesWith(targetIn);
op->erase();
return;
}
else {
builder.setInsertionPoint(op);
quantum::CustomOp xgate = builder.create<quantum::CustomOp>(
loc, /*gate_name=*/"PauliX",
/*in_qubits=*/mlir::ValueRange({targetIn}));
targetOut.replaceAllUsesWith(xgate->getResult(0));
op->erase();
return;
}
}
// |+> target, always do nothing
if (pssa.isPlus(qubitValues[targetIn])) {
builder.replaceAllUsesWith(controlOut, controlIn);
builder.replaceAllUsesWith(targetOut, targetIn);
builder.eraseOp(op);
return;
}

// |-> target, insert PauliZ on control
if (pssa.isMinus(qubitValues[targetIn])) {
builder.replaceAllUsesWith(targetOut, targetIn);

// |+> target, always do nothing
if (pssa.isPlus(qubitValues[targetIn])) {
controlOut.replaceAllUsesWith(controlIn);
targetOut.replaceAllUsesWith(targetIn);
op->erase();
// PauliZ on |01> is unnecessary: they are eigenstates!
if ((pssa.isZero(qubitValues[controlIn])) || (pssa.isOne(qubitValues[controlIn]))) {
builder.replaceAllUsesWith(controlOut, controlIn);
builder.eraseOp(op);
return;
}
else {
builder.setInsertionPoint(op);
quantum::CustomOp zgate =
builder.create<quantum::CustomOp>(loc, /*gate_name=*/"PauliZ",
/*in_qubits=*/mlir::ValueRange({controlIn}));
builder.replaceAllUsesWith(controlOut, zgate->getResult(0));
builder.eraseOp(op);
return;
}
}
});
}
} // namespace

namespace catalyst {
#define GEN_PASS_DEF_DISENTANGLECNOTPASS
#define GEN_PASS_DECL_DISENTANGLECNOTPASS
#include "Quantum/Transforms/Passes.h.inc"

// |-> target, insert PauliZ on control
if (pssa.isMinus(qubitValues[targetIn])) {
targetOut.replaceAllUsesWith(targetIn);

// PauliZ on |01> is unnecessary: they are eigenstates!
if ((pssa.isZero(qubitValues[controlIn])) || (pssa.isOne(qubitValues[controlIn]))) {
controlOut.replaceAllUsesWith(controlIn);
op->erase();
return;
}
else {
builder.setInsertionPoint(op);
quantum::CustomOp zgate = builder.create<quantum::CustomOp>(
loc, /*gate_name=*/"PauliZ",
/*in_qubits=*/mlir::ValueRange({controlIn}));
controlOut.replaceAllUsesWith(zgate->getResult(0));
op->erase();
return;
}
struct DisentangleCNOTPass : public impl::DisentangleCNOTPassBase<DisentangleCNOTPass> {
using impl::DisentangleCNOTPassBase<DisentangleCNOTPass>::DisentangleCNOTPassBase;

void runOnOperation() override
{
auto op = getOperation();
for (Operation &nestedOp : op->getRegion(0).front().getOperations()) {
if (auto func = dyn_cast<FunctionOpInterface>(nestedOp)) {
disentangleCNOTs(func, EmitFSMStateRemark);
}
});
}
}
};

Expand Down
Loading

0 comments on commit d08502e

Please sign in to comment.