diff --git a/src/problems/nonlinear_problems.jl b/src/problems/nonlinear_problems.jl index 78a6e05ef..971173f5c 100644 --- a/src/problems/nonlinear_problems.jl +++ b/src/problems/nonlinear_problems.jl @@ -104,6 +104,9 @@ function IntervalNonlinearProblem(f, tspan, p = NullParameters(); kwargs...) IntervalNonlinearProblem(IntervalNonlinearFunction(f), tspan, p; kwargs...) end + +_default_nl_specialize(p) = sizeof(p)==0 || ismutable(p) ? AutoSpecialize : FullSpecialize + @doc doc""" Defines a nonlinear system problem. @@ -183,7 +186,7 @@ mutable struct NonlinearProblem{uType, isinplace, P, F, K, PT} <: This is determined automatically, but not inferred. """ function NonlinearProblem{iip}(f, u0, p = NullParameters(); kwargs...) where {iip} - NonlinearProblem{iip}(NonlinearFunction{iip}(f), u0, p; kwargs...) + NonlinearProblem{iip}(NonlinearFunction{iip, _default_nl_specialize(p)}(f), u0, p; kwargs...) end end @@ -198,7 +201,9 @@ function NonlinearProblem(f::AbstractNonlinearFunction, u0, p = NullParameters() end function NonlinearProblem(f, u0, p = NullParameters(); kwargs...) - NonlinearProblem(NonlinearFunction(f), u0, p; kwargs...) + iip = isinplace(f, 3) + + NonlinearProblem(NonlinearFunction{iip, _default_nl_specialize(p)}(f), u0, p; kwargs...) end """ diff --git a/src/scimlfunctions.jl b/src/scimlfunctions.jl index 239791d61..cc02213df 100644 --- a/src/scimlfunctions.jl +++ b/src/scimlfunctions.jl @@ -1724,19 +1724,16 @@ For more details on this argument, see the ODEFunction documentation. The fields of the NonlinearFunction type directly match the names of the inputs. """ -struct NonlinearFunction{iip, specialize, F, TMM, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TWt, +struct NonlinearFunction{iip, specialize, F, TMM, Ta, TJ, JVP, VJP, JP, SP, TPJ, O, TCV, SYS, RP, ID} <: AbstractNonlinearFunction{iip} f::F mass_matrix::TMM analytic::Ta - tgrad::Tt jac::TJ jvp::JVP vjp::VJP jac_prototype::JP sparsity::SP - Wfact::TW - Wfact_t::TWt paramjac::TPJ observed::O colorvec::TCV @@ -3801,23 +3798,17 @@ end SDDEFunction(f::SDDEFunction; kwargs...) = f function NonlinearFunction{iip, specialize}(f; - mass_matrix = __has_mass_matrix(f) ? - f.mass_matrix : - I, - analytic = __has_analytic(f) ? f.analytic : + mass_matrix = __has_mass_matrix(f) ? f.mass_matrix : I, + analytic = __has_analytic(f) ? Void(f.analytic) : nothing, - tgrad = __has_tgrad(f) ? f.tgrad : nothing, - jac = __has_jac(f) ? f.jac : nothing, - jvp = __has_jvp(f) ? f.jvp : nothing, - vjp = __has_vjp(f) ? f.vjp : nothing, + jac = __has_jac(f) ? Void(f.jac) : nothing, + jvp = __has_jvp(f) ? Void(f.jvp) : nothing, + vjp = __has_vjp(f) ? Void(f.vjp) : nothing, jac_prototype = __has_jac_prototype(f) ? f.jac_prototype : nothing, sparsity = __has_sparsity(f) ? f.sparsity : jac_prototype, - Wfact = __has_Wfact(f) ? f.Wfact : nothing, - Wfact_t = __has_Wfact_t(f) ? f.Wfact_t : - nothing, - paramjac = __has_paramjac(f) ? f.paramjac : + paramjac = __has_paramjac(f) ? Void(f.paramjac) : nothing, syms = nothing, paramsyms = nothing, @@ -3864,40 +3855,51 @@ function NonlinearFunction{iip, specialize}(f; sys = sys_or_symbolcache(sys, syms, paramsyms) if specialize === NoSpecialize NonlinearFunction{iip, specialize, + Any, Any, Any, Any, Any, Any, Any, Any, Any, - Any, Any, Any, Any, Any, - Any, Any, Any, + Any, typeof(_colorvec), Any, Any, Any}(_f, mass_matrix, - analytic, tgrad, jac, + analytic, jac, jvp, vjp, jac_prototype, - sparsity, Wfact, - Wfact_t, paramjac, + sparsity, paramjac, observed, _colorvec, sys, resid_prototype, initialization_data) + elseif specialize === AutoSpecialize && iip + NonlinearFunction{iip, specialize, + Void, typeof(mass_matrix), + analytic isa Void ? Void : typeof(analytic), + jac isa Void ? Void : typeof(jac), + jvp isa Void ? Void : typeof(jvp), + vjp isa Void ? Void : typeof(vjp), + typeof(jac_prototype), + typeof(sparsity), typeof(paramjac), + observed isa Void ? Void : typeof(observed), + typeof(_colorvec), typeof(sys), typeof(resid_prototype), + typeof(initialization_data)}(Void(_f), mass_matrix, + analytic, jac, + jvp, vjp, jac_prototype, sparsity, paramjac, + observed, _colorvec, sys, resid_prototype, initialization_data) else NonlinearFunction{iip, specialize, - typeof(_f), typeof(mass_matrix), typeof(analytic), typeof(tgrad), + typeof(_f), typeof(mass_matrix), typeof(analytic), typeof(jac), typeof(jvp), typeof(vjp), typeof(jac_prototype), - typeof(sparsity), typeof(Wfact), - typeof(Wfact_t), typeof(paramjac), + typeof(sparsity), typeof(paramjac), typeof(observed), typeof(_colorvec), typeof(sys), typeof(resid_prototype), typeof(initialization_data)}(_f, mass_matrix, - analytic, tgrad, jac, - jvp, vjp, jac_prototype, sparsity, - Wfact, - Wfact_t, paramjac, + analytic, jac, + jvp, vjp, jac_prototype, sparsity, paramjac, observed, _colorvec, sys, resid_prototype, initialization_data) end end function NonlinearFunction{iip}(f; kwargs...) where {iip} - NonlinearFunction{iip, FullSpecialize}(f; kwargs...) + NonlinearFunction{iip, AutoSpecialize}(f; kwargs...) end NonlinearFunction{iip}(f::NonlinearFunction; kwargs...) where {iip} = f function NonlinearFunction(f; kwargs...) - NonlinearFunction{isinplace(f, 3), FullSpecialize}(f; kwargs...) + NonlinearFunction{isinplace(f, 3), AutoSpecialize}(f; kwargs...) end NonlinearFunction(f::NonlinearFunction; kwargs...) = f