Source code for random_utils

"""Random number generation utilities for the MAIS simulation.

This module provides helpers for reproducible random sampling, ordered-tuple
generation, and discrete duration sampling. It is used across the simulation
to draw stochastic values for disease progression durations and other
time-varying random quantities.
"""

import numpy as np
from numpy.random import Generator, SFC64
import logging


[docs] class RandomGenerator(): """Seeded pseudo-random number generator wrapper using the SFC64 bit generator. Wraps NumPy's ``Generator`` backed by the fast ``SFC64`` bit generator so that independent, reproducible streams can be attached to different parts of the model. Note: Adding per-component generators is still in progress. Args: seed (int): Integer seed passed to ``SFC64`` for reproducibility. """ # add own generators to individual parts of a model (in progress) def __init__(self, seed): self.rng = Generator(SFC64(seed))
[docs] def rand(n): """Generate ``n`` uniformly distributed random floats in ``[0, 1)``. Args: n (int): Number of random values to generate. Returns: numpy.ndarray: Array of shape ``(n,)`` with values in ``[0, 1)``. """ return rng.random(n)
def _random_from_probs(what, p, n=1): """Draw ``n`` random samples from ``what`` according to probabilities ``p``. Args: what (int or array-like): If an integer, samples are drawn from ``range(what)``. Otherwise samples are drawn from the provided sequence. p (array-like): Probability weights for each element of ``what``. Must sum to 1. n (int, optional): Number of samples to draw. Defaults to ``1``. Returns: numpy.ndarray: Array of ``n`` sampled values. """ return np.random.choice(what, p=p, size=n) def _check_sorted(value_list): """Check that corresponding elements across a list of arrays are strictly increasing. Given arrays ``[x1, x2, ..., xn]`` (each of the same shape), verifies element-wise that ``x1[i] < x2[i] < ... < xn[i]`` for every index ``i``. Args: value_list (list of numpy.ndarray): Ordered list of arrays to compare pairwise. All arrays must have the same shape. Returns: numpy.ndarray: Boolean array of the same shape as each input array. ``True`` at position ``i`` means all consecutive pairs satisfy the strict ordering at that position; ``False`` indicates a violation. """ # print() # print(value_list) # print() partial_results_list = [ np.atleast_1d(x < y) for x, y in zip(value_list[:-1], value_list[1:]) ] # print() # print("partial_results_list", partial_results_list, len(partial_results_list)) if len(partial_results_list) > 1: return np.logical_and(*partial_results_list) else: return partial_results_list[0]
[docs] def gen_tuple1(n, shape, *args): """Generate an ``n``-tuple of random values satisfying a strict ordering. Draws values ``(r_1, r_2, ..., r_n)`` such that ``r_1 < r_2 < ... < r_n`` element-wise across a batch of size ``shape``. Any positions that violate the ordering are resampled repeatedly until all positions satisfy the constraint. Args: n (int): Number of elements in the tuple. Must equal ``len(args)``. shape (int or tuple): Shape of the batch to generate (passed as ``n`` argument to each generator's ``get`` method). *args: Exactly ``n`` random duration generator objects, each exposing a ``get(n=...)`` method that returns a NumPy array of samples. Returns: list of numpy.ndarray: List of ``n`` arrays, each of shape ``shape``, satisfying ``result[0] < result[1] < ... < result[n-1]`` element-wise. Example: >>> gen_tuple(3, rng1, rng2, rng3) """ def _gen(s): result = [] for i in range(n): result.append(args[i].get(n=s)) return result assert len(args) == n result = _gen(shape) check = _check_sorted(result) while not np.all(check): loggin.info("gen_tuple: condition no satisfied, repairing") indices_to_fix = np.where(check == False)[0] new_values = _gen(indices_to_fix.shape[0]) # list of length n again # but with shorter items for i in range(n): result[i][indices_to_fix] = new_values[i].reshape(-1, 1) check = _check_sorted(result) return result
[docs] def gen_tuple2(n, shape, *args): """Generate an ``n``-tuple of random values satisfying a strict ordering (clipping variant). Draws values ``(r_1, r_2, ..., r_n)`` such that ``r_1 < r_2 < ... < r_n`` element-wise across a batch of size ``shape``. Unlike :func:`gen_tuple1`, ordering is enforced by clipping each subsequent value to be at least ``previous + 1`` rather than by resampling. Args: n (int): Number of elements in the tuple. Must equal ``len(args)``. shape (int or tuple): Shape of the batch to generate (passed as ``n`` argument to each generator's ``get`` method). *args: Exactly ``n`` random duration generator objects, each exposing a ``get(n=...)`` method that returns a NumPy array of samples. Returns: list of numpy.ndarray: List of ``n`` arrays, each of shape ``shape``, satisfying ``result[0] < result[1] < ... < result[n-1]`` element-wise. Example: >>> gen_tuple(3, rng1, rng2, rng3) """ result = [] for i in range(n): values = args[i].get(n=shape) if i > 0: values = np.clip(values, result[i-1]+1, None) result.append(values) return result
[docs] def gen_tuple(n, shape, *args): """Generate an ``n``-tuple of strictly ordered random values. Delegates to :func:`gen_tuple2`. See that function for full documentation. Args: n (int): Number of elements in the tuple. shape (int or tuple): Shape of the batch to generate. *args: Exactly ``n`` random duration generator objects. Returns: list of numpy.ndarray: List of ``n`` strictly ordered arrays. """ return gen_tuple2(n, shape, *args)
[docs] class RandomDuration(): """Discrete random duration sampler driven by a full probability distribution. Intended for generating the time (in discrete steps, e.g. days) that an agent spends in a particular disease state. The distribution is specified as a probability mass function (PMF) over non-negative integer durations starting from zero. Args: probs (array-like): NumPy array of probabilities for durations ``0, 1, 2, ..., len(probs)-1``. Values must be non-negative and sum to 1. precompute (bool, optional): If ``True``, a large buffer of ``10^6`` pre-drawn samples is generated at construction time. Currently the buffer is not stored, so this flag has no effect on subsequent ``get`` calls. Defaults to ``False``. """ def __init__(self, probs, precompute=False): self.N = len(probs) self.probs = probs if precompute: buf = _random_from_probs(self.N, self.probs, 10**6)
[docs] def get(self, n=1): """Draw ``n`` random duration values from the distribution. Args: n (int, optional): Number of samples to draw. Defaults to ``1``. Returns: numpy.ndarray: Array of ``n`` integer duration values drawn according to ``self.probs``. """ values = _random_from_probs(self.N, self.probs, n) return values
if __name__ == "__main__": import matplotlib.pyplot as plt def uncumulate(l): res = [x - y for (x, y) in zip(l, [0]+l[:-1]) ] s = sum(res) res[-1] += 1.0-s return res cdf_incubation = [0, 0.002079467, 0.045532967, 0.158206035, 0.303711753, 0.446245776, 0.569141375, 0.668484586, 0.746107988, 0.805692525, 0.851037774, 0.885435436, 0.911529759, 0.931365997, 0.946495014, 0.958080947, 0.966993762, 0.973882948, 0.979233968, 0.983410614, 0.986686454, 0.98926803, 0.991311965, 0.992937571, 0.994236158, 0.995277934, 0.996117131, 0.996795835, 0.997346849, 0.997795859, 0.998163058, 0.998464392, 0.998712499, 0.998917441, 0.999087255, 0.999228384, 0.999346016, 0.999444337, 0.999526742, 0.999595989, 0.999654327] p_incubation = uncumulate(cdf_incubation) values = [] values2 = [] durations = RandomDuration(p_incubation) pre_durations = RandomDuration(p_incubation, precompute=True) for _ in range(10000): values.extend(durations.get(100)) values2.extend(durations.get(100)) print(np.mean(values), np.median(values)) print(np.mean(values2), np.median(values2)) print(np.max(values), np.max(values2)) max_value = max(values + values2) min_value = min(values + values2) fig, axs = plt.subplots(nrows=2, figsize=(10, 7)) axs[0].hist(values, color="pink", label="onfly", bins=range(min_value, max_value+1)) axs[0].hist(values2, color="blue", label="precomputed", bins=range(min_value, max_value+1)) axs[1].hist(values2, color="blue", label="precomputed", bins=range(min_value, max_value+1)) axs[1].hist(values, color="pink", label="onfly", bins=range(min_value, max_value+1)) axs[0].set_xticks(range(min_value, max_value+1)) axs[1].set_xticks(range(min_value, max_value+1)) axs[0].legend() axs[1].legend() fig.suptitle("days in E") # Show plot plt.show()