Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added merge rotation patterns for qml.Rot and qml.CRot #1270

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
71 changes: 71 additions & 0 deletions frontend/test/pytest/test_peephole_optimizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,77 @@

# pylint: disable=missing-function-docstring

#
# Complex_merging_rotations
#

# Parameterize test with different angle sets for qml.Rot and qml.CRot to ensure coverage of complex cases.

Check notice on line 30 in frontend/test/pytest/test_peephole_optimizations.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/test/pytest/test_peephole_optimizations.py#L30

Line too long (107/100) (line-too-long)
@pytest.mark.parametrize("params1, params2", [
Copy link
Contributor

@paul0403 paul0403 Nov 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So testing this new merge rotation pattern is a bit tricky: we know that regardless of whether the merge rotation transformation took effect or not, the circuit will produce the same results. Given that, does this test here actually test for whether the rotation gates are merged? If not, what is the best way to test that the rotation gates are merged, and what is the purpose of these end-to-end circuit execution tests here?

Hint: search through the code base and look for how the existing merge rotation patterns are tested!

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @paul0403

I’ve added explicit checks in test_complex_merge_rotation to verify that the merge_rotations transformation reduces the number of rotation gates (Rot and CRot) and preserves the circuit's functionality. By explicitly calling qml.transforms.merge_rotations, we can compare the unoptimized and optimized circuits directly. This allows us to confirm both that the rotation gates are actually merged (fewer gates) and that the results remain the same, addressing the need for both functionality and transformation verification in the test. Please let me know if this method is correct.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding the test! An additional approach on top of plain functionality check is definitely good to have 💯

((0.5, 1.0, 1.5), (0.6, 0.8, 0.7)), # Arbitrary angles for general coverage
((np.pi / 2, np.pi / 4, np.pi / 6), (np.pi, 3 * np.pi / 4, np.pi / 3)) # Important angles with multiples of π

Check notice on line 33 in frontend/test/pytest/test_peephole_optimizations.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/test/pytest/test_peephole_optimizations.py#L33

Line too long (114/100) (line-too-long)
])
def test_complex_merge_rotation(params1, params2, backend):
"""Comprehensive test for complex merge rotations with qml.Rot and qml.CRot using full-angle formulas."""

Check notice on line 36 in frontend/test/pytest/test_peephole_optimizations.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/test/pytest/test_peephole_optimizations.py#L36

Line too long (109/100) (line-too-long)

# Test for qml.Rot
@qjit
def rot_workflow():
@qml.qnode(qml.device(backend, wires=1))
def f():
qml.Rot(params1[0], params1[1], params1[2], wires=0)
qml.Rot(params2[0], params2[1], params2[2], wires=0)
return qml.probs()

@merge_rotations
@qml.qnode(qml.device(backend, wires=1))
def g():
qml.Rot(params1[0], params1[1], params1[2], wires=0)
qml.Rot(params2[0], params2[1], params2[2], wires=0)
return qml.probs()

return f(), g()

# Reference function for qml.Rot without merging
@qml.qnode(qml.device("default.qubit", wires=1))
def rot_reference():
qml.Rot(params1[0], params1[1], params1[2], wires=0)
qml.Rot(params2[0], params2[1], params2[2], wires=0)
return qml.probs()

# Verify results for qml.Rot
rot_results = rot_workflow()
assert np.allclose(rot_results[0], rot_results[1]), "Merged result for qml.Rot differs from unmerged."

Check notice on line 65 in frontend/test/pytest/test_peephole_optimizations.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/test/pytest/test_peephole_optimizations.py#L65

Line too long (106/100) (line-too-long)
assert np.allclose(rot_results[1], rot_reference()), "Merged result for qml.Rot differs from reference."

Check notice on line 66 in frontend/test/pytest/test_peephole_optimizations.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/test/pytest/test_peephole_optimizations.py#L66

Line too long (108/100) (line-too-long)

# Test for qml.CRot
@qjit
def crot_workflow():
@qml.qnode(qml.device(backend, wires=2))
def f():
qml.CRot(params1[0], params1[1], params1[2], wires=[0, 1])
qml.CRot(params2[0], params2[1], params2[2], wires=[0, 1])
return qml.probs()

