"""Plan-based discrete-time simulation engine.
This module defines :class:`SimulationEngine`, which replaces the Gillespie
propensity loop with a *planning* approach: each node holds a pre-scheduled
``(time_to_go, state_to_go)`` pair. Every day the engine decrements counters
and moves nodes whose countdown has reached zero, then invokes
:meth:`daily_update` for state-dependent daily checks (e.g. infection
attempts).
"""
import numpy as np
import pandas as pd
import logging
import time
from models.engine import BaseEngine
from utils.history_utils import TimeSeries, TransitionHistory
import utils.global_configs as global_configs
from utils.global_configs import monitor
EXPECTED_NUM_DAYS = 300
[docs]
class SimulationEngine(BaseEngine):
"""Discrete-time plan-based epidemic engine.
Each simulated agent (node) carries a *plan*: a countdown
(``time_to_go``) and a target state (``state_to_go``). On every
simulated day:
1. :meth:`daily_update` is called for nodes that need a check
(e.g. susceptible nodes that might be infected).
2. All countdown timers are decremented.
3. Nodes whose timer hits zero are moved to their planned state via
:meth:`change_states`.
4. :meth:`update_plan` sets new plans for the nodes that just moved.
Subclasses implement :meth:`daily_update` and :meth:`update_plan` to
define the model's disease-progression logic.
Class-level attributes (override in subclasses):
states (list): Ordered list of state codes.
num_states (int): Number of states.
state_str_dict (dict): State-code → label mapping.
ext_code (int): State code used for external nodes.
transitions (list): Allowed ``(from, to)`` pairs.
num_transitions (int): Number of transitions.
final_states (list): Absorbing states.
invisible_states (list): States excluded from population count.
unstable_states (list): States that can still change.
fixed_model_parameters (dict): Scalar constructor parameters.
model_parameters (dict): Per-node constructor parameters.
common_arguments (dict): Common constructor parameters (seed, etc.).
"""
states = []
num_states = len(states)
state_str_dict = {}
ext_code = 0
transitions = []
num_transitions = len(transitions)
final_states = []
invisible_states = []
unstable_states = []
fixed_model_parameters = {}
model_parameters = {}
common_arguments = {
"random_seed": (None, "random seed value"),
"start_day": (1, "day to start")
}
[docs]
def __init__(self, G, **kwargs):
"""Initialise the simulation engine on a contact graph.
Args:
G: Contact graph or multi-layer graph object. Stored as both
``self.G`` (backward compatibility) and ``self.graph``.
**kwargs: Keyword arguments that override any default declared in
``fixed_model_parameters``, ``model_parameters``, or
``common_arguments``. State initial counts supplied as
``init_<STATE_LABEL>=<count>``.
"""
self.G = G # backward compatibility
self.graph = G
self.init_kwargs = kwargs
# 2. model initialization
self.inicialization()
# 3. time and history setup
self.setup_series_and_time_keeping()
# 4. init states and their counts
self.states_and_counts_init(ext_nodes=self.num_ext_nodes,
ext_code=self.ext_code)
# 5. set callback to None
self.periodic_update_callback = None
self.T = self.start_day - 1
[docs]
def update_graph(self, new_G):
"""Update the internal graph reference and derived node metadata.
Safe to call with ``None`` (no-op). Updates ``self.graph``,
``self.num_nodes``, ``self.num_ext_nodes``, and ``self.nodes``.
Args:
new_G: New graph object, or ``None`` to leave the graph unchanged.
"""
if new_G is not None:
self.G = new_G # just for backward compability
self.graph = new_G
self.num_nodes = self.graph.num_nodes
try:
self.num_ext_nodes = self.graph.num_nodes - self.graph.num_base_nodes
except AttributeError:
# for saved old graph
self.num_ext_nodes = 0
self.nodes = np.arange(self.graph.number_of_nodes).reshape(-1, 1)
[docs]
def inicialization(self):
"""Initialise model parameters and build node-index array.
Delegates to the parent :meth:`inicialization`, then stores a
``(num_nodes, 1)`` array of node indices in ``self.nodes`` and
caches ``self.num_nodes``.
"""
super().inicialization()
# node indexes
self.nodes = np.arange(self.graph.num_nodes).reshape(-1, 1)
self.num_nodes = self.graph.num_nodes
[docs]
def setup_series_and_time_keeping(self):
"""Create time-series buffers and per-node tracking arrays.
Extends the parent setup with:
* Event-log buffers (``tseries``, ``history``).
* State-history array (size depends on ``global_configs.SAVE_NODES``).
* Per-state duration lists (when ``global_configs.SAVE_DURATIONS``).
* Per-node ``durations`` counter.
* Per-state :class:`~utils.history_utils.TimeSeries` for counts and
increments (pre-allocated to ``EXPECTED_NUM_DAYS`` entries).
"""
super().setup_series_and_time_keeping()
tseries_len = self.num_transitions * self.num_nodes
self.tseries = TimeSeries(tseries_len, dtype=float)
self.history = TransitionHistory(tseries_len)
# state history
if global_configs.SAVE_NODES:
history_len = EXPECTED_NUM_DAYS
else:
history_len = 1
self.states_history = TransitionHistory(
history_len, width=self.num_nodes)
if global_configs.SAVE_DURATIONS:
self.states_durations = {
s: []
for s in self.states
}
self.durations = np.zeros(self.num_nodes, dtype=int)
# state_counts ... numbers of inidividuals in given states
self.state_counts = {
state: TimeSeries(EXPECTED_NUM_DAYS, dtype=int)
for state in self.states
}
self.state_increments = {
state: TimeSeries(EXPECTED_NUM_DAYS, dtype=int)
for state in self.states
}
# N ... actual number of individuals in population
self.N = TimeSeries(EXPECTED_NUM_DAYS, dtype=float)
[docs]
def states_and_counts_init(self, ext_nodes=None, ext_code=None):
"""Initialise state counts and per-node planning arrays.
Extends the parent :meth:`states_and_counts_init` with:
* ``self.time_to_go`` – per-node countdown to next transition
(``-1`` means "no scheduled transition").
* ``self.state_to_go`` – planned next state for each node.
* ``self.current_state`` – copy of the initial state assignment.
* ``self.need_update`` – boolean flag per node indicating that the
plan must be recomputed.
Args:
ext_nodes (int, optional): Number of external nodes. Defaults to
``None``.
ext_code (int, optional): State code for external nodes. Defaults
to ``None``.
"""
super().states_and_counts_init(ext_nodes, ext_code)
# time to go until I move to the state state_to_go
self.time_to_go = np.full(
self.num_nodes, fill_value=-1, dtype="int32").reshape(-1, 1)
self.state_to_go = np.full(
self.num_nodes, fill_value=-1, dtype="int32").reshape(-1, 1)
self.current_state = self.states_history[0].copy().reshape(-1, 1)
# need update = need to recalculate time to go and state_to_go
self.need_update = np.ones(self.num_nodes, dtype=bool)
[docs]
def daily_update(self, nodes):
"""Perform daily per-node checks (e.g. infection attempts).
Called once per day for nodes flagged in ``self.need_check``.
No-op in the base class. Subclasses override this to implement
infection logic and other daily events.
Args:
nodes (numpy.ndarray): Boolean bitmap of nodes that require a
daily check.
"""
pass
[docs]
def change_states(self, nodes, target_state=None):
"""Move nodes to their planned (or a forced) target state.
Clears the old membership, assigns the new state, updates
``state_counts``, ``state_increments``, and optionally
``states_history``, then calls :meth:`update_plan` so each node
gets a fresh plan.
Args:
nodes (numpy.ndarray): Boolean bitmap indicating which nodes
should change state.
target_state (int, optional): If given, all *nodes* are moved
to this state, ignoring ``self.state_to_go``. If ``None``
(default), each node is moved to its own ``self.state_to_go``
value.
"""
# discard current state
self.memberships[:, nodes == True] = 0
for node in nodes.nonzero()[0]:
if target_state is None:
new_state = self.state_to_go[node][0]
else:
new_state = target_state
old_state = self.current_state[node, 0]
self.memberships[new_state, node] = 1
self.state_counts[new_state][self.t] += 1
self.state_counts[old_state][self.t] -= 1
self.state_increments[new_state][self.t] += 1
if global_configs.SAVE_NODES:
self.states_history[self.t][node] = new_state
if target_state is None:
self.current_state[nodes] = self.state_to_go[nodes]
else:
self.current_state[nodes] = target_state
self.update_plan(nodes)
[docs]
def update_plan(self, nodes):
"""Generate new transition plans for nodes that just changed state.
Sets ``self.time_to_go`` and ``self.state_to_go`` for each node in
*nodes* based on the node's current state. No-op in the base class.
Subclasses override this to implement state-specific duration sampling.
Args:
nodes (numpy.ndarray): Boolean bitmap of nodes whose plans need
updating.
"""
pass
def _get_target_nodes(self, nodes, state):
"""Return a bitmap of nodes that are both in *nodes* and in *state*.
Args:
nodes (numpy.ndarray): Boolean bitmap of candidate nodes.
state (int): State code to intersect with.
Returns:
numpy.ndarray: Boolean bitmap of shape ``(num_nodes,)`` that is
``True`` only where *nodes* is ``True`` and the node's current
state equals *state*.
"""
ret = nodes.copy().ravel()
is_target_state = self.memberships[state, ret, 0]
ret[nodes.flatten()] = is_target_state
# ret = np.logical_and(
# self.memberships[state].flatten(),
# nodes.flatten()
# )
return ret
[docs]
def print(self, verbose=False):
"""Print the current calendar day and optionally per-state counts.
Args:
verbose (bool, optional): If ``True``, prints ``T`` (calendar day)
and the count for every state. Defaults to ``False``.
"""
if verbose:
print(f"T = {self.T} ({self.t})")
for state in self.states:
print(f"\t {self.state_str_dict[state]} = {self.state_counts[state][self.t]}")
[docs]
def save_durations(self, f):
"""Write per-state duration lists to an open file as CSV rows.
Args:
f (file-like): Open writable file object.
"""
for s in self.states:
line = ",".join([str(x) for x in self.states_durations[s]])
print(f"{self.state_str_dict[s]},{line}", file=f)
[docs]
def save_node_states(self, filename):
"""Save the per-node daily state history to a CSV file.
If ``global_configs.SAVE_NODES`` is ``False``, logs a warning and
returns an empty DataFrame.
Args:
filename (str): Destination file path.
Returns:
pandas.DataFrame: Empty DataFrame when node states were not saved.
"""
if global_configs.SAVE_NODES is False:
logging.warning(
"Nodes states were not saved, returning empty data frame.")
return pd.DataFrame()
index = range(0, self.t+1)
columns = self.states_history.values
df = pd.DataFrame(columns, index=index)
df.to_csv(filename)
# df = df.replace(self.state_str_dict)
# df.to_csv(filename)
# print(df)
[docs]
def to_df(self):
"""Convert simulation output to a :class:`pandas.DataFrame`.
Extends the parent :meth:`to_df` by adjusting the ``day`` column and
the index when ``start_day`` is not 1.
Returns:
pandas.DataFrame: State-count and increment time-series with
calendar-day index.
"""
df = super().to_df()
if self.start_day != 1:
df["day"] = self.start_day + df["day"] - 1
df.index = self.start_day + df.index - 1
return df
[docs]
def run(self, T, print_interval=10, verbose=False):
"""Run the plan-based simulation for *T* days.
Iterates over days 1..T, calling :meth:`run_iteration` each day and
the periodic callback when set. If the epidemic ends before *T* days,
fills remaining days with the last observed counts.
Args:
T (int): Number of days to simulate.
print_interval (int, optional): Print status every this many days;
``0`` or negative suppresses output. Defaults to ``10``.
verbose (bool, optional): If ``True``, include per-state detail.
Defaults to ``False``.
Returns:
bool: Always ``True``.
"""
if global_configs.MONITOR_NODE is not None:
monitor(0, f" being monitored, now in {self.state_str_dict[self.current_state[global_configs.MONITOR_NODE,0]]}")
running = True
self.tidx = 0
self.T = self.start_day - 1
if print_interval >= 0:
self.print(verbose)
for self.t in range(1, T+1):
self.T = self.start_day + self.t - 1
if __debug__ and print_interval >= 0 and verbose:
print(flush=True)
if (self.t >= len(self.state_counts[0])):
# room has run out in the timeseries storage arrays; double the size of these arrays
self.increase_data_series_length()
if print_interval > 0 and verbose:
start = time.time()
running = self.run_iteration()
# run periodical update
if self.periodic_update_callback is not None:
self.periodic_update_callback.run()
if print_interval > 0 and (self.t % print_interval == 0):
self.print(verbose)
if verbose:
end = time.time()
print(f"Last day took: {end - start} seconds")
if self.t < T:
for t in range(self.t+1, T+1):
if (t >= len(self.state_counts[0])):
self.increase_data_series_length()
for state in self.states:
self.state_counts[state][t] = self.state_counts[state][t-1]
self.state_increments[state][t] = 0
# finalize durations
if global_configs.SAVE_DURATIONS:
for s in self.states:
durations = self.durations[self.memberships[s].flatten() == 1]
durations = durations[durations != 0]
self.states_durations[s].extend(list(durations))
if print_interval >= 0:
self.print(verbose)
self.finalize_data_series()
return True
[docs]
def run_iteration(self):
"""Perform one day of plan-based simulation.
Steps performed each day:
1. Copies previous-day state counts and resets increments.
2. Increments all duration counters.
3. Calls :meth:`daily_update` for nodes that need a check.
4. Decrements ``self.time_to_go`` for all nodes.
5. Moves nodes whose countdown reached zero via :meth:`change_states`.
6. Saves duration statistics when configured.
"""
logging.debug("DBG run iteration")
# prepare
# add timeseries members
for state in self.states:
self.state_counts[state][self.t] = self.state_counts[state][self.t-1]
self.state_increments[state][self.t] = 0
self.N[self.t] = self.N[self.t-1]
self.durations += 1
if global_configs.SAVE_NODES:
self.states_history[self.t] = self.states_history[self.t-1]
#print("DBG Time to go", self.time_to_go)
#print("DBG State to go", self.state_to_go)
# update times_to_go and states_to_go and
# do daily_checkup
self.daily_update(self.need_check)
self.time_to_go -= 1
#print("DBG Time to go", self.time_to_go)
nodes_to_move = self.time_to_go == 0
if global_configs.MONITOR_NODE and nodes_to_move[global_configs.MONITOR_NODE]:
node = global_configs.MONITOR_NODE
monitor(self.t,
f"changing state from {self.state_str_dict[self.current_state[node,0]]} to {self.state_str_dict[self.state_to_go[node,0]]}")
orig_states = self.current_state[nodes_to_move]
durs = self.durations[nodes_to_move.flatten()]
self.change_states(nodes_to_move)
self.durations[nodes_to_move.flatten()] = 0
if global_configs.SAVE_DURATIONS:
for s, d in zip(orig_states, durs):
assert(d > 0)
self.states_durations[s].append(d)