-
Notifications
You must be signed in to change notification settings - Fork 18
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add fused spherical harmonics #46
Conversation
Thanks for the PR in both torch and JAX. I have few questions:
x = jnp.array([[0, 0, 1.0]])
print(e3nn.sh(range(1 + 1), x, True, algorithm=("legendre",)))
# [[1. 0. 0. 1.732]]
print(calc_Ylm(1, x))
# [[0.141 0. 0.244 0. ]]
|
1. At a high level, the difference is that I put the expensive polynomial
calls in a matrix multiplication instead of using loops. My hypothesis is
that rewriting the most expensive part in terms of matrix
multiplication would improve efficiency because Jax's matrix
multiplication is heavily optimized. I haven't tested the performance yet
though.
2. I don't expect it to perform as well on a CPU because the coefficient
tensor I'm using has considerable sparsity (~81% sparse), so overall there
are a lot of unnecessary operations. However, I anticipate that running it
on a GPU would have performance advantages, especially at higher degrees.
3. Ok, I'll make that change :)
4. I'm still new to Jax, so I don't know why I wouldn't want to cache the
constants and how Jax would optimize them. Their computational graph isn't
necessary for backpropagation and I didn't include them in the jit compiled
function which is why it seemed like a good idea to cache them. Would the
jit compiled function cache the constants for me?
…On Thu, Nov 2, 2023 at 2:35 AM Mario Geiger ***@***.***> wrote:
Thanks for the PR in both torch and JAX. I have few questions:
1.
What is the difference between this algorithm and this one
<https://github.com/e3nn/e3nn-jax/blob/main/e3nn_jax/_src/spherical_harmonics.py#L435-L464>
?
2.
A preliminary test on my CPU shows that the current algorithm (e3nn.sh(range(max_degree
+ 1), x, True, algorithm=("legendre",))) is faster than calc_Ylm(max_degree,
x) when max_degree is larger than 5. For smaller degrees, the
algorithm e3nn.sh(range(lmax + 1), x, True, algorithm=("recursive",
"dense", "custom_jvp")) becomes faster. In which case do you see an
improvement over the current algorithms?
3.
Please use the same conventions as e3nn:
x = jnp.array([[0, 0, 1.0]])
print(e3nn.sh(range(1 + 1), x, True, algorithm=("legendre",)))# [[1. 0. 0. 1.732]]print(calc_Ylm(1, x))# [[0.141 0. 0.244 0. ]]
1. Don't use lru_cache, jax.jit will automatically optimize the
constants
—
Reply to this email directly, view it on GitHub
<#46 (comment)>, or
unsubscribe
<https://github.com/notifications/unsubscribe-auth/AL633ZD32L2RTX5XDVKAV6LYCNLM3AVCNFSM6AAAAAA62LMGL6VHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMYTOOJQGI4DMNRSGA>
.
You are receiving this because you authored the thread.Message ID:
***@***.***>
|
I reorganize the code to separate the different implementation into different files.
|
I just want to ping you to let you know I'm still planning on working on
this project. I've been very busy so I haven't had a lot of time for this
project. I'll look into this again once I have more free time. To answer
your question in the meantime:
1. Yes, it optimizes that function. I'll rearrange my code as an
alternative to the legendre_spherical_harmonics code when
Hope you had a great Thanksgiving!
…On Sun, Nov 5, 2023 at 3:33 AM Mario Geiger ***@***.***> wrote:
I reorganize the code to separate the different implementation into
different files.
1. Your code is aiming to optimize the function
jax.scipy.special.lpmn_values ? If so could you rearrange your code as
an alternative implementation of the legendre function?
2. Can you run those tests?
3. Great!
4. Yes JAX will take the computational graph and optimizing it. It's
very good at: dead code elimination, common subexpression elimination,
constant folding
—
Reply to this email directly, view it on GitHub
<#46 (comment)>, or
unsubscribe
<https://github.com/notifications/unsubscribe-auth/AL633ZG6ZX2DK7SVZ4FHXK3YC5TO5AVCNFSM6AAAAAA62LMGL6VHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMYTOOJTGY4TOOJUGQ>
.
You are receiving this because you authored the thread.Message ID:
***@***.***>
|
No description provided.