Skip to content

Commit

Permalink
Merge pull request #351 from SciML/default-ssa
Browse files Browse the repository at this point in the history
Choose an SSA if no SSA is passed in `JumpProblem`.
  • Loading branch information
isaacsas authored Aug 6, 2024
2 parents afae6f6 + 331ea57 commit 4f262cb
Show file tree
Hide file tree
Showing 15 changed files with 182 additions and 94 deletions.
32 changes: 32 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,38 @@

## JumpProcesses unreleased (master branch)

## 9.12

- Added a default aggregator selection algorithm based on the number of passed
in jumps. i.e. the following now auto-selects an aggregator (`Direct` in this
case):

```julia
using JumpProcesses
rate(u, p, t) = u[1]
affect(integrator) = (integrator.u[1] -= 1; nothing)
crj = ConstantRateJump(rate, affect)
dprob = DiscreteProblem([10], (0.0, 10.0))
jprob = JumpProblem(dprob, crj)
sol = solve(jprob, SSAStepper())
```

- For `JumpProblem`s over `DiscreteProblem`s that only have `MassActionJump`s,
`ConstantRateJump`s, and bounded `VariableRateJump`s, one no longer needs to
specify `SSAStepper()` when calling `solve`, i.e. the following now works for
the previous example and is equivalent to manually passing `SSAStepper()`:

```julia
sol = solve(jprob)
```
- Plotting a solution generated with `save_positions = (false, false)` now uses
piecewise linear plots between any saved time points specified via `saveat`
instead (previously the plots appeared piecewise constant even though each
jump was not being shown). Note that solution objects still use piecewise
constant interpolation, see [the
docs](https://docs.sciml.ai/JumpProcesses/stable/tutorials/discrete_stochastic_example/#save_positions_docs)
for details.

## 9.7

- `Coevolve` was updated to support use with coupled ODEs/SDEs. See the updated
Expand Down
12 changes: 6 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,10 @@ using JumpProcesses, Plots
# here we order S = 1, I = 2, and R = 3
# substrate stoichiometry:
substoich = [[1 => 1, 2 => 1], # 1*S + 1*I
[2 => 1]] # 1*I
[2 => 1]] # 1*I
# net change by each jump type
netstoich = [[1 => -1, 2 => 1], # S -> S-1, I -> I+1
[2 => -1, 3 => 1]] # I -> I-1, R -> R+1
[2 => -1, 3 => 1]] # I -> I-1, R -> R+1
# rate constants for each jump
p = (0.1 / 1000, 0.01)

Expand All @@ -96,10 +96,10 @@ tspan = (0.0, 250.0)
dprob = DiscreteProblem(u₀, tspan, p)

# use the Direct method to simulate
jprob = JumpProblem(dprob, Direct(), maj)
jprob = JumpProblem(dprob, maj)

# solve as a pure jump process, i.e. using SSAStepper
sol = solve(jprob, SSAStepper())
sol = solve(jprob)
plot(sol)
```

Expand All @@ -122,8 +122,8 @@ function affect2!(integrator)
integrator.u[3] += 1 # R -> R + 1
end
jump2 = ConstantRateJump(rate2, affect2!)
jprob = JumpProblem(dprob, Direct(), jump, jump2)
sol = solve(jprob, SSAStepper())
jprob = JumpProblem(dprob, jump, jump2)
sol = solve(jprob)
```

### Jump-ODE Example
Expand Down
4 changes: 1 addition & 3 deletions src/SSA_stepper.jl
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,7 @@ function DiffEqBase.u_modified!(integrator::SSAIntegrator, bool::Bool)
integrator.u_modified = bool
end

function DiffEqBase.__solve(jump_prob::JumpProblem,
alg::SSAStepper;
kwargs...)
function DiffEqBase.__solve(jump_prob::JumpProblem, alg::SSAStepper; kwargs...)
integrator = init(jump_prob, alg; kwargs...)
solve!(integrator)
integrator.sol
Expand Down
34 changes: 34 additions & 0 deletions src/aggregators/aggregators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,40 @@ needs_vartojumps_map(aggregator::RSSACR) = true
supports_variablerates(aggregator::AbstractAggregatorAlgorithm) = false
supports_variablerates(aggregator::Coevolve) = true

# true if aggregator supports hops, e.g. diffusion
is_spatial(aggregator::AbstractAggregatorAlgorithm) = false
is_spatial(aggregator::NSM) = true
is_spatial(aggregator::DirectCRDirect) = true

# return the fastest aggregator out of the available ones
function select_aggregator(jumps::JumpSet; vartojumps_map = nothing,
jumptovars_map = nothing, dep_graph = nothing, spatial_system = nothing,
hopping_constants = nothing)

# detect if a spatial SSA should be used
!isnothing(spatial_system) && !isnothing(hopping_constants) && return DirectCRDirect

# if variable rate jumps are present, return one of the two SSAs that support them
if num_vrjs(jumps) > 0
(num_bndvrjs(jumps) > 0) && return Coevolve
return Direct
end

# if the number of jumps is small, return the Direct
num_jumps(jumps) < 20 && return Direct

# if there are only massaction jumps, we can build the species-jumps dependency graphs
can_build_dgs = num_crjs(jumps) == 0 && num_vrjs(jumps) == 0
have_species_to_jumps_dgs = !isnothing(vartojumps_map) && !isnothing(jumptovars_map)

# if we have the species-jumps dgs or can build them, use a Rejection-based methods
if can_build_dgs || have_species_to_jumps_dgs
(num_jumps(jumps) < 100) && return RSSA
return RSSACR
elseif !isnothing(dep_graph) # if only have a normal dg
(num_jumps(jumps) < 200) && return SortingDirect
return DirectCR
else
return Direct
end
end
2 changes: 1 addition & 1 deletion src/jumps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ end
using_params(maj::MassActionJump{T, S, U, Nothing}) where {T, S, U} = false
using_params(maj::MassActionJump) = true
using_params(maj::Nothing) = false
@inline get_num_majumps(maj::MassActionJump) = length(maj.scaled_rates)
@inline get_num_majumps(maj::MassActionJump) = length(maj.net_stoch)
@inline get_num_majumps(maj::Nothing) = 0

struct MassActionJumpParamMapper{U}
Expand Down
13 changes: 11 additions & 2 deletions src/problem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,17 @@ function JumpProblem(prob, aggregator::AbstractAggregatorAlgorithm, jumps::Abstr
kwargs...)
JumpProblem(prob, aggregator, JumpSet(jumps...); kwargs...)
end
function JumpProblem(prob, jumps::JumpSet; kwargs...)
JumpProblem(prob, NullAggregator(), jumps; kwargs...)
function JumpProblem(prob, jumps::JumpSet; vartojumps_map = nothing,
jumptovars_map = nothing, dep_graph = nothing,
spatial_system = nothing, hopping_constants = nothing, kwargs...)
ps = (; vartojumps_map, jumptovars_map, dep_graph, spatial_system, hopping_constants)
aggtype = select_aggregator(jumps; ps...)
return JumpProblem(prob, aggtype(), jumps; ps..., kwargs...)
end

# this makes it easier to test the aggregator selection
function JumpProblem(prob, aggregator::NullAggregator, jumps::JumpSet; kwargs...)
JumpProblem(prob, jumps; kwargs...)
end

