diff --git a/src/ODE_nlsolve.jl b/src/ODE_nlsolve.jl new file mode 100644 index 000000000..0e45d2d0e --- /dev/null +++ b/src/ODE_nlsolve.jl @@ -0,0 +1,46 @@ +""" + $(TYPEDEF) + +A collection of all the data required for custom ODE Nonlinear problem solving +""" +struct ODE_NLProbData{NLProb, UNLProb, NLProbMap, NLProbPmap} + """ + The `AbstractNonlinearProblem` to define custom nonlinear problems to be used for + implicit time discretizations. This allows to use extra structure of the ODE function (e.g. + multi-level structure). The nonlinear function must match that form of the function implicit + ODE integration algorithms need do solve the a nonlinear problems, + specifically of the form `z = outer_tmp + dt⋅f(γ⋅z+inner_tmp,p,t)`. + Here `z` is the stage solution vector, `p` is the parameter of the ODE problem, `t` is + the time, `dt` the respective time increment`, `γ` is some scaling factor and the temporary + variables are some compatible vectors set by the specific solver. + Note that this field will not be used for integrators such as fully-implicit Runge-Kutta methods + that need to solve different nonlinear systems. + The inner nonlinear function of the nonlinear problem is in general of the form `g(z,p') = 0` + where `p'` is a NamedTuple with all information about the specific nonlinear problem at hand to solve + for a specific time discretization. Specifically, it is `(;dt, γ, inner_tmp, outer_tmp, t, p)`, such that + `g(z,p') = dt⋅f(γ⋅z+inner_tmp,p,t) + outer_tmp - z = 0`. + """ + nlprob::NLProb + """ + A function which takes `(nlprob, value_provider)` and updates + the parameters of the former with their values in the latter. + If absent (`nothing`) this will not be called, and the parameters + in `nlprob` will be used without modification. `value_provider` + refers to a value provider as defined by SymbolicIndexingInterface.jl. + Usually this will refer to a problem or integrator. + """ + update_nlprob!::UNLProb + """ + A function which takes the solution of `nlprob` and returns + the state vector of the original problem. + """ + nlprobmap::NLProbMap + """ + A function which takes the solution of `nlprob` and returns + the parameter object of the original problem. If absent (`nothing`), + this will not be called and the parameters of the problem being + solved will be returned as-is. + """ + nlprobpmap::NLProbPmap +end + diff --git a/src/SciMLBase.jl b/src/SciMLBase.jl index e930101e7..5ab13494c 100644 --- a/src/SciMLBase.jl +++ b/src/SciMLBase.jl @@ -658,6 +658,8 @@ Internal. Used for signifying the AD context comes from a Tracker.jl context. """ struct TrackerOriginator <: ADOriginator end +include("initialization.jl") +include("ODE_nlsolve.jl") include("utils.jl") include("function_wrappers.jl") include("scimlfunctions.jl") @@ -744,7 +746,6 @@ include("ensemble/ensemble_problems.jl") include("ensemble/basic_ensemble_solve.jl") include("ensemble/ensemble_analysis.jl") -include("initialization.jl") include("solve.jl") include("interpolation.jl") include("integrator_interface.jl") diff --git a/src/initialization.jl b/src/initialization.jl index 86c71560d..58610269b 100644 --- a/src/initialization.jl +++ b/src/initialization.jl @@ -100,7 +100,7 @@ Check if the algebraic constraints are satisfied, and error if they aren't. Retu the `u0` and `p` as-is, and is always successful if it returns. Valid only for `ODEProblem` and `DAEProblem`. Requires a `DEIntegrator` as its second argument. """ -function get_initial_values(prob::ODEProblem, integrator, f, alg::CheckInit, +function get_initial_values(prob::AbstractODEProblem, integrator, f, alg::CheckInit, isinplace::Union{Val{true}, Val{false}}; kwargs...) u0 = state_values(integrator) p = parameter_values(integrator) @@ -135,7 +135,7 @@ function _evaluate_f_dae(integrator, f, isinplace::Val{false}, args...) return f(args...) end -function get_initial_values(prob::DAEProblem, integrator, f, alg::CheckInit, +function get_initial_values(prob::AbstractDAEProblem, integrator, f, alg::CheckInit, isinplace::Union{Val{true}, Val{false}}; kwargs...) u0 = state_values(integrator) p = parameter_values(integrator) diff --git a/src/scimlfunctions.jl b/src/scimlfunctions.jl index 239791d61..afd1dc80a 100644 --- a/src/scimlfunctions.jl +++ b/src/scimlfunctions.jl @@ -289,11 +289,6 @@ the usage of `f`. These include: based on the sparsity pattern. Defaults to `nothing`, which means a color vector will be internally computed on demand when required. The cost of this operation is highly dependent on the sparsity pattern. -- `nlprob`: a `NonlinearProblem` that solves `f(u, t, p) = u_tmp` - where the nonlinear parameters are the tuple `(t, u_tmp, p)`. - This will be used as the nonlinear problem inside an implicit solver by specifying `u, u_tmp` and `t` - such that solving this function produces a solution to the implicit step of your solver. - ## iip: In-Place vs Out-Of-Place `iip` is the optional boolean for determining whether a given function is written to @@ -406,7 +401,7 @@ numerically-defined functions. """ struct ODEFunction{iip, specialize, F, TMM, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TWt, WP, TPJ, O, TCV, - SYS, ID, NLP} <: AbstractODEFunction{iip} + SYS, ID<:Union{Nothing, OverrideInitData}, NLP<:Union{Nothing, ODE_NLProbData}} <: AbstractODEFunction{iip} f::F mass_matrix::TMM analytic::Ta @@ -424,7 +419,7 @@ struct ODEFunction{iip, specialize, F, TMM, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TW colorvec::TCV sys::SYS initialization_data::ID - nlprob::NLP + nlprob_data::NLP end @doc doc""" @@ -527,8 +522,7 @@ information on generating the SplitFunction from this symbolic engine. """ struct SplitFunction{ iip, specialize, F1, F2, TMM, C, Ta, Tt, TJ, JVP, VJP, JP, WP, SP, TW, TWt, - TPJ, O, - TCV, SYS, ID, NLP} <: AbstractODEFunction{iip} + TPJ, O, TCV, SYS, ID<:Union{Nothing, OverrideInitData}, NLP<:Union{Nothing, ODE_NLProbData}} <: AbstractODEFunction{iip} f1::F1 f2::F2 mass_matrix::TMM @@ -547,8 +541,8 @@ struct SplitFunction{ observed::O colorvec::TCV sys::SYS - nlprob::NLP initialization_data::ID + nlprob_data::NLP end @doc doc""" @@ -2446,9 +2440,9 @@ function ODEFunction{iip, specialize}(f; f.update_initializeprob! : nothing, initializeprobmap = __has_initializeprobmap(f) ? f.initializeprobmap : nothing, initializeprobpmap = __has_initializeprobpmap(f) ? f.initializeprobpmap : nothing, - nlprob = __has_nlprob(f) ? f.nlprob : nothing, initialization_data = __has_initialization_data(f) ? f.initialization_data : - nothing + nothing, + nlprob_data = __has_nlprob_data(f) ? f.nlprob_data : nothing, ) where {iip, specialize } @@ -2506,10 +2500,10 @@ function ODEFunction{iip, specialize}(f; typeof(sparsity), Any, Any, typeof(W_prototype), Any, Any, typeof(_colorvec), - typeof(sys), Any, Any}(_f, mass_matrix, analytic, tgrad, jac, + typeof(sys), Union{Nothing, OverrideInitData}, Union{Nothing, ODE_NLProbData}}(_f, mass_matrix, analytic, tgrad, jac, jvp, vjp, jac_prototype, sparsity, Wfact, Wfact_t, W_prototype, paramjac, - observed, _colorvec, sys, initdata, nlprob) + observed, _colorvec, sys, initdata, nlprob_data) elseif specialize === false ODEFunction{iip, FunctionWrapperSpecialize, typeof(_f), typeof(mass_matrix), typeof(analytic), typeof(tgrad), @@ -2518,11 +2512,11 @@ function ODEFunction{iip, specialize}(f; typeof(paramjac), typeof(observed), typeof(_colorvec), - typeof(sys), typeof(initdata), typeof(nlprob)}(_f, mass_matrix, + typeof(sys), typeof(initdata), typeof(nlprob_data)}(_f, mass_matrix, analytic, tgrad, jac, jvp, vjp, jac_prototype, sparsity, Wfact, Wfact_t, W_prototype, paramjac, - observed, _colorvec, sys, initdata, nlprob) + observed, _colorvec, sys, initdata, nlprob_data) else ODEFunction{iip, specialize, typeof(_f), typeof(mass_matrix), typeof(analytic), typeof(tgrad), @@ -2531,11 +2525,11 @@ function ODEFunction{iip, specialize}(f; typeof(paramjac), typeof(observed), typeof(_colorvec), - typeof(sys), typeof(initdata), typeof(nlprob)}( + typeof(sys), typeof(initdata), typeof(nlprob_data)}( _f, mass_matrix, analytic, tgrad, jac, jvp, vjp, jac_prototype, sparsity, Wfact, Wfact_t, W_prototype, paramjac, - observed, _colorvec, sys, initdata, nlprob) + observed, _colorvec, sys, initdata, nlprob_data) end end @@ -2552,11 +2546,11 @@ function unwrapped_f(f::ODEFunction, newf = unwrapped_f(f.f)) Any, Any, Any, Any, typeof(f.jac_prototype), typeof(f.sparsity), Any, Any, Any, Any, typeof(f.colorvec), - typeof(f.sys), Any, Any}( + typeof(f.sys), Union{Nothing, OverrideInitData}, Union{Nothing, ODE_NLProbData}}( newf, f.mass_matrix, f.analytic, f.tgrad, f.jac, f.jvp, f.vjp, f.jac_prototype, f.sparsity, f.Wfact, f.Wfact_t, f.W_prototype, f.paramjac, - f.observed, f.colorvec, f.sys, f.initialization_data, f.nlprob) + f.observed, f.colorvec, f.sys, f.initialization_data, f.nlprob_data) else ODEFunction{isinplace(f), specialization(f), typeof(newf), typeof(f.mass_matrix), typeof(f.analytic), typeof(f.tgrad), @@ -2564,11 +2558,11 @@ function unwrapped_f(f::ODEFunction, newf = unwrapped_f(f.f)) typeof(f.sparsity), typeof(f.Wfact), typeof(f.Wfact_t), typeof(f.W_prototype), typeof(f.paramjac), typeof(f.observed), typeof(f.colorvec), - typeof(f.sys), typeof(f.initialization_data), typeof(f.nlprob)}( + typeof(f.sys), typeof(f.initialization_data), typeof(f.nlprob_data)}( newf, f.mass_matrix, f.analytic, f.tgrad, f.jac, f.jvp, f.vjp, f.jac_prototype, f.sparsity, f.Wfact, f.Wfact_t, f.W_prototype, f.paramjac, - f.observed, f.colorvec, f.sys, f.initialization_data, f.nlprob) + f.observed, f.colorvec, f.sys, f.initialization_data, f.nlprob_data) end end @@ -2703,7 +2697,7 @@ end @add_kwonly function SplitFunction(f1, f2, mass_matrix, cache, analytic, tgrad, jac, jvp, vjp, jac_prototype, W_prototype, sparsity, Wfact, Wfact_t, paramjac, observed, colorvec, sys, initializeprob = nothing, update_initializeprob! = nothing, - initializeprobmap = nothing, initializeprobpmap = nothing, nlprob = nothing, initialization_data = nothing) + initializeprobmap = nothing, initializeprobpmap = nothing, initialization_data = nothing, nlprob_data = nothing) f1 = ODEFunction(f1) f2 = ODEFunction(f2) @@ -2721,11 +2715,11 @@ end typeof(cache), typeof(analytic), typeof(tgrad), typeof(jac), typeof(jvp), typeof(vjp), typeof(jac_prototype), typeof(W_prototype), typeof(sparsity), typeof(Wfact), typeof(Wfact_t), typeof(paramjac), typeof(observed), typeof(colorvec), - typeof(sys), typeof(initdata), typeof(nlprob)}( + typeof(sys), typeof(initdata), typeof(nlprob_data)}( f1, f2, mass_matrix, cache, analytic, tgrad, jac, jvp, vjp, jac_prototype, W_prototype, sparsity, Wfact, Wfact_t, paramjac, observed, colorvec, sys, - initdata, nlprob) + initdata, nlprob_data) end function SplitFunction{iip, specialize}(f1, f2; mass_matrix = __has_mass_matrix(f1) ? @@ -2762,7 +2756,7 @@ function SplitFunction{iip, specialize}(f1, f2; f1.update_initializeprob! : nothing, initializeprobmap = __has_initializeprobmap(f1) ? f1.initializeprobmap : nothing, initializeprobpmap = __has_initializeprobpmap(f1) ? f1.initializeprobpmap : nothing, - nlprob = __has_nlprob(f1) ? f1.nlprob : nothing, + nlprob_data = __has_nlprob_data(f1) ? f1.nlprob_data : nothing, initialization_data = __has_initialization_data(f1) ? f1.initialization_data : nothing ) where {iip, @@ -2776,11 +2770,11 @@ function SplitFunction{iip, specialize}(f1, f2; if specialize === NoSpecialize SplitFunction{iip, specialize, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, - Any, Any, Any, Any}(f1, f2, mass_matrix, _func_cache, + Any, Any, Union{Nothing, OverrideInitData}, Union{Nothing, ODE_NLProbData}}(f1, f2, mass_matrix, _func_cache, analytic, tgrad, jac, jvp, vjp, jac_prototype, W_prototype, sparsity, Wfact, Wfact_t, paramjac, - observed, colorvec, sys, initdata, nlprob) + observed, colorvec, sys, initdata, nlprob_data) else SplitFunction{iip, specialize, typeof(f1), typeof(f2), typeof(mass_matrix), typeof(_func_cache), typeof(analytic), @@ -2788,11 +2782,11 @@ function SplitFunction{iip, specialize}(f1, f2; typeof(jac_prototype), typeof(W_prototype), typeof(sparsity), typeof(Wfact), typeof(Wfact_t), typeof(paramjac), typeof(observed), typeof(colorvec), - typeof(sys), typeof(initdata), typeof(nlprob)}(f1, f2, + typeof(sys), typeof(initdata), typeof(nlprob_data)}(f1, f2, mass_matrix, _func_cache, analytic, tgrad, jac, jvp, vjp, jac_prototype, W_prototype, sparsity, Wfact, Wfact_t, paramjac, observed, colorvec, sys, - initdata, nlprob) + initdata, nlprob_data) end end @@ -4488,7 +4482,7 @@ __has_colorvec(f) = isdefined(f, :colorvec) __has_sys(f) = isdefined(f, :sys) __has_analytic_full(f) = isdefined(f, :analytic_full) __has_resid_prototype(f) = isdefined(f, :resid_prototype) -__has_nlprob(f) = isdefined(f, :nlprob) +__has_nlprob_data(f) = isdefined(f, :nlprob_data) function __has_initializeprob(f) has_initialization_data(f) && isdefined(f.initialization_data, :initializeprob) end