"""Sequential (discrete-time daily-step) epidemic simulation engine.
This module defines :class:`SequentialEngine`, which overrides the continuous-
time Gillespie loop of :class:`~models.engine_seirspluslike.SeirsPlusLikeEngine`
with a discrete daily-step update: all nodes draw new states simultaneously
once per day using roulette-wheel selection over the propensity matrix.
It also provides an extended :class:`STATES` enumeration and the helper
:func:`_searchsorted2d` used for vectorised roulette-wheel selection.
"""
import pandas as pd
import numpy as np
import scipy as scipy
import scipy.integrate
import networkx as nx
import time
import os
import gc
from utils.history_utils import TimeSeries, TransitionHistory
from models.engine_seirspluslike import SeirsPlusLikeEngine
# from extended_network_model import STATES as s
# [S SS E In Ia Ids Js Jn Ed Ida 0 0 0 0 0 0 0 0]
[docs]
class STATES():
"""Extended integer state codes used by :class:`SequentialEngine` models.
Attributes:
S (int): Susceptible.
S_s (int): Susceptible with false symptoms.
E (int): Exposed.
I_n (int): Infectious, asymptomatic non-symptomatic track.
I_a (int): Infectious, pre-symptomatic.
I_s (int): Infectious, symptomatic.
I_ds (int): Infectious, symptomatic, detected.
J_s (int): Post-infectious, symptomatic.
J_n (int): Post-infectious, asymptomatic.
E_d (int): Exposed, detected.
I_da (int): Infectious, pre-symptomatic, detected.
I_dn (int): Infectious, asymptomatic, detected.
J_ds (int): Post-infectious, symptomatic, detected.
J_dn (int): Post-infectious, asymptomatic, detected.
R_d (int): Recovered, detected.
R_u (int): Recovered, undetected.
D_d (int): Dead, detected.
D_u (int): Dead, undetected.
"""
S = 0
S_s = 1
E = 2
I_n = 3
I_a = 4
I_s = 5
I_ds = 6
J_s = 11
J_n = 12
E_d = 13
I_da = 14
I_dn = 15
J_ds = 16
J_dn = 17
R_d = 7
R_u = 8
D_d = 9
D_u = 10
pass
def _searchsorted2d(a, b):
"""Vectorised row-wise ``numpy.searchsorted``.
For each row *i*, finds the insertion position of ``b[i]`` in ``a[i]``.
Used for efficient roulette-wheel selection across all nodes in one call.
Args:
a (numpy.ndarray): 2-D sorted array of shape ``(m, n)``.
b (numpy.ndarray): 1-D or 2-D query values of shape ``(m,)`` or
``(m, 1)``.
Returns:
numpy.ndarray: 1-D array of shape ``(m,)`` with insertion indices.
"""
m, n = a.shape
max_num = np.maximum(a.max() - a.min(), b.max() - b.min()) + 1
r = max_num * np.arange(a.shape[0])[:, None]
p = np.searchsorted((a+r).ravel(), (b+r).ravel(), side="right")
return p - n*np.arange(m)
[docs]
class SequentialEngine(SeirsPlusLikeEngine):
"""Discrete daily-step engine (roulette-wheel selection over propensities).
Replaces the continuous-time Gillespie loop of
:class:`~models.engine_seirspluslike.SeirsPlusLikeEngine` with a
synchronous, per-node roulette-wheel update executed once per simulated
day. All propensities must therefore sum to 1 across transitions for each
node (i.e. they are interpreted as probabilities rather than rates).
Additional public methods allow manual state manipulation for scenario
modelling: :meth:`move_to_E`, :meth:`move_to_R`, :meth:`force_infect`,
:meth:`detected_node`.
"""
[docs]
def inicialization(self):
"""Initialise the sequential engine and allocate the testable array.
Calls the parent :meth:`inicialization`, then creates
``self.testable`` — a per-node boolean array tracking whether a node
will seek a test when symptomatic.
"""
super().inicialization()
self.testable = np.zeros(
shape=(self.graph.number_of_nodes, 1), dtype=bool)
[docs]
def run_iteration(self):
"""Perform one day of simulation: compute propensities and update states.
For every node draws a uniform random number and selects its next state
via roulette-wheel selection over the column of propensities. State
changes are accumulated in ``self.delta`` and applied in a single batch
at the end of the day.
Returns:
bool: Always ``True`` (termination is managed by :meth:`run`).
"""
# memberships check
# try:
# assert np.all(self.memberships.sum(axis=0) == 1)
# except AssertionError:
# values = self.memberships.sum(axis=0)
# print(values.shape)
# exit()
# try:
# assert np.all(self.memberships >= 0)
# assert np.all(self.memberships <= 1)
# except AssertionError:
# print( (self.memberships < 0).nonzero() )
# exit()
all_testable = (
self.testable[self.memberships[STATES.I_s] == 1].sum() +
self.testable[self.memberships[STATES.J_s] == 1].sum()
)
print(f"DBG testable {all_testable}")
# add timeseries members
for state in self.states:
self.state_counts[state][self.t] = self.state_counts[state][self.t-1]
self.state_increments[state][self.t] = 0
self.num_tests[self.t] = 0
self.num_qtests[self.t] = 0
self.w_times[self.t] = 0
self.all_positive_tests[self.t] = 0
self.durations += 1
self.N[self.t] = self.N[self.t-1]
# self.states_history[self.t] = self.states_history[self.t-1]
# self.meaneprobs[self.t] = self.meaneprobs[self.t-1]
# self.medianprobs[self.t] = self.meaneprobs[self.t-1]
# print(self.memberships.shape)
# print(np.all(self.memberships.sum(axis=0) == 1))
# print(self.memberships.sum(axis=1))
# undetected symptomatic
symptomatic_states = [STATES.I_s, STATES.J_s]
symptomatic_flags = self.memberships[symptomatic_states, :, :].reshape(
len(symptomatic_states), self.num_nodes).sum(axis=0)
self.test_waiting[symptomatic_flags == 1] += 1
plist = self.calc_propensities()
# for idx, prop in enumerate(plist):
# print(f"DBG transition {idx} has prop {prop[21105]}")
#s_and_ss = self.memberships[0] + self.memberships[1]
#p_infect = (plist[0] + plist[3])[s_and_ss == 1]
# print(p_infect.mean()>0, np.median(p_infect)>0)
# exit()
#self.meaneprobs[self.t] = p_infect.mean()
#self.medianeprobs[self.t] = np.median(p_infect)
propensities = np.column_stack(plist)
# print(f"DBG propensities {propensities[21105,:]}")
# assert np.all(propensities >= 0) and np.all(propensities <= 1), \
# f">=0 & <= 1 failed for {propensities[propensities >= 0]} a \
# {propensities[propensities<=1]} "
# check
# print(propensities.shape)
# print(self.memberships.shape)
# print("node 0", self.memberships[:, 0].flatten())
# print(propensities[0])
# print(propensities.sum(axis=1q))
if not np.allclose(propensities.sum(axis=1), 1.0):
print(propensities.sum(axis=1))
print(np.logical_not(np.isclose(propensities.sum(axis=1), 1.0)).nonzero())
index = np.logical_not(np.isclose(
propensities.sum(axis=1), 1.0)).nonzero()[0][0]
print(propensities.sum(axis=1)[index])
index2 = propensities[index].nonzero()[0]
for i in index2:
print(
self.state_str_dict[self.transitions[i][0]],
self.state_str_dict[self.transitions[i][1]],
propensities[index][i]
)
print(self.memberships[:, index])
print("Hey, better exit ... ")
assert np.allclose(propensities.sum(axis=1), 1.0)
# add column with pst P[X->X]
# what is the fastest way to add a column?
# propensities = np.append(
# propensities, np.product(1.0-propensities, axis=1).reshape(-1, 1), axis=1)
cumsum = np.cumsum(propensities, axis=1)
# print(f"DBG cumsum {cumsum[21105]}")
# total = np.sum(propensities, axis=1)
r = np.random.rand(self.num_nodes).reshape(-1, 1)
# print(f"DBG r number {r[21105]}")
# compute which event takes place - roulette wheel selection over rows
transition_idx = _searchsorted2d(cumsum, r)
# print(f"DBG transition_idx {transition_idx[21105]}")
# udpate states
self.delta.fill(0)
# # filter out last transition (that means stay where you are)
# indices = transition_idx != self.num_transitions
# nodes = self.node_ids
# tran_idxes = transition_idx
# looks like list(zip()) is faster than zip(), but not sure what is the best
# to walk through two numpy arrays
# for node, idx in list(zip(nodes, tran_idxes)):
for node, idx in enumerate(transition_idx):
# if idx == self.num_transitions: # state in current state
# continue
if idx > len(self.transitions):
print(
"DBG WARNING idx from searchsorted on the edge of the array, may be round-off error, better to check ")
print(
"DBG WARNING r > sum, but r = np.random.rand() returns [0,1) and sum should be == 1 (except rounding)")
idx = len(self.transitions) - 1
s, e = self.transitions[idx]
if s == e:
continue
# print(f"{node} goes from {self.state_str_dict[s]} to {self.state_str_dict[e]}")
if self.memberships[s, node, 0] != 1:
print(f"node not in state {self.state_str_dict[s]}")
print(self.memberships[:, node, 0])
print(propensities[node, :], idx)
exit()
if node == 21105:
# stalking
print(f"ACTION LOG ({self.t}): node {node} changing state from {self.state_str_dict[s]} to {self.state_str_dict[e]}")
self.states_durations[s].append(self.durations[node])
if e in (
STATES.I_ds,
STATES.E_d,
STATES.I_da,
STATES.I_dn,
STATES.J_ds,
STATES.J_dn
):
self.num_tests[self.t] += 1
if self.test_waiting[node] > 0:
self.w_times[self.t] += self.test_waiting[node]
self.all_positive_tests[self.t] += 1
# node developed symptoms
if e == STATES.I_s:
if np.random.rand() < self.test_rate[node]:
self.testable[node] = True
# node starts to be infectious
if (
(s == STATES.E and e in (STATES.I_a, STATES.I_n))
or (s == STATES.E_d and e in (STATES.I_da, STATES.I_dn))
):
self.infect_start[node] = self.t
if (s, e) in [
(STATES.I_n, STATES.J_n),
(STATES.I_s, STATES.J_s),
(STATES.I_dn, STATES.J_dn),
(STATES.I_ds, STATES.J_ds)
]:
assert self.infect_start[node] != 0
self.infect_time[node] = self.t - self.infect_start[node]
self.durations[node] = 0
self.delta[s, node, :] = -1
self.delta[e, node, :] = 1
self.state_counts[s][self.t] -= 1
self.state_counts[e][self.t] += 1
self.state_increments[e][self.t] += 1
#self.states_history[self.t][node] = e
self.tidx += 1
if self.tidx >= len(self.history):
self.increase_history_len()
self.tseries[self.tidx] = self.t
self.history[self.tidx] = (node, s, e)
# if node died
if e in (self.invisible_states):
self.N[self.t] -= 1
# the real states update
self.memberships += self.delta
return True
# def print(self, verbose=False):
# print("t = %.2f" % self.t)
# if verbose:
# for state in self.states:
# print(f"\t {self.state_str_dict[state]} = {self.current_state_count(state)}")
# # print(flush=True)
[docs]
def run(self, T, print_interval=0, verbose=False):
"""Run the simulation for *T* days with daily-step updates.
Iterates over days 1..T calling :meth:`run_iteration` each day.
Fires ``self.periodic_update_callback`` when set. If the epidemic
ends early (all unstable counts zero), fills remaining days with the
last observed counts.
Args:
T (int): Number of days to simulate.
print_interval (int, optional): Print status every this many days.
Set to ``0`` to suppress. Defaults to ``0``.
verbose (bool, optional): If ``True``, print per-state counts at
each interval. Defaults to ``False``.
Returns:
bool: Always ``True``.
"""
# keep it, saves time
self.delta = np.empty((self.num_states, self.num_nodes, 1), dtype=int)
self.node_ids = np.arange(self.num_nodes)
running = True
self.tidx = 0
if print_interval >= 0:
self.print(verbose)
for self.t in range(1, T+1):
#print("DBG graph.layer_weights", self.graph.layer_weights)
# os.system("free -h")
if __debug__ and print_interval >= 0 and verbose:
print(flush=True)
# input()
# print(f"day {self.t}")
# print(self.t)
# print(len(self.state_counts[0]))
# print(len(self.states_history))
if (self.t >= len(self.state_counts[0])):
# room has run out in the timeseries storage arrays; double the size of these arrays
self.increase_data_series_length()
if print_interval > 0 and verbose:
start = time.time()
running = self.run_iteration()
# run periodical update
if self.periodic_update_callback is not None:
self.periodic_update_callback.run()
# changes = self.periodic_update_callback(
# self.history, self.tseries[:self.tidx +
# 1], self.t, self.contact_history,
# self.memberships)
# if "graph" in changes:
# print("CHANGING GRAPH")
# self.update_graph(changes["graph"])
if print_interval > 0:
if verbose:
end = time.time()
print("Last day took: ", end - start, "seconds")
if (self.t % print_interval == 0):
self.print(verbose)
# Terminate if tmax reached or num infectious and num exposed is 0:
numI = sum([self.current_state_count(s)
for s in self.unstable_states
])
# if True:
# GIRL = 29691
# # infect the girl 29691
# if self.graph.layer_weights[30] == 1.0:
# # move node 29691 to E
# orig_state = self.memberships[:, GIRL].nonzero()[0][0]
# if orig_state == STATES.E:
# print(f"ACTION LOG(92): node 29691 enters the party already exposed")
# else:
# print(f"ACTION LOG(92): node 29691 feeded by infection")
# self.state_counts[STATES.E][self.t] += 1
# self.state_counts[orig_state][self.t] -= 1
# self.state_increments[STATES.E][self.t] += 1
# self.memberships[STATES.E][GIRL] = 1
# self.memberships[orig_state][GIRL] = 0
# if not numI > 0:
# break
# gc.collect()
if self.t < T:
for t in range(self.t+1, T+1):
if (t >= len(self.state_counts[0])):
self.increase_data_series_length()
for state in self.states:
self.state_counts[state][t] = self.state_counts[state][t-1]
self.state_increments[state][t] = 0
self.t = T
# finalize durations
for s in self.states:
durations = self.durations[self.memberships[s].flatten() == 1]
self.states_durations[s].extend(list(durations))
if print_interval >= 0:
self.print(verbose)
self.finalize_data_series()
return True
[docs]
def increase_data_series_length(self):
"""Extend all daily time-series buffers by 100 entries.
Called automatically when the pre-allocated storage is exhausted.
"""
for state in self.states:
self.state_counts[state].bloat(100)
self.state_increments[state].bloat(100)
self.num_tests.bloat(100)
self.num_qtests.bloat(100)
self.w_times.bloat(100)
self.all_positive_tests.bloat(100)
self.N.bloat(100)
# self.states_history.bloat(100)
self.meaneprobs.bloat(100)
self.medianeprobs.bloat(100)
[docs]
def increase_history_len(self):
"""Extend the event-history and time-series buffers.
Enlarges ``self.tseries`` and ``self.history`` by
``10 * num_nodes`` entries.
"""
self.tseries.bloat(10*self.num_nodes)
self.history.bloat(10*self.num_nodes)
[docs]
def finalize_data_series(self):
"""Trim all time-series to the actually consumed length.
Calls ``finalize(self.t)`` on every daily
:class:`~utils.history_utils.TimeSeries` and ``finalize(self.tidx)``
on the event-log series.
"""
self.tseries.finalize(self.tidx)
self.history.finalize(self.tidx)
self.num_tests.finalize(self.t)
self.num_qtests.finalize(self.t)
self.w_times.finalize(self.t)
self.all_positive_tests.finalize(self.t)
for state in self.states:
self.state_counts[state].finalize(self.t)
self.state_increments[state].finalize(self.t)
self.N.finalize(self.t)
# self.states_history.finalize(self.t)
self.meaneprobs.finalize(self.t)
self.medianeprobs.finalize(self.t)
[docs]
def current_state_count(self, state):
"""Return the count of nodes in *state* at the current day.
Overrides the parent method: uses ``self.t`` (day index) rather than
``self.tidx`` (event index).
Args:
state (int): State code.
Returns:
int: Node count on the current day.
"""
return self.state_counts[state][self.t]
[docs]
def current_N(self):
"""Return the effective population size on the current day.
Overrides the parent method: uses ``self.t`` rather than
``self.tidx``.
Returns:
float: Population size (excluding invisible-state nodes).
"""
return self.N[self.t]
[docs]
def get_state_count(self, state=None):
"""Return per-day count time-series for a given state or all states.
Args:
state (int, optional): State code. If ``None``, returns the full
``state_counts`` dictionary. Defaults to ``None``.
Returns:
TimeSeries or dict: The per-day count series for *state*, or the
complete ``state_counts`` mapping when *state* is ``None``.
"""
if state is None:
return self.state_counts
return self.state_counts[state]
[docs]
def to_df(self):
"""Convert simulation output to a :class:`pandas.DataFrame`.
Extends the parent :meth:`to_df` with additional test-related columns:
``tests``, ``quarantine_tests``, ``sum_of_waiting``, and
``all_positive_tests``.
Returns:
pandas.DataFrame: Combined state-count, increment, and test
statistics indexed by day.
"""
df = super().to_df()
df = df.assign(
tests=self.num_tests,
quarantine_tests=self.num_qtests,
sum_of_waiting=self.w_times,
all_positive_tests=self.all_positive_tests
)
return df
# index = range(0, self.t+1)
# col_increments = {
# "inc_" + self.state_str_dict[x]: col_inc
# for x, col_inc in self.state_increments.items()
# }
# col_states = {
# self.state_str_dict[x]: count
# for x, count in self.state_counts.items()
# }
# columns = {**col_states, **col_increments, **col_tests}
# columns["day"] = np.floor(index).astype(int)
# columns["mean_p_infection"] = self.meaneprobs
# columns["median_p_infection"] = self.medianeprobs
# df = pd.DataFrame(columns, index=index)
# df.index.rename('T', inplace=True)
# return df
# def save(self, file_or_filename):
# """ Save timeseries. They have different format than in BaseEngine,
# so I redefined save method here """
# df = self.to_df()
# df.to_csv(file_or_filename)
# print(df)
[docs]
def save_durations(self, f):
"""Write per-state duration lists to an open file as CSV rows.
Each row contains the state label followed by comma-separated integer
durations (in days).
Args:
f (file-like): Open writable file object.
"""
for s in self.states:
line = ",".join([str(x) for x in self.states_durations[s]])
print(f"{self.state_str_dict[s]},{line}", file=f)
[docs]
def save_node_states(self, filename):
"""Write the per-node state history to a CSV file.
Args:
filename (str): Destination file path.
"""
index = range(0, self.t+1)
columns = self.states_history.values
df = pd.DataFrame(columns, index=index)
df.to_csv(filename)
# df = df.replace(self.state_str_dict)
# df.to_csv(filename)
# print(df)
[docs]
def move_to_E(self, num):
"""Randomly expose *num* susceptible nodes by moving them to state E.
Only nodes currently in S or S_s are eligible. The operation
updates state counts, increments, membership arrays, and duration
tracking for the current day.
Args:
num (int): Number of nodes to expose.
"""
nodes = np.random.choice(self.num_nodes, num, replace=False)
for node_number in nodes:
orig_state = self.memberships[:, node_number].nonzero()[0][0]
if orig_state not in (STATES.S_s, STATES.S):
continue
new_state = STATES.E
print(f"DBG Moving node {node_number} from {self.state_str_dict[orig_state]} to {self.state_str_dict[new_state]}")
self.states_durations[orig_state].append(
self.durations[node_number])
self.durations[node_number] = 0
self.state_counts[new_state][self.t] += 1
self.state_counts[orig_state][self.t] -= 1
self.state_increments[new_state][self.t] += 1
self.memberships[new_state][node_number] = 1
self.memberships[orig_state][node_number] = 0
[docs]
def move_to_R(self, nodes):
"""Move a specific set of exposed nodes directly to the detected-recovered state.
Only nodes currently in state E are supported.
Args:
nodes (iterable of int): Node indices to move.
Raises:
ValueError: If any node in *nodes* is not currently in state E.
"""
for node_number in nodes:
orig_state = self.memberships[:, node_number].nonzero()[0][0]
if orig_state != STATES.E:
raise ValueError()
new_state = STATES.R_d
print(f"DBG Moving node {node_number} from {self.state_str_dict[orig_state]} to {self.state_str_dict[new_state]}")
self.states_durations[orig_state].append(
self.durations[node_number])
self.durations[node_number] = 0
self.state_counts[new_state][self.t] += 1
self.state_counts[orig_state][self.t] -= 1
self.state_increments[new_state][self.t] += 1
self.memberships[new_state][node_number] = 1
self.memberships[orig_state][node_number] = 0
[docs]
def force_infect(self, nodes):
"""Forcibly move a set of nodes to the symptomatic-infectious state I_s.
Dead nodes (D_d, D_u) are silently skipped. All other nodes are
moved regardless of their current state. Intended for scenario
seeding.
Args:
nodes (iterable of int): Node indices to infect.
"""
for node_number in nodes:
orig_state = self.memberships[:, node_number].nonzero()[0][0]
if orig_state in (STATES.D_d, STATES.D_u):
continue
# asymptomatic = np.random.rand() > self.asymptomatic_rate
# new_state = STATES.I_n if asymptomatic else STATES.I_a
new_state = STATES.I_s
print(f"DBG Moving node {node_number} from {self.state_str_dict[orig_state]} to {self.state_str_dict[new_state]}")
self.states_durations[orig_state].append(
self.durations[node_number])
self.durations[node_number] = 0
self.state_counts[new_state][self.t] += 1
self.state_counts[orig_state][self.t] -= 1
self.state_increments[new_state][self.t] += 1
self.memberships[new_state][node_number] = 1
self.memberships[orig_state][node_number] = 0
[docs]
def detected_node(self, node_number):
"""Mark a node as detected (positive test) and transition it accordingly.
Maps the node's current undetected state to its detected counterpart
(e.g. E → E_d, I_s → I_ds). If the node is already in a detected or
terminal state, the call is a no-op. Also updates test-waiting
statistics.
Args:
node_number (int): Index of the node that tested positive.
Raises:
ValueError: If the node is in an unexpected state.
"""
# self.num_qtests[self.t] += 1
orig_state = self.memberships[:, node_number].nonzero()[0][0]
if orig_state in (STATES.E_d, STATES.I_da, STATES.I_dn, STATES.I_ds, STATES.J_dn, STATES.J_ds,
STATES.R_d, STATES.D_d):
return
transitions = (
(STATES.E, STATES.E_d),
(STATES.I_a, STATES.I_da),
(STATES.I_n, STATES.I_dn),
(STATES.I_s, STATES.I_ds),
(STATES.J_n, STATES.J_dn),
(STATES.J_s, STATES.J_ds),
(STATES.R_u, STATES.R_d),
(STATES.D_u, STATES.D_d),
)
if self.test_waiting[node_number] > 0:
self.w_times[self.t] += self.test_waiting[node_number]
self.all_positive_tests[self.t] += 1
for t in transitions:
if orig_state == t[0]:
new_state = t[1]
if 29691 == node_number:
print(f"ACTION LOG({self.t}): node 29691 forced to change state to {self.state_str_dict[new_state]} from {self.state_str_dict[orig_state]}")
self.states_durations[orig_state].append(
self.durations[node_number])
self.durations[node_number] = 0
self.state_counts[new_state][self.t] += 1
self.state_counts[orig_state][self.t] -= 1
self.state_increments[new_state][self.t] += 1
self.memberships[new_state][node_number] = 1
self.memberships[orig_state][node_number] = 0
return
raise ValueError(f"Unexpected state: {self.state_str_dict[orig_state]}")