Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

catalyst.grad fails to differentiate circuits with decomposed QubitUnitarys #1393

Open
paul0403 opened this issue Dec 19, 2024 · 0 comments
Open

Comments

@paul0403
Copy link
Contributor

Workflows involving QubitUnitary cannot be differentiated because lightning does not support taking their gradients.

As a workaround, we can try to decompose the QubitUnitary first. However, this also fails. We present two methods that both fail:

(a) With qml.QubitUnitary.compute_decomposition
(Note that this method is restrictive in the first place, as it only works on 2-by-2 and 4-by-4 matrices)

dev = qml.device("lightning.qubit", wires=1)

@qml.qnode(dev)
def circuit(x):
	U = jnp.array([[1,0],[0,x]])
	decomp = qml.QubitUnitary.compute_decomposition(U, wires=0)
	for op in decomp:
		qml.apply(op)
	return qml.probs()

@qjit(keep_intermediate=False)
def f(x):
	probs = circuit(x)
	return probs[0] + probs[1]

print(f(1.0))   # fine, 1.0
print(catalyst.grad(f, argnums=0)(1.0))

##################
catalyst.utils.exceptions.CompileError: catalyst-cli failed with error code 1: Failed to run pipeline: MLIRToLLVMDialect
Compilation failed:
deriv_f:121:12: error: failed to legalize operation 'memref.load' that was explicitly marked illegal
      %1 = stablehlo.reshape %0 : (tensor<1x1x1xcomplex<f64>>) -> tensor<1xcomplex<f64>>
           ^
deriv_f:44:13: note: called from
      %14 = call @det(%13) : (tensor<1x2x2xcomplex<f64>>) -> tensor<1xcomplex<f64>>
            ^
deriv_f:121:12: note: see current operation: %213 = "memref.load"(%202, %204, %207, %210) <{nontemporal = false}> : (memref<1x1x1xcomplex<f64>, strided<[4, 2, 1]>>, index, index, index) -> complex<f64>
      %1 = stablehlo.reshape %0 : (tensor<1x1x1xcomplex<f64>>) -> tensor<1xcomplex<f64>>
           ^
While processing 'MemrefToLLVMWithTBAAPass' pass Failed to lower MLIR module

(b) With qml.transforms.unitary_to_rot

# Changing the qnode in the above to the following gives the same behavior

@qml.transforms.unitary_to_rot
@qml.qnode(dev)
def circuit(x):
	U = jnp.array([[1,0],[0,x]])
	qml.QubitUnitary(U, wires=0)
	return qml.probs()

Note that with this method, the intermediate mlir does indeed show the decomposed rotation gates instead of unitary gates.


Note that specifying a "decomposition" manually does not break:

dev = qml.device("lightning.qubit", wires=1)

@qml.qnode(dev)
def circuit(x):
	qml.RZ(x, wires=0)
	qml.RY(x+1, wires=0)
	qml.RZ(x*2, wires=0)

	return qml.probs()


@qjit(keep_intermediate=False)
def f(x):
	probs = circuit(x)
	return probs[0] + probs[1]

print(f(1.0))
print(catalyst.grad(f, argnums=0)(1.0))

#####################
1.0000000000000002
5.551115123125783e-17

Therefore, likely this error arises from how the decomposition is traced.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant