Skip to content

Commit

Permalink
Fix promoted tspan for BVProblem and ODEProblem
Browse files Browse the repository at this point in the history
Signed-off-by: ErikQQY <[email protected]>
  • Loading branch information
ErikQQY committed Jul 28, 2023
1 parent 5d0b887 commit 841a0a7
Show file tree
Hide file tree
Showing 7 changed files with 50 additions and 36 deletions.
2 changes: 1 addition & 1 deletion src/ensemble/basic_ensemble_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ function __solve(prob::EnsembleProblem{<:AbstractVector{<:AbstractSciMLProblem}}
ensemblealg::BasicEnsembleAlgorithm; kwargs...)
# TODO: @invoke
invoke(__solve, Tuple{AbstractEnsembleProblem, typeof(alg), typeof(ensemblealg)},
prob, alg, ensemblealg; trajectories=length(prob.prob), kwargs...)
prob, alg, ensemblealg; trajectories = length(prob.prob), kwargs...)
end

function __solve(prob::AbstractEnsembleProblem,
Expand Down
34 changes: 20 additions & 14 deletions src/ensemble/ensemble_problems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@ DEFAULT_REDUCTION(u, data, I) = append!(u, data), false
DEFAULT_VECTOR_PROB_FUNC(prob, i, repeat) = prob[i]
function EnsembleProblem(prob::AbstractVector{<:AbstractSciMLProblem}; kwargs...)
# TODO: @invoke
invoke(EnsembleProblem, Tuple{Any}, prob; prob_func=DEFAULT_VECTOR_PROB_FUNC, kwargs...)
invoke(EnsembleProblem,
Tuple{Any},
prob;
prob_func = DEFAULT_VECTOR_PROB_FUNC,
kwargs...)
end
function EnsembleProblem(prob;
output_func = DEFAULT_OUTPUT_FUNC,
Expand All @@ -36,21 +40,23 @@ function EnsembleProblem(; prob,
EnsembleProblem(prob, prob_func, output_func, reduction, u_init, safetycopy)
end

struct WeightedEnsembleProblem{T1<:AbstractEnsembleProblem, T2<:AbstractVector} <: AbstractEnsembleProblem
ensembleprob::T1
weights::T2
struct WeightedEnsembleProblem{T1 <: AbstractEnsembleProblem, T2 <: AbstractVector} <:
AbstractEnsembleProblem
ensembleprob::T1
weights::T2
end
function Base.propertynames(e::WeightedEnsembleProblem)
(Base.propertynames(getfield(e, :ensembleprob))..., :weights)

Check warning on line 49 in src/ensemble/ensemble_problems.jl

View check run for this annotation

Codecov / codecov/patch

src/ensemble/ensemble_problems.jl#L48-L49

Added lines #L48 - L49 were not covered by tests
end
Base.propertynames(e::WeightedEnsembleProblem) = (Base.propertynames(getfield(e, :ensembleprob))..., :weights)
function Base.getproperty(e::WeightedEnsembleProblem, f::Symbol)
f === :weights && return getfield(e, :weights)
f === :ensembleprob && return getfield(e, :ensembleprob)
return getproperty(getfield(e, :ensembleprob), f)
f === :weights && return getfield(e, :weights)
f === :ensembleprob && return getfield(e, :ensembleprob)
return getproperty(getfield(e, :ensembleprob), f)

Check warning on line 54 in src/ensemble/ensemble_problems.jl

View check run for this annotation

Codecov / codecov/patch

src/ensemble/ensemble_problems.jl#L52-L54

Added lines #L52 - L54 were not covered by tests
end
function WeightedEnsembleProblem(args...; weights, kwargs...)
# TODO: allow skipping checks?
@assert sum(weights) 1
ep = EnsembleProblem(args...; kwargs...)
@assert length(ep.prob) == length(weights)
WeightedEnsembleProblem(ep, weights)
# TODO: allow skipping checks?
@assert sum(weights) 1
ep = EnsembleProblem(args...; kwargs...)
@assert length(ep.prob) == length(weights)
WeightedEnsembleProblem(ep, weights)

Check warning on line 61 in src/ensemble/ensemble_problems.jl

View check run for this annotation

Codecov / codecov/patch

src/ensemble/ensemble_problems.jl#L58-L61

Added lines #L58 - L61 were not covered by tests
end

13 changes: 9 additions & 4 deletions src/ensemble/ensemble_solutions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ function EnsembleSolution(sim::T, elapsedTime,
converged)
end

struct WeightedEnsembleSolution{T1<:AbstractEnsembleSolution, T2<:Number}
struct WeightedEnsembleSolution{T1 <: AbstractEnsembleSolution, T2 <: Number}
ensol::T1
weights::Vector{T2}
function WeightedEnsembleSolution(ensol, weights)
Expand Down Expand Up @@ -207,13 +207,18 @@ end
end
end


Base.@propagate_inbounds function Base.getindex(x::AbstractEnsembleSolution, s, ::Colon)
return [xi[s] for xi in x]
end

Base.@propagate_inbounds function Base.getindex(x::AbstractEnsembleSolution, ::Colon, args::Colon...)
return invoke(getindex, Tuple{RecursiveArrayTools.AbstractVectorOfArray, Colon, typeof.(args)...}, x, :, args...)
Base.@propagate_inbounds function Base.getindex(x::AbstractEnsembleSolution,

Check warning on line 214 in src/ensemble/ensemble_solutions.jl

View check run for this annotation

Codecov / codecov/patch

src/ensemble/ensemble_solutions.jl#L214

Added line #L214 was not covered by tests
::Colon,
args::Colon...)
return invoke(getindex,

Check warning on line 217 in src/ensemble/ensemble_solutions.jl

View check run for this annotation

Codecov / codecov/patch

src/ensemble/ensemble_solutions.jl#L217

Added line #L217 was not covered by tests
Tuple{RecursiveArrayTools.AbstractVectorOfArray, Colon, typeof.(args)...},
x,
:,
args...)
end

function (sol::AbstractEnsembleSolution)(args...; kwargs...)
Expand Down
2 changes: 1 addition & 1 deletion src/problems/bvp_problems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ struct BVProblem{uType, tType, isinplace, P, F, bF, PT, K} <:
kwargs...) where {iip}
_tspan = promote_tspan(tspan)
warn_paramtype(p)
new{typeof(u0), typeof(tspan), iip, typeof(p),
new{typeof(u0), typeof(_tspan), iip, typeof(p),
typeof(f), typeof(bc),
typeof(problem_type), typeof(kwargs)}(f, bc, u0, _tspan, p,
problem_type, kwargs)
Expand Down
16 changes: 8 additions & 8 deletions src/problems/ode_problems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -133,9 +133,9 @@ struct ODEProblem{uType, tType, isinplace, P, F, K, PT} <:
This is determined automatically, but not inferred.
"""
function ODEProblem{iip}(f, u0, tspan, p = NullParameters(); kwargs...) where {iip}
ptspan = promote_tspan(tspan)
_tspan = promote_tspan(tspan)
_f = ODEFunction{iip, DEFAULT_SPECIALIZATION}(f)
ODEProblem(_f, u0, tspan, p; kwargs...)
ODEProblem(_f, u0, _tspan, p; kwargs...)
end

@add_kwonly function ODEProblem{iip, recompile}(f, u0, tspan, p = NullParameters();
Expand All @@ -145,19 +145,19 @@ struct ODEProblem{uType, tType, isinplace, P, F, K, PT} <:

function ODEProblem{iip, FunctionWrapperSpecialize}(f, u0, tspan, p = NullParameters();
kwargs...) where {iip}
ptspan = promote_tspan(tspan)
_tspan = promote_tspan(tspan)
if !(f isa FunctionWrappersWrappers.FunctionWrappersWrapper)
if iip
ff = ODEFunction{iip, FunctionWrapperSpecialize}(wrapfun_iip(f,
(u0, u0, p,
ptspan[1])))
_tspan[1])))
else
ff = ODEFunction{iip, FunctionWrapperSpecialize}(wrapfun_oop(f,
(u0, p,
ptspan[1])))
_tspan[1])))
end
end
ODEProblem{iip}(ff, u0, tspan, p; kwargs...)
ODEProblem{iip}(ff, u0, _tspan, p; kwargs...)
end
end
TruncatedStacktraces.@truncate_stacktrace ODEProblem 3 1 2
Expand All @@ -173,9 +173,9 @@ end

function ODEProblem(f, u0, tspan, p = NullParameters(); kwargs...)
iip = isinplace(f, 4)
ptspan = promote_tspan(tspan)
_tspan = promote_tspan(tspan)
_f = ODEFunction{iip, DEFAULT_SPECIALIZATION}(f)
ODEProblem(_f, u0, tspan, p; kwargs...)
ODEProblem(_f, u0, _tspan, p; kwargs...)
end

"""
Expand Down
9 changes: 6 additions & 3 deletions src/scimlfunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,8 @@ See the `modelingtoolkitize` function from
automatically symbolically generating the Jacobian and more from the
numerically-defined functions.
"""
struct ODEFunction{iip, specialize, F, TMM, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TWt, WP, TPJ, S,
struct ODEFunction{iip, specialize, F, TMM, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TWt, WP, TPJ,
S,
S2, S3, O, TCV,
SYS} <: AbstractODEFunction{iip}
f::F
Expand Down Expand Up @@ -2259,7 +2260,8 @@ function ODEFunction{iip, specialize}(f;
ODEFunction{iip, FunctionWrapperSpecialize,
typeof(f), typeof(mass_matrix), typeof(analytic), typeof(tgrad),
typeof(jac), typeof(jvp), typeof(vjp), typeof(jac_prototype),
typeof(sparsity), typeof(Wfact), typeof(Wfact_t), typeof(W_prototype), typeof(paramjac),
typeof(sparsity), typeof(Wfact), typeof(Wfact_t), typeof(W_prototype),
typeof(paramjac),
typeof(syms), typeof(indepsym), typeof(paramsyms), typeof(observed),
typeof(_colorvec),
typeof(sys)}(f, mass_matrix, analytic, tgrad, jac,
Expand All @@ -2270,7 +2272,8 @@ function ODEFunction{iip, specialize}(f;
ODEFunction{iip, specialize,
typeof(f), typeof(mass_matrix), typeof(analytic), typeof(tgrad),
typeof(jac), typeof(jvp), typeof(vjp), typeof(jac_prototype),
typeof(sparsity), typeof(Wfact), typeof(Wfact_t), typeof(W_prototype), typeof(paramjac),
typeof(sparsity), typeof(Wfact), typeof(Wfact_t), typeof(W_prototype),
typeof(paramjac),
typeof(syms), typeof(indepsym), typeof(paramsyms), typeof(observed),
typeof(_colorvec),
typeof(sys)}(f, mass_matrix, analytic, tgrad, jac,
Expand Down
10 changes: 5 additions & 5 deletions test/downstream/ensemble_multi_prob.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@ using ModelingToolkit, OrdinaryDiffEq, Test
D = Differential(t)

@named sys1 = ODESystem([D(x) ~ x,
D(y) ~ -y])
D(y) ~ -y])
@named sys2 = ODESystem([D(x) ~ 2x,
D(y) ~ -2y])
D(y) ~ -2y])
@named sys3 = ODESystem([D(x) ~ 3x,
D(y) ~ -3y])
D(y) ~ -3y])

prob1 = ODEProblem(sys1, [1.0, 1.0], (0.0, 1.0))
prob2 = ODEProblem(sys2, [2.0, 2.0], (0.0, 1.0))
Expand All @@ -22,6 +22,6 @@ for i in 1:3
@test sol[y, :][i] == sol[i][y]
end
# Ensemble is a recursive array
@test only.(sol(0.0, idxs=[x])) == sol[1, 1, :] == first.(sol[x, :])
@test only.(sol(0.0, idxs = [x])) == sol[1, 1, :] == first.(sol[x, :])
# TODO: fix the interpolation
@test only.(sol(1.0, idxs=[x])) last.(sol[x, :])
@test only.(sol(1.0, idxs = [x])) last.(sol[x, :])

0 comments on commit 841a0a7

Please sign in to comment.