Skip to content

Commit

Permalink
feat: require providing tolerances in CheckInit and OverrideInit
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Nov 14, 2024
1 parent e64fcf5 commit 0f41da2
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 26 deletions.
61 changes: 48 additions & 13 deletions src/initialization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,15 @@ function Base.showerror(io::IO, e::OverrideInitMissingAlgorithm)
"OverrideInit specified but no NonlinearSolve.jl algorithm provided. Provide an algorithm via the `nlsolve_alg` keyword argument to `get_initial_values`.")
end

struct OverrideInitNoTolerance <: Exception
tolerance::Symbol
end

function Base.showerror(io::IO, e::CheckInitNoTolerance)
print(io,
"Tolerances were not provided to `OverrideInit`. `$(e.tolerance)` must be provided as a keyword argument to `get_initial_values` or as a keyword argument to the `OverrideInit` constructor.")
end

"""
Utility function to evaluate the RHS of the ODE, using the integrator's `tmp_cache` if
it is in-place or simply calling the function if not.
Expand Down Expand Up @@ -98,11 +107,16 @@ _vec(v::AbstractVector) = v
Check if the algebraic constraints are satisfied, and error if they aren't. Returns
the `u0` and `p` as-is, and is always successful if it returns. Valid only for
`ODEProblem` and `DAEProblem`. Requires a `DEIntegrator` as its second argument.
`AbstractDEProblem` and `AbstractDAEProblem`. Requires a `DEIntegrator` as its second argument.
Keyword arguments:
- `abstol`: The absolute value below which the norm of the residual of algebraic equations
should lie. The norm function used is `integrator.opts.internalnorm` if present, and
`LinearAlgebra.norm` if not.
"""
function get_initial_values(
prob::AbstractDEProblem, integrator::DEIntegrator, f, alg::CheckInit,
isinplace::Union{Val{true}, Val{false}}; kwargs...)
isinplace::Union{Val{true}, Val{false}}; abstol, kwargs...)
u0 = state_values(integrator)
p = parameter_values(integrator)
t = current_time(integrator)
Expand All @@ -117,8 +131,8 @@ function get_initial_values(

normresid = isdefined(integrator.opts, :internalnorm) ?
integrator.opts.internalnorm(tmp, t) : norm(tmp)
if normresid > integrator.opts.abstol
throw(CheckInitFailureError(normresid, integrator.opts.abstol))
if normresid > abstol
throw(CheckInitFailureError(normresid, abstol))
end
return u0, p, true
end
Expand All @@ -139,16 +153,20 @@ end

function get_initial_values(
prob::AbstractDAEProblem, integrator::DEIntegrator, f, alg::CheckInit,
isinplace::Union{Val{true}, Val{false}}; kwargs...)
isinplace::Union{Val{true}, Val{false}}; abstol = nothing, kwargs...)
u0 = state_values(integrator)
p = parameter_values(integrator)
t = current_time(integrator)

resid = _evaluate_f_dae(integrator, f, isinplace, integrator.du, u0, p, t)
normresid = isdefined(integrator.opts, :internalnorm) ?
integrator.opts.internalnorm(resid, t) : norm(resid)
if normresid > integrator.opts.abstol
throw(CheckInitFailureError(normresid, integrator.opts.abstol))

