Source code for plot_experiments

"""Command-line script for plotting MAIS simulation results.

Reads one or more simulation result files (``.csv``, ``.zip``, or
``.feather``), computes aggregate statistics (median/mean with IQR or SD
shading), and saves the resulting plot.  An optional fit/observed-data curve
can be overlaid.

Typical usage::

    python plot_experiments.py results_a.zip results_b.zip \\
        --column I_d --out_file plot.png --label_names "Scenario A,Scenario B"
"""

import glob
import zipfile

import click
import os
import sys

sys.path.append(os.path.join(os.path.dirname(__file__), '../src'))

import numpy as np
import pandas as pd
import seaborn as sns

import matplotlib.pyplot as plt


def _add_id_column(df, fname):
    """Add an ``"id"`` column to ``df`` derived from the base file name.

    The ``"id"`` value is the base name of ``fname`` with the ``.csv``
    extension removed.  The column is added in-place.

    Args:
        df (pandas.DataFrame): DataFrame to modify in-place.
        fname (str): File path whose base name (without ``.csv``) becomes the
            identifier.
    """
    fname = os.path.basename(fname).replace('.csv', '')
    df["id"] = fname


[docs] def process_zip(zip_path: str, save_feather=False): """Extract all CSVs from a ZIP archive and concatenate them into one DataFrame. Creates a temporary directory next to the ZIP file, extracts all ``*.csv`` entries, reads them (ignoring comment lines starting with ``#``), adds an ``"id"`` column to each, and concatenates the results. The temporary directory is removed in a ``finally`` block regardless of errors. Args: zip_path (str): Path to the ``.zip`` archive produced by ``run_multi_experiment.py``. save_feather (bool): If ``True``, the concatenated DataFrame is also saved as a ``.feather`` file with the same base name as the ZIP. Defaults to ``False``. Returns: pandas.DataFrame: Concatenated DataFrame with all replicate results and an ``"id"`` column identifying each source file. """ tmp_dir = zip_path + '.tmp' try: os.mkdir(tmp_dir) with zipfile.ZipFile(zip_path, 'r') as zip_ref: zip_ref.extractall(tmp_dir) csv_files = [file for file in glob.glob(os.path.join(tmp_dir, '*.csv'))] dfs = {cf: pd.read_csv(cf, comment='#') for cf in csv_files} out_dfs = [] for fname, df in dfs.items(): _add_id_column(df, fname) out_dfs.append(df) res = pd.concat(out_dfs, ignore_index=True) if save_feather: res.to_feather(zip_path.replace('.zip', '.feather')) finally: # clean temp dir csv_files = [file for file in glob.glob(os.path.join(tmp_dir, '*.csv'))] for file in csv_files: os.remove(file) os.rmdir(tmp_dir) return res
[docs] def plot_dfs(dfs, column, figsize, out_path, xlabel, ylabel, labels=None, title=None, ymax=None, use_median=True, use_sd=False, fit_me=None, show_whole_fit=False, day_indices=None, day_labels=None): """Create and save a multi-series line plot from a list of DataFrames. Each DataFrame in ``dfs`` is plotted as one line with an uncertainty band (IQR or SD). An optional fit/observed data curve can be overlaid. Axis limits are inferred from the data unless overridden. Args: dfs (list[pandas.DataFrame]): List of result DataFrames, each containing at minimum a ``"T"`` column and the column named by ``column``. column (str): Name of the y-axis column to plot. figsize (tuple[int, int]): Figure size ``(width, height)`` in inches. out_path (str): File path where the plot image is saved. xlabel (str): Label for the x-axis. ylabel (str): Label for the y-axis. labels (list[str] or None): Legend labels, one per DataFrame in ``dfs``. Defaults to ``None`` (no legend). title (str or None): Plot title. Defaults to ``None``. ymax (int or None): Upper limit of the y-axis. Inferred from data if ``None``. Defaults to ``None``. use_median (bool): Use median as the central estimator when ``True``; use mean otherwise. Defaults to ``True``. use_sd (bool): Use standard-deviation shading when ``True``; use interquartile-range shading otherwise. Defaults to ``False``. fit_me (pandas.DataFrame or None): Optional DataFrame with columns ``"T"`` and ``column`` representing observed/fit data to overlay as an unshaded line. Defaults to ``None``. show_whole_fit (bool): If ``True``, include ``fit_me`` in the x-axis range calculation. Defaults to ``False``. day_indices (list[int] or None): Positions along the x-axis at which to place custom tick marks. Must be combined with ``day_labels``. Defaults to ``None``. day_labels (list[str] or None): String labels for the ticks at ``day_indices``. Defaults to ``None``. """ fig, ax = plt.subplots(figsize=figsize) estimator = np.median if use_median else np.mean ci = 'sd' if use_sd else None if title is not None: ax.set_title(title, fontsize=14) # list of all dfs including the fit data - to infer correct plot limits lim_dfs = dfs if fit_me is None else dfs + [fit_me] xlim_dfs = lim_dfs if show_whole_fit else dfs xmin = min([df['T'].min() for df in xlim_dfs]) xmax = max([df['T'].max() for df in xlim_dfs]) ymax = ymax if ymax is not None else max([df[column].max() for df in lim_dfs]) ax.set_ylim(ymin=0.0, ymax=ymax) ax.set_xlim(xmin=xmin, xmax=xmax) for i, df in enumerate(dfs): label = None if labels is not None: label = labels[i] sns.lineplot(x='T', y=column, data=df, label=label, estimator=estimator, errorbar=ci, ax=ax) if not use_sd: df_stats = df.groupby(["T"]).describe() q1 = df_stats[column]["25%"] q3 = df_stats[column]["75%"] ax.fill_between(df["T"].unique(), q1, q3, alpha=0.3, label='_nolegend_') if ax.legend_: ax.legend_.remove() if fit_me is not None: ax.plot(fit_me['T'], fit_me[column]) ax.set_xlabel(xlabel, fontsize=12) ax.set_ylabel(ylabel, fontsize=12) if day_indices is not None: ax.set_xticks([]) ax.xaxis.set_minor_locator(plt.FixedLocator(day_indices)) ax.xaxis.set_minor_formatter(plt.FixedFormatter(day_labels)) ax.grid(which="minor", axis="x", linestyle="--", linewidth=1) plt.setp(ax.xaxis.get_minorticklabels(), rotation=70) if labels is not None: fig.legend(labels, loc="upper right", framealpha=1.0) fig.tight_layout() plt.savefig(out_path)
@click.command() @click.argument('plot_files', nargs=-1) @click.option('--label_names', default=None, help="Labels for each file to show in legend (separated by comma), " "called as --label_names l_1,l_2, ... " "If --fit_me is used, the last label in this argument is the label of" " the fit data.") @click.option('--out_file', default='./plot_out.png', help="Path where to save the plot.") @click.option('--fit_me', default=None, help="Path to a .csv file with fit data.") @click.option('--show_whole_fit/--show_partial_fit', default=False, help="If true, plot the whole fit data on the x" "axis (may be longer than other data).") @click.option('--title', default=None, help="Title of the plot.") @click.option('--column', default='all_infected', help="Column to use for plotting.") @click.option('--zip_to_feather/--no_zip_to_feather', default=False, help="If True, save processed .zips to .feather. " "The file name is the same except the " "extension.") @click.option('--figsize', nargs=2, default=(6, 5), help="Size of the plot (specified without commas: --figsize 6 5).") @click.option('--ymax', default=None, type=int, help="Y axis upper limit. By default the limit is inferred from source" "data.") @click.option('--xlabel', default='Day', help="X axis label.") @click.option('--ylabel', default=None, help="Y axis label.") @click.option('--use_median/--use_mean', default=True, help="Use median or mean in the plot (default median).") @click.option('--use_sd/--use_iqr', default=False, help="Use sd or iqr for shades (default iqr).") @click.option('--day_indices', default=None, help="Use dates on x axis - maps indices to labels (e.g. 5,36,66).") @click.option('--day_labels', default=None, help="Use dates on x axis - string labels (e.g. " "\"March 1,April 1,May 1\").") @click.option('--nodes_counts', default=None, help="Comma separated nodes_counts. If provided, all 'all_infected' values are normalized to 100 000 individuals. Other columns are untouched! fit_me is not normalized!!!") @click.option('--infected_states', default="I_n,I_a,I_s,E,J_n,J_s", help="Comma separated list of infected states.") def run(plot_files, label_names, out_file, fit_me, show_whole_fit, title, column, zip_to_feather, figsize, ymax, xlabel, ylabel, use_median, use_sd, day_indices, day_labels, nodes_counts, infected_states): """ Create plot using an arbitrary number of input files. Optionally, plot a fit curve specified by --fit_me \b PLOT_FILES name of the input files - either .csv, .zip or .feather """ infected_states = infected_states.split(',') if ylabel is None: ylabel = f"Number of cases ({column})" if label_names is not None: label_names = label_names.split(',') else: pass # def clean_name(filename): # filename = filename.split('/')[-1] # filename = filename[::-1].split(".", 1)[-1][::-1] # removes last suffix # return filename.replace("history", "").replace("MODEL","") # label_names = list(map( # lambda x: clean_name(x), # plot_files # )) if day_indices is not None or day_labels is not None: assert day_indices is not None and day_labels is not None, "Both --day_indices and --day_labels must be " \ "passed to the script if string x axis labels are " \ "used." day_labels = day_labels.split(',') # check if indices are valid try: day_indices = [int(i) for i in day_indices.split(',')] except ValueError as e: raise ValueError("Argument --day_indices must be a comma separated list of ints (e.g. 5,10,25)") from e assert len(day_labels) == len(day_indices), "Arguments --day_indices and --day_labels must have the same" \ "number of values." assert len(plot_files), "No input files were passed to the script." if nodes_counts is not None: try: nodes_counts = [int(count) for count in nodes_counts.split(",")] except ValueError: print("--nodes_counts should be comma separated integers.") exit(1) dfs = [] for i, file in enumerate(plot_files): print(f"Processing file {file}...") if file.endswith('.feather'): df = pd.read_feather(file) elif file.endswith('.zip'): df = process_zip(file, save_feather=zip_to_feather) elif file.endswith('.csv'): df = pd.read_csv(file, comment='#') _add_id_column(df, file) else: raise ValueError(f"Unsupported file: {file}, supported extensions - .zip, .feather") if column == "all_infected": for all_inf in ( infected_states, ["I_n","I_a","I_s","E","J_n","J_s"], # for infection spread ["I"], # for SIR model ["Active"] # for Tipping model ): try: df["all_infected"] = df[all_inf].sum(axis=1) except: continue break else: raise ValueError("cannot initialize target column") if nodes_counts is not None: df["all_infected"] /= nodes_counts[i] df["all_infected"] *= 100000 print(f"Divided by {nodes_counts[i]}") dfs.append(df) if fit_me is not None: print(f"Reading fit file - {fit_me}...") fit_df = pd.read_csv(fit_me) else: fit_df = None print("All files processed.") print("Creating plot file...") plot_dfs(dfs, column, figsize, out_file, xlabel, ylabel, labels=label_names, title=title, ymax=ymax, use_median=use_median, use_sd=use_sd, day_indices=day_indices, day_labels=day_labels, fit_me=fit_df, show_whole_fit=show_whole_fit) print(f"Plot file saved to {out_file}.") if __name__ == "__main__": run()