diff --git a/src/initialization.jl b/src/initialization.jl index 8111c1452..7acbea8e9 100644 --- a/src/initialization.jl +++ b/src/initialization.jl @@ -68,6 +68,15 @@ function Base.showerror(io::IO, e::OverrideInitMissingAlgorithm) "OverrideInit specified but no NonlinearSolve.jl algorithm provided. Provide an algorithm via the `nlsolve_alg` keyword argument to `get_initial_values`.") end +struct OverrideInitNoTolerance <: Exception + tolerance::Symbol +end + +function Base.showerror(io::IO, e::CheckInitNoTolerance) + print(io, + "Tolerances were not provided to `OverrideInit`. `$(e.tolerance)` must be provided as a keyword argument to `get_initial_values` or as a keyword argument to the `OverrideInit` constructor.") +end + """ Utility function to evaluate the RHS of the ODE, using the integrator's `tmp_cache` if it is in-place or simply calling the function if not. @@ -98,11 +107,16 @@ _vec(v::AbstractVector) = v Check if the algebraic constraints are satisfied, and error if they aren't. Returns the `u0` and `p` as-is, and is always successful if it returns. Valid only for -`ODEProblem` and `DAEProblem`. Requires a `DEIntegrator` as its second argument. +`AbstractDEProblem` and `AbstractDAEProblem`. Requires a `DEIntegrator` as its second argument. + +Keyword arguments: +- `abstol`: The absolute value below which the norm of the residual of algebraic equations + should lie. The norm function used is `integrator.opts.internalnorm` if present, and + `LinearAlgebra.norm` if not. """ function get_initial_values( prob::AbstractDEProblem, integrator::DEIntegrator, f, alg::CheckInit, - isinplace::Union{Val{true}, Val{false}}; kwargs...) + isinplace::Union{Val{true}, Val{false}}; abstol, kwargs...) u0 = state_values(integrator) p = parameter_values(integrator) t = current_time(integrator) @@ -117,8 +131,8 @@ function get_initial_values( normresid = isdefined(integrator.opts, :internalnorm) ? integrator.opts.internalnorm(tmp, t) : norm(tmp) - if normresid > integrator.opts.abstol - throw(CheckInitFailureError(normresid, integrator.opts.abstol)) + if normresid > abstol + throw(CheckInitFailureError(normresid, abstol)) end return u0, p, true end @@ -139,7 +153,7 @@ end function get_initial_values( prob::AbstractDAEProblem, integrator::DEIntegrator, f, alg::CheckInit, - isinplace::Union{Val{true}, Val{false}}; kwargs...) + isinplace::Union{Val{true}, Val{false}}; abstol = nothing, kwargs...) u0 = state_values(integrator) p = parameter_values(integrator) t = current_time(integrator) @@ -147,8 +161,12 @@ function get_initial_values( resid = _evaluate_f_dae(integrator, f, isinplace, integrator.du, u0, p, t) normresid = isdefined(integrator.opts, :internalnorm) ? integrator.opts.internalnorm(resid, t) : norm(resid) - if normresid > integrator.opts.abstol - throw(CheckInitFailureError(normresid, integrator.opts.abstol)) + + if abstol === nothing + abstol = cache_get_abstol(integrator) + end + if normresid > abstol + throw(CheckInitFailureError(normresid, abstol)) end return u0, p, true end @@ -159,12 +177,19 @@ end Solve a `NonlinearProblem`/`NonlinearLeastSquaresProblem` to obtain the initial `u0` and `p`. Requires that `f` have the field `initialization_data` which is an `OverrideInitData`. If the field is absent or the value is `nothing`, return `u0` and `p` successfully as-is. -The NonlinearSolve.jl algorithm to use must be specified through the `nlsolve_alg` keyword -argument, failing which this function will throw an error. The success value returned -depends on the success of the nonlinear solve. + +The success value returned depends on the success of the nonlinear solve. + +Keyword arguments: +- `nlsolve_alg`: The NonlinearSolve.jl algorithm to use. If not provided, this function will + throw an error. +- `abstol`, `reltol`: The `abstol` (`reltol`) to use for the nonlinear solve. The value + provided to the `OverrideInit` constructor takes priority over this keyword argument. + If the former is `nothing`, this keyword argument will be used. If it is also not provided, + an error will be thrown. """ function get_initial_values(prob, valp, f, alg::OverrideInit, - iip::Union{Val{true}, Val{false}}; nlsolve_alg = nothing, kwargs...) + iip::Union{Val{true}, Val{false}}; nlsolve_alg = nothing, abstol = nothing, reltol = nothing, kwargs...) u0 = state_values(valp) p = parameter_values(valp) @@ -185,10 +210,20 @@ function get_initial_values(prob, valp, f, alg::OverrideInit, end if alg.abstol !== nothing - nlsol = solve(initprob, nlsolve_alg; abstol = alg.abstol) + _abstol = alg.abstol + elseif abstol !== nothing + _abstol = abstol + else + throw(OverrideInitNoTolerance(:abstol)) + end + if alg.reltol !== nothing + _reltol = alg.reltol + elseif reltol !== nothing + _reltol = reltol else - nlsol = solve(initprob, nlsolve_alg) + throw(OverrideInitNoTolerance(:reltol)) end + nlsol = solve(initprob, nlsolve_alg; abstol = _abstol, reltol = _reltol) u0 = initdata.initializeprobmap(nlsol) if initdata.initializeprobpmap !== nothing diff --git a/test/initialization.jl b/test/initialization.jl index c99ce74b6..4ed011569 100644 --- a/test/initialization.jl +++ b/test/initialization.jl @@ -17,13 +17,15 @@ using StochasticDiffEq, OrdinaryDiffEq, NonlinearSolve, SymbolicIndexingInterfac prob = ODEProblem(f, [1.0, 1.0], (0.0, 1.0)) integ = init(prob) u0, _, success = SciMLBase.get_initial_values( - prob, integ, f, SciMLBase.CheckInit(), Val(SciMLBase.isinplace(f))) + prob, integ, f, SciMLBase.CheckInit(), + Val(SciMLBase.isinplace(f)); abstol = 1e-10) @test success @test u0 == prob.u0 integ.u[2] = 2.0 @test_throws SciMLBase.CheckInitFailureError SciMLBase.get_initial_values( - prob, integ, f, SciMLBase.CheckInit(), Val(SciMLBase.isinplace(f))) + prob, integ, f, SciMLBase.CheckInit(), + Val(SciMLBase.isinplace(f)); abstol = 1e-10) end end @@ -43,18 +45,21 @@ using StochasticDiffEq, OrdinaryDiffEq, NonlinearSolve, SymbolicIndexingInterfac prob = DAEProblem(f, [1.0, 0.0], [1.0, 1.0], (0.0, 1.0), 1.0) integ = init(prob, DImplicitEuler()) u0, _, success = SciMLBase.get_initial_values( - prob, integ, f, SciMLBase.CheckInit(), Val(SciMLBase.isinplace(f))) + prob, integ, f, SciMLBase.CheckInit(), + Val(SciMLBase.isinplace(f)); abstol = 1e-10) @test success @test u0 == prob.u0 integ.u[2] = 2.0 @test_throws SciMLBase.CheckInitFailureError SciMLBase.get_initial_values( - prob, integ, f, SciMLBase.CheckInit(), Val(SciMLBase.isinplace(f))) + prob, integ, f, SciMLBase.CheckInit(), + Val(SciMLBase.isinplace(f)); abstol = 1e-10) integ.u[2] = 1.0 integ.du[1] = 2.0 @test_throws SciMLBase.CheckInitFailureError SciMLBase.get_initial_values( - prob, integ, f, SciMLBase.CheckInit(), Val(SciMLBase.isinplace(f))) + prob, integ, f, SciMLBase.CheckInit(), + Val(SciMLBase.isinplace(f)); abstol = 1e-10) end end @@ -86,13 +91,15 @@ using StochasticDiffEq, OrdinaryDiffEq, NonlinearSolve, SymbolicIndexingInterfac prob = SDEProblem(f, [1.0, 1.0, -1.0], (0.0, 1.0)) integ = init(prob, ImplicitEM()) u0, _, success = SciMLBase.get_initial_values( - prob, integ, f, SciMLBase.CheckInit(), Val(SciMLBase.isinplace(f))) + prob, integ, f, SciMLBase.CheckInit(), + Val(SciMLBase.isinplace(f)); abstol = 1e-10) @test success @test u0 == prob.u0 integ.u[2] = 2.0 @test_throws SciMLBase.CheckInitFailureError SciMLBase.get_initial_values( - prob, integ, f, SciMLBase.CheckInit(), Val(SciMLBase.isinplace(f))) + prob, integ, f, SciMLBase.CheckInit(), + Val(SciMLBase.isinplace(f)); abstol = 1e-10) end end end @@ -138,11 +145,13 @@ end prob, integ, fn, SciMLBase.OverrideInit(), Val(false)) end + abstol = 1e-10 + reltol = 1e-10 @testset "Solves" begin @testset "with explicit alg" begin u0, p, success = SciMLBase.get_initial_values( prob, integ, fn, SciMLBase.OverrideInit(), - Val(false); nlsolve_alg = NewtonRaphson()) + Val(false); nlsolve_alg = NewtonRaphson(), abstol, reltol) @test u0 ≈ [2.0, 2.0] @test p ≈ 1.0 @@ -152,7 +161,8 @@ end end @testset "with alg in `OverrideInit`" begin u0, p, success = SciMLBase.get_initial_values( - prob, integ, fn, SciMLBase.OverrideInit(nlsolve = NewtonRaphson()), + prob, integ, fn, + SciMLBase.OverrideInit(; nlsolve = NewtonRaphson(), abstol, reltol), Val(false)) @test u0 ≈ [2.0, 2.0] @@ -170,7 +180,7 @@ end _integ = init(_prob; initializealg = NoInit()) u0, p, success = SciMLBase.get_initial_values( - _prob, _integ, _fn, SciMLBase.OverrideInit(), Val(false)) + _prob, _integ, _fn, SciMLBase.OverrideInit(), Val(false); abstol, reltol) @test u0 ≈ [1.0, 1.0] @test p ≈ 1.0 @@ -182,7 +192,7 @@ end _integ = ProblemState(; u = integ.u, p = parameter_values(integ), t = integ.t) u0, p, success = SciMLBase.get_initial_values( prob, _integ, fn, SciMLBase.OverrideInit(), - Val(false); nlsolve_alg = NewtonRaphson()) + Val(false); nlsolve_alg = NewtonRaphson(), abstol, reltol) @test u0 ≈ [2.0, 2.0] @test p ≈ 1.0 @@ -199,7 +209,7 @@ end u0, p, success = SciMLBase.get_initial_values( prob, integ, fn, SciMLBase.OverrideInit(), - Val(false); nlsolve_alg = NewtonRaphson()) + Val(false); nlsolve_alg = NewtonRaphson(), abstol, reltol) @test u0 ≈ [1.0, 1.0] @test p ≈ 1.0 @test success @@ -213,7 +223,7 @@ end u0, p, success = SciMLBase.get_initial_values( prob, integ, fn, SciMLBase.OverrideInit(), - Val(false); nlsolve_alg = NewtonRaphson()) + Val(false); nlsolve_alg = NewtonRaphson(), abstol, reltol) @test u0 ≈ [2.0, 2.0] @test p ≈ 0.0