Skip to content

Commit

Permalink
Merge pull request #463 from isaacsas/refactor_callbacks
Browse files Browse the repository at this point in the history
make affects wrapping vrjs return nothing
  • Loading branch information
isaacsas authored Nov 12, 2024
2 parents 0a6835f + b280674 commit d8a5177
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 30 deletions.
5 changes: 3 additions & 2 deletions src/SSA_stepper.jl
Original file line number Diff line number Diff line change
Expand Up @@ -184,10 +184,11 @@ function DiffEqBase.__init(jump_prob::JumpProblem,
end
else
cb = deepcopy(jump_prob.jump_callback.discrete_callbacks[end])
rng = cb.condition.rng
if seed === nothing
Random.seed!(cb.condition.rng, rand(UInt64))
Random.seed!(rng, rand(UInt64))
else
Random.seed!(cb.condition.rng, seed)
Random.seed!(rng, seed)
end
end
opts = (callback = CallbackSet(callback),)
Expand Down
34 changes: 13 additions & 21 deletions src/problem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ then be passed within a single [`JumpSet`](@ref) or as subsequent sequential arg
$(FIELDS)
## Keyword Arguments
- `rng`, the random number generator to use. On 1.7 and up defaults to Julia's built-in
generator, below 1.7 uses RandomNumbers.jl's `Xorshifts.Xoroshiro128Star(rand(UInt64))`.
- `rng`, the random number generator to use. Defaults to Julia's built-in
generator.
- `save_positions=(true,true)`, specifies whether to save the system's state (before, after)
the jump occurs.
- `spatial_system`, for spatial problems the underlying spatial structure.
Expand Down Expand Up @@ -430,14 +430,14 @@ function extend_problem(prob::DiffEqBase.AbstractDAEProblem, jumps; rng = DEFAUL
remake(prob; f, u0)
end

function build_variable_callback(cb, idx, jump, jumps...; rng = DEFAULT_RNG)
idx += 1
condition = function (u, t, integrator)
function wrap_jump_in_callback(idx, jump; rng = DEFAULT_RNG)
condition = function(u, t, integrator)
u.jump_u[idx]
end
affect! = function (integrator)
affect! = function(integrator)
jump.affect!(integrator)
integrator.u.jump_u[idx] = -randexp(rng, typeof(integrator.t))
nothing
end
new_cb = ContinuousCallback(condition, affect!;
idxs = jump.idxs,
Expand All @@ -446,26 +446,18 @@ function build_variable_callback(cb, idx, jump, jumps...; rng = DEFAULT_RNG)
save_positions = jump.save_positions,
abstol = jump.abstol,
reltol = jump.reltol)
return new_cb
end

function build_variable_callback(cb, idx, jump, jumps...; rng = DEFAULT_RNG)
idx += 1
new_cb = wrap_jump_in_callback(idx, jump; rng)
build_variable_callback(CallbackSet(cb, new_cb), idx, jumps...; rng = DEFAULT_RNG)
end

function build_variable_callback(cb, idx, jump; rng = DEFAULT_RNG)
idx += 1
condition = function (u, t, integrator)
u.jump_u[idx]
end
affect! = function (integrator)
jump.affect!(integrator)
integrator.u.jump_u[idx] = -randexp(rng, typeof(integrator.t))
end
new_cb = ContinuousCallback(condition, affect!;
idxs = jump.idxs,
rootfind = jump.rootfind,
interp_points = jump.interp_points,
save_positions = jump.save_positions,
abstol = jump.abstol,
reltol = jump.reltol)
CallbackSet(cb, new_cb)
CallbackSet(cb, wrap_jump_in_callback(idx, jump; rng))
end

aggregator(jp::JumpProblem{iip, P, A, C, J}) where {iip, P, A, C, J} = A
Expand Down
6 changes: 3 additions & 3 deletions src/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,12 @@ end

function resetted_jump_problem(_jump_prob, seed)
jump_prob = deepcopy(_jump_prob)
rng = jump_prob.jump_callback.discrete_callbacks[1].condition.rng
if !isempty(jump_prob.jump_callback.discrete_callbacks)
if seed === nothing
Random.seed!(jump_prob.jump_callback.discrete_callbacks[1].condition.rng,
rand(UInt64))
Random.seed!(rng, rand(UInt64))
else
Random.seed!(jump_prob.jump_callback.discrete_callbacks[1].condition.rng, seed)
Random.seed!(rng, seed)
end
end

Expand Down
11 changes: 7 additions & 4 deletions test/variable_rate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -275,18 +275,20 @@ end
# https://github.com/SciML/JumpProcesses.jl/issues/320
# note that even with the seeded StableRNG this test is not
# deterministic for some reason.
function getmean(Nsims, prob, alg, dt, tsave)
function getmean(Nsims, prob, alg, dt, tsave, seed)
umean = zeros(length(tsave))
for i in 1:Nsims
sol = solve(prob, alg; saveat = dt)
sol = solve(prob, alg; saveat = dt, seed)
umean .+= Array(sol(tsave; idxs = 1))
seed += 1
end
umean ./= Nsims
return umean
end

let
rng = StableRNG(12345)
seed = 12345
rng = StableRNG(seed)
b = 2.0
d = 1.0
n0 = 1
Expand Down Expand Up @@ -320,7 +322,8 @@ let
dt = 0.1
tsave = range(tspan[1], tspan[2]; step = dt)
for alg in (Tsit5(), Rodas5P(linsolve = QRFactorization()))
umean = getmean(Nsims, sjm_prob, alg, dt, tsave)
umean = getmean(Nsims, sjm_prob, alg, dt, tsave, seed)
@test all(abs.(umean .- n.(tsave)) .< 0.05 * n.(tsave))
seed += Nsims
end
end

0 comments on commit d8a5177

Please sign in to comment.