Skip to content

Commit

Permalink
Adjoint canonicalization for CustomOp and MultiRZOp (#1205)
Browse files Browse the repository at this point in the history
**Context:**
We lack canonicalization when the adjoint flag is on.

**Description of the Change:**
- Add canonicalization patterns for cutsomOp (rotations and hermitian)
- Add the patterns in the merge rotation pass.

**Benefits:**
Merge rotations works with adjoint.

---------

Co-authored-by: Haochen Wang <[email protected]>
Co-authored-by: paul0403 <[email protected]>
  • Loading branch information
3 people authored Oct 24, 2024
1 parent de7208d commit e5435fd
Show file tree
Hide file tree
Showing 8 changed files with 146 additions and 1 deletion.
5 changes: 5 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@
* A peephole merge rotations pass is now available in MLIR. It can be added to `catalyst.passes.pipeline`, or the
Python function `catalyst.passes.merge_rotations` can be directly called on a `QNode`.
[(#1162)](https://github.com/PennyLaneAI/catalyst/pull/1162)
[(#1205)](https://github.com/PennyLaneAI/catalyst/pull/1205)
[(#1206)](https://github.com/PennyLaneAI/catalyst/pull/1206)

Using the pipeline, one can run:
Expand Down Expand Up @@ -211,6 +212,10 @@

<h3>Improvements</h3>

* Adjoint canonicalization is now available in MLIR for `CustomOp` and `MultiRZOp`. It can be used
with the `--canonicalize` pass in `quantum-opt`.
[(#1205)](https://github.com/PennyLaneAI/catalyst/pull/1205)

* Implement a Catalyst runtime plugin that mocks out all functions in the QuantumDevice interface.
[(#1179)](https://github.com/PennyLaneAI/catalyst/pull/1179)

Expand Down
6 changes: 6 additions & 0 deletions frontend/test/pytest/test_peephole_optimizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ def workflow():
def f(x):
qml.RX(x, wires=0)
qml.RX(x, wires=0)
qml.RZ(x, wires=0)
qml.adjoint(qml.RZ)(x, wires=0)
qml.Rot(x, x, x, wires=0)
qml.Rot(x, x, x, wires=0)
qml.PhaseShift(x, wires=0)
Expand All @@ -85,6 +87,8 @@ def f(x):
def g(x):
qml.RX(x, wires=0)
qml.RX(x, wires=0)
qml.RZ(x, wires=0)
qml.adjoint(qml.RZ)(x, wires=0)
qml.Rot(x, x, x, wires=0)
qml.Rot(x, x, x, wires=0)
qml.PhaseShift(x, wires=0)
Expand All @@ -99,6 +103,8 @@ def g(x):
def reference(x):
qml.RX(x, wires=0)
qml.RX(x, wires=0)
qml.RZ(x, wires=0)
qml.adjoint(qml.RZ)(x, wires=0)
qml.Rot(x, x, x, wires=0)
qml.Rot(x, x, x, wires=0)
qml.PhaseShift(x, wires=0)
Expand Down
3 changes: 3 additions & 0 deletions mlir/include/Quantum/IR/QuantumOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,7 @@ def CustomOp : UnitaryGate_Op<"custom", [DifferentiableGate, NoMemoryEffect,
return getParams();
}
}];
let hasCanonicalizeMethod = 1;
}

def GlobalPhaseOp : UnitaryGate_Op<"gphase", [DifferentiableGate, AttrSizedOperandSegments]> {
Expand Down Expand Up @@ -454,6 +455,8 @@ def MultiRZOp : UnitaryGate_Op<"multirz", [DifferentiableGate, NoMemoryEffect,
return getODSOperands(getParamOperandIdx());
}
}];

let hasCanonicalizeMethod = 1;
}

def QubitUnitaryOp : UnitaryGate_Op<"unitary", [ParametrizedGate, NoMemoryEffect,
Expand Down
46 changes: 46 additions & 0 deletions mlir/lib/Quantum/IR/QuantumOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/OpImplementation.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/ADT/TypeSwitch.h"
#include <optional>

Expand All @@ -34,6 +36,50 @@ using namespace catalyst::quantum;
//===----------------------------------------------------------------------===//
// Quantum op canonicalizers.
//===----------------------------------------------------------------------===//
static const mlir::StringSet<> hermitianOps = {"Hadamard", "PauliX", "PauliY", "PauliZ", "CNOT",
"CY", "CZ", "SWAP", "Toffoli"};
static const mlir::StringSet<> rotationsOps = {"RX", "RY", "RZ", "PhaseShift",
"CRX", "CRY", "CRZ", "ControlledPhaseShift"};
LogicalResult CustomOp::canonicalize(CustomOp op, mlir::PatternRewriter &rewriter)
{
if (op.getAdjoint()) {
auto name = op.getGateName();
if (hermitianOps.contains(name)) {
op.setAdjoint(false);
return success();
}
else if (rotationsOps.contains(name)) {
auto params = op.getParams();
SmallVector<Value> paramsNeg;
for (auto param : params) {
auto paramNeg = rewriter.create<mlir::arith::NegFOp>(op.getLoc(), param);
paramsNeg.push_back(paramNeg);
}

rewriter.replaceOpWithNewOp<CustomOp>(
op, op.getOutQubits().getTypes(), op.getOutCtrlQubits().getTypes(), paramsNeg,
op.getInQubits(), name, nullptr, op.getInCtrlQubits(), op.getInCtrlValues());

return success();
}
return failure();
}
return failure();
}

LogicalResult MultiRZOp::canonicalize(MultiRZOp op, mlir::PatternRewriter &rewriter)
{
if (op.getAdjoint()) {
auto paramNeg = rewriter.create<mlir::arith::NegFOp>(op.getLoc(), op.getTheta());

rewriter.replaceOpWithNewOp<MultiRZOp>(
op, op.getOutQubits().getTypes(), op.getOutCtrlQubits().getTypes(), paramNeg,
op.getInQubits(), nullptr, op.getInCtrlQubits(), op.getInCtrlValues());

return success();
};
return failure();
}

LogicalResult DeallocOp::canonicalize(DeallocOp dealloc, mlir::PatternRewriter &rewriter)
{
Expand Down
1 change: 0 additions & 1 deletion mlir/lib/Quantum/Transforms/MergeRotationsPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ struct MergeRotationsRewritePattern : public mlir::OpRewritePattern<CustomOp> {
LLVM_DEBUG(dbgs() << "Simplifying the following operation:\n" << op << "\n");
auto loc = op.getLoc();
StringRef opGateName = op.getGateName();

if (!rotationsSet.contains(opGateName))
return failure();
ValueRange inQubits = op.getInQubits();
Expand Down
8 changes: 8 additions & 0 deletions mlir/lib/Quantum/Transforms/merge_rotation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,14 @@ struct MergeRotationsPass : impl::MergeRotationsPassBase<MergeRotationsPass> {
// Do nothing and exit!
return;
}
RewritePatternSet patternsCanonicalization(&getContext());
catalyst::quantum::CustomOp::getCanonicalizationPatterns(patternsCanonicalization,
&getContext());
catalyst::quantum::MultiRZOp::getCanonicalizationPatterns(patternsCanonicalization,
&getContext());
if (failed(applyPatternsAndFoldGreedily(targetfunc, std::move(patternsCanonicalization)))) {
return signalPassFailure();
}

RewritePatternSet patterns(&getContext());
populateMergeRotationsPatterns(patterns);
Expand Down
39 changes: 39 additions & 0 deletions mlir/test/Quantum/CanonicalizationTest.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -118,3 +118,42 @@ func.func @test_insert_canonicalize(%r1: !quantum.reg, %i: i64) -> !quantum.bit
quantum.dealloc %r2 : !quantum.reg
return %4 : !quantum.bit
}

// CHECK-LABEL: test_hermitian_adjoint_canonicalize
func.func @test_hermitian_adjoint_canonicalize() -> !quantum.bit {
%0 = quantum.alloc( 1) : !quantum.reg
%1 = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit
// CHECK: [[reg:%.+]] = quantum.alloc( 1) : !quantum.reg
// CHECK: [[qubit:%.+]] = quantum.extract [[reg]][ 0] : !quantum.reg -> !quantum.bit
%2 = quantum.custom "Hadamard"() %1 {adjoint}: !quantum.bit
// CHECK: quantum.custom "Hadamard"() [[qubit]] : !quantum.bit
return %2 : !quantum.bit
}

// CHECK-LABEL: test_rotation_adjoint_canonicalize
func.func @test_rotation_adjoint_canonicalize(%arg0: f64) -> !quantum.bit {
%0 = quantum.alloc( 1) : !quantum.reg
%1 = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit
// CHECK: [[reg:%.+]] = quantum.alloc( 1) : !quantum.reg
// CHECK: [[qubit:%.+]] = quantum.extract [[reg]][ 0] : !quantum.reg -> !quantum.bit
%2 = quantum.custom "RX"(%arg0) %1 {adjoint}: !quantum.bit
// CHECK: [[arg0neg:%.+]] = arith.negf %arg0 : f64
// CHECK: quantum.custom "RX"([[arg0neg]]) [[qubit]] : !quantum.bit
return %2 : !quantum.bit
}

// CHECK-LABEL: test_multirz_adjoint_canonicalize
func.func @test_multirz_adjoint_canonicalize(%arg0: f64) -> (!quantum.bit, !quantum.bit) {
// CHECK: [[reg:%.+]] = quantum.alloc( 2) : !quantum.reg
// CHECK: [[qubit1:%.+]] = quantum.extract [[reg]][ 0] : !quantum.reg -> !quantum.bit
// CHECK: [[qubit2:%.+]] = quantum.extract [[reg]][ 1] : !quantum.reg -> !quantum.bit
%0 = quantum.alloc( 2) : !quantum.reg
%1 = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit
%2 = quantum.extract %0[ 1] : !quantum.reg -> !quantum.bit

// CHECK: [[arg0neg:%.+]] = arith.negf %arg0 : f64
// CHECK: [[ret:%.+]]:2 = quantum.multirz([[arg0neg]]) [[qubit1]], [[qubit2]] : !quantum.bit, !quantum.bit
%3:2 = quantum.multirz (%arg0) %1, %2 {adjoint} : !quantum.bit, !quantum.bit
return %3#0, %3#1 : !quantum.bit, !quantum.bit
}

39 changes: 39 additions & 0 deletions mlir/test/Quantum/MergeRotationsTest.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -273,3 +273,42 @@ func.func @test_merge_rotations(%arg0: f64) -> (!quantum.bit, !quantum.bit, !qua
// CHECK: return [[ret]], [[ctrlret]]#0, [[ctrlret]]#1
return %out_qubits_1, %out_ctrl_qubits_1#0, %out_ctrl_qubits_1#1 : !quantum.bit, !quantum.bit, !quantum.bit
}

// -----


func.func @test_merge_rotations(%arg0: f64, %arg1: f64) -> !quantum.bit {
%0 = quantum.alloc( 1) : !quantum.reg
%1 = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit
// CHECK: [[reg:%.+]] = quantum.alloc( 1) : !quantum.reg
// CHECK: [[qubit:%.+]] = quantum.extract [[reg]][ 0] : !quantum.reg -> !quantum.bit
// CHECK: [[arg0neg:%.+]] = arith.negf %arg0 : f64
// CHECK: [[arg1neg:%.+]] = arith.negf %arg1 : f64
// CHECK: [[add:%.+]] = arith.addf [[arg0neg]], [[arg1neg]] : f64
// CHECK: [[ret:%.+]] = quantum.custom "RX"([[add]]) [[qubit]] : !quantum.bit
%2 = quantum.custom "RX"(%arg0) %1 {adjoint}: !quantum.bit
%3 = quantum.custom "RX"(%arg1) %2 {adjoint}: !quantum.bit

// CHECK: return [[ret]]
return %3 : !quantum.bit
}

// -----


func.func @test_merge_rotations(%arg0: f64, %arg1: f64, %arg2: f64) -> (!quantum.bit, !quantum.bit) {
// CHECK: [[reg:%.+]] = quantum.alloc( 2) : !quantum.reg
// CHECK: [[qubit1:%.+]] = quantum.extract [[reg]][ 0] : !quantum.reg -> !quantum.bit
// CHECK: [[qubit2:%.+]] = quantum.extract [[reg]][ 1] : !quantum.reg -> !quantum.bit
%0 = quantum.alloc( 2) : !quantum.reg
%1 = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit
%2 = quantum.extract %0[ 1] : !quantum.reg -> !quantum.bit
// CHECK: [[arg0neg:%.+]] = arith.negf %arg0 : f64
// CHECK: [[arg1neg:%.+]] = arith.negf %arg1 : f64
// CHECK: [[add:%.+]] = arith.addf [[arg0neg]], [[arg1neg]] : f64
// CHECK: [[ret:%.+]]:2 = quantum.multirz([[add]]) [[qubit1]], [[qubit2]] : !quantum.bit, !quantum.bit
%3:2 = quantum.multirz (%arg0) %1, %2 {adjoint}: !quantum.bit, !quantum.bit
%4:2 = quantum.multirz (%arg1) %3#0, %3#1 {adjoint}: !quantum.bit, !quantum.bit
// CHECK: return [[ret]]#0, [[ret]]#1
return %4#0, %4#1 : !quantum.bit, !quantum.bit
}

0 comments on commit e5435fd

Please sign in to comment.