make_kwarg(; kwargs...) = kwargs
Expand Down
11 changes: 11 additions & 0 deletions src/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,17 @@ function DiffEqBase.__solve(jump_prob::DiffEqBase.AbstractJumpProblem{P},
integrator.sol
end

# if passed a JumpProblem over a DiscreteProblem, and no aggregator is selected use
# SSAStepper
function DiffEqBase.__solve(jump_prob::DiffEqBase.AbstractJumpProblem{P};
kwargs...) where {P <: DiscreteProblem}
DiffEqBase.__solve(jump_prob, SSAStepper(); kwargs...)
end

function DiffEqBase.__solve(jump_prob::DiffEqBase.AbstractJumpProblem; kwargs...)
error("Auto-solver selection is currently only implemented for JumpProblems defined over DiscreteProblems. Please explicitly specify a solver algorithm in calling solve.")
end

function DiffEqBase.__init(_jump_prob::DiffEqBase.AbstractJumpProblem{P},
alg::DiffEqBase.DEAlgorithm, timeseries = [], ts = [], ks = [],
recompile::Type{Val{recompile_flag}} = Val{true};
Expand Down
40 changes: 13 additions & 27 deletions test/bimolerx_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ dotestmean = true
doprintmeans = false

# SSAs to test
SSAalgs = JumpProcesses.JUMP_AGGREGATORS
SSAalgs = (JumpProcesses.JUMP_AGGREGATORS..., JumpProcesses.NullAggregator())

Nsims = 32000
tf = 0.01
Expand Down Expand Up @@ -55,10 +55,10 @@ jump_to_dep_specs = [[1, 2], [1, 2], [1, 2, 3], [1, 2, 3], [1, 3]]
majumps = MassActionJump(rates, reactstoch, netstoch)

# average number of proteins in a simulation
function runSSAs(jump_prob)
function runSSAs(jump_prob; use_stepper = true)
Psamp = zeros(Int, Nsims)
for i in 1:Nsims
sol = solve(jump_prob, SSAStepper())
sol = use_stepper ? solve(jump_prob, SSAStepper()) : solve(jump_prob)
Psamp[i] = sol[1, end]
end
mean(Psamp)
Expand All @@ -81,37 +81,23 @@ end

# test the means
if dotestmean
means = zeros(Float64, length(SSAalgs))
for (i, alg) in enumerate(SSAalgs)
local jump_prob = JumpProblem(prob, alg, majumps, save_positions = (false, false),
vartojumps_map = spec_to_dep_jumps,
jumptovars_map = jump_to_dep_specs, rng = rng)
means[i] = runSSAs(jump_prob)
relerr = abs(means[i] - expected_avg) / expected_avg
if doprintmeans
println("Mean from method: ", typeof(alg), " is = ", means[i], ", rel err = ",
relerr)
end

# if dobenchmark
# @btime (runSSAs($jump_prob);)
# end

@test abs(means[i] - expected_avg) < reltol * expected_avg
means = runSSAs(jump_prob)
relerr = abs(means - expected_avg) / expected_avg
doprintmeans && println("Mean from method: ", typeof(alg), " is = ", means,
", rel err = ", relerr)
@test abs(means - expected_avg) < reltol * expected_avg

# test not specifying SSAStepper
means = runSSAs(jump_prob; use_stepper = false)
relerr = abs(means - expected_avg) / expected_avg
@test abs(means - expected_avg) < reltol * expected_avg
end
end

# benchmark performance
# if dobenchmark
# # exact methods
# for alg in SSAalgs
# println("Solving with method: ", typeof(alg), ", using SSAStepper")
# jump_prob = JumpProblem(prob, alg, majumps, vartojumps_map=spec_to_dep_jumps, jumptovars_map=jump_to_dep_specs, rng=rng)
# @btime solve($jump_prob, SSAStepper())
# end
# println()
# end

# add a test for passing MassActionJumps individually (tests combining)
if dotestmean
majump_vec = Vector{MassActionJump{Float64, Vector{Pair{Int, Int}}}}()
Expand Down
3 changes: 1 addition & 2 deletions test/degenerate_rx_cases.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@ doprint = false
#using Plots; plotlyjs()
doplot = false

methods = (RDirect(), RSSACR(), Direct(), DirectFW(), FRM(), FRMFW(), SortingDirect(),
NRM(), RSSA(), DirectCR(), Coevolve())
methods = (JumpProcesses.JUMP_AGGREGATORS..., JumpProcesses.NullAggregator())

# one reaction case, mass action jump, vector of data
rate = [2.0]
Expand Down
2 changes: 1 addition & 1 deletion test/extinction_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ dg = [[1]]
majump = MassActionJump(rates, reactstoch, netstoch)
u0 = [100000]
dprob = DiscreteProblem(u0, (0.0, 1e5), rates)
algs = JumpProcesses.JUMP_AGGREGATORS
algs = (JumpProcesses.JUMP_AGGREGATORS..., JumpProcesses.NullAggregator())

for n in 1:Nsims
for ssa in algs
Expand Down
46 changes: 20 additions & 26 deletions test/geneexpr_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@ dotestmean = true
doprintmeans = false

# SSAs to test
SSAalgs = (RDirect(), RSSACR(), Direct(), DirectFW(), FRM(), FRMFW(), SortingDirect(),
NRM(), RSSA(), DirectCR(), Coevolve())
SSAalgs = (JumpProcesses.JUMP_AGGREGATORS..., JumpProcesses.NullAggregator())

# numerical parameters
Nsims = 8000
Expand All @@ -23,10 +22,10 @@ expected_avg = 5.926553750000000e+02
reltol = 0.01

# average number of proteins in a simulation
function runSSAs(jump_prob)
function runSSAs(jump_prob; use_stepper = true)
Psamp = zeros(Int, Nsims)
for i in 1:Nsims
sol = solve(jump_prob, SSAStepper())
sol = use_stepper ? solve(jump_prob, SSAStepper()) : solve(jump_prob)
Psamp[i] = sol[3, end]
end
mean(Psamp)
Expand Down Expand Up @@ -86,33 +85,28 @@ end

# test the means
if dotestmean
means = zeros(Float64, length(SSAalgs))
for (i, alg) in enumerate(SSAalgs)
local jump_prob = JumpProblem(prob, alg, majumps, save_positions = (false, false),
vartojumps_map = spec_to_dep_jumps,
jumptovars_map = jump_to_dep_specs, rng = rng)
means[i] = runSSAs(jump_prob)
relerr = abs(means[i] - expected_avg) / expected_avg
if doprintmeans
println("Mean from method: ", typeof(alg), " is = ", means[i], ", rel err = ",
relerr)
end
means = runSSAs(jump_prob)
relerr = abs(means - expected_avg) / expected_avg
doprintmeans && println("Mean from method: ", typeof(alg), " is = ", means,
", rel err = ", relerr)
@test abs(means - expected_avg) < reltol * expected_avg

# if dobenchmark
# @btime (runSSAs($jump_prob);)
# end

@test abs(means[i] - expected_avg) < reltol * expected_avg
means = runSSAs(jump_prob; use_stepper = false)
relerr = abs(means - expected_avg) / expected_avg
@test abs(means - expected_avg) < reltol * expected_avg
end
end

# benchmark performance
# if dobenchmark
# # exact methods
# for alg in SSAalgs
# println("Solving with method: ", typeof(alg), ", using SSAStepper")
# jump_prob = JumpProblem(prob, alg, majumps, vartojumps_map=spec_to_dep_jumps, jumptovars_map=jump_to_dep_specs, rng=rng)
# @btime solve($jump_prob, SSAStepper())
# end
# println()
# end
# no-aggregator tests
jump_prob = JumpProblem(prob, majumps; save_positions = (false, false),
vartojumps_map = spec_to_dep_jumps, jumptovars_map = jump_to_dep_specs, rng)
@test abs(runSSAs(jump_prob) - expected_avg) < reltol * expected_avg
@test abs(runSSAs(jump_prob; use_stepper = false) - expected_avg) < reltol * expected_avg

jump_prob = JumpProblem(prob, majumps, save_positions = (false, false), rng = rng)
@test abs(runSSAs(jump_prob) - expected_avg) < reltol * expected_avg
@test abs(runSSAs(jump_prob; use_stepper = false) - expected_avg) < reltol * expected_avg
Loading

0 comments on commit 4f262cb

Please sign in to comment.