Skip to content

Commit

Permalink
Slightly refactor MCMCEstimator again.
Browse files Browse the repository at this point in the history
  • Loading branch information
RomeoV committed May 25, 2024
1 parent 1fad752 commit 665cfc0
Showing 1 changed file with 7 additions and 8 deletions.
15 changes: 7 additions & 8 deletions src/mcmcestimator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ $(TYPEDFIELDS)
"""
@kwdef struct MCMCEstimator{ST<:Turing.InferenceAlgorithm, SAT<:NamedTuple} <: EstimationMethod
"Inference algorithm type for MCMC sampling. Defaults to `NUTS`."
samplealg::ST = NUTS
"kwargs passed to `Turing.sample`. Defaults to `(; )`."
sampleargs::SAT = (; )
samplealg::ST = NUTS()
"kwargs passed to `Turing.sample`. Defaults to `(; drop_warmup=true, progress=false, verbose=false)`."
sampleargs::SAT = (; drop_warmup=true, progress=false, verbose=false)
end
solvealg(est::MCMCEstimator) = est.samplealg
solveargs(est::MCMCEstimator) = est.sampleargs
Expand All @@ -28,13 +28,12 @@ solveargs(est::MCMCEstimator) = est.sampleargs
return
end

function predictsamples(est::MCMCEstimator, f, xs, ysmeas, paramprior::Sampleable, noisemodel::NoiseModel, nsamples;
drop_warmup=true, progress=false, verbose=false,
kwargs...)
function predictsamples(est::MCMCEstimator, f, xs, ysmeas, paramprior::Sampleable, noisemodel::NoiseModel, nsamples)

chain = with_logger(ConsoleLogger(Warn)) do # ignore "Info" outputs.
alg = solvealg(est)(; solveargs(est)...)
alg = solvealg(est)
sample(bayesianmodel(est, f, xs, maybeflatten(ysmeas), paramprior, noisemodel), alg, nsamples;
drop_warmup, progress, verbose, kwargs...)
solveargs(est)...)
end
d = length(paramprior)
θsamples = stack([chain[Symbol("θ[$i]")][:] for i in 1:d]; dims=1)
Expand Down

0 comments on commit 665cfc0

Please sign in to comment.