From 665cfc03ab4c434b373b77cf966e34605809a039 Mon Sep 17 00:00:00 2001 From: Romeo Valentin Date: Sat, 25 May 2024 01:09:07 -0700 Subject: [PATCH] Slightly refactor MCMCEstimator again. --- src/mcmcestimator.jl | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/src/mcmcestimator.jl b/src/mcmcestimator.jl index 6414c52..2f21ede 100644 --- a/src/mcmcestimator.jl +++ b/src/mcmcestimator.jl @@ -13,9 +13,9 @@ $(TYPEDFIELDS) """ @kwdef struct MCMCEstimator{ST<:Turing.InferenceAlgorithm, SAT<:NamedTuple} <: EstimationMethod "Inference algorithm type for MCMC sampling. Defaults to `NUTS`." - samplealg::ST = NUTS - "kwargs passed to `Turing.sample`. Defaults to `(; )`." - sampleargs::SAT = (; ) + samplealg::ST = NUTS() + "kwargs passed to `Turing.sample`. Defaults to `(; drop_warmup=true, progress=false, verbose=false)`." + sampleargs::SAT = (; drop_warmup=true, progress=false, verbose=false) end solvealg(est::MCMCEstimator) = est.samplealg solveargs(est::MCMCEstimator) = est.sampleargs @@ -28,13 +28,12 @@ solveargs(est::MCMCEstimator) = est.sampleargs return end -function predictsamples(est::MCMCEstimator, f, xs, ysmeas, paramprior::Sampleable, noisemodel::NoiseModel, nsamples; - drop_warmup=true, progress=false, verbose=false, - kwargs...) +function predictsamples(est::MCMCEstimator, f, xs, ysmeas, paramprior::Sampleable, noisemodel::NoiseModel, nsamples) + chain = with_logger(ConsoleLogger(Warn)) do # ignore "Info" outputs. - alg = solvealg(est)(; solveargs(est)...) + alg = solvealg(est) sample(bayesianmodel(est, f, xs, maybeflatten(ysmeas), paramprior, noisemodel), alg, nsamples; - drop_warmup, progress, verbose, kwargs...) + solveargs(est)...) end d = length(paramprior) θsamples = stack([chain[Symbol("θ[$i]")][:] for i in 1:d]; dims=1)