"""Plotting utilities for MAIS simulation histories.
This module provides functions for loading simulation output CSV files and
visualising epidemic curves and other time-series metrics via Matplotlib and
Seaborn. It supports single-run plots, multi-run aggregated line plots, and
animated state-histogram views.
Public API:
- :func:`plot_history`: Quick plot of a single history file.
- :func:`plot_histories`: Aggregate multiple runs on one axis.
- :func:`plot_mutliple_policies`: Compare policies on a single metric.
- :func:`plot_mutliple_policies_everything`: Multi-panel comparison across
all tracked metrics.
- :func:`plot_state_histogram`: Animated per-state bar chart over time.
"""
import math
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from matplotlib import animation
from typing import Dict, List
[docs]
def plot_history(filename: str):
"""Plot the ``all_infectious`` curve from a single simulation history CSV.
Loads the history file, plots ``all_infectious`` against ``T`` using the
default Pandas/Matplotlib backend, and displays the figure interactively.
Args:
filename (str): Path to the simulation output CSV file.
"""
history = _load_history(filename)
history.plot(x="T", y="all_infectious")
plt.show()
[docs]
def plot_histories(*args, group_days: int = None, group_func: str = "max", **kwargs):
"""Plot ``all_infectious`` from multiple simulation history CSV files.
Loads each history file, optionally groups records into day-buckets, then
overlays all runs on a single line plot using :func:`_plot_lineplot`.
Args:
*args (str): One or more paths to simulation output CSV files.
group_days (int, optional): If set, aggregates rows into buckets of
this many days before plotting. Defaults to ``None`` (no
grouping).
group_func (str, optional): Aggregation function applied within each
day bucket (e.g., ``"max"``, ``"mean"``). Defaults to ``"max"``.
**kwargs: Additional keyword arguments forwarded to
:func:`_plot_lineplot` (e.g., ``title``, ``save_path``).
"""
histories = [_history_with_fname(
filename, group_days=group_days, group_func=group_func) for filename in args]
history_one_df = pd.concat(histories)
_plot_lineplot(history_one_df, "day", "all_infectious", **kwargs)
[docs]
def plot_mutliple_policies(policy_dict: Dict[str, List[str]],
group_days: int = None, group_func: str = "max", value="all_infectious", max_days=None, **kwargs):
"""Compare a single metric across multiple policies on one line plot.
For each policy, loads all associated history files, concatenates them,
and renders a median line with inter-quartile shading using
:func:`_plot_lineplot`.
Args:
policy_dict (Dict[str, List[str]]): Mapping of policy name to a list
of history CSV file paths for that policy.
group_days (int, optional): Day-bucket size for temporal aggregation.
Defaults to ``None`` (no grouping).
group_func (str, optional): Aggregation function applied per bucket.
Defaults to ``"max"``.
value (str, optional): Column name of the metric to plot on the
y-axis. Defaults to ``"all_infectious"``.
max_days (int, optional): If set, truncates each history to the first
``max_days`` rows. Defaults to ``None``.
**kwargs: Additional keyword arguments forwarded to
:func:`_plot_lineplot`.
"""
histories = []
for policy_key, history_list in policy_dict.items():
histories.extend([_history_with_fname(filename,
group_days=group_days,
group_func=group_func,
policy_name=policy_key,
max_days=max_days)
for filename in history_list])
history_one_df = pd.concat(histories)
_plot_lineplot(history_one_df, "day", value,
hue="policy_name", **kwargs)
[docs]
def plot_mutliple_policies_everything(policy_dict: Dict[str, List[str]],
group_days: int = None, group_func: str = "max",
max_days=None, **kwargs):
"""Render a multi-panel comparison of all tracked metrics across multiple policies.
Loads and concatenates histories for every policy, then delegates to
either :func:`_plot_lineplot2` (variant 2) or :func:`_plot_lineplot3`
(default) depending on the optional ``variant`` keyword argument.
Args:
policy_dict (Dict[str, List[str]]): Mapping of policy name to a list
of history CSV file paths for that policy.
group_days (int, optional): Day-bucket size for temporal aggregation.
Defaults to ``None`` (no grouping).
group_func (str, optional): Aggregation function applied per bucket.
Defaults to ``"max"``.
max_days (int, optional): If set, truncates each history to the first
``max_days`` rows. Defaults to ``None``.
**kwargs: Additional keyword arguments forwarded to the chosen plot
function. The special key ``variant`` (int) selects the plot
layout (``2`` selects :func:`_plot_lineplot2`; any other value
selects :func:`_plot_lineplot3`). The ``title`` key is required
by the underlying plot functions.
"""
histories = []
for policy_key, history_list in policy_dict.items():
histories.extend([_history_with_fname(filename,
group_days=group_days,
group_func=group_func,
policy_name=policy_key,
max_days=max_days)
for filename in history_list])
history_one_df = pd.concat(histories)
if "variant" in kwargs and kwargs["variant"] == 2:
plot_function = _plot_lineplot2
del kwargs["variant"]
else:
plot_function = _plot_lineplot3
plot_function(history_one_df, "day",
hue="policy_name", **kwargs)
[docs]
def plot_state_histogram(filename: str, title: str = "Simulation", states: List[str] = None, save_path: str = None):
"""Render an animated bar chart showing the per-state population over time.
Reads the given history CSV and produces an animation where each frame
corresponds to one simulation day. Each bar represents a disease/model
state, and its height equals the number of nodes in that state on the
corresponding day.
Args:
filename (str): Path to the simulation output CSV file.
title (str, optional): Base title displayed in the figure. The current
day number is appended dynamically per frame. Defaults to
``"Simulation"``.
states (List[str], optional): Subset of state column names to include
in the histogram. If ``None``, all state columns present in the
CSV (excluding metadata columns) are shown. Defaults to ``None``.
save_path (str, optional): If provided, the animation is saved to this
file path using FFMpeg at 10 fps before being displayed. Defaults
to ``None``.
"""
def animate(i):
fig.suptitle(f"{title} - day {day_labels.iloc[i]}")
data_i = data.iloc[i]
for d, b in zip(data_i, bars):
b.set_height(math.ceil(d))
fig, ax = plt.subplots()
history = _history_with_fname(filename, group_days=1, keep_only_all=False)
day_labels = history["day"]
data = history.drop(["T", "day", "all_infectious", "filename"], axis=1)
if states is not None:
data = data[states]
bars = plt.barplot(range(data.shape[1]),
data.values.max(), tick_label=data.columns)
anim = animation.FuncAnimation(fig, animate, repeat=False, blit=False, frames=history.shape[0],
interval=100)
if save_path is not None:
anim.save(save_path, writer=animation.FFMpegWriter(fps=10))
plt.show()
def _plot_lineplot(history_df, x, y, hue=None, save_path=None, **kwargs):
"""Render a median line plot with IQR shading, grouped by a hue column.
For each unique value in the ``hue`` column the function draws the median
trajectory and fills between the 25th and 75th percentiles. Y-axis limits
are hard-coded based on the metric name (``mean_waiting`` → ``[0, 10]``;
everything else → ``[0, 150]``).
Args:
history_df (pandas.DataFrame): Combined history data for all policies/
runs.
x (str): Column name to use as the x-axis (typically ``"day"``).
y (str): Column name to use as the y-axis metric.
hue (str, optional): Column name used to separate runs into groups
(e.g., ``"policy_name"``). Defaults to ``None``.
save_path (str, optional): If provided, the figure is saved to this
path before display. Defaults to ``None``.
**kwargs: Additional keyword arguments. The ``title`` key (str) is
extracted and applied as the axes title; remaining keys are
currently unused.
"""
if "title" in kwargs:
title = kwargs["title"]
del kwargs["title"]
else:
title = ""
fig, axs = plt.subplots()
for policy in history_df[hue].unique():
policy_df = history_df[history_df[hue] == policy]
policy_stats = policy_df.groupby([x, policy]).describe()
q1 = policy_stats[y]["25%"]
q3 = policy_stats[y]["75%"]
sns.lineplot(x=x, y=y, data=policy_df, label=policy,
estimator=np.median, ci=None, ax=axs)
axs.fill_between(np.arange(len(policy_df[x].unique())), q1, q3, alpha=0.3)
# dirty hack (ro)
if y == "mean_waiting":
axs.set(ylim=(0, 10))
else:
axs.set(ylim=(0, 150))
axs.set_title(title)
if save_path is not None:
fig.savefig(save_path)
plt.show()
def _plot_lineplot2(history_df, x, hue=None, save_path=None, plotall=True, **kwargs):
"""Render a two-panel line plot comparing detected and total infectious cases.
Produces a side-by-side figure with:
- Left panel: median ``I_d`` (detected active cases).
- Right panel: median ``all_infectious`` (all active cases, excluding the
``"Czech Republic (scaled down)"`` policy).
Args:
history_df (pandas.DataFrame): Combined history data for all policies/
runs.
x (str): Column name for the x-axis (typically ``"day"``).
hue (str, optional): Column name used to colour lines by group.
Defaults to ``None``.
save_path (str, optional): If provided, the figure is saved to this
path. Defaults to ``None``.
plotall (bool, optional): Unused in this variant; kept for API
consistency. Defaults to ``True``.
**kwargs: Must contain ``title`` (str) for the figure super-title.
Optional keys: ``maxy`` (int, y-axis upper limit) and ``maxx``
(int, x-axis upper limit).
"""
title = kwargs["title"]
del kwargs["title"]
maxy = kwargs.get("maxy", None)
if "maxy" in kwargs:
del kwargs["maxy"]
maxx = kwargs.get("maxx", None)
if "maxx" in kwargs:
del kwargs["maxx"]
fig = plt.figure()
axs = [None] * 2
axs[0] = fig.add_subplot(121)
axs[1] = fig.add_subplot(122)
# axs[2] = fig.add_subplot(223)
# axs[3] = fig.add_subplot(224)
# dirty hack to get rid of stupid legend title
kwargs["legend"] = False
sns_plot = sns.lineplot(x=x, y="I_d", data=history_df,
hue=hue, estimator=np.median, ci='sd', ax=axs[0], **kwargs)
history_df_r = history_df[history_df["policy_name"]
!= "Czech Republic (scaled down)"]
# maxy = 5000
# dirty hack (ro)
axs[0].set(ylim=(0, maxy))
axs[0].set(xlim=(0, maxx))
axs[0].set_ylabel("all detected states")
axs[0].set_title("detected - active cases - median")
axs[0].legend(history_df[hue].unique(), title=None, fancybox=True, )
sns_plot2 = sns.lineplot(x=x, y="all_infectious", data=history_df_r,
hue=hue, estimator=np.median, ci='sd', ax=axs[1], **kwargs)
# maxy = 25000
# dirty hack (ro)
axs[1].set(ylim=(0, maxy))
axs[1].set(xlim=(0, maxx))
axs[1].set_ylabel("all infected states")
axs[1].set_title("all active cases - median")
axs[1].legend(history_df_r[hue].unique(), title=None, fancybox=True, )
"""
sns_plot = sns.lineplot(x=x, y="I_d", data=history_df,
hue=hue, estimator=np.mean, ci='sd', ax=axs[2], **kwargs)
history_df_r = history_df[history_df["policy_name"] != "Czech Republic (scaled down)"]
maxy = 5000
# dirty hack (ro)
axs[2].set(ylim=(0, maxy))
axs[2].set(xlim=(0, maxx))
axs[2].set_ylabel("all detected states")
axs[2].set_title("detected - active cases - mean")
axs[2].legend(history_df[hue].unique(), title=None, fancybox=True, )
sns_plot2 = sns.lineplot(x=x, y="all_infectious", data=history_df_r,
hue=hue, estimator=np.mean, ci='sd', ax=axs[3], **kwargs)
maxy = 25000
# dirty hack (ro)
axs[3].set(ylim=(0, maxy))
axs[3].set(xlim=(0, maxx))
axs[3].set_ylabel("all infected states")
axs[3].set_title("all active cases - mean")
axs[3].legend(history_df_r[hue].unique(), title=None, fancybox=True, )
"""
fig.suptitle(title, fontsize=20)
if save_path is not None:
plt.savefig(save_path)
def _plot_lineplot3(history_df, x, hue=None, save_path=None, plotall=True, **kwargs):
"""Render a multi-panel line plot covering all key epidemic metrics.
Produces up to six panels depending on ``plotall``:
- Panel 0: Median detected active cases (``I_d``).
- Panel 1: Median total infectious cases (``all_infectious``).
- Panel 2 (if ``plotall``): Median mean waiting time (``mean_waiting``).
- Panel 3: Median detection ratio (``detected_ratio``).
- Panel 4 (if ``plotall``): Median total tests (``all_tests``).
- Panel 5 (if ``plotall``): Median mean infection probability
(``mean_p_infection``).
Vertical reference lines are drawn at days 5, 36, 66, and 97 on
the panels that show them.
Args:
history_df (pandas.DataFrame): Combined history data for all policies/
runs.
x (str): Column name for the x-axis (typically ``"day"``).
hue (str, optional): Column name used to colour lines by group.
Defaults to ``None``.
save_path (str, optional): If provided, the figure is saved to this
path. Defaults to ``None``.
plotall (bool, optional): Whether to include the additional three
panels (waiting time, tests, infection probability). Defaults to
``True``.
**kwargs: Must contain ``title`` (str) for the figure super-title.
Optional key: ``maxy`` (int, y-axis upper limit; default ``300``).
"""
title = kwargs["title"]
del kwargs["title"]
maxy = kwargs.get("maxy", 300)
if "maxy" in kwargs:
del kwargs["maxy"]
fig = plt.figure()
axs = [None] * 6
axs[0] = fig.add_subplot(131)
axs[1] = fig.add_subplot(132)
if plotall:
axs[2] = fig.add_subplot(433)
axs[4] = fig.add_subplot(436)
axs[3] = fig.add_subplot(439)
axs[5] = fig.add_subplot(4, 3, 12)
else:
axs[3] = fig.add_subplot(133)
# axs[2] = fig.add_subplot(433)
# axs[3] = fig.add_subplot(436)
# axs[4] = fig.add_subplot(439)
# axs[5] = fig.add_subplot(4,3,12)
# dirty hack to get rid of stupid legend title
kwargs["legend"] = False
sns_plot = sns.lineplot(x=x, y="I_d", data=history_df,
hue=hue, estimator=np.median, ci='sd', ax=axs[0], **kwargs)
history_df_r = history_df[history_df["policy_name"]
!= "Czech Republic (scaled down)"]
# dirty hack (ro)
axs[0].set(ylim=(0, 40))
# axs[0].set(xlim=(1, 150))
axs[0].set_ylabel("all detected states")
axs[0].set_title("detected - active cases")
axs[0].legend(history_df[hue].unique(), title=None, fancybox=True, )
axs[0].axvline(x=5, color="gray")
axs[0].axvline(x=36, color="gray")
axs[0].axvline(x=66, color="gray")
axs[0].axvline(x=97, color="gray")
sns_plot2 = sns.lineplot(x=x, y="all_infectious", data=history_df_r,
hue=hue, estimator=np.median, ci='sd', ax=axs[1], **kwargs)
# dirty hack (ro)
axs[1].set(ylim=(0, 150))
# axs[1].set(xlim=(1, 150))
axs[1].set_ylabel("all infected states")
axs[1].set_title("all active cases")
axs[1].legend(history_df_r[hue].unique(), title=None, fancybox=True, )
if plotall:
sns_plot3 = sns.lineplot(x=x, y="mean_waiting", data=history_df,
hue=hue, estimator=np.median, ci='sd', ax=axs[2], **kwargs)
axs[2].set(ylim=(0, 15))
axs[2].set(xlim=(1, 120))
axs[2].legend(history_df[hue].unique(), title=None, fancybox=True, )
# axs[2].set_title("waiting times")
sns.lineplot(x=x, y="detected_ratio", data=history_df_r,
hue=hue, estimator=np.median, ci=None, ax=axs[3], **kwargs)
axs[3].set(ylim=(0, 15))
axs[3].set(xlim=(1, 120))
axs[3].legend(history_df_r[hue].unique(), title=None, fancybox=True, )
# axs[3].set_title("detected_ratio")
axs[3].axvline(x=5)
axs[3].axvline(x=36)
axs[3].axvline(x=66)
axs[3].axvline(x=97)
axs[3].axhline(y=10)
sns.lineplot(x=x, y="all_tests", data=history_df,
hue=hue, estimator=np.median, ci='sd', ax=axs[4],
**kwargs)
axs[4].set(ylim=(0, 52))
axs[4].set(xlim=(1, 120))
axs[4].legend(history_df[hue].unique(), title=None, fancybox=True, )
sns.lineplot(x=x, y="mean_p_infection", data=history_df_r,
hue=hue, estimator=np.median, ci='sd', ax=axs[5],
**kwargs)
axs[5].set(xlim=(1, 120))
axs[5].legend(history_df_r[hue].unique(), title=None, fancybox=True, )
# sns.lineplot(x=x, y="nodes_in_quarantine", data=history_df,
# hue=hue, estimator=np.mean, ci='sd', ax=axs[3], **kwargs)
# sns.lineplot(x=x, y="contacts_collected", data=history_df,
# hue=hue, estimator=np.mean, ci='sd', ax=axs[4], **kwargs)
# sns.lineplot(x=x, y="released_nodes", data=history_df,
# hue=hue, estimator=np.mean, ci='sd', ax=axs[5], **kwargs)
# axs[3].set(ylim=(0, 200))
# axs[3].set_title("nodes_in_quarantines")
# axs[4].set(ylim=(0, 50))
# axs[4].set_title("contacts_collected")
# axs[5].set(ylim=(0, 50))
# axs[5].set_title("released_nodes")
# sns.lineplot(x=x, y="", data=history_df,
# hue=hue, estimator=np.median, ci='sd', ax=axs[3], **kwargs)
# axs[3].set(ylim=(0, 15))
# axs[3].set_title("waiting times")
"""
sns.lineplot(x=x, y="tests_ratio", data=history_df,
hue=hue, estimator=np.median, ci='sd', ax=axs[3], **kwargs)
sns.lineplot(x=x, y="tests_ratio_to_s", data=history_df,
hue=hue, estimator=np.median, ci='sd', ax=axs[3], **kwargs)
axs[3].set_title("tests ratio to all infected, symptomatic infected")
axs[3].set(ylim=(0, None))
"""
fig.suptitle(title, fontsize=20)
if save_path is not None:
plt.savefig(save_path)
# plt.show()
def _history_with_fname(filename, group_days: int = None, group_func: str = "max", policy_name: str = None,
keep_only_all: bool = False, max_days=None):
"""Load a history CSV and enrich it with metadata columns.
Reads the history file via :func:`_load_history`, optionally restricts
columns to ``["day", "all_infectious"]``, applies temporal bucketing, and
inserts ``filename`` (and optionally ``policy_name``) as new columns.
Args:
filename (str): Path to the simulation output CSV file.
group_days (int, optional): If set to a positive integer, rows are
bucketed into intervals of this many days and aggregated using
``group_func``. Defaults to ``None`` (no grouping).
group_func (str, optional): Pandas aggregation function applied within
each day bucket. Defaults to ``"max"``.
policy_name (str, optional): If provided, a ``policy_name`` column is
added to the resulting DataFrame. Defaults to ``None``.
keep_only_all (bool, optional): If ``True``, all columns except
``"day"`` and ``"all_infectious"`` are dropped before grouping.
Defaults to ``False``.
max_days (int, optional): Maximum number of rows to retain (passed to
:func:`_load_history`). Defaults to ``None``.
Returns:
pandas.DataFrame: Processed history DataFrame with at minimum the
columns ``"filename"``, ``"day"``, and ``"all_infectious"``.
"""
history = _load_history(filename, max_days=max_days)
if keep_only_all:
history = history[["day", "all_infectious"]]
if group_days is not None and group_days > 0:
history["day"] = history["day"] // group_days * group_days
history = history.groupby(
"day", as_index=False).agg(func=group_func)
history.insert(0, "filename", filename)
if policy_name is not None:
history["policy_name"] = policy_name
return history
def _load_history(filename: str, max_days=None) -> pd.DataFrame:
"""Load and pre-process a simulation history CSV into a DataFrame.
Reads the CSV, computes derived aggregate columns (e.g.,
``all_infectious``, ``I_d``, ``all_tests``, ``detected_ratio``,
``tests_ratio``, ``mean_waiting``, ``nodes_in_quarantine``,
``released_nodes``, ``contacts_collected``), and ensures a ``"day"``
column is always present.
Derivations are performed only when an ``"E"`` column is present in the
CSV (indicating a full SEIR-type output format). For simpler outputs,
several columns are filled with zeros so downstream code can always
reference them.
Args:
filename (str): Path to the simulation output CSV file. Lines
starting with ``'#'`` are treated as comments and skipped.
max_days (int, optional): If set, truncates the DataFrame to the
first ``max_days`` rows after loading. Defaults to ``None``
(no truncation).
Returns:
pandas.DataFrame: Pre-processed history with derived aggregate columns
added. A ``"day"`` column is guaranteed to exist (falls back to a
sequential integer range if not present in the source file).
"""
print(filename)
history = pd.read_csv(filename, comment="#")
if "E" in history.columns:
all_infectious = [
s
for s in [
"I_n", "I_a", "I_s", "E",
"I_dn", "I_da", "I_ds", "E_d", "J_ds", "J_dn",
"J_n", "J_s"]
if s in history.columns
]
history["all_infectious"] = history[all_infectious].sum(axis=1)
try:
history["I_d"] = history[[
"I_dn", "I_da", "I_ds", "E_d", "J_ds", "J_dn"]].sum(axis=1)
history["all_tests"] = history[[
"tests", "quarantine_tests"]].sum(axis=1)
history["detected_ratio"] = history["all_infectious"] / history["I_d"]
history["tests_ratio"] = history["tests"] / \
history["all_infectious"]
history["all_s"] = history[[
"I_s", "I_ds", "J_s", "J_ds"]].sum(axis=1)
history["tests_ratio_to_s"] = history["tests"] / history["all_s"]
history["mean_waiting"] = history["sum_of_waiting"] / \
history["all_positive_tests"]
selected_cols = [
col
for col in history.columns
if "nodes_in_quarantine" in col
]
history["nodes_in_quarantine"] = history[selected_cols].sum(axis=1)
selected_cols = [
col
for col in history.columns
if "released_nodes" in col
]
history["released_nodes"] = history[selected_cols].sum(axis=1)
selected_cols = [
col
for col in history.columns
if "contacts_collected" in col
]
history["contacts_collected"] = history[selected_cols].sum(axis=1)
except KeyError:
print("Warning something is missing in data frame")
else:
history["nodes_in_quarantine"] = 0
history["released_nodes"] = 0
history["contacts_collected"] = 0
history["mean_p_infection"] = 0
if max_days is not None:
history = history[:max_days]
if "day" not in history.columns:
history["day"] = range(len(history))
# print(history)
return history
if __name__ == "__main__":
history = pd.read_csv(
"../result_storage/tmp/history_seirsplus_quarantine_1.csv")
plot_history(history)