@merge_rotations
@qml.qnode(qml.device(backend, wires=2))
def g():
qml.CRot(params1[0], params1[1], params1[2], wires=[0, 1])
qml.CRot(params2[0], params2[1], params2[2], wires=[0, 1])
return qml.probs()

return f(), g()

# Reference function for qml.CRot without merging
@qml.qnode(qml.device("default.qubit", wires=2))
def crot_reference():
qml.CRot(params1[0], params1[1], params1[2], wires=[0, 1])
qml.CRot(params2[0], params2[1], params2[2], wires=[0, 1])
return qml.probs()

# Verify results for qml.CRot
crot_results = crot_workflow()
assert np.allclose(crot_results[0], crot_results[1]), "Merged result for qml.CRot differs from unmerged."

Check notice on line 95 in frontend/test/pytest/test_peephole_optimizations.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/test/pytest/test_peephole_optimizations.py#L95

Line too long (109/100) (line-too-long)
assert np.allclose(crot_results[1], crot_reference()), "Merged result for qml.CRot differs from reference."

Check notice on line 96 in frontend/test/pytest/test_peephole_optimizations.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/test/pytest/test_peephole_optimizations.py#L96

Line too long (111/100) (line-too-long)

#
# cancel_inverses
Expand Down
125 changes: 103 additions & 22 deletions mlir/lib/Quantum/Transforms/MergeRotationsPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,16 @@
#include "llvm/ADT/StringSet.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/Errc.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/Arith/IR/Arith.h""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there's an extra quotation mark here, which made all your code into comments below :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Also this file is already included in here no?)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, Done!


using llvm::dbgs;
using namespace mlir;
using namespace catalyst::quantum;

static const mlir::StringSet<> rotationsSet = {"RX", "RY", "RZ", "PhaseShift",
"CRX", "CRY", "CRZ", "ControlledPhaseShift"};
"CRX", "CRY", "CRZ", "ControlledPhaseShift",
"qml.Rot", "qml.CRot"};

