Skip to content

Commit

Permalink
Use global BB optimization, then pathfinder for initialization
Browse files Browse the repository at this point in the history
  • Loading branch information
sefffal committed Nov 25, 2024
1 parent 50fb251 commit 1f1e654
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 54 deletions.
10 changes: 6 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
MathTeXEngine = "0a4f8689-d25c-4efe-a92b-7142dfc1aa53"
NamedTupleTools = "d9ec5142-1e00-5aa0-9d6a-321866360f50"
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
OptimizationBBO = "3e6eede4-6085-4f62-9a71-46d9bc1eb92b"
OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
Pathfinder = "b1d3bc72-d0e7-4279-b92f-7fa5d6d2d454"
Expand Down Expand Up @@ -80,6 +81,7 @@ Distributions = "0.25"
DistributionsAD = "0.6"
Dynesty = "0.4.0"
FITSIO = "0.16, 0.17"
FiniteDiff = "2"
ForwardDiff = "0.10"
HDF5 = "0.17"
HORIZONS = "0.4"
Expand All @@ -95,6 +97,7 @@ Makie = "0.21"
MathTeXEngine = "0.6"
NamedTupleTools = "0.13, 0.14"
Optimization = "4"
OptimizationBBO = "0.4"
OptimizationOptimJL = "0.4"
OrderedCollections = "1.6"
PairPlots = "2"
Expand All @@ -116,14 +119,13 @@ Tables = "1.6"
Transducers = "0.4"
TypedTables = "1.4"
julia = "1.9"
FiniteDiff = "2"

[extras]
Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0"
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a"
PairPlots = "43a3c2be-4208-490b-832a-a21dcd55d7da"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "FiniteDiff", "CairoMakie", "PairPlots"]
109 changes: 59 additions & 50 deletions src/initialization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ function guess_starting_position(rng::Random.AbstractRNG, model::LogDensityModel
if logpost_ofti > logpost
logpost = logpost_ofti
params = params_ofti
# println("accepted OFTI")
end
# println("logpost_after = ", logpost)
if logpost > bestlogpost
Expand Down Expand Up @@ -208,6 +207,8 @@ function get_starting_point!!(model::LogDensityModel; kwargs...)
return get_starting_point!!(Random.default_rng(), model; kwargs...)
end

using OptimizationOptimJL, OptimizationBBO

"""
default_initializer!(model::LogDensityModel; initial_samples=100_000)
Expand All @@ -219,36 +220,48 @@ If this fails repeatedly, simply draw `initial_samples` from the prior and keepi
function default_initializer!(model::LogDensityModel; kwargs...)
return default_initializer!(Random.default_rng(), model; kwargs...)
end
function default_initializer!(rng::Random.AbstractRNG, model::LogDensityModel; initial_point = nothing, nruns=8, ntries=2, ndraws=1000, initial_samples=10000, verbosity=1)
function default_initializer!(rng::Random.AbstractRNG, model::LogDensityModel; nruns=8, ntries=2, ndraws=1000, verbosity=1)
ldm_any = LogDensityModelAny(model)

# Pathfinder (and especially multipathfinder) do not work well with global optimization methods.
# Instead, we do a two-step process.
# Find the global MAP point, then initialize multi-pathfinder in Gaussian ball around that point.

priors = Octofitter._list_priors(model.system)
lb = model.link(quantile.(priors,0.001))
ub = model.link(quantile.(priors,0.999))
f = OptimizationFunction(
(u,p)->-p.ℓπcallback(u),
grad=(G,u,p)->(G .= p.∇ℓπcallback(u)[2])
)
prob = Optimization.OptimizationProblem(f, model.link(quantile.(priors,0.5)), model; lb, ub)
Random.seed!(rand(rng, UInt64))
sol = solve(prob, BBO_adaptive_de_rand_1_bin(), rel_tol=1e-3, maxiters = 100_000, )

model.starting_points = fill(sol.u, 1000)
initial_logpost_range = (-sol.objective, -sol.objective)
if verbosity > 1
@info "Found the global maximum logpost" MAP=-sol.objective
end

# TODO: we don't really need to use pathfinder in this case, we should look into
# more rigourous variational methods
local result_pf = nothing
local metric = nothing
ldm_any = LogDensityModelAny(model)
verbosity >= 1 && @info "Determining initial positions and metric using pathfinder"
verbosity >= 1 && @info "Determining initial positions using pathfinder, around that location."
# It can sometimes hit a PosDefException sometimes when factoring a matrix.
# When that happens, the next try usually succeeds.
try
for i in 1:ntries
verbosity >= 3 && @info "Starting multipathfinder run"
init_sampler = function(rng, x)
if isnothing(initial_point) || length(initial_point) < model.D
if verbosity > 3
@info "drawing new starting guess by sampling IID from priors"
end
initial_θ, mapv = guess_starting_position(rng,model,initial_samples)
if verbosity > 3
@info "Starting point drawn" initial_logpost=mapv
end
end
if !isnothing(initial_point)
if length(initial_point) < model.D
initial_θ = (initial_point..., initial_θ[length(initial_point)+1:end]...)
else
initial_θ = initial_point
init_sampler = function(rng, x)
for _ in 1:10
# take a random step away from the MAP value according to the gradient
x .= sol.u -0.1 .* rand.(rng) .* model.∇ℓπcallback(sol.u)[2]
if all(lb .< x .< ub) && isfinite(model.ℓπcallback(x))
return
end
end
initial_θ_t = model.link(initial_θ)
x .= initial_θ_t
error("Could not find starting point within 0.001 - 0.999 quantiles of priors.")
end
errlogger = ConsoleLogger(stderr, verbosity >=3 ? Logging.Info : Logging.Error)
initial_mt = _kepsolve_use_threads[]
Expand All @@ -257,7 +270,7 @@ function default_initializer!(rng::Random.AbstractRNG, model::LogDensityModel; i
result_pf = Pathfinder.multipathfinder(
ldm_any, ndraws;
nruns,
init_sampler=CallableAny(init_sampler),
init_sampler,
progress=verbosity > 1,
maxiters=25_000,
reltol=1e-6,
Expand Down Expand Up @@ -285,40 +298,36 @@ function default_initializer!(rng::Random.AbstractRNG, model::LogDensityModel; i
verbosity > 2 && display(result_pf)
break
end
catch
catch err
@warn err
end

if !isnothing(result_pf)
model.starting_points = collect.(eachcol(result_pf.draws))
logposts = model.ℓπcallback.(model.starting_points)
initial_logpost_range = extrema(logposts)
end
# Occasionally there is a failure mode of pathfinder where, despite starting it at a reasonable spot, it returns garbage
# starting draws that are orders of magnitude worse.
# Check for this by ensuring the highest a-posteriori pathfinder draw is better than a random guess
_, random_guess_logpost = guess_starting_position(rng,model,100)
if isnothing(result_pf) || maximum(initial_logpost_range) < random_guess_logpost
if !isnothing(result_pf)
verbosity >= 1 && @warn "The highest posterior density sample from pathfinder is worse than a random guess..."
logposts = model.ℓπcallback.(model.starting_points)
initial_logpost_range = extrema(logposts)

if initial_logpost_range[2] < -sol.objective - 10
if verbosity >= 1
@warn "Pathfinder produced samples with log-likelihood 10 worse than global max. Will just initialize at global max."
end
verbosity >= 1 && @warn "Falling back to sampling from the prior and keeping the $ndraws samples with highest posterior density."
samples_t = map(1:1000) do _
initial_θ, mapv = guess_starting_position(rng,model,max(1,initial_samples÷100))
initial_θ_t = model.link(initial_θ)
return initial_θ_t
model.starting_points = fill(sol.u, 1000)
logposts = fill(-sol.objective, 1000)
initial_logpost_range = (-sol.objective, -sol.objective)
else
if verbosity >= 1
@info "Found a sample of initial positions" initial_logpost_range
end
# samples = sample_priors(rng, model, ndraws)
# samples_t = model.link.(samples)
logposts = model.ℓπcallback.(samples_t)
II = sortperm(logposts, rev=true)[begin:ndraws]
model.starting_points = samples_t[II]
initial_logpost_range = extrema(@view logposts[II])
logposts = logposts[II]
end

if verbosity >= 1
@info "Found a sample of initial positions" initial_logpost_range
end

return model.arr2nt(model.invlink(model.starting_points[argmax(logposts)]))
end


# Helper function for testing that the pathfinder initialization gives reasonable results
function _startingpoints2chain(model)
solnts = [(;logpost=0,model.arr2nt(model.invlink(s))...,) for s in model.starting_points]
chn = Octofitter.result2mcmcchain(solnts, Dict(:internals => [:logpost]))
return chn
end

0 comments on commit 1f1e654

Please sign in to comment.