Source code for calibration.screening

"""Single-seed grid screening with progress tracking and checkpointing.

This module handles the grid screening phase of calibration: testing
many parameter combinations quickly using a single seed, with progress
reporting and checkpoint-based resumption.
"""

from __future__ import annotations

import json
import time
from concurrent.futures import ProcessPoolExecutor, as_completed
from pathlib import Path
from typing import Any

from calibration.analysis import CalibrationResult, format_eta, format_progress
from calibration.io import OUTPUT_DIR
from validation import get_validation_funcs


[docs] def screen_single_seed( params: dict[str, Any], scenario: str, seed: int, n_periods: int, ) -> tuple[CalibrationResult, float]: """Run single-seed validation for quick screening. Parameters ---------- params : dict Parameter configuration. scenario : str Scenario name. seed : int Random seed. n_periods : int Number of simulation periods. Returns ------- tuple[CalibrationResult, float] Result with single-seed metrics and elapsed wall-clock seconds. """ t0 = time.monotonic() validate, _, _, _ = get_validation_funcs(scenario) result = validate(seed=seed, n_periods=n_periods, **params) elapsed = time.monotonic() - t0 return CalibrationResult( params=params, single_score=result.total_score, n_pass=result.n_pass, n_warn=result.n_warn, n_fail=result.n_fail, ), elapsed
# ============================================================================= # Checkpointing # ============================================================================= def _params_key(params: dict[str, Any]) -> str: """Create a deterministic string key from params dict.""" return json.dumps(dict(sorted(params.items())), sort_keys=True, default=str)
[docs] def save_checkpoint( results: list[CalibrationResult], scenario: str, phase: str = "screening", ) -> Path: """Save intermediate results to a checkpoint file. Parameters ---------- results : list[CalibrationResult] Results to checkpoint. scenario : str Scenario name. phase : str Phase name for filename. Returns ------- Path Path to checkpoint file. """ OUTPUT_DIR.mkdir(exist_ok=True) path = OUTPUT_DIR / f"{scenario}_{phase}_checkpoint.json" data = [ { "params": r.params, "single_score": r.single_score, "n_pass": r.n_pass, "n_warn": r.n_warn, "n_fail": r.n_fail, "seed_scores": r.seed_scores, "seed_fails": r.seed_fails, "mean_score": r.mean_score, "std_score": r.std_score, "pass_rate": r.pass_rate, "combined_score": r.combined_score, } for r in results ] with open(path, "w") as f: json.dump(data, f, indent=2, default=str) return path
[docs] def load_checkpoint( scenario: str, phase: str = "screening", ) -> list[CalibrationResult] | None: """Load checkpoint if it exists. Parameters ---------- scenario : str Scenario name. phase : str Phase name. Returns ------- list[CalibrationResult] or None Previously checkpointed results, or None if no checkpoint. """ path = OUTPUT_DIR / f"{scenario}_{phase}_checkpoint.json" if not path.exists(): return None with open(path) as f: data = json.load(f) return [ CalibrationResult( params=d["params"], single_score=d["single_score"], n_pass=d["n_pass"], n_warn=d["n_warn"], n_fail=d["n_fail"], seed_scores=d.get("seed_scores"), seed_fails=d.get("seed_fails"), mean_score=d.get("mean_score"), std_score=d.get("std_score"), pass_rate=d.get("pass_rate"), combined_score=d.get("combined_score"), ) for d in data ]
[docs] def delete_checkpoint(scenario: str, phase: str = "screening") -> None: """Delete checkpoint file if it exists.""" path = OUTPUT_DIR / f"{scenario}_{phase}_checkpoint.json" if path.exists(): path.unlink()
# ============================================================================= # Grid screening with progress and checkpointing # =============================================================================
[docs] def run_screening( combinations: list[dict[str, Any]], scenario: str, n_workers: int = 10, n_periods: int = 1000, avg_time_per_run: float = 0.0, checkpoint_every: int = 50, resume: bool = False, ) -> list[CalibrationResult]: """Screen parameter combinations with progress tracking and checkpointing. Parameters ---------- combinations : list[dict] Parameter combinations to test. scenario : str Scenario name. n_workers : int Parallel workers. n_periods : int Simulation periods. avg_time_per_run : float Estimated time per run (from sensitivity). 0 = measure during warmup. checkpoint_every : int Save checkpoint every N completions. resume : bool If True, load checkpoint and skip already-evaluated configs. Returns ------- list[CalibrationResult] Results sorted by single_score (best first). """ total = len(combinations) results: list[CalibrationResult] = [] done_keys: set[str] = set() # Resume from checkpoint if available if resume: checkpoint = load_checkpoint(scenario, "screening") if checkpoint: results = checkpoint done_keys = {_params_key(r.params) for r in results} print(f" Resumed from checkpoint: {len(results)} already evaluated") remaining_combos = [c for c in combinations if _params_key(c) not in done_keys] if not remaining_combos: print(" All combinations already evaluated") results.sort(key=lambda r: r.single_score, reverse=True) return results completed = len(results) run_times: list[float] = [] if n_workers > 1: with ProcessPoolExecutor(max_workers=n_workers) as executor: futures = { executor.submit(screen_single_seed, p, scenario, 0, n_periods): p for p in remaining_combos } for future in as_completed(futures): result, elapsed = future.result() run_times.append(elapsed) results.append(result) completed += 1 # Compute ETA if avg_time_per_run > 0: est_time = avg_time_per_run elif len(run_times) >= 5: est_time = sum(run_times) / len(run_times) else: est_time = 0.0 remaining = total - completed eta = format_eta(remaining, est_time, n_workers) print(f" {format_progress(completed, total, remaining, eta)}") # Checkpoint periodically if checkpoint_every > 0 and completed % checkpoint_every == 0: save_checkpoint(results, scenario, "screening") else: for p in remaining_combos: result, elapsed = screen_single_seed(p, scenario, 0, n_periods) run_times.append(elapsed) results.append(result) completed += 1 if avg_time_per_run > 0: est_time = avg_time_per_run elif len(run_times) >= 5: est_time = sum(run_times) / len(run_times) else: est_time = 0.0 remaining = total - completed eta = format_eta(remaining, est_time, n_workers) print(f" {format_progress(completed, total, remaining, eta)}") if checkpoint_every > 0 and completed % checkpoint_every == 0: save_checkpoint(results, scenario, "screening") results.sort(key=lambda r: r.single_score, reverse=True) # Clean up checkpoint on completion delete_checkpoint(scenario, "screening") return results