if abstol === nothing
abstol = cache_get_abstol(integrator)
end
if normresid > abstol
throw(CheckInitFailureError(normresid, abstol))
end
return u0, p, true
end
Expand All @@ -159,12 +177,19 @@ end
Solve a `NonlinearProblem`/`NonlinearLeastSquaresProblem` to obtain the initial `u0` and
`p`. Requires that `f` have the field `initialization_data` which is an `OverrideInitData`.
If the field is absent or the value is `nothing`, return `u0` and `p` successfully as-is.
The NonlinearSolve.jl algorithm to use must be specified through the `nlsolve_alg` keyword
argument, failing which this function will throw an error. The success value returned
depends on the success of the nonlinear solve.
The success value returned depends on the success of the nonlinear solve.
Keyword arguments:
- `nlsolve_alg`: The NonlinearSolve.jl algorithm to use. If not provided, this function will
throw an error.
- `abstol`, `reltol`: The `abstol` (`reltol`) to use for the nonlinear solve. The value
provided to the `OverrideInit` constructor takes priority over this keyword argument.
If the former is `nothing`, this keyword argument will be used. If it is also not provided,
an error will be thrown.
"""
function get_initial_values(prob, valp, f, alg::OverrideInit,
iip::Union{Val{true}, Val{false}}; nlsolve_alg = nothing, kwargs...)
iip::Union{Val{true}, Val{false}}; nlsolve_alg = nothing, abstol = nothing, reltol = nothing, kwargs...)
u0 = state_values(valp)
p = parameter_values(valp)

Expand All @@ -185,10 +210,20 @@ function get_initial_values(prob, valp, f, alg::OverrideInit,
end

if alg.abstol !== nothing
nlsol = solve(initprob, nlsolve_alg; abstol = alg.abstol)
_abstol = alg.abstol
elseif abstol !== nothing
_abstol = abstol
else
throw(OverrideInitNoTolerance(:abstol))
end
if alg.reltol !== nothing
_reltol = alg.reltol
elseif reltol !== nothing
_reltol = reltol
else
nlsol = solve(initprob, nlsolve_alg)
throw(OverrideInitNoTolerance(:reltol))
end
nlsol = solve(initprob, nlsolve_alg; abstol = _abstol, reltol = _reltol)

u0 = initdata.initializeprobmap(nlsol)
if initdata.initializeprobpmap !== nothing
Expand Down
36 changes: 23 additions & 13 deletions test/initialization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,15 @@ using StochasticDiffEq, OrdinaryDiffEq, NonlinearSolve, SymbolicIndexingInterfac
prob = ODEProblem(f, [1.0, 1.0], (0.0, 1.0))
integ = init(prob)
u0, _, success = SciMLBase.get_initial_values(
prob, integ, f, SciMLBase.CheckInit(), Val(SciMLBase.isinplace(f)))
prob, integ, f, SciMLBase.CheckInit(),
Val(SciMLBase.isinplace(f)); abstol = 1e-10)
@test success
@test u0 == prob.u0

integ.u[2] = 2.0
@test_throws SciMLBase.CheckInitFailureError SciMLBase.get_initial_values(
prob, integ, f, SciMLBase.CheckInit(), Val(SciMLBase.isinplace(f)))
prob, integ, f, SciMLBase.CheckInit(),
Val(SciMLBase.isinplace(f)); abstol = 1e-10)
end
end

Expand All @@ -43,18 +45,21 @@ using StochasticDiffEq, OrdinaryDiffEq, NonlinearSolve, SymbolicIndexingInterfac
prob = DAEProblem(f, [1.0, 0.0], [1.0, 1.0], (0.0, 1.0), 1.0)
integ = init(prob, DImplicitEuler())
u0, _, success = SciMLBase.get_initial_values(
prob, integ, f, SciMLBase.CheckInit(), Val(SciMLBase.isinplace(f)))
prob, integ, f, SciMLBase.CheckInit(),
Val(SciMLBase.isinplace(f)); abstol = 1e-10)
@test success
@test u0 == prob.u0

integ.u[2] = 2.0
@test_throws SciMLBase.CheckInitFailureError SciMLBase.get_initial_values(
prob, integ, f, SciMLBase.CheckInit(), Val(SciMLBase.isinplace(f)))
prob, integ, f, SciMLBase.CheckInit(),
Val(SciMLBase.isinplace(f)); abstol = 1e-10)

integ.u[2] = 1.0
integ.du[1] = 2.0
@test_throws SciMLBase.CheckInitFailureError SciMLBase.get_initial_values(
prob, integ, f, SciMLBase.CheckInit(), Val(SciMLBase.isinplace(f)))
prob, integ, f, SciMLBase.CheckInit(),
Val(SciMLBase.isinplace(f)); abstol = 1e-10)
end
end

Expand Down Expand Up @@ -86,13 +91,15 @@ using StochasticDiffEq, OrdinaryDiffEq, NonlinearSolve, SymbolicIndexingInterfac
prob = SDEProblem(f, [1.0, 1.0, -1.0], (0.0, 1.0))
integ = init(prob, ImplicitEM())
u0, _, success = SciMLBase.get_initial_values(
prob, integ, f, SciMLBase.CheckInit(), Val(SciMLBase.isinplace(f)))
prob, integ, f, SciMLBase.CheckInit(),
Val(SciMLBase.isinplace(f)); abstol = 1e-10)
@test success
@test u0 == prob.u0

integ.u[2] = 2.0
@test_throws SciMLBase.CheckInitFailureError SciMLBase.get_initial_values(
prob, integ, f, SciMLBase.CheckInit(), Val(SciMLBase.isinplace(f)))
prob, integ, f, SciMLBase.CheckInit(),
Val(SciMLBase.isinplace(f)); abstol = 1e-10)
end
end
end
Expand Down Expand Up @@ -138,11 +145,13 @@ end
prob, integ, fn, SciMLBase.OverrideInit(), Val(false))
end

abstol = 1e-10
reltol = 1e-10
@testset "Solves" begin
@testset "with explicit alg" begin
u0, p, success = SciMLBase.get_initial_values(
prob, integ, fn, SciMLBase.OverrideInit(),
Val(false); nlsolve_alg = NewtonRaphson())
Val(false); nlsolve_alg = NewtonRaphson(), abstol, reltol)

@test u0 [2.0, 2.0]
@test p 1.0
Expand All @@ -152,7 +161,8 @@ end
end
@testset "with alg in `OverrideInit`" begin
u0, p, success = SciMLBase.get_initial_values(
prob, integ, fn, SciMLBase.OverrideInit(nlsolve = NewtonRaphson()),
prob, integ, fn,
SciMLBase.OverrideInit(; nlsolve = NewtonRaphson(), abstol, reltol),
Val(false))

@test u0 [2.0, 2.0]
Expand All @@ -170,7 +180,7 @@ end
_integ = init(_prob; initializealg = NoInit())

u0, p, success = SciMLBase.get_initial_values(
_prob, _integ, _fn, SciMLBase.OverrideInit(), Val(false))
_prob, _integ, _fn, SciMLBase.OverrideInit(), Val(false); abstol, reltol)

@test u0 [1.0, 1.0]
@test p 1.0
Expand All @@ -182,7 +192,7 @@ end
_integ = ProblemState(; u = integ.u, p = parameter_values(integ), t = integ.t)
u0, p, success = SciMLBase.get_initial_values(
prob, _integ, fn, SciMLBase.OverrideInit(),
Val(false); nlsolve_alg = NewtonRaphson())
Val(false); nlsolve_alg = NewtonRaphson(), abstol, reltol)

@test u0 [2.0, 2.0]
@test p 1.0
Expand All @@ -199,7 +209,7 @@ end

u0, p, success = SciMLBase.get_initial_values(
prob, integ, fn, SciMLBase.OverrideInit(),
Val(false); nlsolve_alg = NewtonRaphson())
Val(false); nlsolve_alg = NewtonRaphson(), abstol, reltol)
@test u0 [1.0, 1.0]
@test p 1.0
@test success
Expand All @@ -213,7 +223,7 @@ end

u0, p, success = SciMLBase.get_initial_values(
prob, integ, fn, SciMLBase.OverrideInit(),
Val(false); nlsolve_alg = NewtonRaphson())
Val(false); nlsolve_alg = NewtonRaphson(), abstol, reltol)

@test u0 [2.0, 2.0]
@test p 0.0
Expand Down

0 comments on commit 0f41da2

Please sign in to comment.