Skip to content

Commit

Permalink
try to make rngs deterministic
Browse files Browse the repository at this point in the history
  • Loading branch information
isaacsas committed Nov 12, 2024
1 parent 0614f1b commit b280674
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 11 deletions.
5 changes: 3 additions & 2 deletions src/SSA_stepper.jl
Original file line number Diff line number Diff line change
Expand Up @@ -184,10 +184,11 @@ function DiffEqBase.__init(jump_prob::JumpProblem,
end
else
cb = deepcopy(jump_prob.jump_callback.discrete_callbacks[end])
rng = cb.condition.rng
if seed === nothing
Random.seed!(cb.condition.rng, rand(UInt64))
Random.seed!(rng, rand(UInt64))
else
Random.seed!(cb.condition.rng, seed)
Random.seed!(rng, seed)
end
end
opts = (callback = CallbackSet(callback),)
Expand Down
4 changes: 2 additions & 2 deletions src/problem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ then be passed within a single [`JumpSet`](@ref) or as subsequent sequential arg
$(FIELDS)
## Keyword Arguments
- `rng`, the random number generator to use. On 1.7 and up defaults to Julia's built-in
generator, below 1.7 uses RandomNumbers.jl's `Xorshifts.Xoroshiro128Star(rand(UInt64))`.
- `rng`, the random number generator to use. Defaults to Julia's built-in
generator.
- `save_positions=(true,true)`, specifies whether to save the system's state (before, after)
the jump occurs.
- `spatial_system`, for spatial problems the underlying spatial structure.
Expand Down
6 changes: 3 additions & 3 deletions src/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,12 @@ end

function resetted_jump_problem(_jump_prob, seed)
jump_prob = deepcopy(_jump_prob)
rng = jump_prob.jump_callback.discrete_callbacks[1].condition.rng
if !isempty(jump_prob.jump_callback.discrete_callbacks)
if seed === nothing
Random.seed!(jump_prob.jump_callback.discrete_callbacks[1].condition.rng,
rand(UInt64))
Random.seed!(rng, rand(UInt64))
else
Random.seed!(jump_prob.jump_callback.discrete_callbacks[1].condition.rng, seed)
Random.seed!(rng, seed)
end
end

Expand Down
11 changes: 7 additions & 4 deletions test/variable_rate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -275,18 +275,20 @@ end
# https://github.com/SciML/JumpProcesses.jl/issues/320
# note that even with the seeded StableRNG this test is not
# deterministic for some reason.
function getmean(Nsims, prob, alg, dt, tsave)
function getmean(Nsims, prob, alg, dt, tsave, seed)
umean = zeros(length(tsave))
for i in 1:Nsims
sol = solve(prob, alg; saveat = dt)
sol = solve(prob, alg; saveat = dt, seed)
umean .+= Array(sol(tsave; idxs = 1))
seed += 1
end
umean ./= Nsims
return umean
end

let
rng = StableRNG(12345)
seed = 12345
rng = StableRNG(seed)
b = 2.0
d = 1.0
n0 = 1
Expand Down Expand Up @@ -320,7 +322,8 @@ let
dt = 0.1
tsave = range(tspan[1], tspan[2]; step = dt)
for alg in (Tsit5(), Rodas5P(linsolve = QRFactorization()))
umean = getmean(Nsims, sjm_prob, alg, dt, tsave)
umean = getmean(Nsims, sjm_prob, alg, dt, tsave, seed)
@test all(abs.(umean .- n.(tsave)) .< 0.05 * n.(tsave))
seed += Nsims
end
end

0 comments on commit b280674

Please sign in to comment.