diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 919806fb84a..a763d67d880 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -276,6 +276,11 @@ * Improve builtin types support with `qml.pauli_decompose`. [(#4577)](https://github.com/PennyLaneAI/pennylane/pull/4577) +* The function `integrals.py` is modified to replace indexing with slicing in the computationally + expensive functions `electron_repulsion` and `_hermite_coulomb`, for a better compatibility with + JAX. + [(#4685)](https://github.com/PennyLaneAI/pennylane/pull/4685) + * Various changes to measurements to improve feature parity between the legacy `default.qubit` and the new `DefaultQubit2`. This includes not trying to squeeze batched `CountsMP` results and implementing `MutualInfoMP.map_wires`. diff --git a/pennylane/qchem/integrals.py b/pennylane/qchem/integrals.py index b277f2d70c5..b6065110ebd 100644 --- a/pennylane/qchem/integrals.py +++ b/pennylane/qchem/integrals.py @@ -774,7 +774,7 @@ def _hermite_coulomb(t, u, v, n, p, dr): Returns: array[float]: value of the Hermite integral """ - x, y, z = dr[0], dr[1], dr[2] + x, y, z = dr[0:3] T = p * (dr**2).sum(axis=0) r = 0 @@ -967,12 +967,17 @@ def electron_repulsion(la, lb, lc, ld, ra, rb, rc, rd, alpha, beta, gamma, delta + delta * rd[:, np.newaxis, np.newaxis, np.newaxis, np.newaxis] ) / (gamma + delta) - g_t = [expansion(l1, l2, ra[0], rb[0], alpha, beta, t) for t in range(l1 + l2 + 1)] - g_u = [expansion(m1, m2, ra[1], rb[1], alpha, beta, u) for u in range(m1 + m2 + 1)] - g_v = [expansion(n1, n2, ra[2], rb[2], alpha, beta, v) for v in range(n1 + n2 + 1)] - g_r = [expansion(l3, l4, rc[0], rd[0], gamma, delta, r) for r in range(l3 + l4 + 1)] - g_s = [expansion(m3, m4, rc[1], rd[1], gamma, delta, s) for s in range(m3 + m4 + 1)] - g_w = [expansion(n3, n4, rc[2], rd[2], gamma, delta, w) for w in range(n3 + n4 + 1)] + ra0, ra1, ra2 = ra[0:3] + rb0, rb1, rb2 = rb[0:3] + rc0, rc1, rc2 = rc[0:3] + rd0, rd1, rd2 = rd[0:3] + + g_t = [expansion(l1, l2, ra0, rb0, alpha, beta, t) for t in range(l1 + l2 + 1)] + g_u = [expansion(m1, m2, ra1, rb1, alpha, beta, u) for u in range(m1 + m2 + 1)] + g_v = [expansion(n1, n2, ra2, rb2, alpha, beta, v) for v in range(n1 + n2 + 1)] + g_r = [expansion(l3, l4, rc0, rd0, gamma, delta, r) for r in range(l3 + l4 + 1)] + g_s = [expansion(m3, m4, rc1, rd1, gamma, delta, s) for s in range(m3 + m4 + 1)] + g_w = [expansion(n3, n4, rc2, rd2, gamma, delta, w) for w in range(n3 + n4 + 1)] g = 0.0 lengths = [l1 + l2 + 1, m1 + m2 + 1, n1 + n2 + 1, l3 + l4 + 1, m3 + m4 + 1, n3 + n4 + 1]