Source code for calibration.grid

"""Grid building, loading, validation, and combination generation.

This module handles parameter grid operations:
- Building focused grids from sensitivity analysis results
- Loading grids from YAML/JSON files
- Validating grid structure
- Generating and counting parameter combinations
"""

from __future__ import annotations

import json
from collections.abc import Iterator
from itertools import product
from pathlib import Path
from typing import Any

import yaml

from calibration.parameter_space import get_parameter_grid
from calibration.sensitivity import SensitivityResult


[docs] def build_focused_grid( sensitivity: SensitivityResult, full_grid: dict[str, list[Any]] | None = None, scenario: str = "baseline", sensitivity_threshold: float = 0.02, pruning_threshold: float | None = 0.04, ) -> tuple[dict[str, list[Any]], dict[str, Any]]: """Build focused grid from sensitivity analysis. Parameters ---------- sensitivity : SensitivityResult Result from run_sensitivity_analysis(). full_grid : dict, optional Full parameter grid. Defaults to scenario-specific grid. scenario : str Scenario name. sensitivity_threshold : float Minimum sensitivity (delta) for inclusion in grid search. pruning_threshold : float or None Maximum score gap from best value for keeping a grid value. ``None`` disables pruning. Returns ------- tuple[dict, dict] (grid_to_search, fixed_params) - INCLUDE params (delta > threshold): all grid values (pruned if enabled) - FIX params (delta <= threshold): fix at best value """ if full_grid is None: full_grid = get_parameter_grid(scenario) included, _ = sensitivity.get_important(sensitivity_threshold) param_best = {p.name: p.best_value for p in sensitivity.parameters} grid_to_search: dict[str, list[Any]] = {} fixed_params: dict[str, Any] = {} for name, values in full_grid.items(): if name in included: grid_to_search[name] = values else: fixed_params[name] = param_best[name] grid_to_search = sensitivity.prune_grid(grid_to_search, pruning_threshold) return grid_to_search, fixed_params
[docs] def load_grid(path: Path) -> dict[str, list[Any]]: """Load parameter grid from YAML/JSON file. Light validation: check dict-of-lists structure, warn about empty values. Supports both .yaml/.yml and .json extensions. Parameters ---------- path : Path Path to grid file. Returns ------- dict[str, list[Any]] Parameter grid (param_name -> list of values). Raises ------ ValueError If the file contents are not a dict-of-lists structure. FileNotFoundError If the file does not exist. """ path = Path(path) if not path.exists(): raise FileNotFoundError(f"Grid file not found: {path}") if path.suffix == ".json": with open(path) as f: data = json.load(f) else: with open(path) as f: data = yaml.safe_load(f) if not isinstance(data, dict): raise ValueError(f"Grid file must contain a dict, got {type(data).__name__}") # Validate and normalize: ensure all values are lists grid: dict[str, list[Any]] = {} for key, values in data.items(): if not isinstance(values, list): raise ValueError( f"Grid values for '{key}' must be a list, got {type(values).__name__}" ) grid[key] = values warnings = validate_grid(grid) for w in warnings: print(f" Warning: {w}") return grid
[docs] def validate_grid(grid: dict[str, list[Any]]) -> list[str]: """Light validation of grid structure. Parameters ---------- grid : dict[str, list[Any]] Parameter grid to validate. Returns ------- list[str] List of warnings (empty = OK). """ warnings: list[str] = [] for name, values in grid.items(): if not values: warnings.append(f"Parameter '{name}' has empty values list") elif len(values) == 1: warnings.append(f"Parameter '{name}' has only one value: {values[0]}") return warnings
[docs] def count_combinations(grid: dict[str, list[Any]]) -> int: """Count total combinations in grid. Parameters ---------- grid : dict[str, list[Any]] Parameter grid. Returns ------- int Number of combinations in the grid. """ count = 1 for values in grid.values(): count *= len(values) return count
[docs] def generate_combinations( grid: dict[str, list[Any]], fixed: dict[str, Any] | None = None, constraints: list[Any] | None = None, ) -> Iterator[dict[str, Any]]: """Generate all parameter combinations, merged with fixed params. Parameters ---------- grid : dict[str, list[Any]] Parameter grid to generate combinations from. fixed : dict, optional Fixed parameter values to merge into each combination. constraints : list[callable], optional List of callables that take a combo dict and return bool. A combination is yielded only if ALL constraints return True. Useful for coupled params (e.g., ``lambda c: c['nfpf'] >= c['nfsf']``). Yields ------ dict[str, Any] Dictionary mapping parameter names to values. """ keys = list(grid.keys()) fixed = fixed or {} if fixed: overlap = set(keys) & set(fixed.keys()) if overlap: raise ValueError(f"Grid and fixed params overlap on: {overlap}") for values in product(*grid.values()): combo = dict(zip(keys, values, strict=True)) if fixed: combo = {**fixed, **combo} if constraints: if all(fn(combo) for fn in constraints): yield combo else: yield combo