"""
Simulation results container for BAM Engine.
This module provides the SimulationResults class that encapsulates
simulation output data and provides convenient methods for data access
and export to pandas DataFrames.
Note: pandas is an optional dependency. It is only required when using
DataFrame export methods (to_dataframe, get_role_data, economy_metrics, summary).
Install with: pip install bamengine[pandas] or pip install pandas
"""
from __future__ import annotations
from collections import defaultdict
from collections.abc import Callable
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Literal, cast
import numpy as np
from numpy.typing import NDArray
if TYPE_CHECKING: # pragma: no cover
from pandas import DataFrame
from bamengine.simulation import Simulation
def _import_pandas() -> Any:
"""
Lazily import pandas with helpful error message if not installed.
Returns
-------
module
The pandas module.
Raises
------
ImportError
If pandas is not installed.
"""
try:
import pandas as pd
return pd
except ImportError: # pragma: no cover
raise ImportError(
"pandas is required for DataFrame export methods. "
"Install it with: pip install pandas"
) from None
class _DataCollector:
"""
Internal helper to collect data during simulation.
This class captures per-period snapshots of role and economy data
during simulation execution. It's used by Simulation.run() when
collect=True or collect={...} is specified.
Parameters
----------
variables : dict
Mapping of role/component name to variables to capture.
Keys are role names (e.g., 'Producer', 'Worker') or 'Economy'.
Values are either:
- list[str]: specific variables to capture
- True: capture all variables for that role/component
aggregate : str or None, default=None
Aggregation method ('mean', 'median', 'sum', 'std') or None for full data.
capture_after : str or None, default=None
Default event name after which to capture data. If None, captures
at end of period (after all events).
capture_timing : dict or None, default=None
Per-variable capture timing overrides. Maps "RoleName.var_name" to
event name. Variables not in this dict use capture_after default.
Examples
--------
Collect all variables from Producer and Worker, economy metrics:
>>> collector = _DataCollector(
... variables={"Producer": True, "Worker": True, "Economy": True},
... aggregate="mean",
... )
Collect specific variables with custom capture timing:
>>> collector = _DataCollector(
... variables={"Producer": ["production"], "Worker": ["employed", "wage"]},
... aggregate=None,
... capture_after="firms_update_net_worth", # Default capture event
... capture_timing={
... "Producer.production": "firms_run_production", # Before bankruptcy
... "Worker.wage": "workers_receive_wage",
... },
... )
"""
# Available economy metrics (unemployment_rate removed - calculate from Worker.employed)
ECONOMY_METRICS = [
"avg_price",
"inflation",
"n_firm_bankruptcies",
"n_bank_bankruptcies",
]
def __init__(
self,
variables: dict[str, list[str] | Literal[True]],
aggregate: str | None = None,
capture_after: str | None = None,
capture_timing: dict[str, str] | None = None,
) -> None:
self.variables = variables
self.aggregate = aggregate
self.capture_after = capture_after
self.capture_timing = capture_timing or {}
# Storage: role_data[role_name][var_name] = list of arrays/scalars
self.role_data: dict[str, dict[str, list[Any]]] = defaultdict(
lambda: defaultdict(list)
)
self.economy_data: dict[str, list[float]] = defaultdict(list)
# Storage for relationship data: rel_data[rel_name][var_name] = list
self.relationship_data: dict[str, dict[str, list[Any]]] = defaultdict(
lambda: defaultdict(list)
)
# Track which variables have been captured this period
self._captured_this_period: set[str] = set()
# Flag to indicate if timed capture is active
self._use_timed_capture = bool(capture_after or capture_timing)
# Cache for relationship names (populated on first use)
self._relationship_names: set[str] | None = None
def setup_pipeline_callbacks(self, pipeline: Any) -> None:
"""
Register capture callbacks with the pipeline for timed data capture.
This method groups variables by their capture event and registers
callbacks that will fire after each relevant event during pipeline
execution.
Parameters
----------
pipeline : Pipeline
The pipeline to register callbacks with.
Notes
-----
This method should be called before starting the simulation run.
The callbacks will capture data at the appropriate events, and
`capture_remaining()` should be called at end-of-period to capture
any variables that weren't captured by callbacks.
"""
from bamengine.core import Pipeline
if not isinstance(pipeline, Pipeline):
raise TypeError(f"Expected Pipeline, got {type(pipeline)}")
# Group variables by their capture event
# Each entry is (name, var_name, is_relationship)
event_to_vars: dict[str, list[tuple[str, str, bool]]] = defaultdict(list)
for name, var_spec in self.variables.items():
is_rel = self._is_relationship(name)
if name == "Economy":
# Economy metrics: check capture_timing first, then capture_after
if var_spec is True:
vars_to_capture = self.ECONOMY_METRICS
else:
vars_to_capture = var_spec
for var_name in vars_to_capture:
key = f"Economy.{var_name}"
event = self.capture_timing.get(key, self.capture_after)
if event:
event_to_vars[event].append(("Economy", var_name, False))
else:
# Role or Relationship data
try:
# Can't check variables until we have sim, so skip validation
if var_spec is True:
# Will capture all at runtime
if self.capture_after:
event_to_vars[self.capture_after].append(
(name, "*", is_rel)
)
else:
for var_name in var_spec:
key = f"{name}.{var_name}"
event = self.capture_timing.get(key, self.capture_after)
if event:
event_to_vars[event].append((name, var_name, is_rel))
except Exception: # pragma: no cover
pass # Will capture at end-of-period
# Register callbacks for each event
for event_name, vars_list in event_to_vars.items():
# Create callback with closure over vars_list
def make_callback(
vars_to_capture: list[tuple[str, str, bool]],
) -> Callable[[Simulation], None]:
def callback(sim: Simulation) -> None:
for name, var_name, is_rel in vars_to_capture:
if var_name == "*": # pragma: no cover
# Capture all variables from this role/relationship
if is_rel:
self._capture_relationship_all(sim, name)
else:
self._capture_role_all(sim, name)
elif name == "Economy":
self._capture_economy_single(sim, var_name)
elif is_rel:
self._capture_relationship_single(sim, name, var_name)
else:
self._capture_role_single(sim, name, var_name)
return callback
pipeline.register_after_event(event_name, make_callback(vars_list))
def _capture_role_single(
self, sim: Simulation, role_name: str, var_name: str
) -> None:
"""Capture a single variable from a role."""
key = f"{role_name}.{var_name}"
if key in self._captured_this_period:
return # Already captured
try:
role = sim.get_role(role_name)
except KeyError:
return
if not hasattr(role, var_name):
return
data = getattr(role, var_name)
if not isinstance(data, np.ndarray):
return
# Apply aggregation if requested
if self.aggregate:
if self.aggregate == "mean":
value = float(np.mean(data))
elif self.aggregate == "median":
value = float(np.median(data))
elif self.aggregate == "sum":
value = float(np.sum(data))
elif self.aggregate == "std":
value = float(np.std(data))
else:
value = float(np.mean(data)) # fallback
self.role_data[role_name][var_name].append(value)
else:
# Store full array (copy to avoid mutation issues)
self.role_data[role_name][var_name].append(data.copy())
self._captured_this_period.add(key)
def _capture_role_all(self, sim: Simulation, role_name: str) -> None:
"""Capture all variables from a role."""
try:
role = sim.get_role(role_name)
except KeyError:
return
var_names = [f for f in role.__dataclass_fields__ if not f.startswith("_")]
for var_name in var_names:
self._capture_role_single(sim, role_name, var_name)
def _capture_economy_single(self, sim: Simulation, metric_name: str) -> None:
"""Capture a single economy metric."""
key = f"Economy.{metric_name}"
if key in self._captured_this_period:
return # Already captured
ec = sim.ec
# History-based metrics (take last value from history array)
history_sources = {
"avg_price": ec.avg_mkt_price_history,
"inflation": ec.inflation_history,
}
if metric_name in history_sources:
history = history_sources[metric_name]
if len(history) > 0:
self.economy_data[metric_name].append(float(history[-1]))
self._captured_this_period.add(key)
return
# Transient metrics (capture current value, not from history)
if metric_name == "n_firm_bankruptcies":
self.economy_data[metric_name].append(len(ec.exiting_firms))
self._captured_this_period.add(key)
elif metric_name == "n_bank_bankruptcies":
self.economy_data[metric_name].append(len(ec.exiting_banks))
self._captured_this_period.add(key)
def _is_relationship(self, name: str) -> bool:
"""Check if a name refers to a registered relationship."""
if self._relationship_names is None:
from bamengine.core.registry import list_relationships
self._relationship_names = set(list_relationships())
return name in self._relationship_names
def _capture_relationship_single(
self, sim: Simulation, rel_name: str, field_name: str
) -> None:
"""Capture a single field from a relationship."""
key = f"{rel_name}.{field_name}"
if key in self._captured_this_period:
return # Already captured
try:
rel = sim.get_relationship(rel_name)
except KeyError:
return
if not hasattr(rel, field_name):
return
data = getattr(rel, field_name)
if not isinstance(data, np.ndarray):
return
# Slice to only valid entries (up to rel.size)
valid_data = data[: rel.size]
# Apply aggregation if requested
if self.aggregate:
if len(valid_data) == 0:
# Empty relationship, store NaN or 0
value = 0.0
elif self.aggregate == "mean":
value = float(np.mean(valid_data))
elif self.aggregate == "median":
value = float(np.median(valid_data))
elif self.aggregate == "sum":
value = float(np.sum(valid_data))
elif self.aggregate == "std":
value = float(np.std(valid_data))
else:
value = float(np.mean(valid_data)) # fallback
self.relationship_data[rel_name][field_name].append(value)
else:
# Store full array (copy to avoid mutation issues)
self.relationship_data[rel_name][field_name].append(valid_data.copy())
self._captured_this_period.add(key)
def _capture_relationship_all(self, sim: Simulation, rel_name: str) -> None:
"""Capture all fields from a relationship."""
try:
rel = sim.get_relationship(rel_name)
except KeyError:
return
# Get edge-specific fields (exclude base fields)
base_fields = {"source_ids", "target_ids", "size", "capacity"}
fields = getattr(rel, "__dataclass_fields__", {})
field_names = [
f for f in fields if f not in base_fields and not f.startswith("_")
]
for field_name in field_names:
self._capture_relationship_single(sim, rel_name, field_name)
def _capture_relationship(
self, sim: Simulation, rel_name: str, var_spec: list[str] | Literal[True]
) -> None:
"""Capture data from a relationship."""
if var_spec is True:
self._capture_relationship_all(sim, rel_name)
else:
for field_name in var_spec:
self._capture_relationship_single(sim, rel_name, field_name)
def capture_remaining(self, sim: Simulation) -> None:
"""
Capture any variables not yet captured this period.
This is called at the end of each period to capture variables that
weren't captured by timed callbacks (either because they have no
capture_timing specified, or timed capture is not being used).
After capturing, resets the captured tracking set for the next period.
Parameters
----------
sim : Simulation
Simulation instance to capture data from.
"""
for name, var_spec in self.variables.items():
if name == "Economy":
# Capture remaining economy metrics
if var_spec is True:
metrics = self.ECONOMY_METRICS
else:
metrics = var_spec
for metric in metrics:
key = f"Economy.{metric}"
if key not in self._captured_this_period:
self._capture_economy_single(sim, metric)
elif self._is_relationship(name):
# Capture remaining relationship fields
if var_spec is True:
self._capture_relationship_all(sim, name)
else:
for field_name in var_spec:
key = f"{name}.{field_name}"
if key not in self._captured_this_period:
self._capture_relationship_single(sim, name, field_name)
else:
# Capture remaining role variables
if var_spec is True:
self._capture_role_all(sim, name)
else:
for var_name in var_spec:
key = f"{name}.{var_name}"
if key not in self._captured_this_period: # pragma: no cover
self._capture_role_single(sim, name, var_name)
# Reset for next period
self._captured_this_period.clear()
def capture(self, sim: Simulation) -> None:
"""
Capture one period of data from simulation.
This is the original capture method for non-timed capture (when
capture_after and capture_timing are not specified). All data is
captured at the same point (end of period).
For timed capture (when capture_after or capture_timing are specified),
use `setup_pipeline_callbacks()` before the run and `capture_remaining()`
at the end of each period instead.
Parameters
----------
sim : Simulation
Simulation instance to capture data from.
"""
for name, var_spec in self.variables.items():
if name == "Economy":
# Handle Economy as a pseudo-role
self._capture_economy(sim, var_spec)
elif self._is_relationship(name):
# Handle relationships
self._capture_relationship(sim, name, var_spec)
else:
# Handle regular roles
self._capture_role(sim, name, var_spec)
# Clear tracking for next period
self._captured_this_period.clear()
def _capture_role(
self, sim: Simulation, role_name: str, var_spec: list[str] | Literal[True]
) -> None:
"""Capture data from a single role."""
try:
role = sim.get_role(role_name)
except KeyError:
return
# Determine which variables to capture
if var_spec is True:
# Capture all public fields (those not starting with underscore)
var_names = [f for f in role.__dataclass_fields__ if not f.startswith("_")]
else:
var_names = var_spec
for var_name in var_names:
if not hasattr(role, var_name):
continue
data = getattr(role, var_name)
if not isinstance(data, np.ndarray):
continue
# Apply aggregation if requested
if self.aggregate:
if self.aggregate == "mean":
value = float(np.mean(data))
elif self.aggregate == "median":
value = float(np.median(data))
elif self.aggregate == "sum":
value = float(np.sum(data))
elif self.aggregate == "std":
value = float(np.std(data))
else:
value = float(np.mean(data)) # fallback
self.role_data[role_name][var_name].append(value)
else:
# Store full array (copy to avoid mutation issues)
self.role_data[role_name][var_name].append(data.copy())
def _capture_economy(
self, sim: Simulation, var_spec: list[str] | Literal[True]
) -> None:
"""Capture economy metrics."""
ec = sim.ec
# Determine which metrics to capture
if var_spec is True:
metrics_to_capture = self.ECONOMY_METRICS
else:
metrics_to_capture = var_spec
# History-based metrics (take last value from history array)
history_sources = {
"avg_price": ec.avg_mkt_price_history,
"inflation": ec.inflation_history,
}
for metric_name in metrics_to_capture:
if metric_name in history_sources:
history = history_sources[metric_name]
if len(history) > 0:
self.economy_data[metric_name].append(float(history[-1]))
elif metric_name == "n_firm_bankruptcies":
self.economy_data[metric_name].append(len(ec.exiting_firms))
elif metric_name == "n_bank_bankruptcies":
self.economy_data[metric_name].append(len(ec.exiting_banks))
def finalize(
self, config: dict[str, Any], metadata: dict[str, Any]
) -> SimulationResults:
"""
Convert collected data to SimulationResults.
Parameters
----------
config : dict
Simulation configuration parameters.
metadata : dict
Run metadata (n_periods, seed, runtime, etc.).
Returns
-------
SimulationResults
Results container with collected data as NumPy arrays.
"""
# Convert role data lists to arrays
final_role_data: dict[str, dict[str, NDArray[Any]]] = {}
for role_name, role_vars in self.role_data.items():
final_role_data[role_name] = {}
for var_name, data_list in role_vars.items():
if not data_list:
continue
if self.aggregate:
# List of scalars -> 1D array
final_role_data[role_name][var_name] = np.array(data_list)
else:
# List of arrays -> 2D array (n_periods, n_agents)
final_role_data[role_name][var_name] = np.stack(data_list, axis=0)
# Convert economy data lists to arrays
final_economy_data: dict[str, NDArray[Any]] = {}
for metric_name, data_list in self.economy_data.items():
if data_list:
final_economy_data[metric_name] = np.array(data_list)
# Convert relationship data lists to arrays or keep as list
final_relationship_data: dict[
str, dict[str, NDArray[Any] | list[NDArray[Any]]]
] = {}
for rel_name, rel_vars in self.relationship_data.items():
final_relationship_data[rel_name] = {}
for field_name, data_list in rel_vars.items():
if not data_list:
continue
if self.aggregate:
# List of scalars -> 1D array
final_relationship_data[rel_name][field_name] = np.array(data_list)
else:
# List of variable-length arrays -> keep as list
# (cannot stack into 2D because edge counts vary per period)
final_relationship_data[rel_name][field_name] = data_list
return SimulationResults(
role_data=final_role_data,
economy_data=final_economy_data,
relationship_data=final_relationship_data,
config=config,
metadata=metadata,
)
class _Namespace:
"""Lightweight read-only proxy for attribute-style access to results data.
Returned by ``SimulationResults.__getattr__`` for role, economy, and
relationship namespaces. Supports tab-completion in IPython/Jupyter.
"""
__slots__ = ("_data", "_name")
def __init__(self, data: dict[str, Any], name: str) -> None:
object.__setattr__(self, "_data", data)
object.__setattr__(self, "_name", name)
def __getattr__(self, var_name: str) -> Any:
data = object.__getattribute__(self, "_data")
if var_name in data:
return data[var_name]
name = object.__getattribute__(self, "_name")
available = sorted(data.keys())
raise AttributeError(
f"'{var_name}' not found in {name}. Available: {', '.join(available)}"
)
def __dir__(self) -> list[str]:
return sorted(object.__getattribute__(self, "_data").keys())
def __repr__(self) -> str:
name = object.__getattribute__(self, "_name")
data = object.__getattribute__(self, "_data")
vars_str = ", ".join(sorted(data.keys()))
return f"Namespace({name}: {vars_str})"
[docs]
@dataclass
class SimulationResults:
"""
Container for simulation results with convenient data access methods.
This class is returned by Simulation.run() and provides structured
access to simulation data, including time series of role states,
economy-wide metrics, relationship edge data, and metadata about the
simulation run.
Attributes
----------
role_data : dict
Time series data for each role, keyed by role name.
Each value is a dict of arrays with shape (n_periods, n_agents).
economy_data : dict
Time series of economy-wide metrics with shape (n_periods,).
relationship_data : dict
Time series data for each relationship, keyed by relationship name.
Each value is a dict of arrays. When aggregated, arrays have shape
(n_periods,). When not aggregated, values are lists of variable-length
arrays (one per period).
config : dict
Configuration parameters used for this simulation.
metadata : dict
Run metadata (seed, runtime, n_periods, etc.).
Examples
--------
>>> sim = bam.Simulation.init(n_firms=100, seed=42)
>>> results = sim.run(n_periods=100, collect=True)
>>> # Get all data as DataFrame
>>> df = results.to_dataframe()
>>> # Get specific role data
>>> prod_df = results.get_role_data("Producer")
>>> # Access economy metrics directly
>>> unemployment = results.economy_data["unemployment_rate"]
>>> # Access relationship data (when collected)
>>> results = sim.run(n_periods=100, collect={"LoanBook": True, "aggregate": "sum"})
>>> total_principal = results.relationship_data["LoanBook"]["principal"]
"""
role_data: dict[str, dict[str, NDArray[Any]]] = field(default_factory=dict)
"""Per-role time series data keyed by role name then variable name."""
economy_data: dict[str, NDArray[Any]] = field(default_factory=dict)
"""Economy-wide metric time series keyed by metric name."""
relationship_data: dict[str, dict[str, NDArray[Any] | list[NDArray[Any]]]] = field(
default_factory=dict
)
"""Per-relationship time series data keyed by relationship name then field name."""
config: dict[str, Any] = field(default_factory=dict)
"""Configuration parameters used for this simulation run."""
metadata: dict[str, Any] = field(default_factory=dict)
"""Run metadata including seed, runtime, and period count."""
[docs]
def to_dataframe(
self,
roles: list[str] | None = None,
variables: list[str] | None = None,
include_economy: bool = True,
aggregate: str | None = None,
relationships: list[str] | None = None,
) -> DataFrame:
"""
Export results to a pandas DataFrame.
Parameters
----------
roles : list of str, optional
Specific roles to include. If None, includes all roles.
variables : list of str, optional
Specific variables to include. If None, includes all variables.
include_economy : bool, default=True
Whether to include economy-wide metrics.
aggregate : {'mean', 'median', 'sum', 'std'}, optional
How to aggregate agent-level data. If None, returns all agents.
relationships : list of str, optional
Specific relationships to include. If None, includes all relationships
with aggregated data. Relationships with non-aggregated data (list of
arrays) are skipped with a warning.
Returns
-------
pd.DataFrame
DataFrame with simulation results. Index is period number.
Columns depend on parameters and aggregation method.
Raises
------
ImportError
If pandas is not installed.
Examples
--------
# Get everything
>>> df = results.to_dataframe()
# Get only Producer price and inventory, averaged
>>> df = results.to_dataframe(
... roles=["Producer"], variables=["price", "inventory"], aggregate="mean"
... )
# Get only economy metrics
>>> df = results.to_dataframe(include_economy=True, roles=[])
# Include relationship data
>>> df = results.to_dataframe(relationships=["LoanBook"])
"""
pd = _import_pandas()
import warnings
dfs = []
# Add role data
if roles is None:
roles = list(self.role_data.keys())
for role_name in roles:
if role_name not in self.role_data:
continue
role_dict = self.role_data[role_name]
for var_name, data in role_dict.items():
if variables and var_name not in variables:
continue
# Handle both 1D (already aggregated) and 2D (per-agent) data
if data.ndim == 1:
# Data is already 1D (aggregated during collection)
df = pd.DataFrame({f"{role_name}.{var_name}": data})
dfs.append(df)
elif aggregate:
if data.ndim > 2:
# Skip 3D+ data — can't aggregate to 1D column
continue
# 2D data, aggregate across agents (axis=1)
if aggregate == "mean":
agg_data = np.mean(data, axis=1)
elif aggregate == "median":
agg_data = np.median(data, axis=1)
elif aggregate == "sum":
agg_data = np.sum(data, axis=1)
elif aggregate == "std":
agg_data = np.std(data, axis=1)
else:
raise ValueError(f"Unknown aggregation method: {aggregate}")
df = pd.DataFrame({f"{role_name}.{var_name}.{aggregate}": agg_data})
dfs.append(df)
else:
if data.ndim == 2:
# 2D data, return all agents
_n_periods, n_agents = data.shape
columns = {
f"{role_name}.{var_name}.{i}": data[:, i]
for i in range(n_agents)
}
df = pd.DataFrame(columns)
dfs.append(df)
# Skip 3D+ data (e.g. job_apps_targets) — not DataFrame-friendly
# Add relationship data
if relationships is None:
relationships = list(self.relationship_data.keys())
for rel_name in relationships:
if rel_name not in self.relationship_data:
continue
rel_dict = self.relationship_data[rel_name]
for var_name, rel_data in rel_dict.items():
if variables and var_name not in variables:
continue
# Check if data is a list (non-aggregated variable-length arrays)
if isinstance(rel_data, list):
warnings.warn(
f"Relationship '{rel_name}.{var_name}' has non-aggregated "
f"variable-length data and cannot be included in DataFrame. "
f"Access it directly via results.relationship_data['{rel_name}']"
f"['{var_name}'] or use aggregation during collection.",
stacklevel=2,
)
continue
# Data is already 1D (aggregated during collection)
df = pd.DataFrame({f"{rel_name}.{var_name}": rel_data})
dfs.append(df)
# Add economy data
if include_economy and self.economy_data:
econ_df = pd.DataFrame(self.economy_data)
dfs.append(econ_df)
# Combine all DataFrames
if not dfs:
return cast("DataFrame", pd.DataFrame())
result = pd.concat(dfs, axis=1)
result.index.name = "period"
return cast("DataFrame", result)
[docs]
def get_role_data(self, role_name: str, aggregate: str | None = None) -> DataFrame:
"""
Get data for a specific role as a DataFrame.
Parameters
----------
role_name : str
Name of the role (e.g., 'Producer', 'Worker').
aggregate : {'mean', 'median', 'sum', 'std'}, optional
How to aggregate across agents.
Returns
-------
pd.DataFrame
DataFrame with the role's time series data.
Raises
------
ImportError
If pandas is not installed.
Examples
--------
>>> prod_df = results.get_role_data("Producer")
>>> prod_mean = results.get_role_data("Producer", aggregate="mean")
"""
return self.to_dataframe(
roles=[role_name], include_economy=False, aggregate=aggregate
)
[docs]
def get_relationship_data(
self, rel_name: str, aggregate: str | None = None
) -> DataFrame:
"""
Get data for a specific relationship as a DataFrame.
Parameters
----------
rel_name : str
Name of the relationship (e.g., 'LoanBook').
aggregate : {'mean', 'median', 'sum', 'std'}, optional
How to aggregate (only used if data needs re-aggregation).
Returns
-------
pd.DataFrame
DataFrame with the relationship's time series data.
Raises
------
ImportError
If pandas is not installed.
Notes
-----
If the relationship data was collected without aggregation
(variable-length arrays per period), this method will issue a
warning and return an empty DataFrame. Use
``results.relationship_data[rel_name]`` directly for such data.
Examples
--------
>>> loans_df = results.get_relationship_data("LoanBook")
"""
return self.to_dataframe(
roles=[],
relationships=[rel_name],
include_economy=False,
aggregate=aggregate,
)
@property
def economy_metrics(self) -> DataFrame:
"""
Get economy-wide metrics as a DataFrame.
Returns
-------
pd.DataFrame
DataFrame with economy time series (unemployment rate, GDP, etc.).
Raises
------
ImportError
If pandas is not installed.
Examples
--------
>>> econ_df = results.economy_metrics
>>> econ_df[["unemployment_rate", "avg_price"]].plot()
"""
pd = _import_pandas()
if not self.economy_data:
return cast("DataFrame", pd.DataFrame())
df = pd.DataFrame(self.economy_data)
df.index.name = "period"
return cast("DataFrame", df)
@property
def data(self) -> dict[str, dict[str, NDArray[Any] | list[NDArray[Any]]]]:
"""
Unified access to all data (roles + economy + relationships).
Economy data is accessible under the "Economy" key.
Relationship data is merged with role data (relationships have
unique names so no conflicts).
Returns
-------
dict
Combined role, economy, and relationship data. Keys are role names,
relationship names, and "Economy" for economy metrics.
Examples
--------
>>> results.data["Producer"]["price"]
>>> results.data["Economy"]["unemployment_rate"]
>>> results.data["LoanBook"]["principal"] # if collected
"""
combined: dict[str, dict[str, NDArray[Any] | list[NDArray[Any]]]] = {}
# Add role data (NDArray values are compatible with the union type)
for role_name, role_dict in self.role_data.items():
combined[role_name] = cast(
dict[str, NDArray[Any] | list[NDArray[Any]]], role_dict
)
if self.economy_data:
combined["Economy"] = cast(
dict[str, NDArray[Any] | list[NDArray[Any]]], self.economy_data
)
# Add relationship data (already has the right type)
for rel_name, rel_dict in self.relationship_data.items():
combined[rel_name] = rel_dict
return combined
[docs]
def get(
self,
name: str,
variable_name: str,
aggregate: str | None = None,
) -> NDArray[Any] | list[NDArray[Any]]:
"""
Get a variable as a numpy array.
This provides a convenient way to access simulation data without
needing to navigate nested dictionaries.
Parameters
----------
name : str
Role, relationship, or "Economy" name (e.g., "Producer",
"LoanBook", "Economy").
variable_name : str
Variable name ("price", "principal", "unemployment_rate", etc.)
aggregate : {'mean', 'sum', 'std', 'median'}, optional
Aggregation method for 2D data. If provided, reduces
(n_periods, n_agents) to (n_periods,).
Returns
-------
NDArray or list[NDArray]
1D array (n_periods,), 2D array (n_periods, n_agents), or
list of arrays for non-aggregated relationship data.
Raises
------
KeyError
If role/relationship or variable not found.
Examples
--------
>>> productivity = results.get("Producer", "labor_productivity")
>>> avg_prod = results.get("Producer", "labor_productivity", aggregate="mean")
>>> unemployment = results.get("Economy", "unemployment_rate")
>>> total_principal = results.get("LoanBook", "principal")
"""
# Handle Economy data specially
if name == "Economy":
if variable_name not in self.economy_data:
available = list(self.economy_data.keys())
raise KeyError(
f"'{variable_name}' not found in Economy. Available: {available}"
)
return self.economy_data[variable_name]
# Check role data first
if name in self.role_data:
role_dict = self.role_data[name]
if variable_name not in role_dict:
available = list(role_dict.keys())
raise KeyError(
f"'{variable_name}' not found in {name}. Available: {available}"
)
data = role_dict[variable_name]
# Apply aggregation if requested and data is 2D
if aggregate and data.ndim == 2:
AggFunc = Callable[[NDArray[Any], int], NDArray[Any]]
agg_funcs: dict[str, AggFunc] = {
"mean": np.mean,
"sum": np.sum,
"std": np.std,
"median": np.median,
}
if aggregate not in agg_funcs:
raise ValueError(
f"Unknown aggregation '{aggregate}'. "
f"Use one of: {list(agg_funcs.keys())}"
)
return agg_funcs[aggregate](data, 1)
return data
# Check relationship data
if name in self.relationship_data:
rel_dict = self.relationship_data[name]
if variable_name not in rel_dict:
available = list(rel_dict.keys())
raise KeyError(
f"'{variable_name}' not found in {name}. Available: {available}"
)
rel_data = rel_dict[variable_name]
# Relationship data is either 1D (aggregated) or list of arrays
# No additional aggregation is applied here (already done during collection)
return rel_data
# Not found in role_data or relationship_data
available_roles = list(self.role_data.keys())
available_rels = list(self.relationship_data.keys())
# For backwards compatibility, use "Role" in error message
raise KeyError(
f"Role '{name}' not found. Available: {available_roles + available_rels}"
)
[docs]
def get_array(
self,
name: str,
variable_name: str,
aggregate: str | None = None,
) -> NDArray[Any] | list[NDArray[Any]]:
"""Deprecated: use ``get()`` instead."""
import warnings
warnings.warn(
"get_array() is deprecated, use get() instead",
DeprecationWarning,
stacklevel=2,
)
return self.get(name, variable_name, aggregate=aggregate)
@property
def summary(self) -> DataFrame:
"""
Get summary statistics for key metrics.
Returns
-------
pd.DataFrame
Summary statistics (mean, std, min, max) for key variables.
Raises
------
ImportError
If pandas is not installed.
Examples
--------
>>> print(results.summary)
"""
# Get aggregated data (this will call _import_pandas via to_dataframe)
df = self.to_dataframe(aggregate="mean")
# Compute summary statistics
summary = df.describe().T
# Add additional statistics if useful
summary["cv"] = summary["std"] / summary["mean"] # Coefficient of variation
return summary
[docs]
def save(self, filepath: str) -> None:
"""
Save results to disk (HDF5 or pickle format).
Parameters
----------
filepath : str
Path to save file. Use .h5 for HDF5, .pkl for pickle.
Examples
--------
>>> results.save("results.h5")
>>> results.save("results.pkl")
"""
# Implementation would use pandas HDFStore or pickle
# This is a placeholder for the interface
raise NotImplementedError("Save functionality not yet implemented")
[docs]
@classmethod
def load(cls, filepath: str) -> SimulationResults:
"""
Load results from disk.
Parameters
----------
filepath : str
Path to saved results file.
Returns
-------
SimulationResults
Loaded results object.
Examples
--------
>>> results = SimulationResults.load("results.h5")
"""
# Implementation would use pandas HDFStore or pickle
# This is a placeholder for the interface
raise NotImplementedError("Load functionality not yet implemented")
[docs]
def __repr__(self) -> str:
"""String representation showing summary information."""
n_periods = self.metadata.get("n_periods", 0)
n_firms = self.metadata.get("n_firms", 0)
n_households = self.metadata.get("n_households", 0)
roles_str = ", ".join(self.role_data.keys()) if self.role_data else "None"
rels_str = (
", ".join(self.relationship_data.keys())
if self.relationship_data
else "None"
)
return (
f"SimulationResults("
f"periods={n_periods}, "
f"firms={n_firms}, "
f"households={n_households}, "
f"roles=[{roles_str}], "
f"relationships=[{rels_str}])"
)
def _all_names(self) -> list[str]:
"""Return sorted list of all available names (roles, Economy, relationships)."""
names = sorted(self.role_data.keys())
if self.economy_data:
names.append("Economy")
names.extend(sorted(self.relationship_data.keys()))
return names
[docs]
def __getitem__(self, key: str) -> Any:
"""Access data via flat 'Name.variable' key."""
if "." not in key:
available_fn = getattr(self.__class__, "available", None)
available = available_fn(self) if available_fn is not None else []
matching = [k for k in available if k.startswith(f"{key}.")]
if matching:
raise KeyError(
f"Use '{key}.variable_name' format. "
f"Available: {', '.join(matching)}"
)
raise KeyError(
f"'{key}' not found. Use 'Name.variable' format. "
f"Available names: {', '.join(self._all_names())}"
)
name, var_name = key.split(".", 1)
if name in self.role_data:
if var_name in self.role_data[name]:
return self.role_data[name][var_name]
available = sorted(self.role_data[name].keys())
raise KeyError(
f"'{var_name}' not found in {name}. Available: {', '.join(available)}"
)
if name == "Economy":
if var_name in self.economy_data:
return self.economy_data[var_name]
available = sorted(self.economy_data.keys())
raise KeyError(
f"'{var_name}' not found in Economy. Available: {', '.join(available)}"
)
if name in self.relationship_data:
if var_name in self.relationship_data[name]:
return self.relationship_data[name][var_name]
available = sorted(self.relationship_data[name].keys())
raise KeyError(
f"'{var_name}' not found in {name}. Available: {', '.join(available)}"
)
raise KeyError(f"'{name}' not found. Available: {', '.join(self._all_names())}")
[docs]
def __getattr__(self, name: str) -> _Namespace:
"""Attribute-style access to collected data namespaces."""
# Guard: don't intercept private attrs or dataclass fields during init
if name.startswith("_") or name in self.__dataclass_fields__:
raise AttributeError(name)
if name in self.role_data:
return _Namespace(self.role_data[name], name)
if name == "Economy" and self.economy_data:
return _Namespace(self.economy_data, "Economy")
if name in self.relationship_data:
return _Namespace(self.relationship_data[name], name)
available = self._all_names()
raise AttributeError(
f"'{name}' was not collected. Available: {', '.join(available)}"
)
[docs]
def __dir__(self) -> list[str]:
"""List available attributes including collected data names."""
attrs = list(super().__dir__())
attrs.extend(self._all_names())
return sorted(set(attrs))
[docs]
def available(self) -> list[str]:
"""List all collected data as 'Name.variable' strings.
Returns
-------
list[str]
Sorted list of available data keys.
Examples
--------
>>> results.available()
['Economy.avg_price', 'Economy.inflation', 'Producer.production', ...]
"""
keys: list[str] = []
for role_name, role_vars in self.role_data.items():
for var_name in role_vars:
keys.append(f"{role_name}.{var_name}")
for metric_name in self.economy_data:
keys.append(f"Economy.{metric_name}")
for rel_name, rel_vars in self.relationship_data.items():
for var_name in rel_vars:
keys.append(f"{rel_name}.{var_name}")
return sorted(keys)