From ac60b895cb14b691211a3578277d24e95fa5495d Mon Sep 17 00:00:00 2001 From: ErikQQY <2283984853@qq.com> Date: Thu, 19 Jan 2023 18:42:54 +0800 Subject: [PATCH 1/9] Add BVPFunction Signed-off-by: ErikQQY <2283984853@qq.com> --- src/SciMLBase.jl | 5 +- src/problems/bvp_problems.jl | 12 +- src/scimlfunctions.jl | 231 ++++++++++++++++++++++- test/function_building_error_messages.jl | 88 +++++++++ 4 files changed, 327 insertions(+), 9 deletions(-) diff --git a/src/SciMLBase.jl b/src/SciMLBase.jl index 5708d96b7..b64b43bac 100644 --- a/src/SciMLBase.jl +++ b/src/SciMLBase.jl @@ -608,7 +608,8 @@ function specialization(::Union{ODEFunction{iip, specialize}, ImplicitDiscreteFunction{iip, specialize}, RODEFunction{iip, specialize}, NonlinearFunction{iip, specialize}, - OptimizationFunction{iip, specialize}}) where {iip, + OptimizationFunction{iip, specialize}, + BVPFunction{iip, specialize}}) where {iip, specialize} specialize end @@ -726,7 +727,7 @@ export remake export ODEFunction, DiscreteFunction, ImplicitDiscreteFunction, SplitFunction, DAEFunction, DDEFunction, SDEFunction, SplitSDEFunction, RODEFunction, SDDEFunction, - IncrementingODEFunction, NonlinearFunction, IntervalNonlinearFunction + IncrementingODEFunction, NonlinearFunction, IntervalNonlinearFunction, BVPFunction export OptimizationFunction diff --git a/src/problems/bvp_problems.jl b/src/problems/bvp_problems.jl index c06c13696..37e13b139 100644 --- a/src/problems/bvp_problems.jl +++ b/src/problems/bvp_problems.jl @@ -79,7 +79,7 @@ struct BVProblem{uType, tType, isinplace, P, F, bF, PT, K} <: p::P problem_type::PT kwargs::K - @add_kwonly function BVProblem{iip}(f::AbstractODEFunction, bc, u0, tspan, + @add_kwonly function BVProblem{iip}(f::AbstractBVPFunction, bc, u0, tspan, p = NullParameters(), problem_type = StandardBVProblem(); kwargs...) where {iip} @@ -91,26 +91,26 @@ struct BVProblem{uType, tType, isinplace, P, F, bF, PT, K} <: end function BVProblem{iip}(f, bc, u0, tspan, p = NullParameters(); kwargs...) where {iip} - BVProblem(ODEFunction{iip}(f), bc, u0, tspan, p; kwargs...) + BVProblem(BVPFunction{iip}(f), bc, u0, tspan, p; kwargs...) end end -function BVProblem(f::AbstractODEFunction, bc, u0, tspan, args...; kwargs...) +function BVProblem(f::AbstractBVPFunction, bc, u0, tspan, args...; kwargs...) BVProblem{isinplace(f, 4)}(f, bc, u0, tspan, args...; kwargs...) end function BVProblem(f, bc, u0, tspan, p = NullParameters(); kwargs...) - BVProblem(ODEFunction(f), bc, u0, tspan, p; kwargs...) + BVProblem(BVPFunction(f), bc, u0, tspan, p; kwargs...) end # convenience interfaces: # Allow any previous timeseries solution -function BVProblem(f::AbstractODEFunction, bc, sol::T, tspan::Tuple, p = NullParameters(); +function BVProblem(f::AbstractBVPFunction, bc, sol::T, tspan::Tuple, p = NullParameters(); kwargs...) where {T <: AbstractTimeseriesSolution} BVProblem(f, bc, sol.u, tspan, p) end # Allow a function of time for the initial guess -function BVProblem(f::AbstractODEFunction, bc, initialGuess, tspan::AbstractVector, +function BVProblem(f::AbstractBVPFunction, bc, initialGuess, tspan::AbstractVector, p = NullParameters(); kwargs...) u0 = [initialGuess(i) for i in tspan] BVProblem(f, bc, u0, (tspan[1], tspan[end]), p) diff --git a/src/scimlfunctions.jl b/src/scimlfunctions.jl index 931962f4d..6e7638978 100644 --- a/src/scimlfunctions.jl +++ b/src/scimlfunctions.jl @@ -2091,6 +2091,120 @@ struct OptimizationFunction{iip, AD, F, G, H, HV, C, CJ, CH, LH, HP, CJP, CHP, L sys::SYS end +""" +$(TYPEDEF) +""" +abstract type AbstractBVPFunction{iip} <: + AbstractDiffEqFunction{iip} end + +@doc doc""" + BVPFunction{iip,F,TMM,Ta,Tt,TJ,JVP,VJP,JP,SP,TW,TWt,TPJ,S,S2,S3,O,TCV} <: AbstractBVPFunction{iip,specialize} + +A representation of a BVP function `f`, defined by: + +```math +\frac{du}{dt}=f(u,p,t) +``` + +and all of its related functions, such as the Jacobian of `f`, its gradient +with respect to time, and more. For all cases, `u0` is the initial condition, +`p` are the parameters, and `t` is the independent variable. + +```julia +BVPFunction{iip,specialize}(f; + mass_matrix = __has_mass_matrix(f) ? f.mass_matrix : I, + analytic = __has_analytic(f) ? f.analytic : nothing, + tgrad= __has_tgrad(f) ? f.tgrad : nothing, + jac = __has_jac(f) ? f.jac : nothing, + jvp = __has_jvp(f) ? f.jvp : nothing, + vjp = __has_vjp(f) ? f.vjp : nothing, + jac_prototype = __has_jac_prototype(f) ? f.jac_prototype : nothing, + sparsity = __has_sparsity(f) ? f.sparsity : jac_prototype, + paramjac = __has_paramjac(f) ? f.paramjac : nothing, + syms = __has_syms(f) ? f.syms : nothing, + indepsym= __has_indepsym(f) ? f.indepsym : nothing, + paramsyms = __has_paramsyms(f) ? f.paramsyms : nothing, + colorvec = __has_colorvec(f) ? f.colorvec : nothing, + sys = __has_sys(f) ? f.sys : nothing) +``` + +Note that only the function `f` itself is required. This function should +be given as `f!(out,du,u,p,t)` or `out = f(du,u,p,t)`. See the section on `iip` +for more details on in-place vs out-of-place handling. + +All of the remaining functions are optional for improving or accelerating +the usage of `f`. These include: + +- `mass_matrix`: the mass matrix `M` represented in the BVP function. Can be used + to determine that the equation is actually a BVP for differential algebraic equation (DAE) + if `M` is singular. +- `analytic(u0,p,t)`: used to pass an analytical solution function for the analytical + solution of the BVP. Generally only used for testing and development of the solvers. +- `tgrad(dT,u,h,p,t)` or dT=tgrad(u,p,t): returns ``\frac{\partial f(u,p,t)}{\partial t}`` +- `jac(J,du,u,p,gamma,t)` or `J=jac(du,u,p,gamma,t)`: returns ``\frac{df}{du}`` +- `jvp(Jv,v,du,u,p,gamma,t)` or `Jv=jvp(v,du,u,p,gamma,t)`: returns the directional + derivative``\frac{df}{du} v`` +- `vjp(Jv,v,du,u,p,gamma,t)` or `Jv=vjp(v,du,u,p,gamma,t)`: returns the adjoint + derivative``\frac{df}{du}^\ast v`` +- `jac_prototype`: a prototype matrix matching the type that matches the Jacobian. For example, + if the Jacobian is tridiagonal, then an appropriately sized `Tridiagonal` matrix can be used + as the prototype and integrators will specialize on this structure where possible. Non-structured + sparsity patterns should use a `SparseMatrixCSC` with a correct sparsity pattern for the Jacobian. + The default is `nothing`, which means a dense Jacobian. +- `paramjac(pJ,u,p,t)`: returns the parameter Jacobian ``\frac{df}{dp}``. +- `syms`: the symbol names for the elements of the equation. This should match `u0` in size. For + example, if `u0 = [0.0,1.0]` and `syms = [:x, :y]`, this will apply a canonical naming to the + values, allowing `sol[:x]` in the solution and automatically naming values in plots. +- `indepsym`: the canonical naming for the independent variable. Defaults to nothing, which + internally uses `t` as the representation in any plots. +- `paramsyms`: the symbol names for the parameters of the equation. This should match `p` in + size. For example, if `p = [0.0, 1.0]` and `paramsyms = [:a, :b]`, this will apply a canonical + naming to the values, allowing `sol[:a]` in the solution. +- `colorvec`: a color vector according to the SparseDiffTools.jl definition for the sparsity + pattern of the `jac_prototype`. This specializes the Jacobian construction when using + finite differences and automatic differentiation to be computed in an accelerated manner + based on the sparsity pattern. Defaults to `nothing`, which means a color vector will be + internally computed on demand when required. The cost of this operation is highly dependent + on the sparsity pattern. + +## iip: In-Place vs Out-Of-Place + +For more details on this argument, see the ODEFunction documentation. + +## specialize: Controlling Compilation and Specialization + +For more details on this argument, see the ODEFunction documentation. + +## Fields + +The fields of the BVPFunction type directly match the names of the inputs. +""" +struct BVPFunction{iip, specialize, F, TMM, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TWt, + TPJ, + S, S2, S3, O, TCV, + SYS} <: + AbstractBVPFunction{iip} + f::F + mass_matrix::TMM + analytic::Ta + tgrad::Tt + jac::TJ + jvp::JVP + vjp::VJP + jac_prototype::JP + sparsity::SP + Wfact::TW + Wfact_t::TWt + paramjac::TPJ + syms::S + indepsym::S2 + paramsyms::S3 + observed::O + colorvec::TCV + sys::SYS +end + + ######### Backwards Compatibility Overloads (f::ODEFunction)(args...) = f.f(args...) @@ -2144,6 +2258,8 @@ end (f::RODEFunction)(args...) = f.f(args...) +(f::BVPFunction)(args...) = f.f(args...) + ######### Basic Constructor function ODEFunction{iip, specialize}(f; @@ -3622,6 +3738,118 @@ function OptimizationFunction{iip}(f, adtype::AbstractADType = NoAD(); cons_expr, sys) end +function BVPFunction{iip, specialize}(f; + mass_matrix = __has_mass_matrix(f) ? f.mass_matrix : + I, + analytic = __has_analytic(f) ? f.analytic : nothing, + tgrad = __has_tgrad(f) ? f.tgrad : nothing, + jac = __has_jac(f) ? f.jac : nothing, + jvp = __has_jvp(f) ? f.jvp : nothing, + vjp = __has_vjp(f) ? f.vjp : nothing, + jac_prototype = __has_jac_prototype(f) ? + f.jac_prototype : + nothing, + sparsity = __has_sparsity(f) ? f.sparsity : + jac_prototype, + Wfact = __has_Wfact(f) ? f.Wfact : nothing, + Wfact_t = __has_Wfact_t(f) ? f.Wfact_t : nothing, + paramjac = __has_paramjac(f) ? f.paramjac : nothing, + syms = __has_syms(f) ? f.syms : nothing, + indepsym = __has_indepsym(f) ? f.indepsym : nothing, + paramsyms = __has_paramsyms(f) ? f.paramsyms : + nothing, + observed = __has_observed(f) ? f.observed : + DEFAULT_OBSERVED, + colorvec = __has_colorvec(f) ? f.colorvec : nothing, + sys = __has_sys(f) ? f.sys : nothing) where {iip, specialize} + + if mass_matrix === I && typeof(f) <: Tuple + mass_matrix = ((I for i in 1:length(f))...,) + end + + if (specialize === FunctionWrapperSpecialize) && + !(f isa FunctionWrappersWrappers.FunctionWrappersWrapper) + error("FunctionWrapperSpecialize must be used on the problem constructor for access to u0, p, and t types!") + end + + if jac === nothing && isa(jac_prototype, AbstractDiffEqLinearOperator) + if iip + jac = update_coefficients! #(J,u,p,t) + else + jac = (u, p, t) -> update_coefficients!(deepcopy(jac_prototype), u, p, t) + end + end + + if jac_prototype !== nothing && colorvec === nothing && ArrayInterfaceCore.fast_matrix_colors(jac_prototype) + _colorvec = ArrayInterfaceCore.matrix_colors(jac_prototype) + else + _colorvec = colorvec + end + + jaciip = jac !== nothing ? isinplace(jac, 4, "jac", iip) : iip + tgradiip = tgrad !== nothing ? isinplace(tgrad, 4, "tgrad", iip) : iip + jvpiip = jvp !== nothing ? isinplace(jvp, 5, "jvp", iip) : iip + vjpiip = vjp !== nothing ? isinplace(vjp, 5, "vjp", iip) : iip + Wfactiip = Wfact !== nothing ? isinplace(Wfact, 5, "Wfact", iip) : iip + Wfact_tiip = Wfact_t !== nothing ? isinplace(Wfact_t, 5, "Wfact_t", iip) : iip + paramjaciip = paramjac !== nothing ? isinplace(paramjac, 4, "paramjac", iip) : iip + + nonconforming = (jaciip, tgradiip, jvpiip, vjpiip, Wfactiip, Wfact_tiip, paramjaciip) .!= iip + if any(nonconforming) + nonconforming = findall(nonconforming) + functions = ["jac", "tgrad", "jvp", "vjp", "Wfact", "Wfact_t", "paramjac"][nonconforming] + throw(NonconformingFunctionsError(functions)) + end + + if specialize === NoSpecialize + BVPFunction{iip, specialize, Any, Any, Any, Any, + Any, Any, Any, Any, Any, Any, Any, + Any, typeof(syms), typeof(indepsym), typeof(paramsyms), + Any, typeof(_colorvec), Any}(f, mass_matrix, + analytic, + tgrad, + jac, jvp, vjp, + jac_prototype, + sparsity, Wfact, + Wfact_t, + paramjac, syms, + indepsym, paramsyms, + observed, + _colorvec, sys) + elseif specialize === false + BVPFunction{iip, FunctionWrapperSpecialize, + typeof(f), typeof(mass_matrix), typeof(analytic), typeof(tgrad), + typeof(jac), typeof(jvp), typeof(vjp), typeof(jac_prototype), + typeof(sparsity), typeof(Wfact), typeof(Wfact_t), typeof(paramjac), + typeof(syms), typeof(indepsym), typeof(paramsyms), typeof(observed), + typeof(_colorvec), + typeof(sys)}(f, mass_matrix, analytic, tgrad, jac, + jvp, vjp, jac_prototype, sparsity, Wfact, + Wfact_t, paramjac, syms, indepsym, paramsyms, + observed, _colorvec, sys) + else + BVPFunction{iip, specialize, typeof(f), typeof(mass_matrix), typeof(analytic), + typeof(tgrad), + typeof(jac), typeof(jvp), typeof(vjp), typeof(jac_prototype), + typeof(sparsity), typeof(Wfact), typeof(Wfact_t), + typeof(paramjac), typeof(syms), typeof(indepsym), typeof(paramsyms), + typeof(observed), + typeof(_colorvec), typeof(sys)}(f, mass_matrix, analytic, + tgrad, jac, jvp, vjp, + jac_prototype, sparsity, + Wfact, Wfact_t, paramjac, + syms, indepsym, paramsyms, observed, + _colorvec, sys) + end +end + +function BVPFunction{iip}(f; kwargs...) where {iip} + BVPFunction{iip, FullSpecialize}(f; kwargs...) +end +BVPFunction{iip}(f::BVPFunction; kwargs...) where {iip} = f +BVPFunction(f; kwargs...) = BVPFunction{isinplace(f, 4), FullSpecialize}(f; kwargs...) +BVPFunction(f::BVPFunction; kwargs...) = f + ########## Existence Functions # Check that field/property exists (may be nothing) @@ -3728,7 +3956,8 @@ for S in [:ODEFunction :SDDEFunction :NonlinearFunction :IntervalNonlinearFunction - :IncrementingODEFunction] + :IncrementingODEFunction + :BVPFunction] @eval begin function ConstructionBase.constructorof(::Type{<:$S{iip}}) where { iip } diff --git a/test/function_building_error_messages.jl b/test/function_building_error_messages.jl index fe6ce9d18..39e0a83b5 100644 --- a/test/function_building_error_messages.jl +++ b/test/function_building_error_messages.jl @@ -372,3 +372,91 @@ optf(u) = 1.0 optf(u, p) = 1.0 OptimizationFunction(optf) OptimizationProblem(optf, 1.0) + +# BVPFunction + +bfoop(u, p, t) = u +bfiip(du, u, p, t) = du .= u + +bofboth(u, p, t) = u +bofboth(du, u, p, t) = du .= u + +BVPFunction(bofboth) +BVPFunction{true}(bofboth) +BVPFunction{false}(bofboth) + +jac(u, t) = [1.0] +@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfiip, jac = jac) +@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfoop, jac = jac) +jac(u, p, t) = [1.0] +@test_throws SciMLBase.NonconformingFunctionsError BVPFunction(bfiip, jac = jac) +BVPFunction(bfoop, jac = jac) +jac(du, u, p, t) = [1.0] +BVPFunction(bfiip, jac = jac) +BVPFunction(bfoop, jac = jac) + +Wfact(u, t) = [1.0] +@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfiip, Wfact = Wfact) +@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfoop, Wfact = Wfact) +Wfact(u, p, t) = [1.0] +@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfiip, Wfact = Wfact) +@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfoop, Wfact = Wfact) +Wfact(u, p, gamma, t) = [1.0] +@test_throws SciMLBase.NonconformingFunctionsError BVPFunction(bfiip, Wfact = Wfact) +BVPFunction(bfoop, Wfact = Wfact) +Wfact(du, u, p, gamma, t) = [1.0] +BVPFunction(bfiip, Wfact = Wfact) +BVPFunction(bfoop, Wfact = Wfact) + +Wfact_t(u, t) = [1.0] +@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfiip, Wfact_t = Wfact_t) +@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfoop, Wfact_t = Wfact_t) +Wfact_t(u, p, t) = [1.0] +@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfiip, Wfact_t = Wfact_t) +@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfoop, Wfact_t = Wfact_t) +Wfact_t(u, p, gamma, t) = [1.0] +@test_throws SciMLBase.NonconformingFunctionsError BVPFunction(bfiip, Wfact_t = Wfact_t) +BVPFunction(bfoop, Wfact_t = Wfact_t) +Wfact_t(du, u, p, gamma, t) = [1.0] +BVPFunction(bfiip, Wfact_t = Wfact_t) +BVPFunction(bfoop, Wfact_t = Wfact_t) + +tgrad(u, t) = [1.0] +@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfiip, tgrad = tgrad) +@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfoop, tgrad = tgrad) +tgrad(u, p, t) = [1.0] +@test_throws SciMLBase.NonconformingFunctionsError BVPFunction(bfiip, tgrad=tgrad) +BVPFunction(bfoop, tgrad = tgrad) +tgrad(du, u, p, t) = [1.0] +BVPFunction(bfiip, tgrad = tgrad) +BVPFunction(bfoop, tgrad = tgrad) + +paramjac(u, t) = [1.0] +@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfiip, paramjac = paramjac) +@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfoop, paramjac = paramjac) +paramjac(u, p, t) = [1.0] +@test_throws SciMLBase.NonconformingFunctionsError BVPFunction(bfiip, paramjac = paramjac) +BVPFunction(bfoop, paramjac = paramjac) +paramjac(du, u, p, t) = [1.0] +BVPFunction(bfiip, paramjac = paramjac) +BVPFunction(bfoop, paramjac = paramjac) + +jvp(u, p, t) = [1.0] +@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfiip, jvp = jvp) +@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfoop, jvp = jvp) +jvp(u, v, p, t) = [1.0] +@test_throws SciMLBase.NonconformingFunctionsError BVPFunction(bfiip, jvp = jvp) +BVPFunction(bfoop, jvp = jvp) +jvp(du, u, v, p, t) = [1.0] +BVPFunction(bfiip, jvp = jvp) +BVPFunction(bfoop, jvp = jvp) + +vjp(u, p, t) = [1.0] +@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfiip, vjp = vjp) +@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfoop, vjp = vjp) +vjp(u, v, p, t) = [1.0] +@test_throws SciMLBase.NonconformingFunctionsError BVPFunction(bfiip, vjp = vjp) +BVPFunction(bfoop, vjp = vjp) +vjp(du, u, v, p, t) = [1.0] +BVPFunction(bfiip, vjp = vjp) +BVPFunction(bfoop, vjp = vjp) \ No newline at end of file From 7cf85c43aced812d700445537c221758936e217f Mon Sep 17 00:00:00 2001 From: ErikQQY <2283984853@qq.com> Date: Fri, 11 Aug 2023 23:58:09 +0800 Subject: [PATCH 2/9] Add bcjac and bcjac_prototype Signed-off-by: ErikQQY <2283984853@qq.com> --- src/ensemble/basic_ensemble_solve.jl | 2 +- src/ensemble/ensemble_problems.jl | 33 +-- src/ensemble/ensemble_solutions.jl | 13 +- src/problems/bvp_problems.jl | 5 +- src/scimlfunctions.jl | 265 ++++++++++++++--------- test/downstream/ensemble_multi_prob.jl | 10 +- test/function_building_error_messages.jl | 106 ++++----- test/solution_interface.jl | 2 +- 8 files changed, 260 insertions(+), 176 deletions(-) diff --git a/src/ensemble/basic_ensemble_solve.jl b/src/ensemble/basic_ensemble_solve.jl index 0c1047c35..92f7a86d0 100644 --- a/src/ensemble/basic_ensemble_solve.jl +++ b/src/ensemble/basic_ensemble_solve.jl @@ -47,7 +47,7 @@ function __solve(prob::EnsembleProblem{<:AbstractVector{<:AbstractSciMLProblem}} ensemblealg::BasicEnsembleAlgorithm; kwargs...) # TODO: @invoke invoke(__solve, Tuple{AbstractEnsembleProblem, typeof(alg), typeof(ensemblealg)}, - prob, alg, ensemblealg; trajectories=length(prob.prob), kwargs...) + prob, alg, ensemblealg; trajectories = length(prob.prob), kwargs...) end function __solve(prob::AbstractEnsembleProblem, diff --git a/src/ensemble/ensemble_problems.jl b/src/ensemble/ensemble_problems.jl index 9d67bfe9e..c2a6c8ac8 100644 --- a/src/ensemble/ensemble_problems.jl +++ b/src/ensemble/ensemble_problems.jl @@ -16,7 +16,11 @@ DEFAULT_REDUCTION(u, data, I) = append!(u, data), false DEFAULT_VECTOR_PROB_FUNC(prob, i, repeat) = prob[i] function EnsembleProblem(prob::AbstractVector{<:AbstractSciMLProblem}; kwargs...) # TODO: @invoke - invoke(EnsembleProblem, Tuple{Any}, prob; prob_func=DEFAULT_VECTOR_PROB_FUNC, kwargs...) + invoke(EnsembleProblem, + Tuple{Any}, + prob; + prob_func = DEFAULT_VECTOR_PROB_FUNC, + kwargs...) end function EnsembleProblem(prob; output_func = DEFAULT_OUTPUT_FUNC, @@ -36,20 +40,23 @@ function EnsembleProblem(; prob, EnsembleProblem(prob, prob_func, output_func, reduction, u_init, safetycopy) end -struct WeightedEnsembleProblem{T1<:AbstractEnsembleProblem, T2<:AbstractVector} <: AbstractEnsembleProblem - ensembleprob::T1 - weights::T2 +struct WeightedEnsembleProblem{T1 <: AbstractEnsembleProblem, T2 <: AbstractVector} <: + AbstractEnsembleProblem + ensembleprob::T1 + weights::T2 +end +function Base.propertynames(e::WeightedEnsembleProblem) + (Base.propertynames(getfield(e, :ensembleprob))..., :weights) end -Base.propertynames(e::WeightedEnsembleProblem) = (Base.propertynames(getfield(e, :ensembleprob))..., :weights) function Base.getproperty(e::WeightedEnsembleProblem, f::Symbol) - f === :weights && return getfield(e, :weights) - f === :ensembleprob && return getfield(e, :ensembleprob) - return getproperty(getfield(e, :ensembleprob), f) + f === :weights && return getfield(e, :weights) + f === :ensembleprob && return getfield(e, :ensembleprob) + return getproperty(getfield(e, :ensembleprob), f) end function WeightedEnsembleProblem(args...; weights, kwargs...) - # TODO: allow skipping checks? - @assert sum(weights) ≈ 1 - ep = EnsembleProblem(args...; kwargs...) - @assert length(ep.prob) == length(weights) - WeightedEnsembleProblem(ep, weights) + # TODO: allow skipping checks? + @assert sum(weights) ≈ 1 + ep = EnsembleProblem(args...; kwargs...) + @assert length(ep.prob) == length(weights) + WeightedEnsembleProblem(ep, weights) end diff --git a/src/ensemble/ensemble_solutions.jl b/src/ensemble/ensemble_solutions.jl index ee9850e20..265918f3e 100644 --- a/src/ensemble/ensemble_solutions.jl +++ b/src/ensemble/ensemble_solutions.jl @@ -46,7 +46,7 @@ function EnsembleSolution(sim::T, elapsedTime, converged) end -struct WeightedEnsembleSolution{T1<:AbstractEnsembleSolution, T2<:Number} +struct WeightedEnsembleSolution{T1 <: AbstractEnsembleSolution, T2 <: Number} ensol::T1 weights::Vector{T2} function WeightedEnsembleSolution(ensol, weights) @@ -207,13 +207,18 @@ end end end - Base.@propagate_inbounds function Base.getindex(x::AbstractEnsembleSolution, s, ::Colon) return [xi[s] for xi in x] end -Base.@propagate_inbounds function Base.getindex(x::AbstractEnsembleSolution, ::Colon, args::Colon...) - return invoke(getindex, Tuple{RecursiveArrayTools.AbstractVectorOfArray, Colon, typeof.(args)...}, x, :, args...) +Base.@propagate_inbounds function Base.getindex(x::AbstractEnsembleSolution, + ::Colon, + args::Colon...) + return invoke(getindex, + Tuple{RecursiveArrayTools.AbstractVectorOfArray, Colon, typeof.(args)...}, + x, + :, + args...) end function (sol::AbstractEnsembleSolution)(args...; kwargs...) diff --git a/src/problems/bvp_problems.jl b/src/problems/bvp_problems.jl index c1399bdcb..42ff02842 100644 --- a/src/problems/bvp_problems.jl +++ b/src/problems/bvp_problems.jl @@ -108,7 +108,6 @@ end TruncatedStacktraces.@truncate_stacktrace BVProblem 3 1 2 function BVProblem(f::AbstractBVPFunction, bc, u0, tspan, args...; kwargs...) - BVProblem{isinplace(f, 4)}(f, bc, u0, tspan, args...; kwargs...) end @@ -120,12 +119,12 @@ end # Allow any previous timeseries solution function BVProblem(f::AbstractBVPFunction, bc, sol::T, tspan::Tuple, p = NullParameters(); - kwargs...) where {T <: AbstractTimeseriesSolution} + kwargs...) where {T <: AbstractTimeseriesSolution} BVProblem(f, bc, sol.u, tspan, p) end # Allow a function of time for the initial guess function BVProblem(f::AbstractBVPFunction, bc, initialGuess, tspan::AbstractVector, - p = NullParameters(); kwargs...) + p = NullParameters(); kwargs...) u0 = [initialGuess(i) for i in tspan] BVProblem(f, bc, u0, (tspan[1], tspan[end]), p) end diff --git a/src/scimlfunctions.jl b/src/scimlfunctions.jl index b382170e5..706c5d98f 100644 --- a/src/scimlfunctions.jl +++ b/src/scimlfunctions.jl @@ -390,7 +390,8 @@ See the `modelingtoolkitize` function from automatically symbolically generating the Jacobian and more from the numerically-defined functions. """ -struct ODEFunction{iip, specialize, F, TMM, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TWt, WP, TPJ, S, +struct ODEFunction{iip, specialize, F, TMM, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TWt, WP, TPJ, + S, S2, S3, O, TCV, SYS} <: AbstractODEFunction{iip} f::F @@ -2125,11 +2126,11 @@ TruncatedStacktraces.@truncate_stacktrace OptimizationFunction 1 2 """ $(TYPEDEF) """ -abstract type AbstractBVPFunction{iip} <: +abstract type AbstractBVPFunction{iip, iip} <: AbstractDiffEqFunction{iip} end @doc doc""" - BVPFunction{iip,F,TMM,Ta,Tt,TJ,JVP,VJP,JP,SP,TW,TWt,TPJ,S,S2,S3,O,TCV} <: AbstractBVPFunction{iip,specialize} + BVPFunction{iip_f,iip_bc,F,BF,TMM,Ta,Tt,TJ,BCTJ,JVP,VJP,JP,BCJP,SP,TW,TWt,TPJ,S,S2,S3,O,TCV,BCTCV} <: AbstractBVPFunction{iip_f,iip_bc,specialize} A representation of a BVP function `f`, defined by: @@ -2137,34 +2138,43 @@ A representation of a BVP function `f`, defined by: \frac{du}{dt}=f(u,p,t) ``` +and the constraints: + +```math +\frac{du}{dt}=g(u,p,t) +``` + and all of its related functions, such as the Jacobian of `f`, its gradient with respect to time, and more. For all cases, `u0` is the initial condition, `p` are the parameters, and `t` is the independent variable. ```julia -BVPFunction{iip,specialize}(f; +BVPFunction{iip_f,iip_bc,specialize}(f, bc; mass_matrix = __has_mass_matrix(f) ? f.mass_matrix : I, analytic = __has_analytic(f) ? f.analytic : nothing, tgrad= __has_tgrad(f) ? f.tgrad : nothing, jac = __has_jac(f) ? f.jac : nothing, + bcjac = __has_jac(bc) ? bc.jac : nothing, jvp = __has_jvp(f) ? f.jvp : nothing, vjp = __has_vjp(f) ? f.vjp : nothing, jac_prototype = __has_jac_prototype(f) ? f.jac_prototype : nothing, + bcjac_prototype = __has_jac_prototype(bc) ? bc.jac_prototype : nothing, sparsity = __has_sparsity(f) ? f.sparsity : jac_prototype, paramjac = __has_paramjac(f) ? f.paramjac : nothing, syms = __has_syms(f) ? f.syms : nothing, indepsym= __has_indepsym(f) ? f.indepsym : nothing, paramsyms = __has_paramsyms(f) ? f.paramsyms : nothing, colorvec = __has_colorvec(f) ? f.colorvec : nothing, + bccolorvec = __has_colorvec(f) ? bc.colorvec : nothing, sys = __has_sys(f) ? f.sys : nothing) ``` -Note that only the function `f` itself is required. This function should -be given as `f!(out,du,u,p,t)` or `out = f(du,u,p,t)`. See the section on `iip` -for more details on in-place vs out-of-place handling. +Note that both the function `f` and boundary condition `bc` are required. `f` should +be given as `f(du,u,p,t)` or `out = f(u,p,t)`. `bc` should be given as `bc(res, u, p, t)`. +See the section on `iip` for more details on in-place vs out-of-place handling. All of the remaining functions are optional for improving or accelerating -the usage of `f`. These include: +the usage of `f` and `bc`. These include: - `mass_matrix`: the mass matrix `M` represented in the BVP function. Can be used to determine that the equation is actually a BVP for differential algebraic equation (DAE) @@ -2173,6 +2183,7 @@ the usage of `f`. These include: solution of the BVP. Generally only used for testing and development of the solvers. - `tgrad(dT,u,h,p,t)` or dT=tgrad(u,p,t): returns ``\frac{\partial f(u,p,t)}{\partial t}`` - `jac(J,du,u,p,gamma,t)` or `J=jac(du,u,p,gamma,t)`: returns ``\frac{df}{du}`` +- `bcjac(J,du,u,p,gamma,t)` or `J=jac(du,u,p,gamma,t)`: erturns ``\frac{dbc}{du}`` - `jvp(Jv,v,du,u,p,gamma,t)` or `Jv=jvp(v,du,u,p,gamma,t)`: returns the directional derivative``\frac{df}{du} v`` - `vjp(Jv,v,du,u,p,gamma,t)` or `Jv=vjp(v,du,u,p,gamma,t)`: returns the adjoint @@ -2182,6 +2193,11 @@ the usage of `f`. These include: as the prototype and integrators will specialize on this structure where possible. Non-structured sparsity patterns should use a `SparseMatrixCSC` with a correct sparsity pattern for the Jacobian. The default is `nothing`, which means a dense Jacobian. +- `bcjac_prototype`: a prototype matrix maching the type that matches the Jacobian. For example, + if the Jacobian is tridiagonal, then an appropriately sized `Tridiagonal` matrix can be used + as the prototype and integrators will specialize on this structure where possible. Non-structured + sparsity patterns should use a `SparseMatrixCSC` with a correct sparsity pattern for the Jacobian. + The default is `nothing`, which means a dense Jacobian. - `paramjac(pJ,u,p,t)`: returns the parameter Jacobian ``\frac{df}{dp}``. - `syms`: the symbol names for the elements of the equation. This should match `u0` in size. For example, if `u0 = [0.0,1.0]` and `syms = [:x, :y]`, this will apply a canonical naming to the @@ -2197,6 +2213,12 @@ the usage of `f`. These include: based on the sparsity pattern. Defaults to `nothing`, which means a color vector will be internally computed on demand when required. The cost of this operation is highly dependent on the sparsity pattern. +- `bccolorvec`: a color vector according to the SparseDiffTools.jl definition for the sparsity + pattern of the `bcjac_prototype`. This specializes the Jacobian construction when using + finite differences and automatic differentiation to be computed in an accelerated manner + based on the sparsity pattern. Defaults to `nothing`, which means a color vector will be + internally computed on demand when required. The cost of this operation is highly dependent + on the sparsity pattern. ## iip: In-Place vs Out-Of-Place @@ -2210,19 +2232,23 @@ For more details on this argument, see the ODEFunction documentation. The fields of the BVPFunction type directly match the names of the inputs. """ -struct BVPFunction{iip, specialize, F, TMM, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TWt, - TPJ, - S, S2, S3, O, TCV, - SYS} <: - AbstractBVPFunction{iip} +struct BVPFunction{iip_f, iip_bc, specialize, F, BF, TMM, Ta, Tt, TJ, BCTJ, JVP, VJP, JP, + BCJP, SP, TW, TWt, + TPJ, + S, S2, S3, O, TCV, BCTCV, + SYS} <: + AbstractBVPFunction{iip_f, iip_bc} f::F + bc::BF mass_matrix::TMM analytic::Ta tgrad::Tt jac::TJ + bcjac::BCTJ jvp::JVP vjp::VJP jac_prototype::JP + bcjac_prototype::BCJP sparsity::SP Wfact::TW Wfact_t::TWt @@ -2232,11 +2258,10 @@ struct BVPFunction{iip, specialize, F, TMM, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TW paramsyms::S3 observed::O colorvec::TCV + bccolorvec::BCTCV sys::SYS end - - ######### Backwards Compatibility Overloads (f::ODEFunction)(args...) = f.f(args...) @@ -2376,7 +2401,8 @@ function ODEFunction{iip, specialize}(f; ODEFunction{iip, FunctionWrapperSpecialize, typeof(f), typeof(mass_matrix), typeof(analytic), typeof(tgrad), typeof(jac), typeof(jvp), typeof(vjp), typeof(jac_prototype), - typeof(sparsity), typeof(Wfact), typeof(Wfact_t), typeof(W_prototype), typeof(paramjac), + typeof(sparsity), typeof(Wfact), typeof(Wfact_t), typeof(W_prototype), + typeof(paramjac), typeof(syms), typeof(indepsym), typeof(paramsyms), typeof(observed), typeof(_colorvec), typeof(sys)}(f, mass_matrix, analytic, tgrad, jac, @@ -2387,7 +2413,8 @@ function ODEFunction{iip, specialize}(f; ODEFunction{iip, specialize, typeof(f), typeof(mass_matrix), typeof(analytic), typeof(tgrad), typeof(jac), typeof(jvp), typeof(vjp), typeof(jac_prototype), - typeof(sparsity), typeof(Wfact), typeof(Wfact_t), typeof(W_prototype), typeof(paramjac), + typeof(sparsity), typeof(Wfact), typeof(Wfact_t), typeof(W_prototype), + typeof(paramjac), typeof(syms), typeof(indepsym), typeof(paramsyms), typeof(observed), typeof(_colorvec), typeof(sys)}(f, mass_matrix, analytic, tgrad, jac, @@ -3787,31 +3814,35 @@ function OptimizationFunction{iip}(f, adtype::AbstractADType = NoAD(); cons_expr, sys) end -function BVPFunction{iip, specialize}(f; - mass_matrix = __has_mass_matrix(f) ? f.mass_matrix : - I, - analytic = __has_analytic(f) ? f.analytic : nothing, - tgrad = __has_tgrad(f) ? f.tgrad : nothing, - jac = __has_jac(f) ? f.jac : nothing, - jvp = __has_jvp(f) ? f.jvp : nothing, - vjp = __has_vjp(f) ? f.vjp : nothing, - jac_prototype = __has_jac_prototype(f) ? - f.jac_prototype : - nothing, - sparsity = __has_sparsity(f) ? f.sparsity : - jac_prototype, - Wfact = __has_Wfact(f) ? f.Wfact : nothing, - Wfact_t = __has_Wfact_t(f) ? f.Wfact_t : nothing, - paramjac = __has_paramjac(f) ? f.paramjac : nothing, - syms = __has_syms(f) ? f.syms : nothing, - indepsym = __has_indepsym(f) ? f.indepsym : nothing, - paramsyms = __has_paramsyms(f) ? f.paramsyms : - nothing, - observed = __has_observed(f) ? f.observed : - DEFAULT_OBSERVED, - colorvec = __has_colorvec(f) ? f.colorvec : nothing, - sys = __has_sys(f) ? f.sys : nothing) where {iip, specialize} - +function BVPFunction{iip_f, iip_bc, specialize}(f, bc; + mass_matrix = __has_mass_matrix(f) ? f.mass_matrix : + I, + analytic = __has_analytic(f) ? f.analytic : nothing, + tgrad = __has_tgrad(f) ? f.tgrad : nothing, + jac = __has_jac(f) ? f.jac : nothing, + bcjac = __has_jac(bc) ? bc.jac : nothing, + jvp = __has_jvp(f) ? f.jvp : nothing, + vjp = __has_vjp(f) ? f.vjp : nothing, + jac_prototype = __has_jac_prototype(f) ? + f.jac_prototype : + nothing, + bcjac_prototype = __has_jac_prototype(bc) ? + bc.jac_prototype : + nothing, + sparsity = __has_sparsity(f) ? f.sparsity : + jac_prototype, + Wfact = __has_Wfact(f) ? f.Wfact : nothing, + Wfact_t = __has_Wfact_t(f) ? f.Wfact_t : nothing, + paramjac = __has_paramjac(f) ? f.paramjac : nothing, + syms = __has_syms(f) ? f.syms : nothing, + indepsym = __has_indepsym(f) ? f.indepsym : nothing, + paramsyms = __has_paramsyms(f) ? f.paramsyms : + nothing, + observed = __has_observed(f) ? f.observed : + DEFAULT_OBSERVED, + colorvec = __has_colorvec(f) ? f.colorvec : nothing, + bccolorvec = __has_colorvec(bc) ? bc.colorvec : nothing, + sys = __has_sys(f) ? f.sys : nothing) where {iip_f, iip_bc, specialize} if mass_matrix === I && typeof(f) <: Tuple mass_matrix = ((I for i in 1:length(f))...,) end @@ -3822,82 +3853,117 @@ function BVPFunction{iip, specialize}(f; end if jac === nothing && isa(jac_prototype, AbstractDiffEqLinearOperator) - if iip - jac = update_coefficients! #(J,u,p,t) - else - jac = (u, p, t) -> update_coefficients!(deepcopy(jac_prototype), u, p, t) - end + if iip_f + jac = update_coefficients! #(J,u,p,t) + else + jac = (u, p, t) -> update_coefficients!(deepcopy(jac_prototype), u, p, t) + end + end + + if bcjac === nothing && isa(bcjac_prototype, AbstractDiffEqLinearOperator) + if iip_bc + bcjac = update_coefficients! #(J,u,p,t) + else + bcjac = (u, p, t) -> update_coefficients!(deepcopy(bcjac_prototype), u, p, t) + end end - if jac_prototype !== nothing && colorvec === nothing && ArrayInterfaceCore.fast_matrix_colors(jac_prototype) + if jac_prototype !== nothing && colorvec === nothing && + ArrayInterfaceCore.fast_matrix_colors(jac_prototype) _colorvec = ArrayInterfaceCore.matrix_colors(jac_prototype) else _colorvec = colorvec end - jaciip = jac !== nothing ? isinplace(jac, 4, "jac", iip) : iip - tgradiip = tgrad !== nothing ? isinplace(tgrad, 4, "tgrad", iip) : iip - jvpiip = jvp !== nothing ? isinplace(jvp, 5, "jvp", iip) : iip - vjpiip = vjp !== nothing ? isinplace(vjp, 5, "vjp", iip) : iip - Wfactiip = Wfact !== nothing ? isinplace(Wfact, 5, "Wfact", iip) : iip - Wfact_tiip = Wfact_t !== nothing ? isinplace(Wfact_t, 5, "Wfact_t", iip) : iip - paramjaciip = paramjac !== nothing ? isinplace(paramjac, 4, "paramjac", iip) : iip + if bcjac_prototype !== nothing && bccolorvec === nothing && + ArrayInterfaceCore.fast_matrix_colors(bcjac_prototype) + _bccolorvec = ArrayInterfaceCore.matrix_colors(bcjac_prototype) + else + _bccolorvec = bccolorvec + end - nonconforming = (jaciip, tgradiip, jvpiip, vjpiip, Wfactiip, Wfact_tiip, paramjaciip) .!= iip + jaciip = jac !== nothing ? isinplace(jac, 4, "jac", iip_f) : iip_f + bcjaciip = bcjac !== nothing ? isinplace(bcjac, 4, "bcjac", iip_bc) : iip_bc + tgradiip = tgrad !== nothing ? isinplace(tgrad, 4, "tgrad", iip_f) : iip_f + jvpiip = jvp !== nothing ? isinplace(jvp, 5, "jvp", iip_f) : iip_f + vjpiip = vjp !== nothing ? isinplace(vjp, 5, "vjp", iip_f) : iip_f + Wfactiip = Wfact !== nothing ? isinplace(Wfact, 5, "Wfact", iip_f) : iip_f + Wfact_tiip = Wfact_t !== nothing ? isinplace(Wfact_t, 5, "Wfact_t", iip_f) : iip_f + paramjaciip = paramjac !== nothing ? isinplace(paramjac, 4, "paramjac", iip_f) : iip_f + + nonconforming = (jaciip, + tgradiip, + jvpiip, + vjpiip, + Wfactiip, + Wfact_tiip, + paramjaciip) .!= iip_f + bc_nonconforming = bcjaciip .!= iip_bc if any(nonconforming) nonconforming = findall(nonconforming) - functions = ["jac", "tgrad", "jvp", "vjp", "Wfact", "Wfact_t", "paramjac"][nonconforming] + functions = ["jac", "bcjac", "tgrad", "jvp", "vjp", "Wfact", "Wfact_t", "paramjac"][nonconforming] throw(NonconformingFunctionsError(functions)) end - + + if any(bc_nonconforming) + bc_nonconforming = findall(bc_nonconforming) + functions = ["bcjac"][bc_nonconforming] + throw(NonconformingFunctionsError(functions)) + end + if specialize === NoSpecialize - BVPFunction{iip, specialize, Any, Any, Any, Any, - Any, Any, Any, Any, Any, Any, Any, - Any, typeof(syms), typeof(indepsym), typeof(paramsyms), - Any, typeof(_colorvec), Any}(f, mass_matrix, - analytic, - tgrad, - jac, jvp, vjp, - jac_prototype, - sparsity, Wfact, - Wfact_t, - paramjac, syms, - indepsym, paramsyms, - observed, - _colorvec, sys) + BVPFunction{iip_f, iip_bc, specialize, Any, Any, Any, Any, Any, + Any, Any, Any, Any, Any, Any, Any, Any, Any, + Any, typeof(syms), typeof(indepsym), typeof(paramsyms), + Any, typeof(_colorvec), typeof(_bccolorvec), Any}(f, bc, mass_matrix, + analytic, + tgrad, + jac, bcjac, jvp, vjp, + jac_prototype, + bcjac_prototype, + sparsity, Wfact, + Wfact_t, + paramjac, syms, + indepsym, paramsyms, + observed, + _colorvec, _bccolorvec, sys) elseif specialize === false - BVPFunction{iip, FunctionWrapperSpecialize, - typeof(f), typeof(mass_matrix), typeof(analytic), typeof(tgrad), - typeof(jac), typeof(jvp), typeof(vjp), typeof(jac_prototype), - typeof(sparsity), typeof(Wfact), typeof(Wfact_t), typeof(paramjac), - typeof(syms), typeof(indepsym), typeof(paramsyms), typeof(observed), - typeof(_colorvec), - typeof(sys)}(f, mass_matrix, analytic, tgrad, jac, - jvp, vjp, jac_prototype, sparsity, Wfact, - Wfact_t, paramjac, syms, indepsym, paramsyms, - observed, _colorvec, sys) + BVPFunction{iip_f, iip_bc, FunctionWrapperSpecialize, + typeof(f), typeof(bc), typeof(mass_matrix), typeof(analytic), typeof(tgrad), + typeof(jac), typeof(jvp), typeof(vjp), typeof(jac_prototype), + typeof(sparsity), typeof(Wfact), typeof(Wfact_t), typeof(paramjac), + typeof(syms), typeof(indepsym), typeof(paramsyms), typeof(observed), + typeof(_colorvec), typeof(_bccolorvec), + typeof(sys)}(f, bc, mass_matrix, analytic, tgrad, jac, bcjac, + jvp, vjp, jac_prototype, bcjac_prototype, sparsity, Wfact, + Wfact_t, paramjac, syms, indepsym, paramsyms, + observed, _colorvec, _bccolorvec, sys) else - BVPFunction{iip, specialize, typeof(f), typeof(mass_matrix), typeof(analytic), - typeof(tgrad), - typeof(jac), typeof(jvp), typeof(vjp), typeof(jac_prototype), - typeof(sparsity), typeof(Wfact), typeof(Wfact_t), - typeof(paramjac), typeof(syms), typeof(indepsym), typeof(paramsyms), - typeof(observed), - typeof(_colorvec), typeof(sys)}(f, mass_matrix, analytic, - tgrad, jac, jvp, vjp, - jac_prototype, sparsity, - Wfact, Wfact_t, paramjac, - syms, indepsym, paramsyms, observed, - _colorvec, sys) + BVPFunction{iip_f, iip_bc, specialize, typeof(f), typeof(bc), typeof(mass_matrix), + typeof(analytic), + typeof(tgrad), + typeof(jac), typeof(bcjac), typeof(jvp), typeof(vjp), typeof(jac_prototype), + typeof(bcjac_prototype), + typeof(sparsity), typeof(Wfact), typeof(Wfact_t), + typeof(paramjac), typeof(syms), typeof(indepsym), typeof(paramsyms), + typeof(observed), + typeof(_colorvec), typeof(_bccolorvec), typeof(sys)}(f, bc, mass_matrix, analytic, + tgrad, jac, bcjac, jvp, vjp, + jac_prototype, bcjac_prototype, sparsity, + Wfact, Wfact_t, paramjac, + syms, indepsym, paramsyms, observed, + _colorvec, _bccolorvec, sys) end end -function BVPFunction{iip}(f; kwargs...) where {iip} - BVPFunction{iip, FullSpecialize}(f; kwargs...) +function BVPFunction{iip_f, iip_bc}(f, bc; kwargs...) where {iip_f, iip_bc} + BVPFunction{iip_f, iip_bc, FullSpecialize}(f, bc; kwargs...) end -BVPFunction{iip}(f::BVPFunction; kwargs...) where {iip} = f -BVPFunction(f; kwargs...) = BVPFunction{isinplace(f, 4), FullSpecialize}(f; kwargs...) -BVPFunction(f::BVPFunction; kwargs...) = f +BVPFunction{iip_f, iip_bc}(f::BVPFunction, bc; kwargs...) where {iip_f, iip_bc} = f +function BVPFunction(f, bc; kwargs...) + BVPFunction{isinplace(f, 4), isinplace(bc, 4), FullSpecialize}(f, bc; kwargs...) +end +#BVPFunction(f::BVPFunction; kwargs...) = f ########## Existence Functions @@ -4015,5 +4081,4 @@ for S in [:ODEFunction (args...) -> $S{iip, FullSpecialize, map(typeof, args)...}(args...) end end - end diff --git a/test/downstream/ensemble_multi_prob.jl b/test/downstream/ensemble_multi_prob.jl index dfb61f90c..9f09a58b2 100644 --- a/test/downstream/ensemble_multi_prob.jl +++ b/test/downstream/ensemble_multi_prob.jl @@ -4,11 +4,11 @@ using ModelingToolkit, OrdinaryDiffEq, Test D = Differential(t) @named sys1 = ODESystem([D(x) ~ x, - D(y) ~ -y]) + D(y) ~ -y]) @named sys2 = ODESystem([D(x) ~ 2x, - D(y) ~ -2y]) + D(y) ~ -2y]) @named sys3 = ODESystem([D(x) ~ 3x, - D(y) ~ -3y]) + D(y) ~ -3y]) prob1 = ODEProblem(sys1, [1.0, 1.0], (0.0, 1.0)) prob2 = ODEProblem(sys2, [2.0, 2.0], (0.0, 1.0)) @@ -22,6 +22,6 @@ for i in 1:3 @test sol[y, :][i] == sol[i][y] end # Ensemble is a recursive array -@test only.(sol(0.0, idxs=[x])) == sol[1, 1, :] == first.(sol[x, :]) +@test only.(sol(0.0, idxs = [x])) == sol[1, 1, :] == first.(sol[x, :]) # TODO: fix the interpolation -@test only.(sol(1.0, idxs=[x])) ≈ last.(sol[x, :]) +@test only.(sol(1.0, idxs = [x])) ≈ last.(sol[x, :]) diff --git a/test/function_building_error_messages.jl b/test/function_building_error_messages.jl index dfa908b70..fd4ebd747 100644 --- a/test/function_building_error_messages.jl +++ b/test/function_building_error_messages.jl @@ -472,82 +472,90 @@ bfiip(du, u, p, t) = du .= u bofboth(u, p, t) = u bofboth(du, u, p, t) = du .= u -BVPFunction(bofboth) -BVPFunction{true}(bofboth) -BVPFunction{false}(bofboth) +bc(res, u, p, t) = res .= u + +BVPFunction(bofboth, bc) +BVPFunction{true, true}(bofboth, bc) +BVPFunction{false, true}(bofboth, bc) jac(u, t) = [1.0] -@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfiip, jac = jac) -@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfoop, jac = jac) +@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfiip, bc, jac = jac) +@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfoop, bc, jac = jac) +@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfiip, bc, bcjac = jac) +@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfoop, bc, bcjac = jac) jac(u, p, t) = [1.0] -@test_throws SciMLBase.NonconformingFunctionsError BVPFunction(bfiip, jac = jac) -BVPFunction(bfoop, jac = jac) +@test_throws SciMLBase.NonconformingFunctionsError BVPFunction(bfiip, bc, jac = jac) +@test_throws SciMLBase.NonconformingFunctionsError BVPFunction(bfiip, bc, bcjac = jac) +BVPFunction(bfoop, bc, jac = jac) +@test_throws SciMLBase.NonconformingFunctionsError BVPFunction(bfoop, bc, bcjac = jac) jac(du, u, p, t) = [1.0] -BVPFunction(bfiip, jac = jac) -BVPFunction(bfoop, jac = jac) +BVPFunction(bfiip, bc, jac = jac) +BVPFunction(bfoop, bc, jac = jac) Wfact(u, t) = [1.0] -@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfiip, Wfact = Wfact) -@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfoop, Wfact = Wfact) +@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfiip, bc, Wfact = Wfact) +@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfoop, bc, Wfact = Wfact) Wfact(u, p, t) = [1.0] -@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfiip, Wfact = Wfact) -@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfoop, Wfact = Wfact) +@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfiip, bc, Wfact = Wfact) +@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfoop, bc, Wfact = Wfact) Wfact(u, p, gamma, t) = [1.0] -@test_throws SciMLBase.NonconformingFunctionsError BVPFunction(bfiip, Wfact = Wfact) -BVPFunction(bfoop, Wfact = Wfact) +@test_throws SciMLBase.NonconformingFunctionsError BVPFunction(bfiip, bc, Wfact = Wfact) +BVPFunction(bfoop, bc, Wfact = Wfact) Wfact(du, u, p, gamma, t) = [1.0] -BVPFunction(bfiip, Wfact = Wfact) -BVPFunction(bfoop, Wfact = Wfact) +BVPFunction(bfiip, bc, Wfact = Wfact) +BVPFunction(bfoop, bc, Wfact = Wfact) Wfact_t(u, t) = [1.0] -@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfiip, Wfact_t = Wfact_t) -@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfoop, Wfact_t = Wfact_t) +@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfiip, bc, Wfact_t = Wfact_t) +@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfoop, bc, Wfact_t = Wfact_t) Wfact_t(u, p, t) = [1.0] -@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfiip, Wfact_t = Wfact_t) -@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfoop, Wfact_t = Wfact_t) +@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfiip, bc, Wfact_t = Wfact_t) +@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfoop, bc, Wfact_t = Wfact_t) Wfact_t(u, p, gamma, t) = [1.0] -@test_throws SciMLBase.NonconformingFunctionsError BVPFunction(bfiip, Wfact_t = Wfact_t) -BVPFunction(bfoop, Wfact_t = Wfact_t) +@test_throws SciMLBase.NonconformingFunctionsError BVPFunction(bfiip, bc, Wfact_t = Wfact_t) +BVPFunction(bfoop, bc, Wfact_t = Wfact_t) Wfact_t(du, u, p, gamma, t) = [1.0] -BVPFunction(bfiip, Wfact_t = Wfact_t) -BVPFunction(bfoop, Wfact_t = Wfact_t) +BVPFunction(bfiip, bc, Wfact_t = Wfact_t) +BVPFunction(bfoop, bc, Wfact_t = Wfact_t) tgrad(u, t) = [1.0] -@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfiip, tgrad = tgrad) -@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfoop, tgrad = tgrad) +@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfiip, bc, tgrad = tgrad) +@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfoop, bc, tgrad = tgrad) tgrad(u, p, t) = [1.0] -@test_throws SciMLBase.NonconformingFunctionsError BVPFunction(bfiip, tgrad=tgrad) -BVPFunction(bfoop, tgrad = tgrad) +@test_throws SciMLBase.NonconformingFunctionsError BVPFunction(bfiip, bc, tgrad = tgrad) +BVPFunction(bfoop, bc, tgrad = tgrad) tgrad(du, u, p, t) = [1.0] -BVPFunction(bfiip, tgrad = tgrad) -BVPFunction(bfoop, tgrad = tgrad) +BVPFunction(bfiip, bc, tgrad = tgrad) +BVPFunction(bfoop, bc, tgrad = tgrad) paramjac(u, t) = [1.0] -@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfiip, paramjac = paramjac) -@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfoop, paramjac = paramjac) +@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfiip, bc, paramjac = paramjac) +@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfoop, bc, paramjac = paramjac) paramjac(u, p, t) = [1.0] -@test_throws SciMLBase.NonconformingFunctionsError BVPFunction(bfiip, paramjac = paramjac) -BVPFunction(bfoop, paramjac = paramjac) +@test_throws SciMLBase.NonconformingFunctionsError BVPFunction(bfiip, + bc, + paramjac = paramjac) +BVPFunction(bfoop, bc, paramjac = paramjac) paramjac(du, u, p, t) = [1.0] -BVPFunction(bfiip, paramjac = paramjac) -BVPFunction(bfoop, paramjac = paramjac) +BVPFunction(bfiip, bc, paramjac = paramjac) +BVPFunction(bfoop, bc, paramjac = paramjac) jvp(u, p, t) = [1.0] -@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfiip, jvp = jvp) -@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfoop, jvp = jvp) +@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfiip, bc, jvp = jvp) +@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfoop, bc, jvp = jvp) jvp(u, v, p, t) = [1.0] -@test_throws SciMLBase.NonconformingFunctionsError BVPFunction(bfiip, jvp = jvp) -BVPFunction(bfoop, jvp = jvp) +@test_throws SciMLBase.NonconformingFunctionsError BVPFunction(bfiip, bc, jvp = jvp) +BVPFunction(bfoop, bc, jvp = jvp) jvp(du, u, v, p, t) = [1.0] -BVPFunction(bfiip, jvp = jvp) -BVPFunction(bfoop, jvp = jvp) +BVPFunction(bfiip, bc, jvp = jvp) +BVPFunction(bfoop, bc, jvp = jvp) vjp(u, p, t) = [1.0] -@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfiip, vjp = vjp) -@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfoop, vjp = vjp) +@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfiip, bc, vjp = vjp) +@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfoop, bc, vjp = vjp) vjp(u, v, p, t) = [1.0] -@test_throws SciMLBase.NonconformingFunctionsError BVPFunction(bfiip, vjp = vjp) -BVPFunction(bfoop, vjp = vjp) +@test_throws SciMLBase.NonconformingFunctionsError BVPFunction(bfiip, bc, vjp = vjp) +BVPFunction(bfoop, bc, vjp = vjp) vjp(du, u, v, p, t) = [1.0] -BVPFunction(bfiip, vjp = vjp) -BVPFunction(bfoop, vjp = vjp) \ No newline at end of file +BVPFunction(bfiip, bc, vjp = vjp) +BVPFunction(bfoop, bc, vjp = vjp) diff --git a/test/solution_interface.jl b/test/solution_interface.jl index 6d1a1ec5c..798aa31d7 100644 --- a/test/solution_interface.jl +++ b/test/solution_interface.jl @@ -34,7 +34,7 @@ end ode = ODEProblem(f, 1.0, (0.0, 1.0)) sol = SciMLBase.build_solution(ode, :NoAlgorithm, [ode.tspan[begin]], [ode.u0]) @test sol(0.0) == 1.0 - @test sol([0.0,0.0]) == [1.0, 1.0] + @test sol([0.0, 0.0]) == [1.0, 1.0] # test that indexing out of bounds doesn't segfault @test_throws ErrorException sol(1) @test_throws ErrorException sol(-0.5) From 8046eb99f3917acdc92ace1fd58417ac44005e2d Mon Sep 17 00:00:00 2001 From: ErikQQY <2283984853@qq.com> Date: Sun, 13 Aug 2023 16:19:09 +0800 Subject: [PATCH 3/9] Cover more tests Signed-off-by: ErikQQY <2283984853@qq.com> --- test/function_building_error_messages.jl | 146 ++++++++++++++--------- 1 file changed, 90 insertions(+), 56 deletions(-) diff --git a/test/function_building_error_messages.jl b/test/function_building_error_messages.jl index fd4ebd747..e7ec1ba72 100644 --- a/test/function_building_error_messages.jl +++ b/test/function_building_error_messages.jl @@ -469,93 +469,127 @@ OptimizationProblem(optf, 1.0) bfoop(u, p, t) = u bfiip(du, u, p, t) = du .= u -bofboth(u, p, t) = u -bofboth(du, u, p, t) = du .= u +bfboth(u, p, t) = u +bfboth(du, u, p, t) = du .= u -bc(res, u, p, t) = res .= u +bcoop(u, p, t) = u +bciip(res, u, p, t) = res .= u -BVPFunction(bofboth, bc) -BVPFunction{true, true}(bofboth, bc) -BVPFunction{false, true}(bofboth, bc) +bcfboth(u, p, t) = u +bcfboth(du, u, p, t) = du .= u + +BVPFunction(bfboth, bcfboth) +BVPFunction{true, true}(bfboth, bcfboth) +BVPFunction{false, true}(bfboth, bcfboth) +BVPFunction{true, false}(bfboth, bcfboth) +BVPFunction{false, false}(bfboth, bcfboth) jac(u, t) = [1.0] -@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfiip, bc, jac = jac) -@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfoop, bc, jac = jac) -@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfiip, bc, bcjac = jac) -@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfoop, bc, bcjac = jac) +bcjac(u, t) = [1.0] +@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfiip, + bciip, + jac = jac, + bcjac = bcjac) +@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfoop, + bciip, + jac = jac, + bcjac = bcjac) +@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfiip, + bcoop, + jac = jac, + bcjac = bcjac) +@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfoop, + bcoop, + jac = jac, + bcjac = bcjac) jac(u, p, t) = [1.0] -@test_throws SciMLBase.NonconformingFunctionsError BVPFunction(bfiip, bc, jac = jac) -@test_throws SciMLBase.NonconformingFunctionsError BVPFunction(bfiip, bc, bcjac = jac) -BVPFunction(bfoop, bc, jac = jac) -@test_throws SciMLBase.NonconformingFunctionsError BVPFunction(bfoop, bc, bcjac = jac) +bcjac(u, p, t) = [1.0] +@test_throws SciMLBase.NonconformingFunctionsError BVPFunction(bfiip, + bcoop, + jac = jac, + bcjac = bcjac) +@test_throws SciMLBase.NonconformingFunctionsError BVPFunction(bfiip, + bciip, + jac = jac, + bcjac = bcjac) +@test_throws SciMLBase.NonconformingFunctionsError BVPFunction(bfoop, + bciip, + jac = jac, + bcjac = bcjac) +BVPFunction(bfoop, bcoop, jac = jac) jac(du, u, p, t) = [1.0] -BVPFunction(bfiip, bc, jac = jac) -BVPFunction(bfoop, bc, jac = jac) +bcjac(du, u, p, t) = [1.0] +BVPFunction(bfiip, bciip, jac = jac, bcjac = bcjac) +BVPFunction(bfoop, bciip, jac = jac, bcjac = bcjac) +BVPFunction(bfiip, bcoop, jac = jac, bcjac = bcjac) +BVPFunction(bfoop, bcoop, jac = jac, bcjac = bcjac) Wfact(u, t) = [1.0] -@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfiip, bc, Wfact = Wfact) -@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfoop, bc, Wfact = Wfact) +@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfiip, bciip, Wfact = Wfact) +@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfoop, bciip, Wfact = Wfact) Wfact(u, p, t) = [1.0] -@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfiip, bc, Wfact = Wfact) -@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfoop, bc, Wfact = Wfact) +@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfiip, bciip, Wfact = Wfact) +@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfoop, bciip, Wfact = Wfact) Wfact(u, p, gamma, t) = [1.0] -@test_throws SciMLBase.NonconformingFunctionsError BVPFunction(bfiip, bc, Wfact = Wfact) -BVPFunction(bfoop, bc, Wfact = Wfact) +@test_throws SciMLBase.NonconformingFunctionsError BVPFunction(bfiip, bciip, Wfact = Wfact) +BVPFunction(bfoop, bciip, Wfact = Wfact) Wfact(du, u, p, gamma, t) = [1.0] -BVPFunction(bfiip, bc, Wfact = Wfact) -BVPFunction(bfoop, bc, Wfact = Wfact) +BVPFunction(bfiip, bciip, Wfact = Wfact) +BVPFunction(bfoop, bciip, Wfact = Wfact) Wfact_t(u, t) = [1.0] -@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfiip, bc, Wfact_t = Wfact_t) -@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfoop, bc, Wfact_t = Wfact_t) +@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfiip, bciip, Wfact_t = Wfact_t) +@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfoop, bciip, Wfact_t = Wfact_t) Wfact_t(u, p, t) = [1.0] -@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfiip, bc, Wfact_t = Wfact_t) -@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfoop, bc, Wfact_t = Wfact_t) +@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfiip, bciip, Wfact_t = Wfact_t) +@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfoop, bciip, Wfact_t = Wfact_t) Wfact_t(u, p, gamma, t) = [1.0] -@test_throws SciMLBase.NonconformingFunctionsError BVPFunction(bfiip, bc, Wfact_t = Wfact_t) -BVPFunction(bfoop, bc, Wfact_t = Wfact_t) +@test_throws SciMLBase.NonconformingFunctionsError BVPFunction(bfiip, + bciip, + Wfact_t = Wfact_t) +BVPFunction(bfoop, bciip, Wfact_t = Wfact_t) Wfact_t(du, u, p, gamma, t) = [1.0] -BVPFunction(bfiip, bc, Wfact_t = Wfact_t) -BVPFunction(bfoop, bc, Wfact_t = Wfact_t) +BVPFunction(bfiip, bciip, Wfact_t = Wfact_t) +BVPFunction(bfoop, bciip, Wfact_t = Wfact_t) tgrad(u, t) = [1.0] -@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfiip, bc, tgrad = tgrad) -@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfoop, bc, tgrad = tgrad) +@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfiip, bciip, tgrad = tgrad) +@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfoop, bciip, tgrad = tgrad) tgrad(u, p, t) = [1.0] -@test_throws SciMLBase.NonconformingFunctionsError BVPFunction(bfiip, bc, tgrad = tgrad) -BVPFunction(bfoop, bc, tgrad = tgrad) +@test_throws SciMLBase.NonconformingFunctionsError BVPFunction(bfiip, bciip, tgrad = tgrad) +BVPFunction(bfoop, bciip, tgrad = tgrad) tgrad(du, u, p, t) = [1.0] -BVPFunction(bfiip, bc, tgrad = tgrad) -BVPFunction(bfoop, bc, tgrad = tgrad) +BVPFunction(bfiip, bciip, tgrad = tgrad) +BVPFunction(bfoop, bciip, tgrad = tgrad) paramjac(u, t) = [1.0] -@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfiip, bc, paramjac = paramjac) -@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfoop, bc, paramjac = paramjac) +@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfiip, bciip, paramjac = paramjac) +@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfoop, bciip, paramjac = paramjac) paramjac(u, p, t) = [1.0] @test_throws SciMLBase.NonconformingFunctionsError BVPFunction(bfiip, - bc, + bciip, paramjac = paramjac) -BVPFunction(bfoop, bc, paramjac = paramjac) +BVPFunction(bfoop, bciip, paramjac = paramjac) paramjac(du, u, p, t) = [1.0] -BVPFunction(bfiip, bc, paramjac = paramjac) -BVPFunction(bfoop, bc, paramjac = paramjac) +BVPFunction(bfiip, bciip, paramjac = paramjac) +BVPFunction(bfoop, bciip, paramjac = paramjac) jvp(u, p, t) = [1.0] -@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfiip, bc, jvp = jvp) -@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfoop, bc, jvp = jvp) +@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfiip, bciip, jvp = jvp) +@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfoop, bciip, jvp = jvp) jvp(u, v, p, t) = [1.0] -@test_throws SciMLBase.NonconformingFunctionsError BVPFunction(bfiip, bc, jvp = jvp) -BVPFunction(bfoop, bc, jvp = jvp) +@test_throws SciMLBase.NonconformingFunctionsError BVPFunction(bfiip, bciip, jvp = jvp) +BVPFunction(bfoop, bciip, jvp = jvp) jvp(du, u, v, p, t) = [1.0] -BVPFunction(bfiip, bc, jvp = jvp) -BVPFunction(bfoop, bc, jvp = jvp) +BVPFunction(bfiip, bciip, jvp = jvp) +BVPFunction(bfoop, bciip, jvp = jvp) vjp(u, p, t) = [1.0] -@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfiip, bc, vjp = vjp) -@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfoop, bc, vjp = vjp) +@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfiip, bciip, vjp = vjp) +@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfoop, bciip, vjp = vjp) vjp(u, v, p, t) = [1.0] -@test_throws SciMLBase.NonconformingFunctionsError BVPFunction(bfiip, bc, vjp = vjp) -BVPFunction(bfoop, bc, vjp = vjp) +@test_throws SciMLBase.NonconformingFunctionsError BVPFunction(bfiip, bciip, vjp = vjp) +BVPFunction(bfoop, bciip, vjp = vjp) vjp(du, u, v, p, t) = [1.0] -BVPFunction(bfiip, bc, vjp = vjp) -BVPFunction(bfoop, bc, vjp = vjp) +BVPFunction(bfiip, bciip, vjp = vjp) +BVPFunction(bfoop, bciip, vjp = vjp) From a734ff8771052996a20166f0f62f172dc882f089 Mon Sep 17 00:00:00 2001 From: ErikQQY <2283984853@qq.com> Date: Sun, 13 Aug 2023 21:43:45 +0800 Subject: [PATCH 4/9] Complete BVPFunction Signed-off-by: ErikQQY <2283984853@qq.com> --- src/problems/bvp_problems.jl | 12 +- src/scimlfunctions.jl | 53 +++---- test/function_building_error_messages.jl | 168 +++++++++++------------ 3 files changed, 116 insertions(+), 117 deletions(-) diff --git a/src/problems/bvp_problems.jl b/src/problems/bvp_problems.jl index 42ff02842..b2a7cbbee 100644 --- a/src/problems/bvp_problems.jl +++ b/src/problems/bvp_problems.jl @@ -78,17 +78,17 @@ every solve call. * `p`: The parameters for the problem. Defaults to `NullParameters` * `kwargs`: The keyword arguments passed onto the solves. """ -struct BVProblem{uType, tType, isinplace, P, F, bF, PT, K} <: +struct BVProblem{uType, tType, isinplace, P, F, BF, PT, K} <: AbstractBVProblem{uType, tType, isinplace} f::F - bc::bF + bc::BF u0::uType tspan::tType p::P problem_type::PT kwargs::K - @add_kwonly function BVProblem{iip}(f::AbstractODEFunction, bc, u0, tspan, + @add_kwonly function BVProblem{iip}(f::AbstractBVPFunction{iip}, bc, u0, tspan, p = NullParameters(), problem_type = StandardBVProblem(); kwargs...) where {iip} @@ -96,12 +96,12 @@ struct BVProblem{uType, tType, isinplace, P, F, bF, PT, K} <: warn_paramtype(p) new{typeof(u0), typeof(_tspan), iip, typeof(p), typeof(f), typeof(bc), - typeof(problem_type), typeof(kwargs)}(f, bc, u0, _tspan, p, + typeof(problem_type), typeof(kwargs)}(f, f.bc, u0, _tspan, p, problem_type, kwargs) end function BVProblem{iip}(f, bc, u0, tspan, p = NullParameters(); kwargs...) where {iip} - BVProblem(BVPFunction{iip}(f), bc, u0, tspan, p; kwargs...) + BVProblem(BVPFunction{iip}(f, bc), bc, u0, tspan, p; kwargs...) end end @@ -112,7 +112,7 @@ function BVProblem(f::AbstractBVPFunction, bc, u0, tspan, args...; kwargs...) end function BVProblem(f, bc, u0, tspan, p = NullParameters(); kwargs...) - BVProblem(BVPFunction(f), bc, u0, tspan, p; kwargs...) + BVProblem(BVPFunction(f, bc), bc, u0, tspan, p; kwargs...) end # convenience interfaces: diff --git a/src/scimlfunctions.jl b/src/scimlfunctions.jl index 706c5d98f..e050c9cf4 100644 --- a/src/scimlfunctions.jl +++ b/src/scimlfunctions.jl @@ -1083,7 +1083,7 @@ SDEFunction{iip,specialize}(f,g; sys = __has_sys(f) ? f.sys : nothing) ``` -Note that only the function `f` itself is required. This function should +Note that both the function `f` and `g` are required. This function should be given as `f!(du,u,p,t)` or `du = f(u,p,t)`. See the section on `iip` for more details on in-place vs out-of-place handling. @@ -2126,11 +2126,11 @@ TruncatedStacktraces.@truncate_stacktrace OptimizationFunction 1 2 """ $(TYPEDEF) """ -abstract type AbstractBVPFunction{iip, iip} <: +abstract type AbstractBVPFunction{iip} <: AbstractDiffEqFunction{iip} end @doc doc""" - BVPFunction{iip_f,iip_bc,F,BF,TMM,Ta,Tt,TJ,BCTJ,JVP,VJP,JP,BCJP,SP,TW,TWt,TPJ,S,S2,S3,O,TCV,BCTCV} <: AbstractBVPFunction{iip_f,iip_bc,specialize} + BVPFunction{iip,F,BF,TMM,Ta,Tt,TJ,BCTJ,JVP,VJP,JP,BCJP,SP,TW,TWt,TPJ,S,S2,S3,O,TCV,BCTCV} <: AbstractBVPFunction{iip,specialize} A representation of a BVP function `f`, defined by: @@ -2149,7 +2149,7 @@ with respect to time, and more. For all cases, `u0` is the initial condition, `p` are the parameters, and `t` is the independent variable. ```julia -BVPFunction{iip_f,iip_bc,specialize}(f, bc; +BVPFunction{iip,specialize}(f, bc; mass_matrix = __has_mass_matrix(f) ? f.mass_matrix : I, analytic = __has_analytic(f) ? f.analytic : nothing, tgrad= __has_tgrad(f) ? f.tgrad : nothing, @@ -2232,12 +2232,12 @@ For more details on this argument, see the ODEFunction documentation. The fields of the BVPFunction type directly match the names of the inputs. """ -struct BVPFunction{iip_f, iip_bc, specialize, F, BF, TMM, Ta, Tt, TJ, BCTJ, JVP, VJP, JP, +struct BVPFunction{iip, specialize, F, BF, TMM, Ta, Tt, TJ, BCTJ, JVP, VJP, JP, BCJP, SP, TW, TWt, TPJ, S, S2, S3, O, TCV, BCTCV, SYS} <: - AbstractBVPFunction{iip_f, iip_bc} + AbstractBVPFunction{iip} f::F bc::BF mass_matrix::TMM @@ -3814,7 +3814,7 @@ function OptimizationFunction{iip}(f, adtype::AbstractADType = NoAD(); cons_expr, sys) end -function BVPFunction{iip_f, iip_bc, specialize}(f, bc; +function BVPFunction{iip, specialize}(f, bc; mass_matrix = __has_mass_matrix(f) ? f.mass_matrix : I, analytic = __has_analytic(f) ? f.analytic : nothing, @@ -3842,7 +3842,7 @@ function BVPFunction{iip_f, iip_bc, specialize}(f, bc; DEFAULT_OBSERVED, colorvec = __has_colorvec(f) ? f.colorvec : nothing, bccolorvec = __has_colorvec(bc) ? bc.colorvec : nothing, - sys = __has_sys(f) ? f.sys : nothing) where {iip_f, iip_bc, specialize} + sys = __has_sys(f) ? f.sys : nothing) where {iip, specialize} if mass_matrix === I && typeof(f) <: Tuple mass_matrix = ((I for i in 1:length(f))...,) end @@ -3882,14 +3882,15 @@ function BVPFunction{iip_f, iip_bc, specialize}(f, bc; _bccolorvec = bccolorvec end - jaciip = jac !== nothing ? isinplace(jac, 4, "jac", iip_f) : iip_f - bcjaciip = bcjac !== nothing ? isinplace(bcjac, 4, "bcjac", iip_bc) : iip_bc - tgradiip = tgrad !== nothing ? isinplace(tgrad, 4, "tgrad", iip_f) : iip_f - jvpiip = jvp !== nothing ? isinplace(jvp, 5, "jvp", iip_f) : iip_f - vjpiip = vjp !== nothing ? isinplace(vjp, 5, "vjp", iip_f) : iip_f - Wfactiip = Wfact !== nothing ? isinplace(Wfact, 5, "Wfact", iip_f) : iip_f - Wfact_tiip = Wfact_t !== nothing ? isinplace(Wfact_t, 5, "Wfact_t", iip_f) : iip_f - paramjaciip = paramjac !== nothing ? isinplace(paramjac, 4, "paramjac", iip_f) : iip_f + bciip = isinplace(bc, 4, "bc", iip) + jaciip = jac !== nothing ? isinplace(jac, 4, "jac", iip) : iip + bcjaciip = bcjac !== nothing ? isinplace(bcjac, 4, "bcjac", bciip) : bciip + tgradiip = tgrad !== nothing ? isinplace(tgrad, 4, "tgrad", iip) : iip + jvpiip = jvp !== nothing ? isinplace(jvp, 5, "jvp", iip) : iip + vjpiip = vjp !== nothing ? isinplace(vjp, 5, "vjp", iip) : iip + Wfactiip = Wfact !== nothing ? isinplace(Wfact, 5, "Wfact", iip) : iip + Wfact_tiip = Wfact_t !== nothing ? isinplace(Wfact_t, 5, "Wfact_t", iip) : iip + paramjaciip = paramjac !== nothing ? isinplace(paramjac, 4, "paramjac", iip) : iip nonconforming = (jaciip, tgradiip, @@ -3897,8 +3898,8 @@ function BVPFunction{iip_f, iip_bc, specialize}(f, bc; vjpiip, Wfactiip, Wfact_tiip, - paramjaciip) .!= iip_f - bc_nonconforming = bcjaciip .!= iip_bc + paramjaciip) .!= iip + bc_nonconforming = bcjaciip .!= bciip if any(nonconforming) nonconforming = findall(nonconforming) functions = ["jac", "bcjac", "tgrad", "jvp", "vjp", "Wfact", "Wfact_t", "paramjac"][nonconforming] @@ -3912,7 +3913,7 @@ function BVPFunction{iip_f, iip_bc, specialize}(f, bc; end if specialize === NoSpecialize - BVPFunction{iip_f, iip_bc, specialize, Any, Any, Any, Any, Any, + BVPFunction{iip, specialize, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, typeof(syms), typeof(indepsym), typeof(paramsyms), Any, typeof(_colorvec), typeof(_bccolorvec), Any}(f, bc, mass_matrix, @@ -3928,7 +3929,7 @@ function BVPFunction{iip_f, iip_bc, specialize}(f, bc; observed, _colorvec, _bccolorvec, sys) elseif specialize === false - BVPFunction{iip_f, iip_bc, FunctionWrapperSpecialize, + BVPFunction{iip, FunctionWrapperSpecialize, typeof(f), typeof(bc), typeof(mass_matrix), typeof(analytic), typeof(tgrad), typeof(jac), typeof(jvp), typeof(vjp), typeof(jac_prototype), typeof(sparsity), typeof(Wfact), typeof(Wfact_t), typeof(paramjac), @@ -3939,7 +3940,7 @@ function BVPFunction{iip_f, iip_bc, specialize}(f, bc; Wfact_t, paramjac, syms, indepsym, paramsyms, observed, _colorvec, _bccolorvec, sys) else - BVPFunction{iip_f, iip_bc, specialize, typeof(f), typeof(bc), typeof(mass_matrix), + BVPFunction{iip, specialize, typeof(f), typeof(bc), typeof(mass_matrix), typeof(analytic), typeof(tgrad), typeof(jac), typeof(bcjac), typeof(jvp), typeof(vjp), typeof(jac_prototype), @@ -3956,14 +3957,14 @@ function BVPFunction{iip_f, iip_bc, specialize}(f, bc; end end -function BVPFunction{iip_f, iip_bc}(f, bc; kwargs...) where {iip_f, iip_bc} - BVPFunction{iip_f, iip_bc, FullSpecialize}(f, bc; kwargs...) +function BVPFunction{iip}(f, bc; kwargs...) where {iip} + BVPFunction{iip, FullSpecialize}(f, bc; kwargs...) end -BVPFunction{iip_f, iip_bc}(f::BVPFunction, bc; kwargs...) where {iip_f, iip_bc} = f +BVPFunction{iip}(f::BVPFunction, bc; kwargs...) where {iip} = f function BVPFunction(f, bc; kwargs...) - BVPFunction{isinplace(f, 4), isinplace(bc, 4), FullSpecialize}(f, bc; kwargs...) + BVPFunction{isinplace(f, 4), FullSpecialize}(f, bc; kwargs...) end -#BVPFunction(f::BVPFunction; kwargs...) = f +BVPFunction(f::BVPFunction; kwargs...) = f ########## Existence Functions diff --git a/test/function_building_error_messages.jl b/test/function_building_error_messages.jl index e7ec1ba72..a150371d8 100644 --- a/test/function_building_error_messages.jl +++ b/test/function_building_error_messages.jl @@ -479,117 +479,115 @@ bcfboth(u, p, t) = u bcfboth(du, u, p, t) = du .= u BVPFunction(bfboth, bcfboth) -BVPFunction{true, true}(bfboth, bcfboth) -BVPFunction{false, true}(bfboth, bcfboth) -BVPFunction{true, false}(bfboth, bcfboth) -BVPFunction{false, false}(bfboth, bcfboth) +BVPFunction{true}(bfboth, bcfboth) +BVPFunction{false}(bfboth, bcfboth) -jac(u, t) = [1.0] +bjac(u, t) = [1.0] bcjac(u, t) = [1.0] @test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfiip, bciip, - jac = jac, + jac = bjac, bcjac = bcjac) @test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfoop, bciip, - jac = jac, + jac = bjac, bcjac = bcjac) @test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfiip, bcoop, - jac = jac, + jac = bjac, bcjac = bcjac) @test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfoop, bcoop, - jac = jac, + jac = bjac, bcjac = bcjac) -jac(u, p, t) = [1.0] +bjac(u, p, t) = [1.0] bcjac(u, p, t) = [1.0] @test_throws SciMLBase.NonconformingFunctionsError BVPFunction(bfiip, bcoop, - jac = jac, + jac = bjac, bcjac = bcjac) @test_throws SciMLBase.NonconformingFunctionsError BVPFunction(bfiip, bciip, - jac = jac, + jac = bjac, bcjac = bcjac) @test_throws SciMLBase.NonconformingFunctionsError BVPFunction(bfoop, bciip, - jac = jac, + jac = bjac, bcjac = bcjac) -BVPFunction(bfoop, bcoop, jac = jac) -jac(du, u, p, t) = [1.0] +BVPFunction(bfoop, bcoop, jac = bjac) +bjac(du, u, p, t) = [1.0] bcjac(du, u, p, t) = [1.0] -BVPFunction(bfiip, bciip, jac = jac, bcjac = bcjac) -BVPFunction(bfoop, bciip, jac = jac, bcjac = bcjac) -BVPFunction(bfiip, bcoop, jac = jac, bcjac = bcjac) -BVPFunction(bfoop, bcoop, jac = jac, bcjac = bcjac) - -Wfact(u, t) = [1.0] -@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfiip, bciip, Wfact = Wfact) -@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfoop, bciip, Wfact = Wfact) -Wfact(u, p, t) = [1.0] -@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfiip, bciip, Wfact = Wfact) -@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfoop, bciip, Wfact = Wfact) -Wfact(u, p, gamma, t) = [1.0] -@test_throws SciMLBase.NonconformingFunctionsError BVPFunction(bfiip, bciip, Wfact = Wfact) -BVPFunction(bfoop, bciip, Wfact = Wfact) -Wfact(du, u, p, gamma, t) = [1.0] -BVPFunction(bfiip, bciip, Wfact = Wfact) -BVPFunction(bfoop, bciip, Wfact = Wfact) - -Wfact_t(u, t) = [1.0] -@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfiip, bciip, Wfact_t = Wfact_t) -@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfoop, bciip, Wfact_t = Wfact_t) -Wfact_t(u, p, t) = [1.0] -@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfiip, bciip, Wfact_t = Wfact_t) -@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfoop, bciip, Wfact_t = Wfact_t) -Wfact_t(u, p, gamma, t) = [1.0] +BVPFunction(bfiip, bciip, jac = bjac, bcjac = bcjac) +BVPFunction(bfoop, bciip, jac = bjac, bcjac = bcjac) +BVPFunction(bfiip, bcoop, jac = bjac, bcjac = bcjac) +BVPFunction(bfoop, bcoop, jac = bjac, bcjac = bcjac) + +bWfact(u, t) = [1.0] +@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfiip, bciip, Wfact = bWfact) +@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfoop, bciip, Wfact = bWfact) +bWfact(u, p, t) = [1.0] +@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfiip, bciip, Wfact = bWfact) +@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfoop, bciip, Wfact = bWfact) +bWfact(u, p, gamma, t) = [1.0] +@test_throws SciMLBase.NonconformingFunctionsError BVPFunction(bfiip, bciip, Wfact = bWfact) +BVPFunction(bfoop, bciip, Wfact = bWfact) +bWfact(du, u, p, gamma, t) = [1.0] +BVPFunction(bfiip, bciip, Wfact = bWfact) +BVPFunction(bfoop, bciip, Wfact = bWfact) + +bWfact_t(u, t) = [1.0] +@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfiip, bciip, Wfact_t = bWfact_t) +@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfoop, bciip, Wfact_t = bWfact_t) +bWfact_t(u, p, t) = [1.0] +@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfiip, bciip, Wfact_t = bWfact_t) +@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfoop, bciip, Wfact_t = bWfact_t) +bWfact_t(u, p, gamma, t) = [1.0] @test_throws SciMLBase.NonconformingFunctionsError BVPFunction(bfiip, bciip, - Wfact_t = Wfact_t) -BVPFunction(bfoop, bciip, Wfact_t = Wfact_t) -Wfact_t(du, u, p, gamma, t) = [1.0] -BVPFunction(bfiip, bciip, Wfact_t = Wfact_t) -BVPFunction(bfoop, bciip, Wfact_t = Wfact_t) - -tgrad(u, t) = [1.0] -@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfiip, bciip, tgrad = tgrad) -@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfoop, bciip, tgrad = tgrad) -tgrad(u, p, t) = [1.0] -@test_throws SciMLBase.NonconformingFunctionsError BVPFunction(bfiip, bciip, tgrad = tgrad) -BVPFunction(bfoop, bciip, tgrad = tgrad) -tgrad(du, u, p, t) = [1.0] -BVPFunction(bfiip, bciip, tgrad = tgrad) -BVPFunction(bfoop, bciip, tgrad = tgrad) - -paramjac(u, t) = [1.0] -@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfiip, bciip, paramjac = paramjac) -@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfoop, bciip, paramjac = paramjac) -paramjac(u, p, t) = [1.0] + Wfact_t = bWfact_t) +BVPFunction(bfoop, bciip, Wfact_t = bWfact_t) +bWfact_t(du, u, p, gamma, t) = [1.0] +BVPFunction(bfiip, bciip, Wfact_t = bWfact_t) +BVPFunction(bfoop, bciip, Wfact_t = bWfact_t) + +btgrad(u, t) = [1.0] +@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfiip, bciip, tgrad = btgrad) +@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfoop, bciip, tgrad = btgrad) +btgrad(u, p, t) = [1.0] +@test_throws SciMLBase.NonconformingFunctionsError BVPFunction(bfiip, bciip, tgrad = btgrad) +BVPFunction(bfoop, bciip, tgrad = btgrad) +btgrad(du, u, p, t) = [1.0] +BVPFunction(bfiip, bciip, tgrad = btgrad) +BVPFunction(bfoop, bciip, tgrad = btgrad) + +bparamjac(u, t) = [1.0] +@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfiip, bciip, paramjac = bparamjac) +@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfoop, bciip, paramjac = bparamjac) +bparamjac(u, p, t) = [1.0] @test_throws SciMLBase.NonconformingFunctionsError BVPFunction(bfiip, bciip, - paramjac = paramjac) -BVPFunction(bfoop, bciip, paramjac = paramjac) -paramjac(du, u, p, t) = [1.0] -BVPFunction(bfiip, bciip, paramjac = paramjac) -BVPFunction(bfoop, bciip, paramjac = paramjac) - -jvp(u, p, t) = [1.0] -@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfiip, bciip, jvp = jvp) -@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfoop, bciip, jvp = jvp) -jvp(u, v, p, t) = [1.0] -@test_throws SciMLBase.NonconformingFunctionsError BVPFunction(bfiip, bciip, jvp = jvp) -BVPFunction(bfoop, bciip, jvp = jvp) -jvp(du, u, v, p, t) = [1.0] -BVPFunction(bfiip, bciip, jvp = jvp) -BVPFunction(bfoop, bciip, jvp = jvp) - -vjp(u, p, t) = [1.0] -@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfiip, bciip, vjp = vjp) -@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfoop, bciip, vjp = vjp) -vjp(u, v, p, t) = [1.0] -@test_throws SciMLBase.NonconformingFunctionsError BVPFunction(bfiip, bciip, vjp = vjp) -BVPFunction(bfoop, bciip, vjp = vjp) -vjp(du, u, v, p, t) = [1.0] -BVPFunction(bfiip, bciip, vjp = vjp) -BVPFunction(bfoop, bciip, vjp = vjp) + paramjac = bparamjac) +BVPFunction(bfoop, bciip, paramjac = bparamjac) +bparamjac(du, u, p, t) = [1.0] +BVPFunction(bfiip, bciip, paramjac = bparamjac) +BVPFunction(bfoop, bciip, paramjac = bparamjac) + +bjvp(u, p, t) = [1.0] +@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfiip, bciip, jvp = bjvp) +@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfoop, bciip, jvp = bjvp) +bjvp(u, v, p, t) = [1.0] +@test_throws SciMLBase.NonconformingFunctionsError BVPFunction(bfiip, bciip, jvp = bjvp) +BVPFunction(bfoop, bciip, jvp = bjvp) +bjvp(du, u, v, p, t) = [1.0] +BVPFunction(bfiip, bciip, jvp = bjvp) +BVPFunction(bfoop, bciip, jvp = bjvp) + +bvjp(u, p, t) = [1.0] +@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfiip, bciip, vjp = bvjp) +@test_throws SciMLBase.TooFewArgumentsError BVPFunction(bfoop, bciip, vjp = bvjp) +bvjp(u, v, p, t) = [1.0] +@test_throws SciMLBase.NonconformingFunctionsError BVPFunction(bfiip, bciip, vjp = bvjp) +BVPFunction(bfoop, bciip, vjp = bvjp) +bvjp(du, u, v, p, t) = [1.0] +BVPFunction(bfiip, bciip, vjp = bvjp) +BVPFunction(bfoop, bciip, vjp = bvjp) From c6fa298f8aecb649e1fd17ee2178bba4dc9bd571 Mon Sep 17 00:00:00 2001 From: ErikQQY <2283984853@qq.com> Date: Wed, 16 Aug 2023 10:09:24 +0800 Subject: [PATCH 5/9] Add new dispatch to problem construction Signed-off-by: ErikQQY <2283984853@qq.com> --- src/problems/bvp_problems.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/problems/bvp_problems.jl b/src/problems/bvp_problems.jl index b2a7cbbee..9e8638ef5 100644 --- a/src/problems/bvp_problems.jl +++ b/src/problems/bvp_problems.jl @@ -115,6 +115,10 @@ function BVProblem(f, bc, u0, tspan, p = NullParameters(); kwargs...) BVProblem(BVPFunction(f, bc), bc, u0, tspan, p; kwargs...) end +function BVProblem(f::AbstractBVPFunction, u0, tspan, p = NullParameters(); kwargs...) + BVProblem(f, f.bc, u0, tspan, p; kwargs...) +end + # convenience interfaces: # Allow any previous timeseries solution From 681b9e1b4b83d7f9733958fa1db9b8abdc44f2f3 Mon Sep 17 00:00:00 2001 From: ErikQQY <2283984853@qq.com> Date: Thu, 17 Aug 2023 20:06:24 +0800 Subject: [PATCH 6/9] Add new dispatch for SDEProblem Signed-off-by: ErikQQY <2283984853@qq.com> --- src/problems/sde_problems.jl | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/problems/sde_problems.jl b/src/problems/sde_problems.jl index d81687f83..b7ab7ca71 100644 --- a/src/problems/sde_problems.jl +++ b/src/problems/sde_problems.jl @@ -124,13 +124,17 @@ end =# function SDEProblem(f::AbstractSDEFunction, g, u0, tspan, p = NullParameters(); kwargs...) - SDEProblem{isinplace(f)}(f, g, u0, tspan, p; kwargs...) + SDEProblem{isinplace(f, 4)}(f, g, u0, tspan, p; kwargs...) end function SDEProblem(f, g, u0, tspan, p = NullParameters(); kwargs...) SDEProblem(SDEFunction(f, g), g, u0, tspan, p; kwargs...) end +function SDEProblem(f::AbstractSDEFunction, u0, tspan, p = NullParameters(); kwargs...) + SDEProblem(f, f.g, u0, tspan, p; kwargs...) +end + """ $(TYPEDEF) """ From 61b726da8f4bd12bf8f1349f055aac610c1236de Mon Sep 17 00:00:00 2001 From: ErikQQY <2283984853@qq.com> Date: Fri, 25 Aug 2023 20:56:31 +0800 Subject: [PATCH 7/9] Done Signed-off-by: ErikQQY <2283984853@qq.com> --- src/problems/bvp_problems.jl | 12 ++++-------- src/problems/sde_problems.jl | 12 ++++-------- 2 files changed, 8 insertions(+), 16 deletions(-) diff --git a/src/problems/bvp_problems.jl b/src/problems/bvp_problems.jl index 9e8638ef5..de206460e 100644 --- a/src/problems/bvp_problems.jl +++ b/src/problems/bvp_problems.jl @@ -95,8 +95,8 @@ struct BVProblem{uType, tType, isinplace, P, F, BF, PT, K} <: _tspan = promote_tspan(tspan) warn_paramtype(p) new{typeof(u0), typeof(_tspan), iip, typeof(p), - typeof(f), typeof(bc), - typeof(problem_type), typeof(kwargs)}(f, f.bc, u0, _tspan, p, + typeof(f.f), typeof(bc), + typeof(problem_type), typeof(kwargs)}(f.f, bc, u0, _tspan, p, problem_type, kwargs) end @@ -107,16 +107,12 @@ end TruncatedStacktraces.@truncate_stacktrace BVProblem 3 1 2 -function BVProblem(f::AbstractBVPFunction, bc, u0, tspan, args...; kwargs...) - BVProblem{isinplace(f, 4)}(f, bc, u0, tspan, args...; kwargs...) -end - function BVProblem(f, bc, u0, tspan, p = NullParameters(); kwargs...) - BVProblem(BVPFunction(f, bc), bc, u0, tspan, p; kwargs...) + BVProblem(BVPFunction(f, bc), u0, tspan, p; kwargs...) end function BVProblem(f::AbstractBVPFunction, u0, tspan, p = NullParameters(); kwargs...) - BVProblem(f, f.bc, u0, tspan, p; kwargs...) + BVProblem{isinplace(f)}(f.f, f.bc, u0, tspan, p; kwargs...) end # convenience interfaces: diff --git a/src/problems/sde_problems.jl b/src/problems/sde_problems.jl index b7ab7ca71..79d06f732 100644 --- a/src/problems/sde_problems.jl +++ b/src/problems/sde_problems.jl @@ -103,9 +103,9 @@ struct SDEProblem{uType, tType, isinplace, P, NP, F, G, K, ND} <: warn_paramtype(p) new{typeof(u0), typeof(_tspan), isinplace(f), typeof(p), - typeof(noise), typeof(f), typeof(f.g), + typeof(noise), typeof(f.f), typeof(g), typeof(kwargs), - typeof(noise_rate_prototype)}(f, f.g, u0, _tspan, p, + typeof(noise_rate_prototype)}(f.f, g, u0, _tspan, p, noise, kwargs, noise_rate_prototype, seed) end @@ -123,16 +123,12 @@ function SDEProblem(f::AbstractSDEFunction,u0,tspan,p=NullParameters();kwargs... end =# -function SDEProblem(f::AbstractSDEFunction, g, u0, tspan, p = NullParameters(); kwargs...) - SDEProblem{isinplace(f, 4)}(f, g, u0, tspan, p; kwargs...) -end - function SDEProblem(f, g, u0, tspan, p = NullParameters(); kwargs...) - SDEProblem(SDEFunction(f, g), g, u0, tspan, p; kwargs...) + SDEProblem(SDEFunction(f, g), u0, tspan, p; kwargs...) end function SDEProblem(f::AbstractSDEFunction, u0, tspan, p = NullParameters(); kwargs...) - SDEProblem(f, f.g, u0, tspan, p; kwargs...) + SDEProblem{isinplace(f)}(f.f, f.g, u0, tspan, p; kwargs...) end """ From 98afac4ade8c0741d66e66f5c69a35d813b3fde6 Mon Sep 17 00:00:00 2001 From: ErikQQY <2283984853@qq.com> Date: Fri, 25 Aug 2023 22:03:57 +0800 Subject: [PATCH 8/9] Remove convenient constructor for BVProblem Signed-off-by: ErikQQY <2283984853@qq.com> --- src/problems/bvp_problems.jl | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/src/problems/bvp_problems.jl b/src/problems/bvp_problems.jl index de206460e..73f39463a 100644 --- a/src/problems/bvp_problems.jl +++ b/src/problems/bvp_problems.jl @@ -115,20 +115,6 @@ function BVProblem(f::AbstractBVPFunction, u0, tspan, p = NullParameters(); kwar BVProblem{isinplace(f)}(f.f, f.bc, u0, tspan, p; kwargs...) end -# convenience interfaces: -# Allow any previous timeseries solution - -function BVProblem(f::AbstractBVPFunction, bc, sol::T, tspan::Tuple, p = NullParameters(); - kwargs...) where {T <: AbstractTimeseriesSolution} - BVProblem(f, bc, sol.u, tspan, p) -end -# Allow a function of time for the initial guess -function BVProblem(f::AbstractBVPFunction, bc, initialGuess, tspan::AbstractVector, - p = NullParameters(); kwargs...) - u0 = [initialGuess(i) for i in tspan] - BVProblem(f, bc, u0, (tspan[1], tspan[end]), p) -end - """ $(TYPEDEF) """ From 30fb3c2ee317df772da1983a1be6600ba2f63868 Mon Sep 17 00:00:00 2001 From: ErikQQY <2283984853@qq.com> Date: Tue, 5 Sep 2023 00:32:16 +0800 Subject: [PATCH 9/9] BVP only Signed-off-by: ErikQQY <2283984853@qq.com> --- src/problems/sde_problems.jl | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/problems/sde_problems.jl b/src/problems/sde_problems.jl index 79d06f732..d81687f83 100644 --- a/src/problems/sde_problems.jl +++ b/src/problems/sde_problems.jl @@ -103,9 +103,9 @@ struct SDEProblem{uType, tType, isinplace, P, NP, F, G, K, ND} <: warn_paramtype(p) new{typeof(u0), typeof(_tspan), isinplace(f), typeof(p), - typeof(noise), typeof(f.f), typeof(g), + typeof(noise), typeof(f), typeof(f.g), typeof(kwargs), - typeof(noise_rate_prototype)}(f.f, g, u0, _tspan, p, + typeof(noise_rate_prototype)}(f, f.g, u0, _tspan, p, noise, kwargs, noise_rate_prototype, seed) end @@ -123,12 +123,12 @@ function SDEProblem(f::AbstractSDEFunction,u0,tspan,p=NullParameters();kwargs... end =# -function SDEProblem(f, g, u0, tspan, p = NullParameters(); kwargs...) - SDEProblem(SDEFunction(f, g), u0, tspan, p; kwargs...) +function SDEProblem(f::AbstractSDEFunction, g, u0, tspan, p = NullParameters(); kwargs...) + SDEProblem{isinplace(f)}(f, g, u0, tspan, p; kwargs...) end -function SDEProblem(f::AbstractSDEFunction, u0, tspan, p = NullParameters(); kwargs...) - SDEProblem{isinplace(f)}(f.f, f.g, u0, tspan, p; kwargs...) +function SDEProblem(f, g, u0, tspan, p = NullParameters(); kwargs...) + SDEProblem(SDEFunction(f, g), g, u0, tspan, p; kwargs...) end """