Source code for validation.robustness.sensitivity

"""Univariate sensitivity analysis (Section 3.10.1, Part 2).

Varies one parameter at a time while holding all others at baseline
values. For each parameter value, runs multiple simulations with
different seeds and computes the same statistics as the internal
validity analysis.

The five parameter groups from the book are:
    (i)   H — local credit markets
    (ii)  Z — local consumption goods markets
    (iii) M — local labour markets (applications)
    (iv)  theta — employment contracts duration
    (v)   Economy size and composition
"""

from __future__ import annotations

from collections.abc import Callable
from concurrent.futures import ProcessPoolExecutor, as_completed
from dataclasses import dataclass, field
from typing import Any

import numpy as np
from numpy.typing import NDArray

import bamengine as bam
from validation.robustness.experiments import (
    ALL_EXPERIMENT_NAMES,
    EXPERIMENTS,
    Experiment,
)
from validation.robustness.internal_validity import (
    COMOVEMENT_VARIABLES,
    SeedAnalysis,
    _run_seed,
)
from validation.robustness.stats import impulse_response
from validation.scenarios._utils import adjust_burn_in

# ─── Result Types ───────────────────────────────────────────────────────────


[docs] @dataclass class ValueResult: """Aggregated results for one parameter value across multiple seeds.""" label: str config_overrides: dict[str, Any] n_seeds: int n_collapsed: int # Mean co-movements across seeds mean_comovements: dict[str, NDArray[np.floating]] std_comovements: dict[str, NDArray[np.floating]] # Mean AR fit mean_ar_coeffs: NDArray[np.floating] mean_ar_r_squared: float mean_irf: NDArray[np.floating] # Cross-seed summary statistics stats: dict[str, dict[str, float]] = field(default_factory=dict) # Peak timing for co-movement classification mean_peak_lags: dict[str, int] = field(default_factory=dict) # Degenerate dynamics count n_degenerate: int = 0 @property def collapse_rate(self) -> float: return self.n_collapsed / self.n_seeds if self.n_seeds > 0 else 0.0 @property def degenerate_rate(self) -> float: return self.n_degenerate / self.n_seeds if self.n_seeds > 0 else 0.0
[docs] @dataclass class ExperimentResult: """Results for one sensitivity experiment (all parameter values).""" experiment: Experiment value_results: list[ValueResult] baseline_idx: int # Index of baseline value in value_results @property def baseline(self) -> ValueResult: return self.value_results[self.baseline_idx]
[docs] def get_stat_table(self, stat_name: str) -> list[tuple[str, float, float]]: """Get a statistic across all values as (label, mean, std) tuples.""" return [ ( vr.label, vr.stats.get(stat_name, {}).get("mean", float("nan")), vr.stats.get(stat_name, {}).get("std", float("nan")), ) for vr in self.value_results ]
[docs] @dataclass class SensitivityResult: """Full sensitivity analysis result across all experiments.""" experiments: dict[str, ExperimentResult] n_seeds_per_value: int n_periods: int burn_in: int
# ─── Aggregation Helper ──────────────────────────────────────────────────── def _aggregate_seed_analyses( seed_analyses: list[SeedAnalysis], irf_periods: int = 20, ) -> ValueResult: """Aggregate multiple SeedAnalysis objects into a ValueResult.""" valid = [a for a in seed_analyses if not a.degenerate] n_collapsed = sum(1 for a in seed_analyses if a.collapsed) n_degenerate = sum(1 for a in seed_analyses if a.degenerate) # Co-movements mean_comovements: dict[str, NDArray[np.floating]] = {} std_comovements: dict[str, NDArray[np.floating]] = {} # Determine expected array length from the first available seed any_seed = seed_analyses[0] if seed_analyses else None n_lags = len(any_seed.comovements[COMOVEMENT_VARIABLES[0]]) if any_seed else 9 for var in COMOVEMENT_VARIABLES: if valid: all_corrs = np.array([a.comovements[var] for a in valid]) mean_comovements[var] = np.nanmean(all_corrs, axis=0) std_comovements[var] = np.nanstd(all_corrs, axis=0) else: mean_comovements[var] = np.full(n_lags, np.nan) std_comovements[var] = np.full(n_lags, np.nan) # AR fit (average of individual fits) if valid: mean_phi1 = np.mean([a.ar_coeffs[1] for a in valid]) mean_const = np.mean([a.ar_coeffs[0] for a in valid]) mean_ar_coeffs = np.array([mean_const, mean_phi1]) mean_ar_r2 = float(np.mean([a.ar_r_squared for a in valid])) mean_irf = impulse_response(mean_ar_coeffs, n_periods=irf_periods) else: mean_ar_coeffs = np.zeros(2) mean_ar_r2 = 0.0 mean_irf = np.zeros(irf_periods) # Summary statistics stat_fields = [ "unemployment_mean", "unemployment_std", "inflation_mean", "inflation_std", "gdp_growth_mean", "gdp_growth_std", "real_wage_mean", "productivity_mean", "phillips_corr", "okun_corr", "beveridge_corr", "firm_size_skewness_sales", "firm_size_skewness_net_worth", "firm_size_kurtosis_sales", "firm_size_kurtosis_net_worth", "firm_size_tail_index", "wage_productivity_ratio", ] stats_dict: dict[str, dict[str, float]] = {} for attr_name in stat_fields: values = [ getattr(a, attr_name) for a in valid if not np.isnan(getattr(a, attr_name)) ] if values: mean_val = float(np.mean(values)) stats_dict[attr_name] = { "mean": mean_val, "std": float(np.std(values)), "min": float(np.min(values)), "max": float(np.max(values)), "cv": float(np.std(values) / abs(mean_val)) if abs(mean_val) > 1e-10 else 0.0, } # Peak-lag aggregation: modal peak lag per variable mean_peak_lags: dict[str, int] = {} if valid: for var in COMOVEMENT_VARIABLES: lags_for_var = [a.peak_lags[var] for a in valid if var in a.peak_lags] if lags_for_var: # Mode: most common peak lag across seeds values, counts = np.unique(lags_for_var, return_counts=True) mean_peak_lags[var] = int(values[np.argmax(counts)]) return ValueResult( label="", # Set by caller config_overrides={}, # Set by caller n_seeds=len(seed_analyses), n_collapsed=n_collapsed, mean_comovements=mean_comovements, std_comovements=std_comovements, mean_ar_coeffs=mean_ar_coeffs, mean_ar_r_squared=mean_ar_r2, mean_irf=mean_irf, stats=stats_dict, mean_peak_lags=mean_peak_lags, n_degenerate=n_degenerate, ) # ─── Main Entry Point ──────────────────────────────────────────────────────
[docs] def run_sensitivity_analysis( experiments: list[str] | None = None, n_seeds: int = 20, n_periods: int = 1000, burn_in: int = 500, n_workers: int = 10, max_lag: int = 4, ar_order: int = 2, irf_periods: int = 20, verbose: bool = True, setup_hook: Callable[[bam.Simulation], None] | None = None, collect_config: dict[str, Any] | None = None, **config_overrides: Any, ) -> SensitivityResult: """Run univariate sensitivity analysis (Section 3.10.1, Part 2). For each experiment, varies one parameter while holding all others at baseline values. Runs ``n_seeds`` simulations per parameter value. Parameters ---------- experiments : list[str] or None Experiment names to run. None means all experiments. n_seeds : int Number of random seeds per parameter value. n_periods : int Simulation periods per seed. burn_in : int Burn-in periods to discard. n_workers : int Parallel workers for simulation execution. max_lag : int Maximum lead/lag for cross-correlation. ar_order : int AR order for GDP cycle fitting. irf_periods : int Impulse-response function horizon. verbose : bool Print progress messages. setup_hook : callable or None Optional function ``(sim) -> None`` called after ``Simulation.init()`` to attach extension roles, events, and config. Must be a module-level function for ``ProcessPoolExecutor`` pickling. collect_config : dict or None Custom collection configuration. When *None*, uses the default ``ROBUSTNESS_COLLECT_CONFIG``. **config_overrides Additional simulation config overrides applied to all runs. Returns ------- SensitivityResult Results for all requested experiments. """ burn_in = adjust_burn_in(burn_in, n_periods) if experiments is None: experiments = ALL_EXPERIMENT_NAMES experiment_results: dict[str, ExperimentResult] = {} for exp_name in experiments: if exp_name not in EXPERIMENTS: raise ValueError( f"Unknown experiment: {exp_name}. Available: {ALL_EXPERIMENT_NAMES}" ) exp = EXPERIMENTS[exp_name] labels = exp.get_labels() n_values = len(exp.values) if verbose: print(f"\n{'=' * 60}") print(f"Experiment: {exp.description}") print( f" {n_values} values x {n_seeds} seeds = " f"{n_values * n_seeds} simulations" ) print(f"{'=' * 60}") value_results: list[ValueResult] = [] baseline_idx = 0 for val_idx in range(n_values): label = labels[val_idx] val_config = exp.get_config(val_idx) # Merge with any global overrides run_config = {**config_overrides, **val_config} if verbose: print(f"\n [{val_idx + 1}/{n_values}] {label}") # Check if this is the baseline value if exp.values[val_idx] == exp.baseline_value: baseline_idx = val_idx # Run all seeds (parallel or sequential) seed_analyses: list[SeedAnalysis] = [] seeds = list(range(n_seeds)) if n_workers == 1: for i, seed in enumerate(seeds, 1): analysis = _run_seed( seed, n_periods, burn_in, run_config, max_lag, ar_order, irf_periods, setup_hook, collect_config, exp.setup_fn, ) seed_analyses.append(analysis) if verbose and i % 5 == 0: print(f" {i}/{n_seeds} seeds done") else: with ProcessPoolExecutor(max_workers=n_workers) as executor: futures = { executor.submit( _run_seed, seed, n_periods, burn_in, run_config, max_lag, ar_order, irf_periods, setup_hook, collect_config, exp.setup_fn, ): seed for seed in seeds } for completed, future in enumerate(as_completed(futures), 1): analysis = future.result() seed_analyses.append(analysis) if verbose and completed % 5 == 0: print(f" {completed}/{n_seeds} seeds done") # Sort by seed seed_analyses.sort(key=lambda a: a.seed) # Aggregate vr = _aggregate_seed_analyses(seed_analyses, irf_periods) vr.label = label vr.config_overrides = val_config n_ok = vr.n_seeds - vr.n_collapsed if verbose: parts = [f"{n_ok} valid", f"{vr.n_collapsed} collapsed"] if vr.n_degenerate > 0: parts.append(f"{vr.n_degenerate} degenerate") print(f" Result: {', '.join(parts)}") if "unemployment_mean" in vr.stats: u = vr.stats["unemployment_mean"]["mean"] print(f" Unemployment: {u:.1%}") value_results.append(vr) experiment_results[exp_name] = ExperimentResult( experiment=exp, value_results=value_results, baseline_idx=baseline_idx, ) return SensitivityResult( experiments=experiment_results, n_seeds_per_value=n_seeds, n_periods=n_periods, burn_in=burn_in, )