Stochastic simulations#
Gillespie Algorithm#
using Plots
using DataInterpolations: LinearInterpolation
using StatsBase: Weights, sample
using Statistics: mean
using Random
using DisplayAs: PNG ## For faster rendering
Random.seed!(2024)
Random.TaskLocalRNG()
Do-it-yourself#
Stochastic chemical reaction: Gillespie Algorithm (direct and first reaction method) Adapted from: Chemical and Biomedical Engineering Calculations Using Python Ch.4-3
function ssa_alg(model, u0::AbstractVector, tend, p, stoich; tstart=zero(tend), method=:direct)
t = tstart ## Current time
ts = [t] ## Time points
u = copy(u0) ## Current state
us = copy(u') ## States over time
while t < tend
a = model(u, p, t) ## propensities
if method == :direct
dt = randexp() / sum(a) ## Time step for the direct method
du = sample(stoich, Weights(a)) ## Choose the stoichiometry for the next reaction
elseif method == :first
dts = randexp(length(a)) ./ a ## time scales of all reactions
i = argmin(dts) ## Choose the most recent reaction to occur
dt = dts[i]
du = stoich[i]
else
error("Method should be either :direct or :first")
end
u .+= du ## Update time
t += dt ## Update time
us = [us; u'] ## Append state
push!(ts, t) ## Append time point
end
return (t=ts, u=us)
end
ssa_alg (generic function with 1 method)
Propensity model for this example reaction. Reaction of A <-> B with rate constants k1 & k2
model(u, p, t) = [p.k1 * u[1], p.k2 * u[2]]
model (generic function with 1 method)
parameters = (k1=1.0, k2=0.5)
(k1 = 1.0, k2 = 0.5)
Stoichiometry for each reaction
stoich = [[-1, 1], [1, -1]]
2-element Vector{Vector{Int64}}:
[-1, 1]
[1, -1]
Initial conditions (Usually discrete values)
u0 = [200, 0]
2-element Vector{Int64}:
200
0
Simulation time
tend = 10.0
10.0
Solve the problem using both direct and first reaction method
@time soldirect = ssa_alg(model, u0, tend, parameters, stoich; method=:direct)
@time solfirst = ssa_alg(model, u0, tend, parameters, stoich; method=:first)
0.368385 seconds (432.49 k allocations: 44.325 MiB, 4.89% gc time, 94.44% compilation time)
0.002050 seconds (9.95 k allocations: 14.957 MiB)
(t = [0.0, 0.006874672879606227, 0.0070050915014185, 0.01329115182676454, 0.015335530116774018, 0.01942319769041779, 0.021346508038833215, 0.02345030020299408, 0.025497753122539372, 0.02734483697039726 … 9.937773197389458, 9.938360255993572, 9.964450740930666, 9.96564917153764, 9.975440668199402, 9.978594329657025, 9.982318425245289, 9.987213207313044, 9.995833233887383, 10.024199143896238], u = [200 0; 199 1; … ; 75 125; 74 126])
Plot the solution from the direct method
plot(soldirect.t, soldirect.u,
xlabel="time", ylabel="# of molecules",
title="SSA (direct method)", label=["A" "B"]) |> PNG

Plot the solution by the first reaction method
plot(solfirst.t, solfirst.u,
xlabel="time", ylabel="# of molecules",
title="SSA (1st reaction method)", label=["A" "B"]) |> PNG

Running 50 simulations
numRuns = 50
@time sols = map(1:numRuns) do _
ssa_alg(model, u0, tend, parameters, stoich; method=:direct)
end;
0.428256 seconds (654.10 k allocations: 761.752 MiB, 57.59% gc time, 17.54% compilation time)
Average values and interpolation
ts = range(0, tend, 101)
a_avg(t) = mean(sols) do sol
A = LinearInterpolation(sol.u[:, 1], sol.t; cache_parameters=true)
A(t)
end
b_avg(t) = mean(sols) do sol
A = LinearInterpolation(sol.u[:, 2], sol.t; cache_parameters=true)
A(t)
end
b_avg (generic function with 1 method)
Plot the solution
fig1 = plot(xlabel="Time", ylabel="# of molecules", title="SSA (first method) ensemble")
for sol in sols
plot!(fig1, sol.t, sol.u, linecolor=[:blue :red], linealpha=0.05, label=false)
end
fig1 |> PNG

Plot averages
plot!(fig1, a_avg, 0.0, tend, linecolor=:black, linewidth=3, linestyle=:solid, label="Average [A]") |> PNG
plot!(fig1, b_avg, 0.0, tend, linecolor=:black, linewidth=3, linestyle=:dash, label="Average [B]") |> PNG
fig1 |> PNG

Using Catalyst (recommended)#
SciML/Catalyst.jl is a domain-specific language (DSL) package to simulate chemical reaction networks.
using Catalyst
using JumpProcesses
using Plots
two_state_model = @reaction_network begin
k1, A --> B
k2, B --> A
end
State variables are integers
params = [:k1 => 1.0, :k2 => 0.5]
u0 = [:A => 200, :B => 0]
tspan = (0.0, 10.0)
prob = DiscreteProblem(two_state_model, u0, tspan, params)
jump_prob = JumpProblem(two_state_model, prob, Direct())
JumpProblem with problem DiscreteProblem with aggregator JumpProcesses.Direct
Number of jumps with discrete aggregation: 0
Number of jumps with continuous aggregation: 0
Number of mass action jumps: 2
In this case, we would like to solve a JumpProblem
using Gillespie’s Direct stochastic simulation algorithm (SSA).
@time sol = solve(jump_prob, SSAStepper())
plot(sol) |> PNG
0.358158 seconds (495.79 k allocations: 33.739 MiB, 5.71% gc time, 99.96% compilation time)

Parallel ensemble simulation
ensprob = EnsembleProblem(jump_prob)
@time sim = solve(ensprob, SSAStepper(), EnsembleThreads(); trajectories=50)
1.091052 seconds (2.33 M allocations: 162.244 MiB, 3.23% gc time, 104.06% compilation time)
EnsembleSolution Solution of length 50 with uType:
SciMLBase.ODESolution{Int64, 2, Vector{Vector{Int64}}, Nothing, Nothing, Vector{Float64}, Nothing, Nothing, SciMLBase.DiscreteProblem{Vector{Int64}, Tuple{Float64, Float64}, true, ModelingToolkit.MTKParameters{Vector{Float64}, StaticArraysCore.SizedVector{0, Float64, Vector{Float64}}, Tuple{}, Tuple{}, Tuple{}, Tuple{}}, SciMLBase.DiscreteFunction{true, true, SciMLBase.DiscreteFunction{true, SciMLBase.FullSpecialize, SciMLBase.var"#236#237", Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Nothing, ModelingToolkit.ObservedFunctionCache{ModelingToolkit.JumpSystem{RecursiveArrayTools.ArrayPartition{Any, Tuple{Vector{JumpProcesses.MassActionJump}, Vector{JumpProcesses.ConstantRateJump}, Vector{JumpProcesses.VariableRateJump}, Vector{Symbolics.Equation}}}}}, ModelingToolkit.JumpSystem{RecursiveArrayTools.ArrayPartition{Any, Tuple{Vector{JumpProcesses.MassActionJump}, Vector{JumpProcesses.ConstantRateJump}, Vector{JumpProcesses.VariableRateJump}, Vector{Symbolics.Equation}}}}, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, @NamedTuple{}}}, JumpProcesses.SSAStepper, SciMLBase.ConstantInterpolation{Vector{Float64}, Vector{Vector{Int64}}}, SciMLBase.DEStats, Nothing, Nothing, Nothing, Nothing}
plot(sim, alpha=0.1, color=[:blue :red]) |> PNG

summ = EnsembleSummary(sim, 0:0.1:10)
plot(summ, fillalpha=0.5) |> PNG

SIR model#
sir_model = @reaction_network begin
beta, S + I --> 2I
gamma, I --> R
end
p = (:beta => 0.1 / 1000, :gamma => 0.01)
u0 = [:S => 990, :I => 10, :R => 0]
tspan = (0.0, 250.0)
prob = DiscreteProblem(sir_model, u0, tspan, p)
jump_prob = JumpProblem(sir_model, prob, Direct())
JumpProblem with problem DiscreteProblem with aggregator JumpProcesses.Direct
Number of jumps with discrete aggregation: 0
Number of jumps with continuous aggregation: 0
Number of mass action jumps: 2
@time sol = solve(jump_prob, SSAStepper())
0.000226 seconds (1.84 k allocations: 199.969 KiB)
retcode: Success
Interpolation: Piecewise constant interpolation
t: 1827-element Vector{Float64}:
0.0
0.473096636084029
0.8557259166742717
1.4082894926246763
2.1737896809645885
2.802916732568468
3.66834025220687
4.22743341800077
5.085723781291026
5.265530069473048
⋮
244.4319869337805
244.62973767946696
244.78493504620218
246.25931998717863
247.2586443495297
247.94699762893305
248.0812845871357
248.7097144416353
250.0
u: 1827-element Vector{Vector{Int64}}:
[990, 10, 0]
[989, 11, 0]
[988, 12, 0]
[987, 13, 0]
[986, 14, 0]
[985, 15, 0]
[985, 14, 1]
[984, 15, 1]
[983, 16, 1]
[982, 17, 1]
⋮
[0, 172, 828]
[0, 171, 829]
[0, 170, 830]
[0, 169, 831]
[0, 168, 832]
[0, 167, 833]
[0, 166, 834]
[0, 165, 835]
[0, 165, 835]
plot(sol) |> PNG

See also the JumpProcesses.jl docs about discrete stochastic algorithm examples.
Grid simulation#
https://docs.sciml.ai/Catalyst/stable/spatial_modelling/lattice_reaction_systems/ https://docs.sciml.ai/JumpProcesses/stable/tutorials/spatial/
using Catalyst
using JumpProcesses
using Plots
sir_model = @reaction_network begin
beta, S + I --> 2I
gamma, I --> R
end
dS = @transport_reaction D S
dI = @transport_reaction D I
lattice = CartesianGrid((3,3))
lrs = LatticeReactionSystem(sir_model, [dS, dI], lattice)
s0 = ones(Int, 3, 3) .* 110
i0 = zeros(Int, 3, 3)
i0[1, 1] = 10
r0 = zeros(Int, 3, 3)
u0 = [:S => s0, :I => i0, :R => r0]
ps = [:beta => 0.1 / 100, :gamma => 0.03, :D => 1.0]
tspan = (0.0, 250.0)
prob = DiscreteProblem(lrs, u0, tspan, ps)
jump_prob = JumpProblem(lrs, prob, NSM())
@time sol = solve(jump_prob, SSAStepper())
0.511041 seconds (921.88 k allocations: 119.758 MiB, 85.22% compilation time)
retcode: Success
Interpolation: Piecewise constant interpolation
t: 260891-element Vector{Float64}:
0.0
0.0012069567021312662
0.0013538966171038392
0.0020068570594335643
0.002033189009452479
0.002060173972575241
0.0022720922444916165
0.002347500104400583
0.002462111910372735
0.002814697169192512
⋮
249.94521048289837
249.95618014937278
249.95842893059444
249.95919147682466
249.96323233960823
249.98288384615506
249.9902229272222
249.9981588859541
250.0
u: 260891-element Vector{Matrix{Int64}}:
[110 110 … 110 110; 10 0 … 0 0; 0 0 … 0 0]
[110 110 … 110 110; 10 0 … 0 0; 0 0 … 0 0]
[110 110 … 111 110; 10 0 … 0 0; 0 0 … 0 0]
[109 110 … 111 110; 10 0 … 0 0; 0 0 … 0 0]
[109 111 … 111 110; 10 0 … 0 0; 0 0 … 0 0]
[109 111 … 110 110; 10 0 … 0 0; 0 0 … 0 0]
[109 111 … 110 109; 10 0 … 0 0; 0 0 … 0 0]
[109 111 … 110 109; 10 0 … 0 0; 0 0 … 0 0]
[109 112 … 110 109; 10 0 … 0 0; 0 0 … 0 0]
[109 112 … 111 109; 10 0 … 0 0; 0 0 … 0 0]
⋮
[2 5 … 6 3; 0 0 … 0 0; 112 114 … 115 81]
[2 5 … 5 3; 0 0 … 0 0; 112 114 … 115 81]
[2 5 … 5 3; 0 0 … 0 0; 112 114 … 115 81]
[2 6 … 5 3; 0 0 … 0 0; 112 114 … 115 81]
[2 6 … 4 4; 0 0 … 0 0; 112 114 … 115 81]
[2 6 … 4 4; 0 0 … 0 0; 112 114 … 115 81]
[2 6 … 4 5; 0 0 … 0 0; 112 114 … 115 81]
[2 5 … 4 5; 0 0 … 0 0; 112 114 … 115 81]
[2 5 … 4 5; 0 0 … 0 0; 112 114 … 115 81]
lat_getu(sol, :S, lrs)
260891-element Vector{Matrix{Int64}}:
[110 110 110; 110 110 110; 110 110 110]
[110 110 110; 110 110 110; 111 109 110]
[110 110 110; 110 109 111; 111 109 110]
[109 111 110; 110 109 111; 111 109 110]
[109 111 110; 111 109 111; 110 109 110]
[109 111 110; 111 110 110; 110 109 110]
[109 111 110; 111 110 110; 110 110 109]
[109 111 110; 111 110 110; 109 111 109]
[109 111 110; 112 110 110; 108 111 109]
[109 111 110; 112 109 111; 108 111 109]
⋮
[2 2 7; 5 3 6; 7 8 3]
[2 2 7; 5 4 5; 7 8 3]
[2 3 6; 5 4 5; 7 8 3]
[2 3 6; 6 3 5; 7 8 3]
[2 3 6; 6 3 4; 7 8 4]
[2 4 5; 6 3 4; 7 8 4]
[2 4 5; 6 3 4; 7 7 5]
[2 4 5; 5 4 4; 7 7 5]
[2 4 5; 5 4 4; 7 7 5]
lat_getu(sol, :I, lrs)
260891-element Vector{Matrix{Int64}}:
[10 0 0; 0 0 0; 0 0 0]
[10 0 0; 0 0 0; 0 0 0]
[10 0 0; 0 0 0; 0 0 0]
[10 0 0; 0 0 0; 0 0 0]
[10 0 0; 0 0 0; 0 0 0]
[10 0 0; 0 0 0; 0 0 0]
[10 0 0; 0 0 0; 0 0 0]
[10 0 0; 0 0 0; 0 0 0]
[10 0 0; 0 0 0; 0 0 0]
[10 0 0; 0 0 0; 0 0 0]
⋮
[0 0 0; 0 1 0; 1 0 0]
[0 0 0; 0 1 0; 1 0 0]
[0 0 0; 0 1 0; 1 0 0]
[0 0 0; 0 1 0; 1 0 0]
[0 0 0; 0 1 0; 1 0 0]
[0 0 0; 0 1 0; 1 0 0]
[0 0 0; 0 1 0; 1 0 0]
[0 0 0; 0 1 0; 1 0 0]
[0 0 0; 0 1 0; 1 0 0]
lat_getu(sol, :R, lrs)
260891-element Vector{Matrix{Int64}}:
[0 0 0; 0 0 0; 0 0 0]
[0 0 0; 0 0 0; 0 0 0]
[0 0 0; 0 0 0; 0 0 0]
[0 0 0; 0 0 0; 0 0 0]
[0 0 0; 0 0 0; 0 0 0]
[0 0 0; 0 0 0; 0 0 0]
[0 0 0; 0 0 0; 0 0 0]
[0 0 0; 0 0 0; 0 0 0]
[0 0 0; 0 0 0; 0 0 0]
[0 0 0; 0 0 0; 0 0 0]
⋮
[112 94 124; 114 105 115; 100 110 81]
[112 94 124; 114 105 115; 100 110 81]
[112 94 124; 114 105 115; 100 110 81]
[112 94 124; 114 105 115; 100 110 81]
[112 94 124; 114 105 115; 100 110 81]
[112 94 124; 114 105 115; 100 110 81]
[112 94 124; 114 105 115; 100 110 81]
[112 94 124; 114 105 115; 100 110 81]
[112 94 124; 114 105 115; 100 110 81]
t = sol.t
s = sum.(lat_getu(sol, :S, lrs))
i = sum.(lat_getu(sol, :I, lrs))
r = sum.(lat_getu(sol, :R, lrs))
plot(t, [s i r], label=["S" "I" "R"], lw=1.5) |> PNG

This notebook was generated using Literate.jl.