Source code for prob_infection

"""Network contact and infection probability utilities.

This module provides standalone functions that compute which edges become
"active" on a given day and whether susceptible nodes are exposed via those
active edges.  Two main entry points are provided:

* :func:`prob_of_contact` – the current production implementation.
* :func:`prob_of_contact_old` – a legacy implementation retained for
  reference.

Helper functions :func:`select_active_edges`, :func:`archive_active_edges`,
and :func:`get_relevant_edges` are used internally by both entry points.
"""

import time
import numpy as np
import numpy_indexed as npi

import logging

from models.states import STATES


[docs] def select_active_edges(model, source_states, source_candidate_states, dest_states, dest_candidate_states): """Stochastically select active (contact) edges between two node sets. For each edge connecting a node in *source_candidate_states* to a node in *dest_candidate_states*, flips a biased coin with the edge's daily-contact probability. Returns only edges where contact actually occurred. Args: model: The active simulation model instance (must expose ``memberships``, ``graph``, ``num_nodes``). source_states (list of int): States whose nodes can become exposed (used for final filtering, not edge selection). source_candidate_states (list of int): Superset of *source_states* used for initial candidate-edge selection. dest_states (list of int): Infectious states (used for filtering). dest_candidate_states (list of int): Superset of *dest_states* used for initial candidate-edge selection. Returns: tuple: ``(active_edges, active_edges_dirs)`` where both are ``numpy.ndarray`` of shape ``(num_active_edges,)``, or ``(None, None)`` if no edges are active. """ assert type(dest_states) == list and type(source_states) == list # 1. select active edges # candidates for active edges are edges between source_candidate_states and dest_candidate_states # source (the one who is going to be infected) # dest (the one who can offer infection) s = time.time() source_candidate_flags = model.memberships[source_candidate_states, :, :].reshape( len(source_candidate_states), model.num_nodes).sum(axis=0) # source_candidate_indices = source_candidate_flags.nonzero()[0] dest_candidate_flags = model.memberships[dest_candidate_states, :, :].reshape( len(dest_candidate_states), model.num_nodes).sum(axis=0) # dest_candidate_indices = dest_candidate_flags.nonzero()[0] e = time.time() logging.info(f"Create flags {e-s}") s = time.time() possibly_active_edges, possibly_active_edges_dirs = model.graph.get_edges( source_candidate_flags, dest_candidate_flags ) num_possibly_active_edges = len(possibly_active_edges) e = time.time() logging.info(f"Select possibly active {e-s}") s = time.time() if num_possibly_active_edges == 0: return None, None # for each possibly active edge flip coin r = np.random.rand(num_possibly_active_edges) e = time.time() logging.info(f"Random numbers: {e-s}") s = time.time() # edges probs p = model.graph.get_edges_probs(possibly_active_edges) e = time.time() logging.info(f"Get intensities: {e-s}") s = time.time() active_indices = (r < p).nonzero()[0] num_active_edges = len(active_indices) if num_active_edges == 0: return None, None active_edges = possibly_active_edges[active_indices] active_edges_dirs = possibly_active_edges_dirs[active_indices] e = time.time() logging.info(f"Select active edges: {e-s}") return active_edges, active_edges_dirs
[docs] def archive_active_edges(model, active_edges, active_edges_dirs): """Record today's active edges in the model's contact-history buffer. Retrieves source and destination node indices for each active edge and appends them (along with edge type) to ``model.contact_history``. Args: model: The active simulation model instance. active_edges (numpy.ndarray): Indices of today's active edges. active_edges_dirs (numpy.ndarray): Direction flags for each edge (``True`` = canonical direction, ``False`` = reversed). """ s = time.time() # get source and dest nodes for active edges # source and dest met today, dest is possibly infectious, source was possibly infected source_nodes, dest_nodes = model.graph.get_edges_nodes( active_edges, active_edges_dirs) # add to contact_history (infectious node goes first!!!) contact_indices = list(zip(dest_nodes, source_nodes, active_edges)) model.contact_history.append(contact_indices) # uncoment for statistics # for dest_node, source_node, _ in contact_indices: # model.nodes_inf_contacts[source_node] += 1 # if model.t == 6: # contact_numbers = {} # for e in active_edges: # layer_number = model.graph.get_layer_for_edge(e) # contact_numbers[layer_number] = contact_numbers.get( # layer_number, 0) + 1 # contact_num_str = ', '.join([f'{layer_type}:{number}' for layer_type, number in contact_numbers.items()]) # logging.debug(f"DBG contact numbers {contact_num_str}") # print("Potkali se u kolina:", contact_indices) logging.info(f"Todays contact num: {len(contact_indices)}") e = time.time() logging.info(f"Archive active edges: {e-s}")
[docs] def get_relevant_edges(model, active_edges, active_edges_dirs, source_states, dest_states): """Filter active edges to those connecting relevant source/destination states. From the set of *active_edges* retains only those where the source node is in one of *source_states* and the destination node is in one of *dest_states*. Args: model: The active simulation model instance. active_edges (numpy.ndarray): Indices of today's active edges. active_edges_dirs (numpy.ndarray): Direction flags for *active_edges*. source_states (list of int): States eligible to be exposed. dest_states (list of int): Infectious states. Returns: tuple: ``(active_relevant_edges, active_relevant_edges_dirs)`` or ``(None, None)`` if no relevant edges remain after filtering. """ # restrict the selection to only relevant states # (ie. candidates can be both E and I, relevant are only I) # candidates are those who will be possibly relevant in future s = time.time() dest_flags = model.memberships[dest_states, :, :].reshape( len(dest_states), model.num_nodes).sum(axis=0) source_flags = model.memberships[source_states, :, :].reshape( len(source_states), model.num_nodes).sum(axis=0) relevant_edges, relevant_edges_dirs = model.graph.get_edges( source_flags, dest_flags) if len(relevant_edges) == 0: return None, None # get intersection active_relevant_edges = np.intersect1d(active_edges, relevant_edges) if len(active_relevant_edges) == 0: return None, None # print(active_relevant_edges.shape) # print(relevant_edges.shape) # print((active_relevant_edges[:, np.newaxis] == relevant_edges).shape) try: # this causes exceptin, but only sometimes ??? # where_indices = ( # active_relevant_edges[:, np.newaxis] == relevant_edges).nonzero()[1] # lets try npi instead where_indices = npi.indices(relevant_edges, active_relevant_edges) except AttributeError as e: print(e) print("Lucky we are ...") print(active_relevant_edges.shape) np.save("active_relevant_edges", active_relevant_edges) print(relevant_edges.shape) np.save("relevant_edges", relevant_edges) print(active_relevant_edges[:, np.newaxis] == relevant_edges) exit() # print(where_indices, len(where_indices)) # always one index! (sources and dest must be disjunct) active_relevant_edges_dirs = relevant_edges_dirs[where_indices] e = time.time() logging.info(f"Get relevant active edges: {e-s}") return active_relevant_edges, active_relevant_edges_dirs
[docs] def prob_of_contact_old(model, source_states, source_candidate_states, dest_states, dest_candidate_states, beta, beta_in_family): """Legacy implementation of the contact-probability computation (abandoned). Retained for historical reference only. Prefer :func:`prob_of_contact`. Args: model: The active simulation model instance. source_states (list of int): States that can be exposed. source_candidate_states (list of int): Candidate source states for edge selection. dest_states (list of int): Infectious states. dest_candidate_states (list of int): Candidate destination states for edge selection. beta (numpy.ndarray): Per-node non-family transmission rate, shape ``(num_nodes, 1)``. beta_in_family (numpy.ndarray): Per-node family transmission rate, shape ``(num_nodes, 1)``. Returns: numpy.ndarray: Per-node binary exposure indicator, shape ``(num_nodes, 1)``. """ main_s = time.time() active_edges, active_edges_dirs = select_active_edges(model, source_states, source_candidate_states, dest_states, dest_candidate_states) if active_edges is None: # we have no active edges today return np.zeros((model.num_nodes, 1)) #archive_active_edges(model, active_edges, active_edges_dirs) active_relevant_edges, active_relevant_edges_dirs = get_relevant_edges(model, active_edges, active_edges_dirs, source_states, dest_states) if active_relevant_edges is None: return np.zeros((model.num_nodes, 1)) s = time.time() intensities = model.graph.get_edges_intensities( active_relevant_edges).reshape(-1, 1) relevant_sources, relevant_dests = model.graph.get_edges_nodes( active_relevant_edges, active_relevant_edges_dirs) is_family_edge = model.graph.is_family_edge( active_relevant_edges).reshape(-1, 1) if False: is_class_edge = model.graph.is_class_edge( active_relevant_edges).reshape(-1, 1) is_pub_edge = model.graph.is_pub_edge( active_relevant_edges).reshape(-1, 1) is_super_edge = model.graph.is_super_edge( active_relevant_edges).reshape(-1, 1) is_family_edge = np.logical_or.reduce((is_family_edge, is_class_edge, is_super_edge, is_pub_edge)) # assert len(relevant_sources) == len(set(relevant_sources)) # TOD beta - b_ij # new beta depands on the one who is going to be infected # b_intensities = beta[relevant_sources] # b_f_intensities = beta_in_family[relevant_sources] # reduce asymptomatic is_A = model.memberships[STATES.I_n][relevant_dests] b_original_intensities = ( beta_in_family[relevant_sources] * (1 - is_A) + model.beta_A_in_family[relevant_sources] * is_A ) b_reduced_intensities = ( beta[relevant_sources] * (1 - is_A) + model.beta_A[relevant_sources] * is_A ) b_intensities = ( b_original_intensities * is_family_edge + b_reduced_intensities * (1 - is_family_edge) ) assert b_intensities.shape == intensities.shape # relevant_sources_unique, unique_indices = np.unique( # relevant_sources, return_inverse=True) # print(len(relevant_sources_unique)) # print(b_intensities, b_intensities.shape) # print(active_relevant_edges, active_relevant_edges.shape) # exit() r = np.random.rand(b_intensities.ravel().shape[0]).reshape(-1, 1) # print(b_intensities.shape) # print(intensities.shape) # print((b_intensities*intensities).shape) is_exposed = r < (b_intensities * intensities) # print(is_exposed, is_exposed.shape) if np.all(is_exposed == False): return np.zeros((model.num_nodes, 1)) is_exposed = is_exposed.ravel() exposed_nodes = relevant_sources[is_exposed] ret = np.zeros((model.num_nodes, 1)) ret[exposed_nodes] = 1 sourced_nodes = relevant_dests[is_exposed] model.successfull_source_of_infection[sourced_nodes] += 1 succesfull_edges = active_relevant_edges[is_exposed] successfull_layers = model.graph.get_layer_for_edge(succesfull_edges) for e in successfull_layers: model.stat_successfull_layers[e][model.t] += 1 main_e = time.time() logging.info(f"PROBS OF CONTACT {main_e - main_s}") return ret
# no_infection = (1 - b_intensities * intensities).ravel() # res = np.ones(len(relevant_sources_unique), dtype='float32') # for i in range(len(unique_indices)): # res[unique_indices[i]] = res[unique_indices[i]]*no_infection[i] # prob_of_no_infection = res # # prob_of_no_infection2 = np.fromiter((np.prod(no_infection, where=(relevant_sources==v).T) # # for v in relevant_sources_unique), dtype='float32') # result = np.zeros(model.num_nodes) # result[relevant_sources_unique] = 1 - prob_of_no_infection # e = time.time() # print("Comp probability", e-s) # return result.reshape(model.num_nodes, 1)
[docs] def prob_of_contact(model, source_states, source_candidate_states, dest_states, dest_candidate_states, beta, beta_in_family): """Evaluate per-node exposure via stochastic edge activation (production). For each graph edge, independently activates it with the edge's daily contact probability. For active edges connecting a susceptible source to an infectious destination, draws a Bernoulli trial with the effective transmission rate (accounting for family vs. non-family edges and asymptomatic reduction). Returns a binary indicator vector: 1 for each node that was exposed on this day. *source_candidate_states* and *dest_candidate_states* are accepted for backward compatibility but are no longer used; the function queries all edges. Updates the following model statistics: * ``model.contact_history`` – active edge contacts. * ``model.successfull_source_of_infection`` – per-node successful infection counts. * ``model.stat_successfull_layers`` – per-layer successful transmission counts. Args: model: The active simulation model instance. source_states (list of int): States that can become exposed. source_candidate_states (list of int): Unused (backward compat.). dest_states (list of int): Infectious states. dest_candidate_states (list of int): Unused (backward compat.). beta (numpy.ndarray): Per-node non-family transmission rate, shape ``(num_nodes, 1)``. beta_in_family (numpy.ndarray): Per-node family transmission rate, shape ``(num_nodes, 1)``. Returns: numpy.ndarray: Per-node binary exposure indicator, shape ``(num_nodes, 1)``. Value is 1 if the node was newly exposed, 0 otherwise. """ # source_states - states that can be infected # dest_states - states that are infectious main_s = time.time() edges_probs = model.graph.get_all_edges_probs() num_edges = len(edges_probs) r = np.random.rand(num_edges) active_edges = (r < edges_probs).nonzero()[0] logging.info(f"active_edges {len(active_edges)}") source_nodes = model.graph.e_source[active_edges] dest_nodes = model.graph.e_dest[active_edges] types = model.graph.e_types[active_edges] contact_info = ( np.concatenate([source_nodes, dest_nodes]), np.concatenate([dest_nodes, source_nodes]), np.concatenate([types, types]) ) model.contact_history.append(contact_info) # take them in both directions n = len(active_edges) active_edges = np.concatenate([active_edges, active_edges]) active_edges_dirs = np.ones(2*n, dtype=bool) active_edges_dirs[n:] = False source_nodes, dest_nodes = model.graph.get_edges_nodes( active_edges, active_edges_dirs ) # is source in feasible state? is_relevant_source = model.memberships[source_states[0], source_nodes, 0] for state in source_states[1:]: is_relevant_source += model.memberships[state, source_nodes, 0] # is dest in feasible state? is_relevant_dest = model.memberships[dest_states[0], dest_nodes, 0] for state in dest_states[1:]: is_relevant_dest += model.memberships[state, dest_nodes, 0] is_relevant_edge = np.logical_and( is_relevant_source, is_relevant_dest ) ########################## relevant_edges = active_edges[is_relevant_edge] intensities = model.graph.get_edges_intensities( relevant_edges).reshape(-1, 1) relevant_sources, relevant_dests = model.graph.get_edges_nodes( relevant_edges, active_edges_dirs[is_relevant_edge]) is_family_edge = model.graph.is_family_edge( relevant_edges).reshape(-1, 1) # reduce asymptomatic is_A = model.memberships[STATES.I_n][relevant_dests] b_original_intensities = ( beta_in_family[relevant_sources] * (1 - is_A) + model.beta_A_in_family[relevant_sources] * is_A ) b_reduced_intensities = ( beta[relevant_sources] * (1 - is_A) + model.beta_A[relevant_sources] * is_A ) b_intensities = ( b_original_intensities * is_family_edge + b_reduced_intensities * (1 - is_family_edge) ) assert b_intensities.shape == intensities.shape r = np.random.rand(b_intensities.ravel().shape[0]).reshape(-1, 1) is_exposed = r < (b_intensities * intensities) if np.all(is_exposed == False): return np.zeros((model.num_nodes, 1)) is_exposed = is_exposed.ravel() exposed_nodes = relevant_sources[is_exposed] ret = np.zeros((model.num_nodes, 1)) ret[exposed_nodes] = 1 # save stats sourced_nodes = relevant_dests[is_exposed] model.successfull_source_of_infection[sourced_nodes] += 1 succesfull_edges = relevant_edges[is_exposed] successfull_layers = model.graph.get_layer_for_edge(succesfull_edges) for e in successfull_layers: model.stat_successfull_layers[e][model.t] += 1 main_e = time.time() logging.info(f"PROBS OF CONTACT {main_e - main_s}") return ret