Stochastic simulations#

Gillespie Algorithm#

using Plots
using DataInterpolations: LinearInterpolation
using StatsBase: Weights, sample
using Statistics: mean
using Random
using DisplayAs: PNG    ## For faster rendering
Random.seed!(2024)
Random.TaskLocalRNG()

Do-it-yourself#

Stochastic chemical reaction: Gillespie Algorithm (direct and first reaction method) Adapted from: Chemical and Biomedical Engineering Calculations Using Python Ch.4-3

function ssa_alg(model, u0::AbstractVector, tend, p, stoich; tstart=zero(tend), method=:direct)
    t = tstart   ## Current time
    ts = [t]     ## Time points
    u = copy(u0) ## Current state
    us = copy(u') ## States over time
    while t < tend
        a = model(u, p, t)               ## propensities
        if method == :direct
            dt = randexp() / sum(a)          ## Time step for the direct method
            du = sample(stoich, Weights(a))  ## Choose the stoichiometry for the next reaction
        elseif method == :first
            dts = randexp(length(a)) ./ a   ## time scales of all reactions
            i = argmin(dts)                 ## Choose the most recent reaction to occur
            dt = dts[i]
            du = stoich[i]
        else
            error("Method should be either :direct or :first")
        end
        u .+= du   ## Update time
        t += dt    ## Update time
        us = [us; u']  ## Append state
        push!(ts, t)   ## Append time point
    end
    return (t=ts, u=us)
end
ssa_alg (generic function with 1 method)

Propensity model for this example reaction. Reaction of A <-> B with rate constants k1 & k2

model(u, p, t) = [p.k1 * u[1], p.k2 * u[2]]
model (generic function with 1 method)
parameters = (k1=1.0, k2=0.5)
(k1 = 1.0, k2 = 0.5)

Stoichiometry for each reaction

stoich = [[-1, 1], [1, -1]]
2-element Vector{Vector{Int64}}:
 [-1, 1]
 [1, -1]

Initial conditions (Usually discrete values)

u0 = [200, 0]
2-element Vector{Int64}:
 200
   0

Simulation time

tend = 10.0
10.0

Solve the problem using both direct and first reaction method

@time soldirect = ssa_alg(model, u0, tend, parameters, stoich; method=:direct)
@time solfirst = ssa_alg(model, u0, tend, parameters, stoich; method=:first)
  0.368385 seconds (432.49 k allocations: 44.325 MiB, 4.89% gc time, 94.44% compilation time)
  0.002050 seconds (9.95 k allocations: 14.957 MiB)
(t = [0.0, 0.006874672879606227, 0.0070050915014185, 0.01329115182676454, 0.015335530116774018, 0.01942319769041779, 0.021346508038833215, 0.02345030020299408, 0.025497753122539372, 0.02734483697039726  …  9.937773197389458, 9.938360255993572, 9.964450740930666, 9.96564917153764, 9.975440668199402, 9.978594329657025, 9.982318425245289, 9.987213207313044, 9.995833233887383, 10.024199143896238], u = [200 0; 199 1; … ; 75 125; 74 126])

Plot the solution from the direct method

plot(soldirect.t, soldirect.u,
    xlabel="time", ylabel="# of molecules",
    title="SSA (direct method)", label=["A" "B"]) |> PNG
_images/3b1fa6a16368f92991f6e610608827619a0c57ebb66dcb2a6d8c19099967c795.png

Plot the solution by the first reaction method

plot(solfirst.t, solfirst.u,
    xlabel="time", ylabel="# of molecules",
    title="SSA (1st reaction method)", label=["A" "B"]) |> PNG
_images/309e984f5796524430fcf1189c0aa111fa16853c7e7db6816108766a59821189.png

Running 50 simulations

numRuns = 50

@time sols = map(1:numRuns) do _
    ssa_alg(model, u0, tend, parameters, stoich; method=:direct)
end;
  0.428256 seconds (654.10 k allocations: 761.752 MiB, 57.59% gc time, 17.54% compilation time)

Average values and interpolation

ts = range(0, tend, 101)
a_avg(t) = mean(sols) do sol
    A = LinearInterpolation(sol.u[:, 1], sol.t; cache_parameters=true)
    A(t)
end

b_avg(t) = mean(sols) do sol
    A = LinearInterpolation(sol.u[:, 2], sol.t; cache_parameters=true)
    A(t)
end
b_avg (generic function with 1 method)

Plot the solution

fig1 = plot(xlabel="Time", ylabel="# of molecules", title="SSA (first method) ensemble")

for sol in sols
    plot!(fig1, sol.t, sol.u, linecolor=[:blue :red], linealpha=0.05, label=false)
end

fig1 |> PNG
_images/06c5dfb2e7a53cb0bf6409715c823318d10a02e91fa34573c95220b4e195e890.png

Plot averages

plot!(fig1, a_avg, 0.0, tend, linecolor=:black, linewidth=3, linestyle=:solid, label="Average [A]") |> PNG
plot!(fig1, b_avg, 0.0, tend, linecolor=:black, linewidth=3, linestyle=:dash, label="Average [B]") |> PNG
fig1 |> PNG
_images/3976497b463d33c283b4e633a5b2a390792b8f95c729fa40db9f16f3123cad7a.png

Grid simulation#

https://docs.sciml.ai/Catalyst/stable/spatial_modelling/lattice_reaction_systems/ https://docs.sciml.ai/JumpProcesses/stable/tutorials/spatial/

using Catalyst
using JumpProcesses
using Plots

sir_model = @reaction_network begin
    beta, S + I --> 2I
    gamma, I --> R
end

dS = @transport_reaction D S
dI = @transport_reaction D I
lattice = CartesianGrid((3,3))
lrs = LatticeReactionSystem(sir_model, [dS, dI], lattice)

s0 = ones(Int, 3, 3) .* 110
i0 = zeros(Int, 3, 3)
i0[1, 1] = 10
r0 = zeros(Int, 3, 3)
u0 = [:S => s0, :I => i0, :R => r0]
ps = [:beta => 0.1 / 100, :gamma => 0.03, :D => 1.0]
tspan = (0.0, 250.0)
prob = DiscreteProblem(lrs, u0, tspan, ps)
jump_prob = JumpProblem(lrs, prob, NSM())

@time sol = solve(jump_prob, SSAStepper())
  0.511041 seconds (921.88 k allocations: 119.758 MiB, 85.22% compilation time)
retcode: Success
Interpolation: Piecewise constant interpolation
t: 260891-element Vector{Float64}:
   0.0
   0.0012069567021312662
   0.0013538966171038392
   0.0020068570594335643
   0.002033189009452479
   0.002060173972575241
   0.0022720922444916165
   0.002347500104400583
   0.002462111910372735
   0.002814697169192512
   ⋮
 249.94521048289837
 249.95618014937278
 249.95842893059444
 249.95919147682466
 249.96323233960823
 249.98288384615506
 249.9902229272222
 249.9981588859541
 250.0
u: 260891-element Vector{Matrix{Int64}}:
 [110 110 … 110 110; 10 0 … 0 0; 0 0 … 0 0]
 [110 110 … 110 110; 10 0 … 0 0; 0 0 … 0 0]
 [110 110 … 111 110; 10 0 … 0 0; 0 0 … 0 0]
 [109 110 … 111 110; 10 0 … 0 0; 0 0 … 0 0]
 [109 111 … 111 110; 10 0 … 0 0; 0 0 … 0 0]
 [109 111 … 110 110; 10 0 … 0 0; 0 0 … 0 0]
 [109 111 … 110 109; 10 0 … 0 0; 0 0 … 0 0]
 [109 111 … 110 109; 10 0 … 0 0; 0 0 … 0 0]
 [109 112 … 110 109; 10 0 … 0 0; 0 0 … 0 0]
 [109 112 … 111 109; 10 0 … 0 0; 0 0 … 0 0]
 ⋮
 [2 5 … 6 3; 0 0 … 0 0; 112 114 … 115 81]
 [2 5 … 5 3; 0 0 … 0 0; 112 114 … 115 81]
 [2 5 … 5 3; 0 0 … 0 0; 112 114 … 115 81]
 [2 6 … 5 3; 0 0 … 0 0; 112 114 … 115 81]
 [2 6 … 4 4; 0 0 … 0 0; 112 114 … 115 81]
 [2 6 … 4 4; 0 0 … 0 0; 112 114 … 115 81]
 [2 6 … 4 5; 0 0 … 0 0; 112 114 … 115 81]
 [2 5 … 4 5; 0 0 … 0 0; 112 114 … 115 81]
 [2 5 … 4 5; 0 0 … 0 0; 112 114 … 115 81]
lat_getu(sol, :S, lrs)
260891-element Vector{Matrix{Int64}}:
 [110 110 110; 110 110 110; 110 110 110]
 [110 110 110; 110 110 110; 111 109 110]
 [110 110 110; 110 109 111; 111 109 110]
 [109 111 110; 110 109 111; 111 109 110]
 [109 111 110; 111 109 111; 110 109 110]
 [109 111 110; 111 110 110; 110 109 110]
 [109 111 110; 111 110 110; 110 110 109]
 [109 111 110; 111 110 110; 109 111 109]
 [109 111 110; 112 110 110; 108 111 109]
 [109 111 110; 112 109 111; 108 111 109]
 ⋮
 [2 2 7; 5 3 6; 7 8 3]
 [2 2 7; 5 4 5; 7 8 3]
 [2 3 6; 5 4 5; 7 8 3]
 [2 3 6; 6 3 5; 7 8 3]
 [2 3 6; 6 3 4; 7 8 4]
 [2 4 5; 6 3 4; 7 8 4]
 [2 4 5; 6 3 4; 7 7 5]
 [2 4 5; 5 4 4; 7 7 5]
 [2 4 5; 5 4 4; 7 7 5]
lat_getu(sol, :I, lrs)
260891-element Vector{Matrix{Int64}}:
 [10 0 0; 0 0 0; 0 0 0]
 [10 0 0; 0 0 0; 0 0 0]
 [10 0 0; 0 0 0; 0 0 0]
 [10 0 0; 0 0 0; 0 0 0]
 [10 0 0; 0 0 0; 0 0 0]
 [10 0 0; 0 0 0; 0 0 0]
 [10 0 0; 0 0 0; 0 0 0]
 [10 0 0; 0 0 0; 0 0 0]
 [10 0 0; 0 0 0; 0 0 0]
 [10 0 0; 0 0 0; 0 0 0]
 ⋮
 [0 0 0; 0 1 0; 1 0 0]
 [0 0 0; 0 1 0; 1 0 0]
 [0 0 0; 0 1 0; 1 0 0]
 [0 0 0; 0 1 0; 1 0 0]
 [0 0 0; 0 1 0; 1 0 0]
 [0 0 0; 0 1 0; 1 0 0]
 [0 0 0; 0 1 0; 1 0 0]
 [0 0 0; 0 1 0; 1 0 0]
 [0 0 0; 0 1 0; 1 0 0]
lat_getu(sol, :R, lrs)
260891-element Vector{Matrix{Int64}}:
 [0 0 0; 0 0 0; 0 0 0]
 [0 0 0; 0 0 0; 0 0 0]
 [0 0 0; 0 0 0; 0 0 0]
 [0 0 0; 0 0 0; 0 0 0]
 [0 0 0; 0 0 0; 0 0 0]
 [0 0 0; 0 0 0; 0 0 0]
 [0 0 0; 0 0 0; 0 0 0]
 [0 0 0; 0 0 0; 0 0 0]
 [0 0 0; 0 0 0; 0 0 0]
 [0 0 0; 0 0 0; 0 0 0]
 ⋮
 [112 94 124; 114 105 115; 100 110 81]
 [112 94 124; 114 105 115; 100 110 81]
 [112 94 124; 114 105 115; 100 110 81]
 [112 94 124; 114 105 115; 100 110 81]
 [112 94 124; 114 105 115; 100 110 81]
 [112 94 124; 114 105 115; 100 110 81]
 [112 94 124; 114 105 115; 100 110 81]
 [112 94 124; 114 105 115; 100 110 81]
 [112 94 124; 114 105 115; 100 110 81]
t = sol.t
s = sum.(lat_getu(sol, :S, lrs))
i = sum.(lat_getu(sol, :I, lrs))
r = sum.(lat_getu(sol, :R, lrs))
plot(t, [s i r], label=["S" "I" "R"], lw=1.5) |> PNG
_images/d09f5f533eaf713274a9e64c27da8da4de42e2cee61446895b377a71d3cd38e7.png

This notebook was generated using Literate.jl.