Tracking potential performance improvement #72
Replies: 5 comments 7 replies
-
Nice finds! I think these would be worth exploring. If they seem useful I can add a performance tips page in the documentation. As far as I can tell, there are three varieties of performance tips:
The two packages you found fall into the third category. I think I should have some time tomorrow to look at one of them and do a comparision to NUTS. |
Beta Was this translation helpful? Give feedback.
-
@DominiqueMakowski, I am seeking clarification on AD requirements and parallel capabilities. Currently, I was only able to run MuseInference with Zygote AD on a single thread , which has comparable performance to NUTS with ReverseDiff AD (but we could get higher ESS by running multiple threads). |
Beta Was this translation helpful? Give feedback.
-
Just wanted to flag Tor's fantastic talk on what's new in Turing. Some questions that I have:
|
Beta Was this translation helpful? Give feedback.
-
I was just chatting with him over lunch! The syntax change seems cool. We also spoke about getting the start of TuringGLM. I was gonna play around with those syntax changes for a few examples. |
Beta Was this translation helpful? Give feedback.
-
Today I performed a quick benchmark with Enzyme. The first point of good news is that it ran without error and produced the correct output. The second point of good news is that it ran faster than ForwardDiff and ReverseDiff. cd(@__DIR__)
using Pkg
Pkg.activate("")
using Distributions
using Enzyme
using LinearAlgebra
using ReverseDiff
using SequentialSamplingModels
using Turing
Enzyme.API.runtimeActivity!(true)
n_choices = 20
ν = fill(1, n_choices)
# Generate some data with known parameters
dist = LBA(; ν, A = .8, k = .2, τ = .3)
data = rand(dist, 100)
# Specify LBA model
@model function model(data, n_choices; min_rt = minimum(data.rt))
# Priors
ν ~ MvNormal(zeros(n_choices), I * 2)
A ~ truncated(Normal(.8, .4), 0.0, Inf)
k ~ truncated(Normal(.2, .2), 0.0, Inf)
τ ~ Uniform(0.0, min_rt)
# Likelihood
data ~ LBA(; ν, A, k, τ)
end
# 97.95 seconds
chains_forward = sample(model(data, n_choices), NUTS(1000, .85), 1000)
# 47.58 seconds
chains_enzyme = sample(model(data, n_choices), NUTS(1000, .85; adtype = AutoEnzyme()), 1000)
# compile = false ≈ 3960 seconds (early termination)
# compile = true ≈ 960 seconds (early termination, potentially unsafe caching)
chains_reverse = sample(model(data, n_choices), NUTS(1000, .85; adtype = AutoReverseDiff()), 1000) The downside is that there are still some known problems with Distributions.jl. Nonetheless, progress is being made. |
Beta Was this translation helpful? Give feedback.
-
As we know, performance is currently the main limitation. Maybe we can put in this thread potential directions, optimization tips, benchmarks etc. to discuss and track how things are evolving on the performance front.
Beta Was this translation helpful? Give feedback.
All reactions