"""
Utility functions for BAM Engine.
This module provides common utility functions used throughout BAM Engine,
including statistical operations (trimmed means), random sampling utilities,
efficient array operations, and batch market matching primitives.
Constants
---------
EPS : float
Machine epsilon for numerical comparisons (1.0e-9). Used to avoid
division by zero and detect near-zero values.
Functions
---------
trim_mean
Calculate two-sided trimmed mean (SciPy-style).
trimmed_weighted_mean
Calculate weighted trimmed mean with optional weight filtering.
sample_beta_with_mean
Draw samples from Beta distribution with specified mean.
select_top_k_indices_sorted
Efficiently select and sort top-k elements using argpartition.
grouped_cumsum
Per-group prefix sums via the subtract-offset trick.
resolve_conflicts
Batch conflict resolution for oversubscribed targets.
See Also
--------
bamengine.typing : Type aliases used in this module
scipy.stats.trim_mean : SciPy implementation of trimmed mean
Notes
-----
- Trimmed means are robust statistics that exclude extreme values
- Beta sampling is used for initialization with controlled variance
- Top-k selection uses argpartition for O(n) performance vs O(n log n) for sort
- grouped_cumsum and resolve_conflicts underpin vectorized market matching
"""
from __future__ import annotations
import numpy as np
from numpy.random import Generator, default_rng
from bamengine import Rng
from bamengine.typing import Bool1D, Float1D, Idx1D, Int1D
EPS = 1.0e-9
[docs]
def trim_mean(values: Float1D, trim_pct: float = 0.05) -> float:
"""
Calculate two-sided trimmed mean (robust statistic).
Computes the mean after removing a percentage of the smallest and largest
values. This provides a robust estimate of central tendency that is less
sensitive to outliers than the arithmetic mean.
Parameters
----------
values : Float1D
1D array of values to average.
trim_pct : float, optional
Proportion of values to trim from each tail (default: 0.05 = 5%).
For example, 0.05 removes the bottom 5% and top 5% of values.
Returns
-------
float
Trimmed mean of the values. Returns 0.0 if input array is empty.
Examples
--------
Calculate trimmed mean removing 10% from each tail:
>>> import numpy as np
>>> from bamengine.utils import trim_mean
>>> values = np.array([1, 2, 3, 4, 5, 100]) # 100 is an outlier
>>> trim_mean(values, trim_pct=0.10)
3.5
Default 5% trimming:
>>> values = np.arange(1, 101) # 1 to 100
>>> mean = trim_mean(values) # Removes bottom/top 5% (5 values each)
Notes
-----
- Uses `np.argpartition` for O(n) selection instead of O(n log n) sorting
- If trim_pct results in k=0, returns regular mean
- Compatible with scipy.stats.trim_mean behavior
- Widely used in BAM Engine for initializing new firms/banks from survivors
See Also
--------
trimmed_weighted_mean : Weighted version with optional weight filtering
scipy.stats.trim_mean : SciPy implementation
"""
if values.size == 0:
return 0.0
k = int(round(trim_pct * values.size))
if k == 0:
return float(values.mean())
idx = np.argpartition(values, (k, values.size - k - 1))
core = values[idx[k : values.size - k]]
return float(core.mean()) if core.size > 0 else 0.0
[docs]
def trimmed_weighted_mean(
values: Float1D,
weights: Float1D | None = None,
trim_pct: float = 0.05,
min_weight: float = 1e-3,
) -> float:
"""
Calculate trimmed weighted mean with optional weight filtering.
Computes a weighted mean after (1) filtering out entries with negligible
weights, and (2) trimming extreme values. If no weights are provided,
falls back to unweighted trimmed mean.
Parameters
----------
values : Float1D
1D array of values to average.
weights : Float1D, optional
1D array of weights (same length as values). If None, computes
unweighted trimmed mean (ignores min_weight parameter).
trim_pct : float, optional
Proportion of values to trim from each tail (default: 0.05 = 5%).
Trimming is applied after weight filtering.
min_weight : float, optional
Minimum weight threshold (default: 1e-3). Entries with weights below
this are excluded before computing mean. Ignored if weights is None.
Returns
-------
float
Trimmed weighted mean. Returns 0.0 if no valid entries remain after
filtering and trimming.
Examples
--------
Weighted mean with weight filtering:
>>> import numpy as np
>>> from bamengine.utils import trimmed_weighted_mean
>>> values = np.array([10, 20, 30, 40])
>>> weights = np.array([0.5, 1.0, 1.5, 0.0001]) # Last weight too small
>>> trimmed_weighted_mean(values, weights, trim_pct=0.0, min_weight=0.01)
22.0 # Only first 3 values included
Trimmed weighted mean:
>>> values = np.array([1, 2, 3, 100]) # 100 is outlier
>>> weights = np.array([1, 1, 1, 1])
>>> trimmed_weighted_mean(values, weights, trim_pct=0.25) # Trims 1 from each end
2.5 # Mean of [2, 3]
Unweighted mode (weights=None):
>>> values = np.array([1, 2, 3, 4, 5])
>>> trimmed_weighted_mean(values, weights=None, trim_pct=0.20)
3.0 # Regular trimmed mean
Notes
-----
- Weight filtering occurs before trimming
- Trimming is based on value ordering (not weight ordering)
- Falls back to unweighted mean if all weights are zero
- If weights is None, min_weight parameter is ignored
See Also
--------
trim_mean : Unweighted trimmed mean (faster if no weights needed)
numpy.average : NumPy weighted average (no trimming)
"""
values = np.asarray(values)
if weights is None:
return trim_mean(values, trim_pct)
# Weighted logic
weights = np.asarray(weights)
mask = weights >= min_weight
values = values[mask]
weights = weights[mask]
if values.size == 0:
return 0.0
idx = np.argsort(values)
values = values[idx]
weights = weights[idx]
k = int(round(trim_pct * values.size))
if k == 0:
# Standard weighted mean
return float(np.average(values, weights=weights))
# Apply trim
values_trimmed = values[k : values.size - k]
weights_trimmed = weights[k : weights.size - k]
if weights_trimmed.sum() == 0:
return (
float(values_trimmed.mean()) if values_trimmed.size else 0.0
) # fallback: unweighted mean
return float(np.average(values_trimmed, weights=weights_trimmed))
[docs]
def sample_beta_with_mean(
mean: float,
n: int = 1,
low: float | None = None,
high: float | None = None,
concentration: float = 12.0,
*,
relative_margin: float = 0.50,
rng: Generator | None = None,
) -> float | Float1D:
"""
Draw *n* samples from a Beta distribution scaled to [low, high),
such that the **scaled mean is approximately ``mean``**.
Parameters
----------
mean : float
Desired mean of the returned samples. Can be any positive value.
n : int, default 1
Number of samples to draw.
low, high : float or None, optional
Bounds of the target interval. If either is None, it is derived as::
low = mean * (1 - relative_margin)
high = mean * (1 + relative_margin)
making the mean the midpoint of the interval. A tiny eps is added
when needed to guarantee ``low < mean < high``.
concentration : float, default 12
Total pseudo–sample size of the Beta (a+b). Larger values
concentrate the draws more tightly around *mean*.
relative_margin : float, default 0.50
Half-width of the automatically generated interval as a fraction of
*mean*. Set 0.25 for ±25 %, 1.0 for ±100 %, etc.
rng : np.random.Generator, optional
Random number generator (falls back to ``default_rng()``).
Returns
-------
float or ndarray
*n* samples if ``n > 1``; otherwise a scalar.
Notes
-----
• If *mean* is very close to zero, the automatically chosen ``low`` may be
negative; it is then clipped to zero and an eps-wide gap is kept so that
``mean`` stays strictly inside the interval.
• The function raises ``ValueError`` for invalid arguments.
"""
if n < 1:
raise ValueError("n must be at least 1")
if concentration <= 0:
raise ValueError("concentration must be > 0")
if relative_margin <= 0:
raise ValueError("relative_margin must be > 0")
# Automatic bounds
if low is None or high is None:
half_span = abs(mean) * relative_margin
low = mean - half_span if low is None else low
high = mean + half_span if high is None else high
# Ensure ordering and strict inequality
eps = max(abs(mean), 1.0) * 1e-12 # machine-safe gap
if not (low < mean < high):
# Shift the offending bound just enough to satisfy the inequality
if low >= mean:
low = mean - eps
if high <= mean:
high = mean + eps
if not (low < mean < high): # pragma: no cover
raise ValueError(
f"Could not place mean ({mean}) strictly inside (low, high): "
f"low={low}, high={high}"
)
# Map mean to (0,1) for Beta parameterization
m = (mean - low) / (high - low)
a = m * concentration
b = (1.0 - m) * concentration
rng = default_rng() if rng is None else rng
samples = rng.beta(a, b, size=n)
scaled = low + (high - low) * samples
return scaled.item() if n == 1 else scaled
[docs]
def select_top_k_indices_sorted(
values: Float1D, k: int, descending: bool = True
) -> Idx1D:
"""
Returns indices of k smallest/largest elements, sorted along the last axis.
Identifies the k elements (smallest or largest based on the `descending`
flag) along the last axis of the input N-dimensional array `values`.
It then returns the original indices of these k elements, sorted such
that `np.take_along_axis(values, returned_indices, axis=-1)` yields
values in the specified order (ascending or descending).
Parameters
----------
values : numpy.ndarray
N-dimensional array of numerical values. Selection occurs along the last axis.
k : int
The number of indices to select.
descending : bool, optional
If True (default), selects k largest elements (sorted largest to smallest).
If False, selects k smallest elements (sorted smallest to largest).
Returns
-------
numpy.ndarray
N-dimensional array of integer indices, shaped `values.shape[:-1] +
(selected_k,)`. `selected_k` is `k` (or 0 if `k<=0`, or `values.shape[-1]`
if `k` exceeds the last axis dimension). Indices refer to positions along the
last axis of `values`.
Notes
-----
- Operates on the last axis of N-dimensional arrays.
- Uses `np.argpartition` for efficient selection when `k` is less than the
size of the last dimension.
- Scalar inputs are treated as 1-element arrays.
- Handles empty input arrays and `k <= 0` by returning appropriately
shaped empty index arrays.
"""
# Ensure input is a NumPy array (defensive: accepts list input at runtime).
if not isinstance(values, np.ndarray):
values = np.array(values, dtype=float) # type: ignore[unreachable]
# Ensure values is at least 1D for consistent axis=-1 operations.
if values.ndim == 0:
values = np.atleast_1d(values)
# If k is non-positive, return an empty array with appropriate shape.
if k <= 0:
return np.empty(values.shape[:-1] + (0,), dtype=np.intp)
# If array is empty (and k > 0), also return an empty array.
if values.size == 0:
return np.empty(values.shape[:-1] + (0,), dtype=np.intp)
n = values.shape[-1] # Size of the last dimension.
# If k >= n, all elements are selected; sort all indices along the last axis.
if k >= n:
# Sort all elements if k is large enough.
# Negate values for descending sort with np.argsort.
return np.argsort(-values if descending else values, axis=-1)
# For k < n:
# Determine values to partition (negate for descending to find largest).
values_to_partition = -values if descending else values
# Efficiently find indices of the k smallest/largest elements
# (unsorted among themselves).
# `kth=k-1` because argpartition is 0-indexed and we want the first k elements.
partitioned_indices = np.argpartition(values_to_partition, kth=k - 1, axis=-1)
# Take the indices of the first k elements from the partitioned result.
k_indices_unsorted = partitioned_indices[..., :k]
# Get the actual values corresponding to these k selected indices.
k_values = np.take_along_axis(values, k_indices_unsorted, axis=-1)
# Determine values to sort within the k-selection.
k_values_to_sort = -k_values if descending else k_values
# Get the order to sort these k values.
order_within_k_selection = np.argsort(k_values_to_sort, axis=-1)
# Apply this order to `k_indices_unsorted` to get the final sorted indices.
k_indices_sorted_final = np.take_along_axis(
k_indices_unsorted, order_within_k_selection, axis=-1
)
return k_indices_sorted_final
# ── grouped_cumsum ────────────────────────────────────────────────────────────
def grouped_cumsum(values: Float1D, group_starts: Int1D) -> Float1D:
"""Per-group prefix sums using the subtract-offset trick.
Given a flat array *values* that is logically partitioned into contiguous
groups whose boundaries are indicated by *group_starts*, return an array
of the same length where each element is the cumulative sum within its
group.
Parameters
----------
values : Float1D
Values to accumulate (length *n*).
group_starts : Int1D
Sorted indices into *values* where each new group begins.
``group_starts[0]`` is typically 0.
Returns
-------
Float1D
Per-group cumulative sums, same shape as *values*.
Examples
--------
>>> import numpy as np
>>> vals = np.array([1.0, 2.0, 3.0, 10.0, 20.0])
>>> starts = np.array([0, 3]) # two groups: [0:3], [3:5]
>>> grouped_cumsum(vals, starts)
array([ 1., 3., 6., 10., 30.])
"""
if values.size == 0:
return np.empty(0, dtype=np.float64)
cs = np.cumsum(values)
if group_starts.size <= 1:
# Single group (or no groups) — global cumsum is correct.
return cs
# Map each position to its group, then subtract the cumsum value
# just before that group's start.
positions = np.arange(values.size)
group_ids = np.searchsorted(group_starts, positions, side="right") - 1
# Per-group offset: cs[group_starts[g] - 1] for g > 0, else 0
group_offsets = np.zeros(group_starts.size, dtype=np.float64)
inner_starts = group_starts[1:]
valid = inner_starts < values.size
if valid.any():
group_offsets[1:][valid] = cs[inner_starts[valid] - 1]
return cs - group_offsets[group_ids]
# ── resolve_conflicts ─────────────────────────────────────────────────────────
def resolve_conflicts(
sender_ids: Idx1D,
target_ids: Idx1D,
capacity_per_target: Int1D,
n_targets: int,
rng: Rng,
) -> Bool1D:
"""Batch conflict resolution for oversubscribed targets.
When multiple senders choose the same target, randomly accept up to
``capacity_per_target[t]`` senders for each target *t*.
Parameters
----------
sender_ids : Idx1D
Sender indices (length *m*). Used only for output alignment; the
returned mask is ordered the same as ``sender_ids``.
target_ids : Idx1D
Which target each sender chose (length *m*).
capacity_per_target : Int1D
Maximum number of accepted senders per target (length *n_targets*).
n_targets : int
Total number of possible targets.
rng : Rng
NumPy random generator for tie-breaking.
Returns
-------
Bool1D
Boolean mask of length *m*; ``True`` ⇒ sender is accepted.
Notes
-----
The algorithm:
1. Sort senders by target (``np.argsort``).
2. Find group boundaries per target (``np.searchsorted``).
3. Within each over-subscribed group, randomly select *capacity* senders.
``np.add.at`` is not needed here — we only read capacities.
"""
m = sender_ids.size
if m == 0:
return np.empty(0, dtype=np.bool_)
# Assign random priorities and sort by (target, priority) — gives
# random ordering within each target group in a single vectorized pass.
priorities = rng.random(m)
order = np.lexsort((priorities, target_ids))
sorted_targets = target_ids[order]
# Group boundaries via bincount → cumsum
counts = np.bincount(sorted_targets, minlength=n_targets)
group_starts = np.empty(n_targets + 1, dtype=np.intp)
group_starts[0] = 0
np.cumsum(counts, out=group_starts[1:])
# Within-group rank: position in sorted array minus group start
rank = np.arange(m, dtype=np.intp) - np.repeat(group_starts[:-1], counts)
# Accept if rank < capacity for that target (and capacity > 0)
caps = capacity_per_target[sorted_targets]
accepted_sorted = (rank < caps) & (caps > 0)
# Un-sort back to original order
accepted = np.empty(m, dtype=np.bool_)
accepted[order] = accepted_sorted
return accepted
# ── _flatten_and_shuffle_groups ──────────────────────────────────────────────
def _flatten_and_shuffle_groups(
source: Idx1D,
boundaries_lo: Idx1D,
boundaries_hi: Idx1D,
rng: Rng,
) -> tuple[Idx1D, Idx1D, Idx1D, Idx1D, Idx1D]:
"""Extract items from ragged groups and shuffle within each group.
Given a *source* array and per-group boundaries ``[lo, hi)``, builds a
flat array of all items across the specified groups, then randomly
reorders items within each group using a random-priority lexsort.
Parameters
----------
source : Idx1D
Source array to index into (e.g. sorted worker IDs).
boundaries_lo, boundaries_hi : Idx1D
Per-group start/end indices into *source* (length *n_groups*).
rng : Rng
Random generator for within-group shuffling.
Returns
-------
items : Idx1D
Flat array of items in shuffled-within-group order (length *total*).
group_idx : Idx1D
Which group each item belongs to (0..n_groups-1).
rank : Idx1D
Within-group position (0..group_size-1) in shuffled order.
group_sizes : Idx1D
Number of items per group.
group_starts : Idx1D
Cumulative group starts (length *n_groups + 1*);
``group_starts[-1] == total``.
"""
group_sizes = boundaries_hi - boundaries_lo
total = int(group_sizes.sum())
n_groups = boundaries_lo.size
if total == 0:
empty = np.empty(0, dtype=np.intp)
starts = np.zeros(n_groups + 1, dtype=np.intp)
return empty, empty, empty, group_sizes, starts
# Cumulative group boundaries in the flat output
group_starts = np.empty(n_groups + 1, dtype=np.intp)
group_starts[0] = 0
np.cumsum(group_sizes, out=group_starts[1:])
# Map each flat position to its group and position within source
group_idx = np.repeat(np.arange(n_groups, dtype=np.intp), group_sizes)
within = np.arange(total, dtype=np.intp) - np.repeat(group_starts[:-1], group_sizes)
positions = np.repeat(boundaries_lo, group_sizes) + within
flat_items = source[positions]
# Shuffle within groups via random-priority lexsort
priorities = rng.random(total)
order = np.lexsort((priorities, group_idx))
# Reorder all arrays; groups stay contiguous, internal order is randomized
items = flat_items[order]
group_idx_sorted = group_idx[order]
rank = np.arange(total, dtype=np.intp) - np.repeat(group_starts[:-1], group_sizes)
return items, group_idx_sorted, rank, group_sizes, group_starts