Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add BVPFunction #370

Merged
merged 10 commits into from
Sep 4, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/SciMLBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -655,9 +655,9 @@ function specialization(::Union{ODEFunction{iip, specialize},
ImplicitDiscreteFunction{iip, specialize},
RODEFunction{iip, specialize},
NonlinearFunction{iip, specialize},
OptimizationFunction{iip, specialize}}) where {iip,
OptimizationFunction{iip, specialize},
BVPFunction{iip, specialize}}) where {iip,
specialize}
specialize
end

specialization(f::AbstractSciMLFunction) = FullSpecialize
Expand Down Expand Up @@ -784,7 +784,7 @@ export remake

export ODEFunction, DiscreteFunction, ImplicitDiscreteFunction, SplitFunction, DAEFunction,
DDEFunction, SDEFunction, SplitSDEFunction, RODEFunction, SDDEFunction,
IncrementingODEFunction, NonlinearFunction, IntervalNonlinearFunction
IncrementingODEFunction, NonlinearFunction, IntervalNonlinearFunction, BVPFunction

export OptimizationFunction

Expand Down
2 changes: 1 addition & 1 deletion src/ensemble/basic_ensemble_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ function __solve(prob::EnsembleProblem{<:AbstractVector{<:AbstractSciMLProblem}}
ensemblealg::BasicEnsembleAlgorithm; kwargs...)
# TODO: @invoke
invoke(__solve, Tuple{AbstractEnsembleProblem, typeof(alg), typeof(ensemblealg)},
prob, alg, ensemblealg; trajectories=length(prob.prob), kwargs...)
prob, alg, ensemblealg; trajectories = length(prob.prob), kwargs...)
end

function __solve(prob::AbstractEnsembleProblem,
Expand Down
33 changes: 20 additions & 13 deletions src/ensemble/ensemble_problems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@
DEFAULT_VECTOR_PROB_FUNC(prob, i, repeat) = prob[i]
function EnsembleProblem(prob::AbstractVector{<:AbstractSciMLProblem}; kwargs...)
# TODO: @invoke
invoke(EnsembleProblem, Tuple{Any}, prob; prob_func=DEFAULT_VECTOR_PROB_FUNC, kwargs...)
invoke(EnsembleProblem,
Tuple{Any},
prob;
prob_func = DEFAULT_VECTOR_PROB_FUNC,
kwargs...)
end
function EnsembleProblem(prob;
output_func = DEFAULT_OUTPUT_FUNC,
Expand All @@ -36,20 +40,23 @@
EnsembleProblem(prob, prob_func, output_func, reduction, u_init, safetycopy)
end

struct WeightedEnsembleProblem{T1<:AbstractEnsembleProblem, T2<:AbstractVector} <: AbstractEnsembleProblem
ensembleprob::T1
weights::T2
struct WeightedEnsembleProblem{T1 <: AbstractEnsembleProblem, T2 <: AbstractVector} <:
AbstractEnsembleProblem
ensembleprob::T1
weights::T2
end
function Base.propertynames(e::WeightedEnsembleProblem)
(Base.propertynames(getfield(e, :ensembleprob))..., :weights)

Check warning on line 49 in src/ensemble/ensemble_problems.jl

View check run for this annotation

Codecov / codecov/patch

src/ensemble/ensemble_problems.jl#L48-L49

Added lines #L48 - L49 were not covered by tests
end
Base.propertynames(e::WeightedEnsembleProblem) = (Base.propertynames(getfield(e, :ensembleprob))..., :weights)
function Base.getproperty(e::WeightedEnsembleProblem, f::Symbol)
f === :weights && return getfield(e, :weights)
f === :ensembleprob && return getfield(e, :ensembleprob)
return getproperty(getfield(e, :ensembleprob), f)
f === :weights && return getfield(e, :weights)
f === :ensembleprob && return getfield(e, :ensembleprob)
return getproperty(getfield(e, :ensembleprob), f)

Check warning on line 54 in src/ensemble/ensemble_problems.jl

View check run for this annotation

Codecov / codecov/patch

src/ensemble/ensemble_problems.jl#L52-L54

Added lines #L52 - L54 were not covered by tests
end
function WeightedEnsembleProblem(args...; weights, kwargs...)
# TODO: allow skipping checks?
@assert sum(weights) ≈ 1
ep = EnsembleProblem(args...; kwargs...)
@assert length(ep.prob) == length(weights)
WeightedEnsembleProblem(ep, weights)
# TODO: allow skipping checks?
@assert sum(weights) ≈ 1
ep = EnsembleProblem(args...; kwargs...)
@assert length(ep.prob) == length(weights)
WeightedEnsembleProblem(ep, weights)

Check warning on line 61 in src/ensemble/ensemble_problems.jl

View check run for this annotation

Codecov / codecov/patch

src/ensemble/ensemble_problems.jl#L58-L61

Added lines #L58 - L61 were not covered by tests
end
13 changes: 9 additions & 4 deletions src/ensemble/ensemble_solutions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
converged)
end

struct WeightedEnsembleSolution{T1<:AbstractEnsembleSolution, T2<:Number}
struct WeightedEnsembleSolution{T1 <: AbstractEnsembleSolution, T2 <: Number}
ensol::T1
weights::Vector{T2}
function WeightedEnsembleSolution(ensol, weights)
Expand Down Expand Up @@ -207,13 +207,18 @@
end
end


Base.@propagate_inbounds function Base.getindex(x::AbstractEnsembleSolution, s, ::Colon)
return [xi[s] for xi in x]
end

Base.@propagate_inbounds function Base.getindex(x::AbstractEnsembleSolution, ::Colon, args::Colon...)
return invoke(getindex, Tuple{RecursiveArrayTools.AbstractVectorOfArray, Colon, typeof.(args)...}, x, :, args...)
Base.@propagate_inbounds function Base.getindex(x::AbstractEnsembleSolution,

Check warning on line 214 in src/ensemble/ensemble_solutions.jl

View check run for this annotation

Codecov / codecov/patch

src/ensemble/ensemble_solutions.jl#L214

Added line #L214 was not covered by tests
::Colon,
args::Colon...)
return invoke(getindex,

Check warning on line 217 in src/ensemble/ensemble_solutions.jl

View check run for this annotation

Codecov / codecov/patch

src/ensemble/ensemble_solutions.jl#L217

Added line #L217 was not covered by tests
Tuple{RecursiveArrayTools.AbstractVectorOfArray, Colon, typeof.(args)...},
x,
:,
args...)
end

function (sol::AbstractEnsembleSolution)(args...; kwargs...)
Expand Down
20 changes: 11 additions & 9 deletions src/problems/bvp_problems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,50 +78,52 @@
* `p`: The parameters for the problem. Defaults to `NullParameters`
* `kwargs`: The keyword arguments passed onto the solves.
"""
struct BVProblem{uType, tType, isinplace, P, F, bF, PT, K} <:
struct BVProblem{uType, tType, isinplace, P, F, BF, PT, K} <:
AbstractBVProblem{uType, tType, isinplace}
f::F
bc::bF
bc::BF
u0::uType
tspan::tType
p::P
problem_type::PT
kwargs::K
@add_kwonly function BVProblem{iip}(f::AbstractODEFunction, bc, u0, tspan,

@add_kwonly function BVProblem{iip}(f::AbstractBVPFunction{iip}, bc, u0, tspan,
p = NullParameters(),
problem_type = StandardBVProblem();
kwargs...) where {iip}
_tspan = promote_tspan(tspan)
warn_paramtype(p)
new{typeof(u0), typeof(_tspan), iip, typeof(p),
typeof(f), typeof(bc),
typeof(problem_type), typeof(kwargs)}(f, bc, u0, _tspan, p,
typeof(problem_type), typeof(kwargs)}(f, f.bc, u0, _tspan, p,
problem_type, kwargs)
end

function BVProblem{iip}(f, bc, u0, tspan, p = NullParameters(); kwargs...) where {iip}
BVProblem(ODEFunction{iip}(f), bc, u0, tspan, p; kwargs...)
BVProblem(BVPFunction{iip}(f, bc), bc, u0, tspan, p; kwargs...)
end
end

TruncatedStacktraces.@truncate_stacktrace BVProblem 3 1 2

function BVProblem(f::AbstractODEFunction, bc, u0, tspan, args...; kwargs...)
function BVProblem(f::AbstractBVPFunction, bc, u0, tspan, args...; kwargs...)
BVProblem{isinplace(f, 4)}(f, bc, u0, tspan, args...; kwargs...)
end

function BVProblem(f, bc, u0, tspan, p = NullParameters(); kwargs...)
BVProblem(ODEFunction(f), bc, u0, tspan, p; kwargs...)
BVProblem(BVPFunction(f, bc), bc, u0, tspan, p; kwargs...)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If BVPFunction stores bc should be duplicate the storage in BVProblem again?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, indeed, the problem construction process in BVProblem is similar with SDEProblem, they both have duplicate bc in BVProblem and g in SDEProblem, but if we want to have a BVPFunction stores both f and bc, duplicating bc could really making the problem constructor concise. We need to note that when we are unpacking BVProblem, we actually get a BVPFunction but not f.

function SDEProblem(f, g, u0, tspan, p = NullParameters(); kwargs...)
SDEProblem(SDEFunction(f, g), g, u0, tspan, p; kwargs...)
end

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In that case it's fair to stay consistent.

Maybe we can add an additional dispatch on BVPFunction where we automatically pull out the bc during problem construction that way user wont have to do BVProblem(BVPFunction(f, bc), bc...) and instead can specify BVProblem(BVPFunction(f, bc)....)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you mean a dispatch on BVProblem right? Just updated and now we can directly specify BVProblem(BVPFunction(f, bc).....) to construct a BVProblem.
Do we need to do the same for SDEProblem and SDEFunction? The problem construct of SDEProblem also need users to do somthing like SDEProblem(SDEFunction(f, g), g), see here: https://docs.sciml.ai/DiffEqDocs/stable/tutorials/sde_example/#Using-Higher-Order-Methods

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ChrisRackauckas do you think this is a valid API choice? Is there a particular reason other Problem Types don't have this?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you handle downstream?

Copy link
Member Author

@ErikQQY ErikQQY Aug 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I am working on fixing downstream errors

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am a little lost here, there are three ways of constructing a BVProblem:
(1). Directly construction from scratch

prob = BVProblem(f, bc, u0, tspan)

(2). Use BVPFunction

prob = BVProblem(BVPFunction(f, bc), u0, tspan)

(3). Another way of using BVPFunction

prob = BVProblem(BVPFunction(f, bc), bc, u0, tspan)

As for SDEProblem, the definitions are similar, so my question is that it looks the (2) and (3) dispatches can't exist at the same time, so I think we are deprecating (3) and using (2) instead?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like deprecating (3) and using (2) in SDEProblem and SDEFunction is way more complicated than BVProblem and BVPFunction. SplitSDEProblem, DynamicalSDEProblem and maybe some functions in ModelingToolkit.jl and JumpProcess.jl etc. are all relying on (3) in problem constructor, the new change in this PR would cause a lot of errors and break a lot of APIs.

end

# convenience interfaces:
# Allow any previous timeseries solution
function BVProblem(f::AbstractODEFunction, bc, sol::T, tspan::Tuple, p = NullParameters();

function BVProblem(f::AbstractBVPFunction, bc, sol::T, tspan::Tuple, p = NullParameters();

Check warning on line 121 in src/problems/bvp_problems.jl

View check run for this annotation

Codecov / codecov/patch

src/problems/bvp_problems.jl#L121

Added line #L121 was not covered by tests
kwargs...) where {T <: AbstractTimeseriesSolution}
BVProblem(f, bc, sol.u, tspan, p)
end
# Allow a function of time for the initial guess
function BVProblem(f::AbstractODEFunction, bc, initialGuess, tspan::AbstractVector,
function BVProblem(f::AbstractBVPFunction, bc, initialGuess, tspan::AbstractVector,

Check warning on line 126 in src/problems/bvp_problems.jl

View check run for this annotation

Codecov / codecov/patch

src/problems/bvp_problems.jl#L126

Added line #L126 was not covered by tests
p = NullParameters(); kwargs...)
u0 = [initialGuess(i) for i in tspan]
BVProblem(f, bc, u0, (tspan[1], tspan[end]), p)
Expand Down
Loading
Loading