From 71e89f5f0c47374dae210b16ddcaa603e20095a1 Mon Sep 17 00:00:00 2001 From: Olivier Peltre Date: Mon, 3 Jun 2024 17:16:52 +0200 Subject: [PATCH 1/3] print at every CG computation --- e3nn_jax/_src/so3.py | 2 +- test.py | 13 +++++++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) create mode 100644 test.py diff --git a/e3nn_jax/_src/so3.py b/e3nn_jax/_src/so3.py index 27ee0db..4886c8b 100644 --- a/e3nn_jax/_src/so3.py +++ b/e3nn_jax/_src/so3.py @@ -17,7 +17,6 @@ def change_basis_real_to_complex(l: int) -> np.ndarray: # Added factor of 1j**l to make the Clebsch-Gordan coefficients real return (-1j) ** l * q - def clebsch_gordan(l1: int, l2: int, l3: int) -> np.ndarray: r"""The Clebsch-Gordan coefficients of the real irreducible representations of :math:`SO(3)`. @@ -29,6 +28,7 @@ def clebsch_gordan(l1: int, l2: int, l3: int) -> np.ndarray: Returns: np.ndarray: the Clebsch-Gordan coefficients """ + print("I'm computing Clebsch-Gordan coefficients!") C = su2_clebsch_gordan(l1, l2, l3) Q1 = change_basis_real_to_complex(l1) Q2 = change_basis_real_to_complex(l2) diff --git a/test.py b/test.py new file mode 100644 index 0000000..f511686 --- /dev/null +++ b/test.py @@ -0,0 +1,13 @@ +import jax +import jax.numpy as np + +import e3nn_jax as e3nn + +# some dummy e3-array +a = e3nn.IrrepsArray.zeros("8x0e + 8x1o + 8x2e", (1,)) +# some scalar array +n = a.irreps.num_irreps +b = e3nn.IrrepsArray.zeros(f"{n}x0e", (1,)) + +for i in range(3): + print(a * b) From fe2ac122c61ffeb6f22549a8c7aa98deaaa1902d Mon Sep 17 00:00:00 2001 From: Olivier Peltre Date: Mon, 3 Jun 2024 17:19:40 +0200 Subject: [PATCH 2/3] cache CG coefficients --- e3nn_jax/_src/so3.py | 3 ++- test.py | 13 ------------- 2 files changed, 2 insertions(+), 14 deletions(-) delete mode 100644 test.py diff --git a/e3nn_jax/_src/so3.py b/e3nn_jax/_src/so3.py index 4886c8b..e2820b4 100644 --- a/e3nn_jax/_src/so3.py +++ b/e3nn_jax/_src/so3.py @@ -1,4 +1,5 @@ import numpy as np +import functools from e3nn_jax._src.su2 import su2_clebsch_gordan, su2_generators @@ -17,6 +18,7 @@ def change_basis_real_to_complex(l: int) -> np.ndarray: # Added factor of 1j**l to make the Clebsch-Gordan coefficients real return (-1j) ** l * q +@functools.cache def clebsch_gordan(l1: int, l2: int, l3: int) -> np.ndarray: r"""The Clebsch-Gordan coefficients of the real irreducible representations of :math:`SO(3)`. @@ -28,7 +30,6 @@ def clebsch_gordan(l1: int, l2: int, l3: int) -> np.ndarray: Returns: np.ndarray: the Clebsch-Gordan coefficients """ - print("I'm computing Clebsch-Gordan coefficients!") C = su2_clebsch_gordan(l1, l2, l3) Q1 = change_basis_real_to_complex(l1) Q2 = change_basis_real_to_complex(l2) diff --git a/test.py b/test.py deleted file mode 100644 index f511686..0000000 --- a/test.py +++ /dev/null @@ -1,13 +0,0 @@ -import jax -import jax.numpy as np - -import e3nn_jax as e3nn - -# some dummy e3-array -a = e3nn.IrrepsArray.zeros("8x0e + 8x1o + 8x2e", (1,)) -# some scalar array -n = a.irreps.num_irreps -b = e3nn.IrrepsArray.zeros(f"{n}x0e", (1,)) - -for i in range(3): - print(a * b) From de7822e76ffca6d2dace9a39a083be94e8322be6 Mon Sep 17 00:00:00 2001 From: Mit <53411468+mitkotak@users.noreply.github.com> Date: Sun, 23 Jun 2024 16:05:24 -0400 Subject: [PATCH 3/3] safe cache --- e3nn_jax/_src/so3.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/e3nn_jax/_src/so3.py b/e3nn_jax/_src/so3.py index e2820b4..29440d9 100644 --- a/e3nn_jax/_src/so3.py +++ b/e3nn_jax/_src/so3.py @@ -18,8 +18,13 @@ def change_basis_real_to_complex(l: int) -> np.ndarray: # Added factor of 1j**l to make the Clebsch-Gordan coefficients real return (-1j) ** l * q -@functools.cache + def clebsch_gordan(l1: int, l2: int, l3: int) -> np.ndarray: + return _clebsch_gordan(l1, l2, l3).copy() + + +@functools.cache +def _clebsch_gordan(l1: int, l2: int, l3: int) -> np.ndarray: r"""The Clebsch-Gordan coefficients of the real irreducible representations of :math:`SO(3)`. Args: