Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Choose an SSA if no SSA is passed in JumpProblem. #351

Merged
merged 23 commits into from
Aug 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,38 @@

## JumpProcesses unreleased (master branch)

## 9.12

- Added a default aggregator selection algorithm based on the number of passed
in jumps. i.e. the following now auto-selects an aggregator (`Direct` in this
case):

```julia
using JumpProcesses
rate(u, p, t) = u[1]
affect(integrator) = (integrator.u[1] -= 1; nothing)
crj = ConstantRateJump(rate, affect)
dprob = DiscreteProblem([10], (0.0, 10.0))
jprob = JumpProblem(dprob, crj)
sol = solve(jprob, SSAStepper())
```

- For `JumpProblem`s over `DiscreteProblem`s that only have `MassActionJump`s,
`ConstantRateJump`s, and bounded `VariableRateJump`s, one no longer needs to
specify `SSAStepper()` when calling `solve`, i.e. the following now works for
the previous example and is equivalent to manually passing `SSAStepper()`:

```julia
sol = solve(jprob)
```
- Plotting a solution generated with `save_positions = (false, false)` now uses
piecewise linear plots between any saved time points specified via `saveat`
instead (previously the plots appeared piecewise constant even though each
jump was not being shown). Note that solution objects still use piecewise
constant interpolation, see [the
docs](https://docs.sciml.ai/JumpProcesses/stable/tutorials/discrete_stochastic_example/#save_positions_docs)
for details.

## 9.7

- `Coevolve` was updated to support use with coupled ODEs/SDEs. See the updated
Expand Down
12 changes: 6 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,10 @@ using JumpProcesses, Plots
# here we order S = 1, I = 2, and R = 3
# substrate stoichiometry:
substoich = [[1 => 1, 2 => 1], # 1*S + 1*I
[2 => 1]] # 1*I
[2 => 1]] # 1*I
# net change by each jump type
netstoich = [[1 => -1, 2 => 1], # S -> S-1, I -> I+1
[2 => -1, 3 => 1]] # I -> I-1, R -> R+1
[2 => -1, 3 => 1]] # I -> I-1, R -> R+1
# rate constants for each jump
p = (0.1 / 1000, 0.01)

Expand All @@ -91,10 +91,10 @@ tspan = (0.0, 250.0)
dprob = DiscreteProblem(u₀, tspan, p)

# use the Direct method to simulate
jprob = JumpProblem(dprob, Direct(), maj)
jprob = JumpProblem(dprob, maj)

# solve as a pure jump process, i.e. using SSAStepper
sol = solve(jprob, SSAStepper())
sol = solve(jprob)
plot(sol)
```

Expand All @@ -117,8 +117,8 @@ function affect2!(integrator)
integrator.u[3] += 1 # R -> R + 1
end
jump2 = ConstantRateJump(rate2, affect2!)
jprob = JumpProblem(dprob, Direct(), jump, jump2)
sol = solve(jprob, SSAStepper())
jprob = JumpProblem(dprob, jump, jump2)
sol = solve(jprob)
```

### Jump-ODE Example
Expand Down
4 changes: 1 addition & 3 deletions src/SSA_stepper.jl
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,7 @@ function DiffEqBase.u_modified!(integrator::SSAIntegrator, bool::Bool)
integrator.u_modified = bool
end

function DiffEqBase.__solve(jump_prob::JumpProblem,
alg::SSAStepper;
kwargs...)
function DiffEqBase.__solve(jump_prob::JumpProblem, alg::SSAStepper; kwargs...)
integrator = init(jump_prob, alg; kwargs...)
solve!(integrator)
integrator.sol
Expand Down
34 changes: 34 additions & 0 deletions src/aggregators/aggregators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,40 @@ needs_vartojumps_map(aggregator::RSSACR) = true
supports_variablerates(aggregator::AbstractAggregatorAlgorithm) = false
supports_variablerates(aggregator::Coevolve) = true

# true if aggregator supports hops, e.g. diffusion
is_spatial(aggregator::AbstractAggregatorAlgorithm) = false
is_spatial(aggregator::NSM) = true
is_spatial(aggregator::DirectCRDirect) = true

# return the fastest aggregator out of the available ones
function select_aggregator(jumps::JumpSet; vartojumps_map = nothing,
jumptovars_map = nothing, dep_graph = nothing, spatial_system = nothing,
hopping_constants = nothing)

# detect if a spatial SSA should be used
!isnothing(spatial_system) && !isnothing(hopping_constants) && return DirectCRDirect

# if variable rate jumps are present, return one of the two SSAs that support them
if num_vrjs(jumps) > 0
(num_bndvrjs(jumps) > 0) && return Coevolve
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't they all need bounds?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In theory we support mixing bounded and non-bounded vrjs (but if there are any non-bounded vrjs you need to use a JumpProblem over an ODEProblem or such). JumpProblem has code to check these cases, but I'm not sure how well this is actually tested currently.

return Direct
end

# if the number of jumps is small, return the Direct
num_jumps(jumps) < 20 && return Direct

# if there are only massaction jumps, we can build the species-jumps dependency graphs
can_build_dgs = num_crjs(jumps) == 0 && num_vrjs(jumps) == 0
have_species_to_jumps_dgs = !isnothing(vartojumps_map) && !isnothing(jumptovars_map)

# if we have the species-jumps dgs or can build them, use a Rejection-based methods
if can_build_dgs || have_species_to_jumps_dgs
(num_jumps(jumps) < 100) && return RSSA
return RSSACR
elseif !isnothing(dep_graph) # if only have a normal dg
(num_jumps(jumps) < 200) && return SortingDirect
return DirectCR
else
return Direct
end
end
2 changes: 1 addition & 1 deletion src/jumps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ end
using_params(maj::MassActionJump{T, S, U, Nothing}) where {T, S, U} = false
using_params(maj::MassActionJump) = true
using_params(maj::Nothing) = false
@inline get_num_majumps(maj::MassActionJump) = length(maj.scaled_rates)
@inline get_num_majumps(maj::MassActionJump) = length(maj.net_stoch)
@inline get_num_majumps(maj::Nothing) = 0

struct MassActionJumpParamMapper{U}
Expand Down
13 changes: 11 additions & 2 deletions src/problem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,17 @@ function JumpProblem(prob, aggregator::AbstractAggregatorAlgorithm, jumps::Abstr
kwargs...)
JumpProblem(prob, aggregator, JumpSet(jumps...); kwargs...)
end
function JumpProblem(prob, jumps::JumpSet; kwargs...)
JumpProblem(prob, NullAggregator(), jumps; kwargs...)
function JumpProblem(prob, jumps::JumpSet; vartojumps_map = nothing,
jumptovars_map = nothing, dep_graph = nothing,
spatial_system = nothing, hopping_constants = nothing, kwargs...)
ps = (; vartojumps_map, jumptovars_map, dep_graph, spatial_system, hopping_constants)
aggtype = select_aggregator(jumps; ps...)
return JumpProblem(prob, aggtype(), jumps; ps..., kwargs...)
end

# this makes it easier to test the aggregator selection
function JumpProblem(prob, aggregator::NullAggregator, jumps::JumpSet; kwargs...)
JumpProblem(prob, jumps; kwargs...)
end

make_kwarg(; kwargs...) = kwargs
Expand Down
11 changes: 11 additions & 0 deletions src/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,17 @@ function DiffEqBase.__solve(jump_prob::DiffEqBase.AbstractJumpProblem{P},
integrator.sol
end

# if passed a JumpProblem over a DiscreteProblem, and no aggregator is selected use
# SSAStepper
function DiffEqBase.__solve(jump_prob::DiffEqBase.AbstractJumpProblem{P};
kwargs...) where {P <: DiscreteProblem}
DiffEqBase.__solve(jump_prob, SSAStepper(); kwargs...)
end

function DiffEqBase.__solve(jump_prob::DiffEqBase.AbstractJumpProblem; kwargs...)
error("Auto-solver selection is currently only implemented for JumpProblems defined over DiscreteProblems. Please explicitly specify a solver algorithm in calling solve.")
end

function DiffEqBase.__init(_jump_prob::DiffEqBase.AbstractJumpProblem{P},
alg::DiffEqBase.DEAlgorithm, timeseries = [], ts = [], ks = [],
recompile::Type{Val{recompile_flag}} = Val{true};
Expand Down
40 changes: 13 additions & 27 deletions test/bimolerx_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ dotestmean = true
doprintmeans = false

# SSAs to test
SSAalgs = JumpProcesses.JUMP_AGGREGATORS
SSAalgs = (JumpProcesses.JUMP_AGGREGATORS..., JumpProcesses.NullAggregator())

Nsims = 32000
tf = 0.01
Expand Down Expand Up @@ -55,10 +55,10 @@ jump_to_dep_specs = [[1, 2], [1, 2], [1, 2, 3], [1, 2, 3], [1, 3]]
majumps = MassActionJump(rates, reactstoch, netstoch)

# average number of proteins in a simulation
function runSSAs(jump_prob)
function runSSAs(jump_prob; use_stepper = true)
Psamp = zeros(Int, Nsims)
for i in 1:Nsims
sol = solve(jump_prob, SSAStepper())
sol = use_stepper ? solve(jump_prob, SSAStepper()) : solve(jump_prob)
Psamp[i] = sol[1, end]
end
mean(Psamp)
Expand All @@ -81,37 +81,23 @@ end

# test the means
if dotestmean
means = zeros(Float64, length(SSAalgs))
for (i, alg) in enumerate(SSAalgs)
local jump_prob = JumpProblem(prob, alg, majumps, save_positions = (false, false),
vartojumps_map = spec_to_dep_jumps,
jumptovars_map = jump_to_dep_specs, rng = rng)
means[i] = runSSAs(jump_prob)
relerr = abs(means[i] - expected_avg) / expected_avg
if doprintmeans
println("Mean from method: ", typeof(alg), " is = ", means[i], ", rel err = ",
relerr)
end

# if dobenchmark
# @btime (runSSAs($jump_prob);)
# end

@test abs(means[i] - expected_avg) < reltol * expected_avg
means = runSSAs(jump_prob)
relerr = abs(means - expected_avg) / expected_avg
doprintmeans && println("Mean from method: ", typeof(alg), " is = ", means,
", rel err = ", relerr)
@test abs(means - expected_avg) < reltol * expected_avg

# test not specifying SSAStepper
means = runSSAs(jump_prob; use_stepper = false)
relerr = abs(means - expected_avg) / expected_avg
@test abs(means - expected_avg) < reltol * expected_avg
end
end

# benchmark performance
# if dobenchmark
# # exact methods
# for alg in SSAalgs
# println("Solving with method: ", typeof(alg), ", using SSAStepper")
# jump_prob = JumpProblem(prob, alg, majumps, vartojumps_map=spec_to_dep_jumps, jumptovars_map=jump_to_dep_specs, rng=rng)
# @btime solve($jump_prob, SSAStepper())
# end
# println()
# end

# add a test for passing MassActionJumps individually (tests combining)
if dotestmean
majump_vec = Vector{MassActionJump{Float64, Vector{Pair{Int, Int}}}}()
Expand Down
3 changes: 1 addition & 2 deletions test/degenerate_rx_cases.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@ doprint = false
#using Plots; plotlyjs()
doplot = false

methods = (RDirect(), RSSACR(), Direct(), DirectFW(), FRM(), FRMFW(), SortingDirect(),
NRM(), RSSA(), DirectCR(), Coevolve())
methods = (JumpProcesses.JUMP_AGGREGATORS..., JumpProcesses.NullAggregator())

# one reaction case, mass action jump, vector of data
rate = [2.0]
Expand Down
2 changes: 1 addition & 1 deletion test/extinction_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ dg = [[1]]
majump = MassActionJump(rates, reactstoch, netstoch)
u0 = [100000]
dprob = DiscreteProblem(u0, (0.0, 1e5), rates)
algs = JumpProcesses.JUMP_AGGREGATORS
algs = (JumpProcesses.JUMP_AGGREGATORS..., JumpProcesses.NullAggregator())

for n in 1:Nsims
for ssa in algs
Expand Down
46 changes: 20 additions & 26 deletions test/geneexpr_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@ dotestmean = true
doprintmeans = false

# SSAs to test
SSAalgs = (RDirect(), RSSACR(), Direct(), DirectFW(), FRM(), FRMFW(), SortingDirect(),
NRM(), RSSA(), DirectCR(), Coevolve())
SSAalgs = (JumpProcesses.JUMP_AGGREGATORS..., JumpProcesses.NullAggregator())

# numerical parameters
Nsims = 8000
Expand All @@ -23,10 +22,10 @@ expected_avg = 5.926553750000000e+02
reltol = 0.01

# average number of proteins in a simulation
function runSSAs(jump_prob)
function runSSAs(jump_prob; use_stepper = true)
Psamp = zeros(Int, Nsims)
for i in 1:Nsims
sol = solve(jump_prob, SSAStepper())
sol = use_stepper ? solve(jump_prob, SSAStepper()) : solve(jump_prob)
Psamp[i] = sol[3, end]
end
mean(Psamp)
Expand Down Expand Up @@ -86,33 +85,28 @@ end

# test the means
if dotestmean
means = zeros(Float64, length(SSAalgs))
for (i, alg) in enumerate(SSAalgs)
local jump_prob = JumpProblem(prob, alg, majumps, save_positions = (false, false),
vartojumps_map = spec_to_dep_jumps,
jumptovars_map = jump_to_dep_specs, rng = rng)
means[i] = runSSAs(jump_prob)
relerr = abs(means[i] - expected_avg) / expected_avg
if doprintmeans
println("Mean from method: ", typeof(alg), " is = ", means[i], ", rel err = ",
relerr)
end
means = runSSAs(jump_prob)
relerr = abs(means - expected_avg) / expected_avg
doprintmeans && println("Mean from method: ", typeof(alg), " is = ", means,
", rel err = ", relerr)
@test abs(means - expected_avg) < reltol * expected_avg

# if dobenchmark
# @btime (runSSAs($jump_prob);)
# end

@test abs(means[i] - expected_avg) < reltol * expected_avg
means = runSSAs(jump_prob; use_stepper = false)
relerr = abs(means - expected_avg) / expected_avg
@test abs(means - expected_avg) < reltol * expected_avg
end
end

# benchmark performance
# if dobenchmark
# # exact methods
# for alg in SSAalgs
# println("Solving with method: ", typeof(alg), ", using SSAStepper")
# jump_prob = JumpProblem(prob, alg, majumps, vartojumps_map=spec_to_dep_jumps, jumptovars_map=jump_to_dep_specs, rng=rng)
# @btime solve($jump_prob, SSAStepper())
# end
# println()
# end
# no-aggregator tests
jump_prob = JumpProblem(prob, majumps; save_positions = (false, false),
vartojumps_map = spec_to_dep_jumps, jumptovars_map = jump_to_dep_specs, rng)
@test abs(runSSAs(jump_prob) - expected_avg) < reltol * expected_avg
@test abs(runSSAs(jump_prob; use_stepper = false) - expected_avg) < reltol * expected_avg

jump_prob = JumpProblem(prob, majumps, save_positions = (false, false), rng = rng)
@test abs(runSSAs(jump_prob) - expected_avg) < reltol * expected_avg
@test abs(runSSAs(jump_prob; use_stepper = false) - expected_avg) < reltol * expected_avg
Loading
Loading