Skip to content

Commit

Permalink
feat: add implementations of CheckInit and OverrideInit
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Nov 4, 2024
1 parent 77fff1e commit f5efaba
Show file tree
Hide file tree
Showing 2 changed files with 166 additions and 3 deletions.
7 changes: 6 additions & 1 deletion src/SciMLBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,11 @@ $(TYPEDEF)
"""
struct CheckInit <: DAEInitializationAlgorithm end

"""
$(TYPEDEF)
"""
struct OverrideInit <: DAEInitializationAlgorithm end

# PDE Discretizations

"""
Expand Down Expand Up @@ -654,7 +659,6 @@ Internal. Used for signifying the AD context comes from a Tracker.jl context.
struct TrackerOriginator <: ADOriginator end

include("utils.jl")
include("initialization.jl")
include("function_wrappers.jl")
include("scimlfunctions.jl")
include("alg_traits.jl")
Expand Down Expand Up @@ -740,6 +744,7 @@ include("ensemble/ensemble_problems.jl")
include("ensemble/basic_ensemble_solve.jl")
include("ensemble/ensemble_analysis.jl")

include("initialization.jl")
include("solve.jl")
include("interpolation.jl")
include("integrator_interface.jl")
Expand Down
162 changes: 160 additions & 2 deletions src/initialization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,12 @@ struct OverrideInitData{IProb, UIProb, IProbMap, IProbPmap}
"""
initializeprob::IProb
"""
A function which takes `(initializeprob, prob)` and updates
A function which takes `(initializeprob, value_provider)` and updates
the parameters of the former with their values in the latter.
If absent (`nothing`) this will not be called, and the parameters
in `initializeprob` will be used without modification. `value_provider`
refers to a value provider as defined by SymbolicIndexingInterface.jl.
Usually this will refer to a problem or integrator.
"""
update_initializeprob!::UIProb
"""
Expand All @@ -20,7 +24,9 @@ struct OverrideInitData{IProb, UIProb, IProbMap, IProbPmap}
initializeprobmap::IProbMap
"""
A function which takes the solution of `initializeprob` and returns
the parameter object of the original problem.
the parameter object of the original problem. If absent (`nothing`),
this will not be called and the parameters of the problem being
initialized will be returned as-is.
"""
initializeprobpmap::IProbPmap

Expand All @@ -30,3 +36,155 @@ struct OverrideInitData{IProb, UIProb, IProbMap, IProbPmap}
return new{I, J, K, L}(initprob, update_initprob!, initprobmap, initprobpmap)
end
end

"""
get_initial_values(prob, valp, f, alg, isinplace; kwargs...)
Return the initial `u0` and `p` for the given SciMLProblem and initialization algorithm,
and a boolean indicating whether the initialization process was successful. Keyword
arguments to this function are dependent on the initialization algorithm. `prob` is only
required for dispatching. `valp` refers the appropriate data structure from which the
current state and parameter values should be obtained. `valp` is a non-timeseries value
provider as defined by SymbolicIndexingInterface.jl. `f` is the SciMLFunction for the
problem. `alg` is the initialization algorithm to use. `isinplace` is either `Val{true}`
if `valp` and the SciMLFunction are inplace, and `Val{false}` otherwise.
"""
function get_initial_values end

struct CheckInitFailureError <: Exception
normresid::Any
abstol::Any
end

function Base.showerror(io::IO, e::CheckInitFailureError)
print(io,
"CheckInit specified but initialization not satisfied. normresid = $(e.normresid) > abstol = $(e.abstol)")
end

struct OverrideInitMissingAlgorithm <: Exception end

function Base.showerror(io::IO, e::OverrideInitMissingAlgorithm)
print(io,
"OverrideInit specified but no NonlinearSolve.jl algorithm provided. Provide an algorithm via the `nlsolve_alg` keyword argument to `get_initial_values`.")
end

"""
Utility function to evaluate the RHS of the ODE, using the integrator's `tmp_cache` if
it is in-place or simply calling the function if not.
"""
function _evaluate_f_ode(integrator, f, isinplace::Val{true}, args...)
tmp = first(get_tmp_cache(integrator))
f(tmp, args...)
return tmp
end

