Skip to content

Commit

Permalink
AutoSpecialize for NonlinearProblem
Browse files Browse the repository at this point in the history
  • Loading branch information
oscardssmith committed Nov 14, 2024
1 parent ca77954 commit 4e542f2
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 32 deletions.
9 changes: 7 additions & 2 deletions src/problems/nonlinear_problems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,9 @@ function IntervalNonlinearProblem(f, tspan, p = NullParameters(); kwargs...)
IntervalNonlinearProblem(IntervalNonlinearFunction(f), tspan, p; kwargs...)
end


_default_nl_specialize(p) = sizeof(p)==0 || ismutable(p) ? AutoSpecialize : FullSpecialize

@doc doc"""
Defines a nonlinear system problem.
Expand Down Expand Up @@ -183,7 +186,7 @@ mutable struct NonlinearProblem{uType, isinplace, P, F, K, PT} <:
This is determined automatically, but not inferred.
"""
function NonlinearProblem{iip}(f, u0, p = NullParameters(); kwargs...) where {iip}
NonlinearProblem{iip}(NonlinearFunction{iip}(f), u0, p; kwargs...)
NonlinearProblem{iip}(NonlinearFunction{iip, _default_nl_specialize(p)}(f), u0, p; kwargs...)
end
end

Expand All @@ -198,7 +201,9 @@ function NonlinearProblem(f::AbstractNonlinearFunction, u0, p = NullParameters()
end

function NonlinearProblem(f, u0, p = NullParameters(); kwargs...)
NonlinearProblem(NonlinearFunction(f), u0, p; kwargs...)
iip = isinplace(f, 3)

NonlinearProblem(NonlinearFunction{iip, _default_nl_specialize(p)}(f), u0, p; kwargs...)
end

"""
Expand Down
62 changes: 32 additions & 30 deletions src/scimlfunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1724,19 +1724,16 @@ For more details on this argument, see the ODEFunction documentation.
The fields of the NonlinearFunction type directly match the names of the inputs.
"""
struct NonlinearFunction{iip, specialize, F, TMM, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TWt,
struct NonlinearFunction{iip, specialize, F, TMM, Ta, TJ, JVP, VJP, JP, SP,
TPJ, O, TCV, SYS, RP, ID} <: AbstractNonlinearFunction{iip}
f::F
mass_matrix::TMM
analytic::Ta
tgrad::Tt
jac::TJ
jvp::JVP
vjp::VJP
jac_prototype::JP
sparsity::SP
Wfact::TW
Wfact_t::TWt
paramjac::TPJ
observed::O
colorvec::TCV
Expand Down Expand Up @@ -3801,23 +3798,17 @@ end
SDDEFunction(f::SDDEFunction; kwargs...) = f

function NonlinearFunction{iip, specialize}(f;
mass_matrix = __has_mass_matrix(f) ?
f.mass_matrix :
I,
analytic = __has_analytic(f) ? f.analytic :
mass_matrix = __has_mass_matrix(f) ? f.mass_matrix : I,
analytic = __has_analytic(f) ? Void(f.analytic) :
nothing,
tgrad = __has_tgrad(f) ? f.tgrad : nothing,
jac = __has_jac(f) ? f.jac : nothing,
jvp = __has_jvp(f) ? f.jvp : nothing,
vjp = __has_vjp(f) ? f.vjp : nothing,
jac = __has_jac(f) ? Void(f.jac) : nothing,
jvp = __has_jvp(f) ? Void(f.jvp) : nothing,
vjp = __has_vjp(f) ? Void(f.vjp) : nothing,
jac_prototype = __has_jac_prototype(f) ?
f.jac_prototype : nothing,
sparsity = __has_sparsity(f) ? f.sparsity :
jac_prototype,
Wfact = __has_Wfact(f) ? f.Wfact : nothing,
Wfact_t = __has_Wfact_t(f) ? f.Wfact_t :
nothing,
paramjac = __has_paramjac(f) ? f.paramjac :
paramjac = __has_paramjac(f) ? Void(f.paramjac) :
nothing,
syms = nothing,
paramsyms = nothing,
Expand Down Expand Up @@ -3864,40 +3855,51 @@ function NonlinearFunction{iip, specialize}(f;
sys = sys_or_symbolcache(sys, syms, paramsyms)
if specialize === NoSpecialize
NonlinearFunction{iip, specialize,
Any, Any, Any, Any,
Any, Any, Any, Any, Any,
Any, Any, Any, Any, Any,
Any, Any, Any,
Any,
typeof(_colorvec), Any, Any, Any}(_f, mass_matrix,
analytic, tgrad, jac,
analytic, jac,
jvp, vjp,
jac_prototype,
sparsity, Wfact,
Wfact_t, paramjac,
sparsity, paramjac,
observed,
_colorvec, sys, resid_prototype, initialization_data)
elseif specialize === AutoSpecialize && iip
NonlinearFunction{iip, specialize,
Void, typeof(mass_matrix),
analytic isa Void ? Void : typeof(analytic),
jac isa Void ? Void : typeof(jac),
jvp isa Void ? Void : typeof(jvp),
vjp isa Void ? Void : typeof(vjp),
typeof(jac_prototype),
typeof(sparsity), typeof(paramjac),
observed isa Void ? Void : typeof(observed),
typeof(_colorvec), typeof(sys), typeof(resid_prototype),
typeof(initialization_data)}(Void(_f), mass_matrix,
analytic, jac,
jvp, vjp, jac_prototype, sparsity, paramjac,
observed, _colorvec, sys, resid_prototype, initialization_data)
else
NonlinearFunction{iip, specialize,
typeof(_f), typeof(mass_matrix), typeof(analytic), typeof(tgrad),
typeof(_f), typeof(mass_matrix), typeof(analytic),
typeof(jac), typeof(jvp), typeof(vjp), typeof(jac_prototype),
typeof(sparsity), typeof(Wfact),
typeof(Wfact_t), typeof(paramjac),
typeof(sparsity), typeof(paramjac),
typeof(observed),
typeof(_colorvec), typeof(sys), typeof(resid_prototype),
typeof(initialization_data)}(_f, mass_matrix,
analytic, tgrad, jac,
jvp, vjp, jac_prototype, sparsity,
Wfact,
Wfact_t, paramjac,
analytic, jac,
jvp, vjp, jac_prototype, sparsity, paramjac,
observed, _colorvec, sys, resid_prototype, initialization_data)
end
end

function NonlinearFunction{iip}(f; kwargs...) where {iip}
NonlinearFunction{iip, FullSpecialize}(f; kwargs...)
NonlinearFunction{iip, AutoSpecialize}(f; kwargs...)
end
NonlinearFunction{iip}(f::NonlinearFunction; kwargs...) where {iip} = f
function NonlinearFunction(f; kwargs...)
NonlinearFunction{isinplace(f, 3), FullSpecialize}(f; kwargs...)
NonlinearFunction{isinplace(f, 3), AutoSpecialize}(f; kwargs...)
end
NonlinearFunction(f::NonlinearFunction; kwargs...) = f

Expand Down

0 comments on commit 4e542f2

Please sign in to comment.