Source code for calibration.cost

"""Targeted cost analysis -- measure the cost of swapping values into a base config.

Evaluates the impact of substituting preferred parameter values into an
optimized base configuration. Classifies each swap by cost:
FREE (<0.002), CHEAP (<0.005), MODERATE (<0.010), EXPENSIVE (>=0.010).
"""

from __future__ import annotations

import argparse
import json
import statistics
from dataclasses import dataclass
from pathlib import Path
from typing import Any

from calibration.grid import generate_combinations
from calibration.io import OUTPUT_DIR, load_stability, save_stability
from calibration.screening import run_screening
from calibration.stability import (
    _evaluate_single_seed,
    parse_stability_tiers,
    run_tiered_stability,
)


[docs] @dataclass class SwapResult: """Result of swapping a single parameter value into the base config. Attributes ---------- param : str Parameter name. value : Any Swapped value. base_combined : float Base config's combined score. swap_combined : float Combined score with this value swapped in. delta : float Score change (swap - base). Negative = worse. classification : str Cost classification: FREE, CHEAP, MODERATE, or EXPENSIVE. pass_rate : float Pass rate with swapped value. """ param: str value: Any base_combined: float swap_combined: float delta: float classification: str pass_rate: float
[docs] def classify_cost(delta_abs: float) -> str: """Classify the absolute cost of a swap. Parameters ---------- delta_abs : float Absolute combined score difference. Returns ------- str "FREE", "CHEAP", "MODERATE", or "EXPENSIVE". """ if delta_abs < 0.002: return "FREE" if delta_abs < 0.005: return "CHEAP" if delta_abs < 0.010: return "MODERATE" return "EXPENSIVE"
[docs] def parse_swaps(swap_args: list[str]) -> dict[str, list[Any]]: """Parse swap arguments from CLI. Parameters ---------- swap_args : list[str] List of "param=v1,v2,v3" strings. Returns ------- dict[str, list] Parameter -> list of values to try. """ swaps: dict[str, list[Any]] = {} for arg in swap_args: param, values_str = arg.split("=", 1) values: list[Any] = [] for v in values_str.split(","): v = v.strip() try: values.append(int(v)) except ValueError: try: values.append(float(v)) except ValueError: values.append(v) swaps[param.strip()] = values return swaps
[docs] def run_cost_analysis( base_params: dict[str, Any], swaps: dict[str, list[Any]], scenario: str, n_seeds: int = 20, n_periods: int = 1000, n_workers: int = 10, base_combined: float | None = None, ) -> list[SwapResult]: """Run targeted cost analysis for parameter swaps. Parameters ---------- base_params : dict Base configuration (the stability winner). swaps : dict[str, list] Parameters to swap and their candidate values. scenario : str Scenario name. n_seeds : int Seeds per evaluation. n_periods : int Simulation periods. n_workers : int Parallel workers. base_combined : float, optional Pre-computed base combined score. If None, evaluates the base. Returns ------- list[SwapResult] Results for each swap, sorted by absolute delta. """ from concurrent.futures import ProcessPoolExecutor, as_completed def _eval_seeds( params: dict[str, Any], scenario: str, n_seeds: int, n_periods: int, n_workers: int, ) -> tuple[list[float], list[int]]: """Evaluate a config across seeds, returning (scores, fails).""" scores: list[float] = [] fails: list[int] = [] seeds = list(range(n_seeds)) if n_workers > 1: with ProcessPoolExecutor(max_workers=n_workers) as executor: futures = [ executor.submit( _evaluate_single_seed, params, scenario, s, n_periods ) for s in seeds ] for future in as_completed(futures): _, _, score, n_fail = future.result() scores.append(score) fails.append(n_fail) else: for seed in seeds: _, _, score, n_fail = _evaluate_single_seed( params, scenario, seed, n_periods ) scores.append(score) fails.append(n_fail) return scores, fails # Evaluate base if needed if base_combined is None: print(f" Evaluating base config ({n_seeds} seeds)...") scores, fails = _eval_seeds( base_params, scenario, n_seeds, n_periods, n_workers ) base_mean = statistics.mean(scores) base_std = statistics.stdev(scores) if len(scores) > 1 else 0.0 n_passed = sum(1 for nf in fails if nf == 0) base_pr = n_passed / len(fails) base_combined = base_mean * base_pr * (1.0 - base_std) print(f" Base combined: {base_combined:.4f}") # Evaluate each swap results: list[SwapResult] = [] total_swaps = sum(len(vals) for vals in swaps.values()) done = 0 for param, values in swaps.items(): for value in values: done += 1 swap_params = {**base_params, param: value} scores, fails = _eval_seeds( swap_params, scenario, n_seeds, n_periods, n_workers ) swap_mean = statistics.mean(scores) swap_std = statistics.stdev(scores) if len(scores) > 1 else 0.0 n_passed = sum(1 for nf in fails if nf == 0) swap_pr = n_passed / len(fails) swap_combined = swap_mean * swap_pr * (1.0 - swap_std) delta = swap_combined - base_combined classification = classify_cost(abs(delta)) results.append( SwapResult( param=param, value=value, base_combined=base_combined, swap_combined=swap_combined, delta=delta, classification=classification, pass_rate=swap_pr, ) ) print( f" [{done}/{total_swaps}] {param}={value}: " f"delta={delta:+.4f} ({classification})" ) results.sort(key=lambda r: abs(r.delta)) return results
[docs] def save_cost_results( results: list[SwapResult], scenario: str, path: Path, ) -> None: """Save cost analysis results to JSON.""" data = { "scenario": scenario, "results": [ { "param": r.param, "value": r.value, "base_combined": r.base_combined, "swap_combined": r.swap_combined, "delta": r.delta, "classification": r.classification, "pass_rate": r.pass_rate, } for r in results ], } path.parent.mkdir(parents=True, exist_ok=True) with open(path, "w") as f: json.dump(data, f, indent=2)
def _load_base_config( base_path: Path, ) -> tuple[dict[str, Any], float | None]: """Load base config from stability JSON or YAML config file. Returns ------- tuple[dict, float | None] (base_params, base_combined_score_or_None) """ if base_path.suffix in (".yml", ".yaml"): import yaml with open(base_path) as f: params = yaml.safe_load(f) or {} return params, None # no pre-computed combined score else: results = load_stability(base_path) if not results: raise ValueError(f"No results in {base_path}") return dict(results[0].params), results[0].combined_score
[docs] def run_cost_phase(args: argparse.Namespace, run_dir: Path | None = None) -> None: """CLI entry point for cost phase.""" if not args.base: raise SystemExit("--base is required for cost phase") if not args.swaps: raise SystemExit("--swaps is required for cost phase") base_path = Path(args.base) base_params, base_combined = _load_base_config(base_path) swaps = parse_swaps(args.swaps) print(f"[cost] Base config from {base_path}") print(f"[cost] Testing {sum(len(v) for v in swaps.values())} swaps") results = run_cost_analysis( base_params=base_params, swaps=swaps, scenario=args.scenario, n_seeds=args.sensitivity_seeds, n_periods=args.periods, n_workers=args.workers, base_combined=base_combined, ) # Print summary table print(f"\n {'Param':<25} {'Value':>8} {'Delta':>8} {'Class':>10} {'Pass%':>6}") print(" " + "-" * 59) for r in results: print( f" {r.param:<25} {r.value!s:>8} {r.delta:>+8.4f} " f"{r.classification:>10} {r.pass_rate:>5.0%}" ) # Combo grid: run all FREE+CHEAP swaps combined out = run_dir or OUTPUT_DIR out.mkdir(parents=True, exist_ok=True) if args.combo_grid: cheap_swaps: dict[str, list[Any]] = {} for r in results: if r.classification in ("FREE", "CHEAP"): cheap_swaps.setdefault(r.param, []).append(r.value) if cheap_swaps: print(f"\n[cost] Combo grid: {cheap_swaps}") # Exclude swap keys from fixed to avoid overlap ValueError fixed_for_combo = { k: v for k, v in base_params.items() if k not in cheap_swaps } combos = list(generate_combinations(cheap_swaps, fixed=fixed_for_combo)) print(f"[cost] {len(combos)} combinations") screening = run_screening( combos, args.scenario, n_workers=args.workers, n_periods=args.periods ) tiers = parse_stability_tiers(args.stability_tiers) combo_results = run_tiered_stability( screening, args.scenario, tiers=tiers, n_workers=args.workers, n_periods=args.periods, ) if combo_results: winner = combo_results[0] print(f"[cost] Combo winner: combined={winner.combined_score:.4f}") save_stability( combo_results, args.scenario, out / f"{args.scenario}_cost_combo.json", ) else: print("\n[cost] No FREE/CHEAP swaps found -- skipping combo grid") # Save save_cost_results(results, args.scenario, out / f"{args.scenario}_cost.json") print(f"\nCost results saved to {out}")