namespace {

Expand All @@ -49,27 +52,105 @@ struct MergeRotationsRewritePattern : public mlir::OpRewritePattern<CustomOp> {
return failure();
}

TypeRange outQubitsTypes = op.getOutQubits().getTypes();
TypeRange outQubitsCtrlTypes = op.getOutCtrlQubits().getTypes();
ValueRange parentInQubits = parentOp.getInQubits();
ValueRange parentInCtrlQubits = parentOp.getInCtrlQubits();
ValueRange parentInCtrlValues = parentOp.getInCtrlValues();

auto parentParams = parentOp.getParams();
auto params = op.getParams();
SmallVector<mlir::Value> sumParams;
for (auto [param, parentParam] : llvm::zip(params, parentParams)) {
mlir::Value sumParam =
rewriter.create<arith::AddFOp>(loc, parentParam, param).getResult();
sumParams.push_back(sumParam);
};
auto mergeOp = rewriter.create<CustomOp>(loc, outQubitsTypes, outQubitsCtrlTypes, sumParams,
parentInQubits, opGateName, nullptr,
parentInCtrlQubits, parentInCtrlValues);

op.replaceAllUsesWith(mergeOp);

return success();
if (opGateName == "qml.Rot" || opGateName == "qml.CRot") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens if the rotation gates are adjointed? Should the merge still happen?

(In Catalyst adjointed gates are indicated by a adjoint unit attribute, see for example here

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @paul0403

Please let me know if you did not consider whether the rotation gates are adjointed for regular merging rotations (not for merging non-commutative rotations), or if we do not have adjointed gates for regular merging rotations?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The merge rotation pass applies an rotation gate adjoint canonicalization. The canonicalization simply changes all angles to their negative and removes the adjoint attribute. See #1205

However, looking at the canonicalization pattern, you will find that Rot and CRot are not canonicalized.
(a) Why do you think that is?
(b) Knowing this, what do you think you should do in your added pattern here (assuming no new canonicalization is added for (C)Rot)?

Copy link
Author

@Mohxen Mohxen Nov 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @paul0403
(a) Rot and CRot are not canonicalized because they involve complex, multi-parameter rotations that cannot be standardized by simply negating a single parameter.
(b) First, check for the adjoint attribute on qml.Rot and qml.CRot. If the operation is adjointed, transform it into its non-adjointed, canonical form by reversing the order of the parameters and negating each parameter. After this transformation, remove the adjoint attribute to standardize the operation. Then, proceed with the merging process as if all rotations are in canonical form, ensuring consistency across operations.
Thus, if I have an operation qml.Rot(π/4, π/2, π/3), its adjoint will be qml.Rot(-π/3, -π/2, -π/4)
Please let me know if my method is correct.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good insights! This is what I would do as well.

Due to how the work is organized (aka adjoint canonicalization happens somewhere else, not here in the merge rotation patterns), in the merge rotation patterns, it suffices to assume that the rotation gates will not carry adjoint attributes when the patterns are hit.

Thus the only thing needed here is a check that the (C)Rot gates do not carry adjoint attributes. If they do, the pattern should do nothing.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, added it.

LLVM_DEBUG(dbgs() << "Applying scalar formula for combined rotation operation:\n" << op << "\n");
auto params = op.getParams();
auto parentParams = parentOp.getParams();

// Assuming params[0] = alpha1, params[1] = theta1, params[2] = beta1
// and parentParams[0] = alpha2, parentParams[1] = theta2, parentParams[2] = beta2

// Step 1: Calculate c1, c2, s1, s2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding the new merge rotation pattern! The formula is very long, so we appreciate the good work 🥳

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you so much for your help :) If there's anything specific you'd like me to refine or expand on, please let me know.

auto c1 = rewriter.create<math::CosOp>(loc, params[1]);
auto s1 = rewriter.create<math::SinOp>(loc, params[1]);
auto c2 = rewriter.create<math::CosOp>(loc, parentParams[1]);
auto s2 = rewriter.create<math::SinOp>(loc, parentParams[1]);

// Step 2: Calculate cf
auto c1Squared = rewriter.create<arith::MulFOp>(loc, c1, c1);
auto c2Squared = rewriter.create<arith::MulFOp>(loc, c2, c2);
auto s1Squared = rewriter.create<arith::MulFOp>(loc, s1, s1);
auto s2Squared = rewriter.create<arith::MulFOp>(loc, s2, s2);
auto cosAlphaDiff = rewriter.create<math::CosOp>(loc, rewriter.create<arith::SubFOp>(loc, params[0], parentParams[0]));

auto term1 = rewriter.create<arith::MulFOp>(loc, c1Squared, c2Squared);
auto term2 = rewriter.create<arith::MulFOp>(loc, s1Squared, s2Squared);
auto product = rewriter.create<arith::MulFOp>(loc, c1, c2);
product = rewriter.create<arith::MulFOp>(loc, product, s1);
product = rewriter.create<arith::MulFOp>(loc, product, s2);
auto two = rewriter.create<arith::ConstantOp>(loc, rewriter.getF64FloatAttr(2.0));
auto term3 = rewriter.create<arith::MulFOp>(loc, two, rewriter.create<arith::MulFOp>(loc, product, cosAlphaDiff));

auto cfSquare = rewriter.create<arith::SubFOp>(loc, rewriter.create<arith::AddFOp>(loc, term1, term2), term3);
auto cf = rewriter.create<math::SqrtOp>(loc, cfSquare);

// Step 3: Calculate theta_f = 2 * arccos(|cf|)
auto absCf = rewriter.create<math::AbsFOp>(loc, cf);
auto acosCf = rewriter.create<math::AcosOp>(loc, absCf);
auto thetaF = rewriter.create<arith::MulFOp>(loc, two, acosCf);

// Step 4: Calculate alpha_f
auto alphaSum = rewriter.create<arith::AddFOp>(loc, params[0], parentParams[0]);
auto betaDiff = rewriter.create<arith::SubFOp>(loc, parentParams[2], params[2]);
auto sinAlphaSum = rewriter.create<math::SinOp>(loc, alphaSum);
auto cosBetaDiff = rewriter.create<math::CosOp>(loc, betaDiff);

auto term1_alpha = rewriter.create<arith::MulFOp>(loc, rewriter.create<arith::MulFOp>(loc, c1, s2), sinAlphaSum);
auto term2_alpha = rewriter.create<arith::MulFOp>(loc, rewriter.create<arith::MulFOp>(loc, s1, s2), cosBetaDiff);
auto numerator_alpha = rewriter.create<arith::SubFOp>(loc, rewriter.create<arith::NegFOp>(loc, term1_alpha), term2_alpha);

auto cosAlphaSum = rewriter.create<math::CosOp>(loc, alphaSum);
auto denominator_alpha = rewriter.create<arith::SubFOp>(loc, rewriter.create<arith::MulFOp>(loc, rewriter.create<arith::MulFOp>(loc, c1, c2), cosAlphaSum), term2_alpha);

auto alphaF = rewriter.create<arith::NegFOp>(loc, rewriter.create<math::AtanOp>(loc, rewriter.create<arith::DivFOp>(loc, numerator_alpha, denominator_alpha)));

// Step 5: Calculate beta_f
auto betaSum = rewriter.create<arith::AddFOp>(loc, params[2], parentParams[2]);
auto alphaDiffReversed = rewriter.create<arith::SubFOp>(loc, parentParams[0], params[0]);
auto sinBetaSum = rewriter.create<math::SinOp>(loc, betaSum);
auto cosAlphaDiffReversed = rewriter.create<math::CosOp>(loc, alphaDiffReversed);

auto term1_beta = rewriter.create<arith::MulFOp>(loc, rewriter.create<arith::MulFOp>(loc, c1, s2), sinBetaSum);
auto term2_beta = rewriter.create<arith::MulFOp>(loc, rewriter.create<arith::MulFOp>(loc, s1, s2), cosAlphaDiffReversed);
auto numerator_beta = rewriter.create<arith::AddFOp>(loc, rewriter.create<arith::NegFOp>(loc, term1_beta), term2_beta);

auto denominator_beta = denominator_alpha; // Reuse from alpha calculation if applicable
auto betaF = rewriter.create<arith::NegFOp>(loc, rewriter.create<math::AtanOp>(loc, rewriter.create<arith::DivFOp>(loc, numerator_beta, denominator_beta)));

// Step 6: Output angles (phi_f, theta_f, omega_f)
// Assign phi_f = alphaF, theta_f = thetaF, omega_f = betaF as the final values
SmallVector<mlir::Value> combinedAngles = {alphaF, thetaF, betaF};
auto outQubitsTypes = op.getOutQubits().getTypes();
auto outCtrlQubitsTypes = op.getOutCtrlQubits().getTypes();
auto inQubits = op.getInQubits();
auto inCtrlQubits = op.getInCtrlQubits();
auto inCtrlValues = op.getInCtrlValues();
rewriter.replaceOpWithNewOp<CustomOp>(op, outQubitsTypes, outCtrlQubitsTypes, combinedAngles, inQubits, opGateName, nullptr, inCtrlQubits, inCtrlValues);

return success();
}
else {
TypeRange outQubitsTypes = op.getOutQubits().getTypes();
TypeRange outQubitsCtrlTypes = op.getOutCtrlQubits().getTypes();
ValueRange parentInQubits = parentOp.getInQubits();
ValueRange parentInCtrlQubits = parentOp.getInCtrlQubits();
ValueRange parentInCtrlValues = parentOp.getInCtrlValues();
auto parentParams = parentOp.getParams();
auto params = op.getParams();
SmallVector<mlir::Value> sumParams;
for (auto [param, parentParam] : llvm::zip(params, parentParams)) {
mlir::Value sumParam =
rewriter.create<arith::AddFOp>(loc, parentParam, param).getResult();
sumParams.push_back(sumParam);
};
auto mergeOp = rewriter.create<CustomOp>(loc, outQubitsTypes, outQubitsCtrlTypes, sumParams,
parentInQubits, opGateName, nullptr,
parentInCtrlQubits, parentInCtrlValues);

op.replaceAllUsesWith(mergeOp);

return success();
}
}
};

Expand Down
Loading