-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
diffrax with mesolve is very slow #26
Comments
just to clarify, for the above example, the first snippet takes ~100s on a GPU, bit less on a CPU, second snippet using diffeqsolve() directly takes 114 ms (and the same example with standard mesolve() with normal CSR data layer takes < 600 microseconds) there's probably some overhead on the GPU side, but this also scales up very badly (increasing N makes the first snippet unusable quickly). |
I did expect diffrax solver to be slower than normal mesolve (on cpu), but not that much... On cpu the first snippet takes 3.5s on my computer and using diffeqsolve directly takes 400ms. When I use the same algorithm by adding Can you try to profile on gpu to see why we are so inefficient? ps. I am on the jaxdia branch. |
Thanks eric, that helps a lot! playing around with combinations of options it seems like the stepsize is the thing that was really slowing it down. e.g., doing the native diffeqsolver() with the same constant stepsize is also very slow (though a little faster than qutip-jax, could just be because of different choice of dt0 in qutip-jax). I didn't have much luck with the profiler (will keep trying to get something useful out of it), but after some playing around, it's not really clear to me if there's really a problem or not. some examples, with N=4 (to slow things down a bit), and with the Dopri5() solver, PIDController for stepsize:
|
Just a quick addition, I gave the jaxdia branch a try, this is super encouraging! With jaxdia we can really push up the Hilbert space size, and I see some pretty impressive numbers.. For N=12 with standard qutip-jax I tend to run out of memory around N=5, so jaxdia really helps us see some crossover. I will try and double check I am not messing something up, but this seems very impressive! |
Hey guys, interesting discussion here. Is there a way to include time dependent Hamiltonian params in mesolve with Jax? I just gave the jaxdia branch a try with no luck. |
The normal list format work if jitted functions are used: |
Thanks. Can you also put arrays of a length corresponding to the number of steps in the numerical solver in there for the time dependent parameters instead of a function? This is option 3 described here: https://qutip.org/docs/latest/guide/dynamics/dynamics-time.html |
No, (not yet). That version use scipy and cython, which don't mixes well with jax. |
Doing some basic tests with mesolve()+qutip-jax and the diffrax method seems excessively slow, both on cpu and gpu. Manually extracting the ODE data and putting it in diffrax.diffeqsolve() is much quicker, so perhaps there is some bottleneck somewhere?
e.g., comparing
compare this to a manual attempt using the RHS and initial condition from above (not super sure this is correct, but seems to give reasonable same output)
The text was updated successfully, but these errors were encountered: