Skip to content

Commit

Permalink
ensemble error depwarn fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisRackauckas committed Sep 7, 2024
1 parent c30f1de commit 0841a5e
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 6 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "SciMLBase"
uuid = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
authors = ["Chris Rackauckas <[email protected]> and contributors"]
version = "2.52.1"
version = "2.52.2"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
26 changes: 21 additions & 5 deletions src/ensemble/ensemble_solutions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@ function calculate_ensemble_errors(u; elapsedTime = 0.0, converged = false,
errors = Dict{Symbol, Vector{eltype(u[1].u[1])}}() #Should add type information
error_means = Dict{Symbol, eltype(u[1].u[1])}()
error_medians = Dict{Symbol, eltype(u[1].u[1])}()

analyticvoa = u[1].u_analytic isa AbstractVectorOfArray ? true : false

for k in keys(u[1].errors)
errors[k] = [sol.errors[k] for sol in u]
error_means[k] = mean(errors[k])
Expand All @@ -98,12 +101,24 @@ function calculate_ensemble_errors(u; elapsedTime = 0.0, converged = false,
weak_errors = Dict{Symbol, eltype(u[1].u[1])}()
# Final
m_final = mean([s.u[end] for s in u])
m_final_analytic = mean([s.u_analytic[end] for s in u])

if analyticvoa
m_final_analytic = mean([s.u_analytic.u[end] for s in u])
else
m_final_analytic = mean([s.u_analytic[end] for s in u])
end

res = norm(m_final - m_final_analytic)
weak_errors[:weak_final] = res
if weak_timeseries_errors
ts_weak_errors = [mean([u[j].u[i] - u[j].u_analytic[i] for j in 1:length(u)])
for i in 1:length(u[1])]

if analyticvoa
ts_weak_errors = [mean([u[j].u[i] - u[j].u_analytic.u[i] for j in 1:length(u)])
for i in 1:length(u[1])]
else
ts_weak_errors = [mean([u[j].u[i] - u[j].u_analytic[i] for j in 1:length(u)])
for i in 1:length(u[1])]
end
ts_l2_errors = [sqrt.(sum(abs2, err) / length(err)) for err in ts_weak_errors]
l2_tmp = sqrt(sum(abs2, ts_l2_errors) / length(ts_l2_errors))
max_tmp = maximum([maximum(abs.(err)) for err in ts_weak_errors])
Expand All @@ -113,8 +128,9 @@ function calculate_ensemble_errors(u; elapsedTime = 0.0, converged = false,
if weak_dense_errors
densetimes = collect(range(u[1].t[1], stop = u[1].t[end], length = 100))
u_analytic = [[sol.prob.f.analytic(sol.prob.u0, sol.prob.p, densetimes[i],
sol.W(densetimes[i])[1])
for i in eachindex(densetimes)] for sol in u]
sol.W(densetimes[i])[1])
for i in eachindex(densetimes)] for sol in u]

udense = [u[j](densetimes) for j in 1:length(u)]
dense_weak_errors = [mean([udense[j].u[i] - u_analytic[j][i] for j in 1:length(u)])
for i in eachindex(densetimes)]
Expand Down

0 comments on commit 0841a5e

Please sign in to comment.