eliminate window_adaptation overhead? #728
-
I have noticed that when I run the following code snippet twice in a row the amount of time that it takes is reduced by half.
Is this a known feature of the window_adaptation warmup? Is it the case that there is some jit compilation that happens under the scenes? Is it possible to eliminate this overhead by somehow jitting Any insights or suggestions appreciated and sorry if I missed it in the documentation. Thanks! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
This is a property of JAX, where in the first run it is compiled+run. That's why the second time you run it, it is faster as it is already compiled. Also, while we did not explicitly jit the |
Beta Was this translation helpful? Give feedback.
This is a property of JAX, where in the first run it is compiled+run. That's why the second time you run it, it is faster as it is already compiled.
Also, while we did not explicitly jit the
warmup.run
, we use JAX scan() that "compiles f, so while it can be combined with jit(), it’s usually unnecessary."