Source code for sampling.generate

# -*- coding: utf-8 -*-

"""
Generate communities by stratified sampling.
"""

import logging
import random
from itertools import chain
from typing import Iterator

from . import ess
from .types import Community, Ecosystem, Name


# Generate communities by (possibly trivial) sampling.
[docs] def gen_sample( sources: Ecosystem, reps: int = 1, sizes: list[int] | None = None, pool: Community | None = None, name: str | None = None, ) -> Iterator[tuple[Name, Community]]: """ Generate ``reps`` new communities of ``sizes`` by stratified sampling of a set of ``sources`` communities, possibly of size one. :param sources: sequence of named or unnamed source communities :param reps: number of repetitions of each sample :param sizes: list of sample sizes; if empty, each sample is same size as its source :param pool: community to round out samples whose sizes are not a multiple of the number of sources; by default, all sources combined :param name: optional prefix for the sample name """ sizes = sizes or [] sources = [x if isinstance(x, tuple) else (Name(com=x), x) for x in sources] assert all(isinstance(k, Name) for k, _ in sources) # # Get community name from the source if there is only one, or from an # # eco_name if given. Otherwise the name will be generated on the fly. # if not name and len(sources) == 1: # name = str(sources[0][0].ident) # If not given, pool is all models in source commmunities. if not pool: pool = list(chain.from_iterable(s for _, s in sources)) logging.info( "Sample %s: generate from %d source%s, size%s [%s], reps %d for %s; pool %d model%s", name, len(sources), ess(sources), ess(sizes), " ".join(str(i) for i in sizes), reps, str(name), len(pool), ess(pool), ) if not sizes: logging.info("Trivial sampling should produce %d coms", len(sources)) for source_name, source_com in sources: yield Name(name=source_name, com=source_com), source_com for size in sizes: for rep in range(1, reps + 1): split = divmod(size, len(sources)) part = chain.from_iterable( random.sample(src, split[0]) for _, src in sources ) com = list(part) logging.debug( "Gen %d rep %d: split %dR%d → (%d) part%s", size, rep, *split, len(com), ess(com), ) if split[1] > 0: censored = [m for m in pool if m not in {k: None for k in com}] extra = random.sample(censored, split[1]) com += extra logging.debug( "Gen %d rep %d: (%d) censored → (%d) extra", size, rep, len(censored), len(extra), ) assert len(com) == size yield Name(eco=name, com=com, rep=rep), com
[docs] def gen_added_value( sources: Ecosystem, reps: int = 1, pool: Community | None = None, name: str | None = "added_value", ) -> Iterator[tuple[Name, Community]]: """ Generate new sub communities from ``sources`` by leave-one-out (sub **minus**) or by add-one-in (sub **added**). For each **original** source community of size _N_, generate _N_ **minus** communities of size _N_-1 by leaving one member out, and ``reps`` **added** communities of size _N_+1 by adding a member from the pool. :param sources: sequence of named or unnamed source communities :param reps: number of repetitions of each sample :param pool: community to round out samples whose sizes are not a multiple of the number of sources; by default, all sources combined :param name: optional prefix for sample name """ sources = [(Name(com=x), x) if isinstance(x, list) else x for x in sources] if not pool: pool = list(chain.from_iterable(c for _, c in sources)) logging.info( "Sample %s: added value for %d sources, reps %s", name, len(sources), reps ) # Minus for source_name, source_com in sources: logging.debug( "Minus %s generates %d communities", str(source_name), len(source_com) ) for rep in range(len(source_com)): minus = source_com[:] del minus[rep] yield Name(name=source_name, sub="minus", rep=rep + 1, com=minus), minus # Added for source_name, source_com in sources: logging.debug("Added %s generates %d communities", str(source_name), reps) censored = [m for m in pool if m not in {k: None for k in source_com}] for rep in range(reps): added = source_com[:] + random.choices(censored, k=1) yield Name(name=source_name, sub="added", rep=rep + 1, com=added), added # Original for source_name, source_com in sources: logging.debug("Original %s generates 1 community", str(source_name)) yield Name(name=source_name, sub="original", rep=None), source_com