Source code for validation.types
"""Core types and dataclasses for the validation package.
This module defines all the type definitions used across the validation
and calibration packages.
"""
from __future__ import annotations
from collections.abc import Callable
from dataclasses import dataclass, field
from enum import Enum, auto
from pathlib import Path
from typing import TYPE_CHECKING, Any, Literal
if TYPE_CHECKING:
import bamengine as bam
from bamengine import SimulationResults
# =============================================================================
# Constants
# =============================================================================
# Default seeds for stability testing across validation/calibration
DEFAULT_STABILITY_SEEDS: list[int] = list(range(100))
# Type alias for validation status
Status = Literal["PASS", "WARN", "FAIL"]
# =============================================================================
# Enums for MetricSpec configuration
# =============================================================================
class CheckType(Enum):
"""Type of check to perform on a metric."""
MEAN_TOLERANCE = auto() # value within target ± tolerance
RANGE = auto() # value within [min, max]
PCT_WITHIN = auto() # percentage meeting target threshold
OUTLIER = auto() # outlier percentage penalty
BOOLEAN = auto() # simple true/false check (e.g., > 0)
class MetricGroup(Enum):
"""Grouping for metrics in reports."""
TIME_SERIES = auto()
CURVES = auto()
DISTRIBUTION = auto()
GROWTH = auto()
FINANCIAL = auto()
GROWTH_RATE_DIST = auto()
IMPROVEMENT = auto() # Improvement over baseline scenario
class MetricFormat(Enum):
"""Display format for metric values."""
DEFAULT = auto() # standard decimal format
PERCENT = auto() # multiply by 100 and add %
TREND = auto() # high precision for trend coefficients
INTEGER = auto() # round to integer
# =============================================================================
# MetricSpec - The core abstraction
# =============================================================================
[docs]
@dataclass
class MetricSpec:
"""Unified specification for a validation metric.
This dataclass captures everything needed to validate a single metric.
Target values are looked up from YAML using standardized keys:
- MEAN_TOLERANCE: expects 'target' and 'tolerance' keys
- RANGE: expects 'min' and 'max' keys
- PCT_WITHIN: expects 'target' and 'min' keys
- OUTLIER: expects 'max_outlier' key (and optional 'penalty_weight')
- BOOLEAN: uses threshold defined here
"""
name: str # e.g., "unemployment_rate_mean"
field: str # attribute on Metrics dataclass
check_type: CheckType
target_path: str # dot-separated path in YAML: "time_series.unemployment.mean"
weight: float = 1.0
group: MetricGroup = MetricGroup.TIME_SERIES
format: MetricFormat = MetricFormat.DEFAULT
# For BOOLEAN checks only
threshold: float = 0.0 # value must be > threshold
invert: bool = False # if True, value must be < threshold
# Custom target description (if None, auto-generated)
target_desc: str | None = None
# =============================================================================
# Validation Result Types
# =============================================================================
[docs]
@dataclass
class MetricResult:
"""Result of validating a single metric."""
name: str
status: Status
actual: float
target_desc: str
score: float # 0-1 score (1 = perfect match)
weight: float = 1.0 # Weight for total score calculation
message: str = ""
group: MetricGroup = MetricGroup.TIME_SERIES
format: MetricFormat = MetricFormat.DEFAULT
[docs]
@dataclass
class ValidationScore:
"""Overall validation result with scoring for comparison."""
metric_results: list[MetricResult]
total_score: float # Weighted average of all metric scores
n_pass: int
n_warn: int
n_fail: int
config: dict[str, Any] = field(default_factory=dict) # Config used for this run
@property
def passed(self) -> bool:
"""True if no metrics failed validation."""
return self.n_fail == 0
def __str__(self) -> str:
return (
f"ValidationScore(total={self.total_score:.3f}, "
f"pass={self.n_pass}, warn={self.n_warn}, fail={self.n_fail})"
)
[docs]
@dataclass
class BufferStockValidationScore(ValidationScore):
"""Buffer-stock validation result with improvement tracking over Growth+.
Per-seed PASS/FAIL is determined solely by the 8 unique buffer-stock
metrics (wealth distribution fits, MPC, dissaving). Improvement over
Growth+ is assessed at the aggregate level after stability testing.
The ``improvement_deltas`` are computed per seed (informational) but
do not affect ``passed`` or ``total_score``.
"""
baseline_score: ValidationScore | None = None
"""Growth+ baseline result used for comparison (same seed)."""
improvement_deltas: dict[str, float] = field(default_factory=dict)
"""Per-metric score deltas: ``bs_score - gp_score`` (informational)."""
degraded_metrics: list[str] = field(default_factory=list)
"""Growth+ metrics with systematic degradation (populated at aggregate level
by :func:`~validation.run_buffer_stock_stability_test`, not per seed)."""
blend_alpha: float = 0.6
"""Informational only. Not used in score computation."""
@dataclass
class MetricStats:
"""Statistics for a single metric across multiple seeds."""
name: str
mean_value: float
std_value: float
mean_score: float
std_score: float
pass_rate: float # Fraction of seeds where this metric passed (not FAIL)
format: MetricFormat = MetricFormat.DEFAULT
[docs]
@dataclass
class StabilityResult:
"""Result of multi-seed stability testing."""
seed_results: list[ValidationScore] # Individual seed results
# Aggregate score metrics
mean_score: float # Mean total score across seeds
std_score: float # Standard deviation of scores
min_score: float # Worst seed
max_score: float # Best seed
pass_rate: float # Fraction of seeds that passed (no FAILs)
n_seeds: int # Number of seeds tested
# Per-metric stability
metric_stats: dict[str, MetricStats] # Stats for each metric
@property
def is_stable(self) -> bool:
"""True if pass_rate >= 90% and std_score <= 0.15."""
return self.pass_rate >= 0.9 and self.std_score <= 0.15
def __str__(self) -> str:
return (
f"StabilityResult(mean={self.mean_score:.3f}±{self.std_score:.3f}, "
f"pass_rate={self.pass_rate:.0%}, seeds={self.n_seeds})"
)
# =============================================================================
# Scenario Configuration
# =============================================================================
[docs]
@dataclass
class Scenario:
"""Configuration for a validation scenario.
This dataclass bundles everything needed to run validation for a specific
scenario (baseline, growth_plus, or buffer_stock).
"""
name: str
metric_specs: list[MetricSpec]
collect_config: dict[str, Any]
targets_path: Path # absolute path to scenario's targets.yaml
compute_metrics: Callable[[bam.Simulation, SimulationResults, int], Any]
default_config: dict[str, Any] = field(default_factory=dict)
setup_hook: Callable[[bam.Simulation | None], None] | None = None
"""Optional hook called twice: first with ``None`` (to trigger imports/registration),
then with the ``Simulation`` instance (to attach roles/extensions)."""
title: str = "" # report title, e.g. "BASELINE SCENARIO VALIDATION"
stability_title: str = "" # stability report title, e.g. "SEED STABILITY TEST"