diff --git a/src/mcmcestimator.jl b/src/mcmcestimator.jl index 6414c52..2f21ede 100644 --- a/src/mcmcestimator.jl +++ b/src/mcmcestimator.jl @@ -13,9 +13,9 @@ $(TYPEDFIELDS) """ @kwdef struct MCMCEstimator{ST<:Turing.InferenceAlgorithm, SAT<:NamedTuple} <: EstimationMethod "Inference algorithm type for MCMC sampling. Defaults to `NUTS`." - samplealg::ST = NUTS - "kwargs passed to `Turing.sample`. Defaults to `(; )`." - sampleargs::SAT = (; ) + samplealg::ST = NUTS() + "kwargs passed to `Turing.sample`. Defaults to `(; drop_warmup=true, progress=false, verbose=false)`." + sampleargs::SAT = (; drop_warmup=true, progress=false, verbose=false) end solvealg(est::MCMCEstimator) = est.samplealg solveargs(est::MCMCEstimator) = est.sampleargs @@ -28,13 +28,12 @@ solveargs(est::MCMCEstimator) = est.sampleargs return end -function predictsamples(est::MCMCEstimator, f, xs, ysmeas, paramprior::Sampleable, noisemodel::NoiseModel, nsamples; - drop_warmup=true, progress=false, verbose=false, - kwargs...) +function predictsamples(est::MCMCEstimator, f, xs, ysmeas, paramprior::Sampleable, noisemodel::NoiseModel, nsamples) + chain = with_logger(ConsoleLogger(Warn)) do # ignore "Info" outputs. - alg = solvealg(est)(; solveargs(est)...) + alg = solvealg(est) sample(bayesianmodel(est, f, xs, maybeflatten(ysmeas), paramprior, noisemodel), alg, nsamples; - drop_warmup, progress, verbose, kwargs...) + solveargs(est)...) end d = length(paramprior) θsamples = stack([chain[Symbol("θ[$i]")][:] for i in 1:d]; dims=1)