# -*- coding: utf-8 -*-
"""
Generate communities by stratified sampling.
"""
import logging
import random
from itertools import chain
from typing import Iterator
from . import ess
from .types import Community, Ecosystem, Name
# Generate communities by (possibly trivial) sampling.
[docs]
def gen_sample(
sources: Ecosystem,
reps: int = 1,
sizes: list[int] | None = None,
pool: Community | None = None,
name: str | None = None,
) -> Iterator[tuple[Name, Community]]:
"""
Generate ``reps`` new communities of ``sizes`` by stratified sampling of a set
of ``sources`` communities, possibly of size one.
:param sources: sequence of named or unnamed source communities
:param reps: number of repetitions of each sample
:param sizes: list of sample sizes; if empty, each sample is same size as its source
:param pool: community to round out samples whose sizes are not a multiple of the number
of sources; by default, all sources combined
:param name: optional prefix for the sample name
"""
sizes = sizes or []
sources = [x if isinstance(x, tuple) else (Name(com=x), x) for x in sources]
assert all(isinstance(k, Name) for k, _ in sources)
# # Get community name from the source if there is only one, or from an
# # eco_name if given. Otherwise the name will be generated on the fly.
# if not name and len(sources) == 1:
# name = str(sources[0][0].ident)
# If not given, pool is all models in source commmunities.
if not pool:
pool = list(chain.from_iterable(s for _, s in sources))
logging.info(
"Sample %s: generate from %d source%s, size%s [%s], reps %d for %s; pool %d model%s",
name,
len(sources),
ess(sources),
ess(sizes),
" ".join(str(i) for i in sizes),
reps,
str(name),
len(pool),
ess(pool),
)
if not sizes:
logging.info("Trivial sampling should produce %d coms", len(sources))
for source_name, source_com in sources:
yield Name(name=source_name, com=source_com), source_com
for size in sizes:
for rep in range(1, reps + 1):
split = divmod(size, len(sources))
part = chain.from_iterable(
random.sample(src, split[0]) for _, src in sources
)
com = list(part)
logging.debug(
"Gen %d rep %d: split %dR%d → (%d) part%s",
size,
rep,
*split,
len(com),
ess(com),
)
if split[1] > 0:
censored = [m for m in pool if m not in {k: None for k in com}]
extra = random.sample(censored, split[1])
com += extra
logging.debug(
"Gen %d rep %d: (%d) censored → (%d) extra",
size,
rep,
len(censored),
len(extra),
)
assert len(com) == size
yield Name(eco=name, com=com, rep=rep), com
[docs]
def gen_added_value(
sources: Ecosystem,
reps: int = 1,
pool: Community | None = None,
name: str | None = "added_value",
) -> Iterator[tuple[Name, Community]]:
"""
Generate new sub communities from ``sources`` by leave-one-out (sub
**minus**) or by add-one-in (sub **added**).
For each **original** source community of size _N_, generate
_N_ **minus** communities of size _N_-1 by leaving one member out, and
``reps`` **added** communities of size _N_+1 by adding a member from the pool.
:param sources: sequence of named or unnamed source communities
:param reps: number of repetitions of each sample
:param pool: community to round out samples whose sizes are not a multiple of the number
of sources; by default, all sources combined
:param name: optional prefix for sample name
"""
sources = [(Name(com=x), x) if isinstance(x, list) else x for x in sources]
if not pool:
pool = list(chain.from_iterable(c for _, c in sources))
logging.info(
"Sample %s: added value for %d sources, reps %s", name, len(sources), reps
)
# Minus
for source_name, source_com in sources:
logging.debug(
"Minus %s generates %d communities", str(source_name), len(source_com)
)
for rep in range(len(source_com)):
minus = source_com[:]
del minus[rep]
yield Name(name=source_name, sub="minus", rep=rep + 1, com=minus), minus
# Added
for source_name, source_com in sources:
logging.debug("Added %s generates %d communities", str(source_name), reps)
censored = [m for m in pool if m not in {k: None for k in source_com}]
for rep in range(reps):
added = source_com[:] + random.choices(censored, k=1)
yield Name(name=source_name, sub="added", rep=rep + 1, com=added), added
# Original
for source_name, source_com in sources:
logging.debug("Original %s generates 1 community", str(source_name))
yield Name(name=source_name, sub="original", rep=None), source_com