function _evaluate_f_ode(integrator, f, isinplace::Val{false}, args...)
return f(args...)
end

"""
$(TYPEDSIGNATURES)
A utility function equivalent to `Base.vec` but also handles `Number` and
`AbstractSciMLScalarOperator`.
"""
_vec(v) = vec(v)
_vec(v::Number) = v
_vec(v::SciMLOperators.AbstractSciMLScalarOperator) = v
_vec(v::AbstractVector) = v

"""
$(TYPEDSIGNATURES)
Check if the algebraic constraints are satisfied, and error if they aren't. Returns
the `u0` and `p` as-is, and is always successful if it returns. Valid only for
`ODEProblem` and `DAEProblem`. Requires a `DEIntegrator` as its second argument.
"""
function get_initial_values(prob::ODEProblem, integrator, f, alg::CheckInit,
isinplace::Union{Val{true}, Val{false}}; kwargs...)
u0 = state_values(integrator)
p = parameter_values(integrator)
t = current_time(integrator)
M = f.mass_matrix

algebraic_vars = [all(iszero, x) for x in eachcol(M)]
algebraic_eqs = [all(iszero, x) for x in eachrow(M)]
(iszero(algebraic_vars) || iszero(algebraic_eqs)) && return
update_coefficients!(M, u0, p, t)
tmp = _evaluate_f_ode(integrator, f, isinplace, u0, p, t)
tmp .= ArrayInterface.restructure(tmp, algebraic_eqs .* _vec(tmp))

normresid = integrator.opts.internalnorm(tmp, t)
if normresid > integrator.opts.abstol
throw(CheckInitFailureError(normresid, integrator.opts.abstol))
end
return u0, p, true
end

"""
Utility function to evaluate the RHS of the DAE, using the integrator's `tmp_cache` if
it is in-place or simply calling the function if not.
"""
function _evaluate_f_dae(integrator, f, isinplace::Val{true}, args...)
tmp = get_tmp_cache(integrator)[2]
f(tmp, args...)
return tmp
end

function _evaluate_f_dae(integrator, f, isinplace::Val{false}, args...)
return f(args...)
end

function get_initial_values(prob::DAEProblem, integrator, f, alg::CheckInit,
isinplace::Union{Val{true}, Val{false}}; kwargs...)
u0 = state_values(integrator)
p = parameter_values(integrator)
t = current_time(integrator)

resid = _evaluate_f_dae(integrator, f, isinplace, integrator.du, u0, p, t)
normresid = integrator.opts.internalnorm(resid, t)
if normresid > integrator.opts.abstol
throw(CheckInitFailureError(normresid, integrator.opts.abstol))
end
return u0, p, true
end

"""
$(TYPEDSIGNATURES)
Solve a `NonlinearProblem`/`NonlinearLeastSquaresProblem` to obtain the initial `u0` and
`p`. Requires that `f` have the field `initialization_data` which is an `OverrideInitData`.
If the field is absent or the value is `nothing`, return `u0` and `p` successfully as-is.
The NonlinearSolve.jl algorithm to use must be specified through the `nlsolve_alg` keyword
argument, failing which this function will throw an error. The success value returned
depends on the success of the nonlinear solve.
"""
function get_initial_values(prob, valp, f, alg::OverrideInit,
isinplace::Union{Val{true}, Val{false}}; nlsolve_alg = nothing, kwargs...)
u0 = state_values(valp)
p = parameter_values(valp)

if !has_initialization_data(f)
return u0, p, true
end

initdata::OverrideInitData = f.initialization_data
initprob = initdata.initializeprob

if nlsolve_alg === nothing
throw(OverrideInitMissingAlgorithm())
end

if initdata.update_initializeprob! !== nothing
initdata.update_initializeprob!(initprob, valp)
end

nlsol = solve(initprob, nlsolve_alg)

u0 = initdata.initializeprobmap(nlsol)
if initdata.initializeprobpmap !== nothing
p = initdata.initializeprobpmap(nlsol)
end

return u0, p, SciMLBase.successful_retcode(nlsol)
end

0 comments on commit f5efaba

Please sign in to comment.