"""Daily-batched Gillespie engine.
This module defines :class:`DailyEngine`, which runs the Gillespie event
selection intra-day but batches all state transitions to midnight, so the
observable state changes once per simulated day.
"""
import numpy as np
import scipy as scipy
import scipy.integrate
import networkx as nx
import time
from utils.history_utils import TimeSeries, TransitionHistory
from models.engine_seirspluslike import SeirsPlusLikeEngine
[docs]
class DailyEngine(SeirsPlusLikeEngine):
"""Gillespie engine that applies state transitions only at midnight.
Inherits from :class:`~models.engine_seirspluslike.SeirsPlusLikeEngine`.
During the day the engine collects proposed transitions in a to-do list;
at midnight (:meth:`update_states` / :meth:`midnight`) the transitions are
committed in bulk, ensuring the observable model state changes only once
per day.
"""
[docs]
def inicialization(self):
"""Initialise engine and allocate the daily to-do lists.
Creates empty ``self.todo_list`` and ``self.todo_t`` accumulators
before delegating to the parent initialiser.
"""
self.todo_list = []
self.todo_t = []
super().inicialization()
[docs]
def run_iteration(self, alpha, cumsum, transition_types):
"""Sample the next Gillespie event and add it to the pending to-do list.
Does *not* apply the transition immediately; it is deferred to the
next call to :meth:`update_states`. At most one pending transition
per node is kept (first-event-wins within a day).
Args:
alpha (float): Total propensity (sum of all propensities).
cumsum (numpy.ndarray): Cumulative-sum vector over flattened
propensities, used for event selection.
transition_types (list): Ordered list of ``(from_state,
to_state)`` tuples matching the propensity order.
Returns:
bool: Always ``True`` (day-level termination is handled by
:meth:`run`).
"""
# 1. Generate 2 random numbers uniformly distributed in (0,1)
r1 = np.random.rand()
r2 = np.random.rand()
# 2. Calculate propensities
# propensities, transition_types = self.calc_propensities()
# Terminate when probability of all events is 0:
# if propensities.sum() <= 0.0:
# self.finalize_data_series()
# return False
# 4. Compute the time until the next event takes place
tau = (1/alpha)*np.log(float(1/r1))
self.t += tau
# 5. Compute which event takes place
transition_idx = np.searchsorted(cumsum, r2*alpha)
transition_node = transition_idx % self.num_nodes
transition_type = transition_types[int(transition_idx/self.num_nodes)]
if transition_node not in [x[0] for x in self.todo_list]:
# if (transition_node, transition_type) not in self.todo_list:
self.todo_t.append(self.t)
self.todo_list.append((transition_node, transition_type))
return True
[docs]
def update_states(self):
"""Commit all pending transitions accumulated during the current day.
Iterates ``self.todo_list`` and applies each transition by updating
``self.memberships``, ``self.state_counts``, ``self.history``, and
``self.N``. Clears the to-do lists afterwards.
"""
# print("updating states")
# for t, (transition_node, transition_type) in zip(self.todo_t, self.todo_list):
# print(t, transition_node, "-->", transition_type)
# 6. Update node states and data series
for t, (transition_node, transition_type) in zip(self.todo_t, self.todo_list):
self.tidx += 1
if (self.tidx >= self.tseries.len()-1):
# Room has run out in the timeseries storage arrays; double the size of these arrays
self.increase_data_series_length()
assert (self.memberships[transition_type[0], transition_node] == 1), (f"Assertion error: Node {transition_node} has unexpected current state, given the intended transition of {transition_type}.")
self.memberships[transition_type[0], transition_node] = 0
self.memberships[transition_type[1], transition_node] = 1
self.tseries[self.tidx] = t
self.history[self.tidx] = (transition_node, *transition_type)
for state in self.states:
self.state_counts[state][self.tidx] = self.state_counts[state][self.tidx-1]
self.state_counts[transition_type[0]][self.tidx] -= 1
self.state_counts[transition_type[1]][self.tidx] += 1
self.N[self.tidx] = self.N[self.tidx-1]
# if node died
if transition_type[1] in (self.invisible_states):
self.N[self.tidx] = self.N[self.tidx-1] - 1
del self.todo_list
del self.todo_t
self.todo_list = []
self.todo_t = []
[docs]
def midnight(self, verbose):
"""Execute end-of-day actions: commit transitions and recalculate propensities.
Calls :meth:`update_states` to apply all pending transitions, fires
``self.periodic_update_callback`` if set (updating the graph if the
callback returns a new one), then recomputes propensities for the next
day via :meth:`propensities_recalc`.
Args:
verbose (bool): Passed through (currently unused).
Returns:
tuple: ``(alpha, cumsum, has_events, transition_types)`` as
returned by :meth:`propensities_recalc`.
"""
self.update_states()
# run periodical update
if self.periodic_update_callback:
changes = self.periodic_update_callback(
self.history, self.tseries[:self.tidx+1], self.t)
if "graph" in changes:
print("CHANGING GRAPH")
self.update_graph(changes["graph"])
return self.propensities_recalc()
[docs]
def print(self, verbose=False):
"""Print the current simulation time and optionally per-state counts.
Args:
verbose (bool, optional): If ``True``, also print per-state
counts. Defaults to ``False``.
"""
print("t = %.2f" % self.t)
if verbose:
for state in self.states:
print(f"\t {self.state_str_dict[state]} = {self.current_state_count(state)}")
print(flush=True)
[docs]
def propensities_recalc(self):
"""Recalculate propensities and return flattened cumulative-sum data.
Returns:
tuple: A 4-tuple ``(alpha, cumsum, has_events, transition_types)``
where:
* ``alpha`` (float) – total propensity.
* ``cumsum`` (numpy.ndarray) – cumulative sum of flattened
propensity array.
* ``has_events`` (bool) – ``True`` if total propensity is > 0.
* ``transition_types`` (list) – ordered ``(from, to)`` transition
pairs.
"""
# 2. Calculate propensities
propensities = np.hstack(self.calc_propensities())
transition_types = self.transitions
# 3. Calculate alpha
# nebylo by rychlejsi order=C a prohodi // a % ?
propensities_flat = propensities.ravel(order="F")
cumsum = propensities_flat.cumsum()
alpha = propensities_flat.sum()
return alpha, cumsum, propensities.sum() > 0.0, transition_types
[docs]
def run(self, T, print_interval=10, verbose=False):
"""Run the daily-batched simulation for up to *T* time units.
Calls :meth:`propensities_recalc` once at the start, then loops over
:meth:`run_iteration`. At each midnight calls :meth:`midnight` to
commit transitions and recompute propensities.
Args:
T (int or float): Duration to simulate.
print_interval (int, optional): Print status every this many
days. Defaults to ``10``.
verbose (bool, optional): If ``True``, include per-state detail in
progress messages. Defaults to ``False``.
Returns:
bool: ``True`` on completion, ``False`` if *T* <= 0.
"""
if not T > 0:
return False
self.tmax += T
running = True
day = -1
self.print(verbose=True)
if print_interval > 0 and verbose:
start = time.time()
alpha, cumsum, running, transition_types = self.propensities_recalc()
while running:
running = self.run_iteration(alpha, cumsum, transition_types)
# true after the first event after midnight
day_changed = day != int(self.t)
day = int(self.t)
if day_changed and day != 0:
alpha, cumsum, running, transition_types = self.midnight(
verbose)
if print_interval > 0 and (day % print_interval == 0):
self.print(verbose)
if verbose:
end = time.time()
print("Last day took: ", end - start, "seconds")
start = time.time()
# Terminate if tmax reached or num infectious and num exposed is 0:
numI = sum([self.current_state_count(s)
for s in self.unstable_states
])
if self.t >= self.tmax or numI < 1:
self.finalize_data_series()
running = False
day = int(self.t)
self.print(verbose)
self.finalize_data_series()
return True
# def increase_data_series_length(self):
# self.tseries.bloat()
# self.history.bloat()
# for state in self.states:
# self.state_counts[state].bloat()
# self.N.bloat()
# def finalize_data_series(self):
# self.tseries.finalize(self.tidx)
# self.history.finalize(self.tidx)
# for state in self.states:
# self.state_counts[state].finalize(self.tidx)
# self.N.finalize(self.tidx)