Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

diffrax with mesolve is very slow #26

Open
nwlambert opened this issue Oct 18, 2023 · 8 comments
Open

diffrax with mesolve is very slow #26

nwlambert opened this issue Oct 18, 2023 · 8 comments

Comments

@nwlambert
Copy link
Member

Doing some basic tests with mesolve()+qutip-jax and the diffrax method seems excessively slow, both on cpu and gpu. Manually extracting the ODE data and putting it in diffrax.diffeqsolve() is much quicker, so perhaps there is some bottleneck somewhere?

e.g., comparing

import qutip as qt
import numpy as np
import qutip_jax
import jax
with qt.CoreOptions(default_dtype="jax"):
    N = 2
    a = qt.destroy(N) & qt.qeye(N) & qt.qeye(N)
    b = qt.qeye(N) & qt.destroy(N)  & qt.qeye(N)
    c = qt.qeye(N) & qt.qeye(N) & qt.destroy(N)  
    H = (a.dag()*a + b.dag()*b + c.dag()*c + 
        (a.dag()+a) * (b+b.dag()) + 
        (b.dag()+b) * (c+c.dag())
        )
    
    c_ops =[a,b,c]

    t = 10
    options = {"method": "diffrax", "normalize_output": False}
    solver = qt.MESolver(H, c_ops, options=options)

    result = solver.run(
            qt.basis(N, 1, dtype="jax") & qt.basis(N,0, dtype="jax") & qt.basis(N,0, dtype="jax"),
            [0, t],
            e_ops=qt.num(N, dtype="jax")& qt.qeye(N, dtype="jax") & qt.qeye(N, dtype="jax")
        )

compare this to a manual attempt using the RHS and initial condition from above (not super sure this is correct, but seems to give reasonable same output)

import jax.numpy as jnp
from diffrax import diffeqsolve, ODETerm, Dopri5, PIDController
L=qt.liouvillian(H, c_ops)
LJ=jax.numpy.array(L.full())
rho0J=jax.numpy.array(qt.operator_to_vector(qt.ket2dm(qt.basis(N, 1) & qt.basis(N,0) & qt.basis(N,0))).full())

def f(t, y, args):
    return LJ @ y
stepsize_controller = PIDController(rtol=1e-5, atol=1e-5)
term = ODETerm(f)
solver = Dopri5()
y0 = rho0J
solution = diffeqsolve(term, solver, t0=0, t1=t, dt0=0.01, y0=y0, stepsize_controller=stepsize_controller)
@nwlambert
Copy link
Member Author

just to clarify, for the above example, the first snippet takes ~100s on a GPU, bit less on a CPU, second snippet using diffeqsolve() directly takes 114 ms (and the same example with standard mesolve() with normal CSR data layer takes < 600 microseconds)

there's probably some overhead on the GPU side, but this also scales up very badly (increasing N makes the first snippet unusable quickly).

@Ericgig
Copy link
Member

Ericgig commented Oct 18, 2023

I did expect diffrax solver to be slower than normal mesolve (on cpu), but not that much...

On cpu the first snippet takes 3.5s on my computer and using diffeqsolve directly takes 400ms.
Not as bad but still not great.

When I use the same algorithm by adding "stepsize_controller" : PIDController(rtol=1e-5, atol=1e-5), "solver": Dopri5() to the options. I get 500ms using qutip, which is about the same. So we are not that bad on cpus.
(The defaults we use are "solver": diffrax.Tsit5(), "stepsize_controller": diffrax.ConstantStepSize().)

Can you try to profile on gpu to see why we are so inefficient?
My guess is that we always compute a coefficient event for constant QobjEvo with the diffrax method, this coefficient is a function returning a constant, but it's probably computed on the cpu...

ps. I am on the jaxdia branch.

@nwlambert
Copy link
Member Author

Thanks eric, that helps a lot! playing around with combinations of options it seems like the stepsize is the thing that was really slowing it down. e.g., doing the native diffeqsolver() with the same constant stepsize is also very slow (though a little faster than qutip-jax, could just be because of different choice of dt0 in qutip-jax).

I didn't have much luck with the profiler (will keep trying to get something useful out of it), but after some playing around, it's not really clear to me if there's really a problem or not. some examples, with N=4 (to slow things down a bit), and with the Dopri5() solver, PIDController for stepsize:

  1. Standard CSR qutip: 0.2s
  2. qutip-jax cpu: 14s
  3. qutip-jax gpu: 0.7s
  4. native jax diffeqsolver() cpu: 14s
  5. native jax diffeqsolver() gpu: 0.6s

@nwlambert
Copy link
Member Author

nwlambert commented Oct 20, 2023

Just a quick addition, I gave the jaxdia branch a try, this is super encouraging! With jaxdia we can really push up the Hilbert space size, and I see some pretty impressive numbers..
For N=10
qutip standard CSR or Dia: 200s
qutip-jaxdia-gpu: 5.3s

For N=12
qutip standard CSR or Dia: 824s
qutip-jaxdia-gpu: 16s

with standard qutip-jax I tend to run out of memory around N=5, so jaxdia really helps us see some crossover.

I will try and double check I am not messing something up, but this seems very impressive!

@jan-o-e
Copy link

jan-o-e commented Feb 17, 2024

Hey guys, interesting discussion here.

Is there a way to include time dependent Hamiltonian params in mesolve with Jax? I just gave the jaxdia branch a try with no luck.

@Ericgig
Copy link
Member

Ericgig commented Feb 19, 2024

The normal list format work if jitted functions are used: H = [H0, [H1, jax.jit(f)]].

@jan-o-e
Copy link

jan-o-e commented Feb 21, 2024

Thanks. Can you also put arrays of a length corresponding to the number of steps in the numerical solver in there for the time dependent parameters instead of a function?

This is option 3 described here: https://qutip.org/docs/latest/guide/dynamics/dynamics-time.html

@Ericgig
Copy link
Member

Ericgig commented Feb 22, 2024

No, (not yet). That version use scipy and cython, which don't mixes well with jax.
For now, you would have to make / find a spline function that support jit.
jax-cosmo seems to have something promising.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants