from typing import Any, Dict, List, Optional, Union, TYPE_CHECKING, Type
import warnings
import polars as pl
import numpy as np
if TYPE_CHECKING:
from .agent import Agent
[docs]
class Population:
"""
Manages the columnar state of all agents using Polars DataFrames.
Acts as the single point of truth for agent data.
"""
def __init__(self, schema: Dict[str, Type] = None):
if schema is None:
schema = {}
# Core columns that always exist
self.schema = {
'id': pl.Int64,
'step': pl.Int64,
**schema
}
# Initialize empty DataFrame
self.data = pl.DataFrame(schema=self.schema)
# Buffer for batch operations
self._pending_updates: Dict[int, Dict[str, Any]] = {}
@property
def size(self) -> int:
return len(self.data)
def _align_and_concat(self, new_df: pl.DataFrame) -> pl.DataFrame:
"""
Robustly concatenate new_df to self.data, handling type mismatches.
This is the core fix for Polars Null vs Int64 conflicts.
"""
if self.data.is_empty():
return new_df
# Get union of all columns
all_cols = set(self.data.columns) | set(new_df.columns)
# Align self.data: add missing columns from new_df
for col in all_cols:
if col not in self.data.columns:
# Column exists in new_df but not self.data
dtype = new_df.schema[col]
self.data = self.data.with_columns(pl.lit(None).cast(dtype).alias(col))
if col not in new_df.columns:
# Column exists in self.data but not new_df
dtype = self.data.schema[col]
new_df = new_df.with_columns(pl.lit(None).cast(dtype).alias(col))
# Handle type mismatches between matching columns
for col in all_cols:
left_type = self.data.schema[col]
right_type = new_df.schema[col]
if left_type != right_type:
# Promote Null to concrete type
if left_type == pl.Null and right_type != pl.Null:
self.data = self.data.with_columns(pl.col(col).cast(right_type))
elif right_type == pl.Null and left_type != pl.Null:
new_df = new_df.with_columns(pl.col(col).cast(left_type))
# For other mismatches, try to find supertype or cast to Object
else:
try:
# Try casting new_df to self.data's type
new_df = new_df.with_columns(pl.col(col).cast(left_type))
except (pl.exceptions.ComputeError, pl.exceptions.InvalidOperationError, TypeError, ValueError):
# Last resort: cast both to String. This is
# destructive — log a warning so the user can
# pre-cast their input columns explicitly.
warnings.warn(
f"Column {col!r} has mismatched types "
f"({left_type} vs {right_type}); falling back to "
f"String. Pre-cast your columns to avoid this.",
UserWarning,
stacklevel=2,
)
self.data = self.data.with_columns(pl.col(col).cast(pl.Utf8))
new_df = new_df.with_columns(pl.col(col).cast(pl.Utf8))
# Ensure column order matches
new_df = new_df.select(self.data.columns)
return pl.concat([self.data, new_df], how="vertical")
[docs]
def add_agent(self, agent_id: int, step: int = 0, **attributes):
"""Adds a single agent to the population."""
row = {'id': agent_id, 'step': step, **attributes}
# Ensure all schema columns are present
for col, dtype in self.schema.items():
if col not in row:
row[col] = None
new_row = pl.DataFrame([row])
self.data = self._align_and_concat(new_row)
[docs]
def batch_add_agents(self, count: int, step: int = 0, **attributes):
"""Adds multiple agents efficiently."""
start_id = self.data['id'].max() + 1 if not self.data.is_empty() else 0
ids = range(start_id, start_id + count)
new_data = {
'id': list(ids),
'step': [step] * count
}
for k, v in attributes.items():
if isinstance(v, pl.Series):
if v.len() != count:
raise ValueError(f"Attribute {k} length mismatch")
new_data[k] = v.to_list()
elif isinstance(v, (list, np.ndarray)):
if len(v) != count:
raise ValueError(f"Attribute {k} length mismatch")
new_data[k] = list(v) if isinstance(v, np.ndarray) else v
else:
new_data[k] = [v] * count
# Fill missing schema columns
for col in self.schema:
if col not in new_data:
new_data[col] = [None] * count
new_df = pl.DataFrame(new_data)
self.data = self._align_and_concat(new_df)
[docs]
def get_agent_value(self, agent_id: int, column: str) -> Any:
res = self.data.filter(pl.col("id") == agent_id).select(column)
if res.is_empty():
raise KeyError(f"Agent {agent_id} not found")
return res.item(0, 0)
[docs]
def set_agent_value(self, agent_id: int, column: str, value: Any):
"""Sets a value for a single agent. Very slow if used in loops."""
# Determine Polars type from value
if hasattr(value, 'dtype'): # Handle numpy scalars
if np.issubdtype(value.dtype, np.integer):
pl_type = pl.Int64
elif np.issubdtype(value.dtype, np.floating):
pl_type = pl.Float64
else:
pl_type = pl.Object
else:
pl_type = pl.Int64 if isinstance(value, int) else pl.Float64 if isinstance(value, float) else pl.Utf8 if isinstance(value, str) else pl.Object
# Check if column exists, if not create it with correct type
if column not in self.data.columns:
self.data = self.data.with_columns(pl.lit(None).cast(pl_type).alias(column))
# Polars explicit update
self.data = self.data.with_columns(
pl.when(pl.col("id") == agent_id)
.then(pl.lit(value))
.otherwise(pl.col(column))
.alias(column)
)
[docs]
def batch_update(self, updates: Dict[str, Union[np.ndarray, list]], selector: Optional[pl.Expr] = None):
"""Updates columns for all agents (or a filtered subset)."""
if selector is None:
self.data = self.data.with_columns([
pl.Series(k, v) for k, v in updates.items()
])
else:
cols = []
for col, val in updates.items():
cols.append(
pl.when(selector)
.then(val)
.otherwise(pl.col(col))
.alias(col)
)
self.data = self.data.with_columns(cols)
[docs]
def batch_update_by_ids(self, ids: Union[list, np.ndarray], data: Dict[str, Union[list, np.ndarray, Any]]):
"""Updates specific agents identified by IDs."""
id_series = pl.Series("id", ids)
count = len(ids)
update_data = {"id": id_series}
for col, val in data.items():
if isinstance(val, (list, np.ndarray)):
if len(val) != count:
raise ValueError(f"Value length mismatch for {col}")
update_data[f"{col}_new"] = val
else:
update_data[f"{col}_new"] = [val] * count
update_df = pl.DataFrame(update_data)
self.data = self.data.join(update_df, on="id", how="left")
cols = []
for col in data.keys():
new_col = f"{col}_new"
cols.append(
pl.when(pl.col(new_col).is_not_null())
.then(pl.col(new_col))
.otherwise(pl.col(col))
.alias(col)
)
self.data = self.data.with_columns(cols).drop([f"{col}_new" for col in data.keys()])
[docs]
def create_batch_context(self):
"""Legacy batched-update context manager.
.. deprecated::
Prefer the vectorized view API:
``model.agents.at[ids].col = values`` or
``model.agents.at[ids].scatter_add(col=delta)``. The view API
flushes through the same hash-join path but is discoverable
via attribute access rather than a context manager.
"""
import warnings
warnings.warn(
"Population.create_batch_context() is deprecated; use "
"model.agents.at[ids].col = values (or scatter_add) instead. "
"See the AMBER quickstart for examples.",
DeprecationWarning,
stacklevel=2,
)
return BatchUpdateContext(self)
[docs]
class BatchUpdateContext:
"""Context manager for buffering updates to minimize DataFrame copies."""
def __init__(self, population: Population):
self.population = population
self.updates = {}
def __enter__(self):
return self
[docs]
def add_update(self, agent_id: int, col: str, val: Any):
if agent_id not in self.updates:
self.updates[agent_id] = {}
self.updates[agent_id][col] = val
def __exit__(self, exc_type, exc_val, exc_tb):
if not self.updates:
return
ids = list(self.updates.keys())
cols = set()
for u in self.updates.values():
cols.update(u.keys())
final_data = {}
for col in cols:
vals = []
for aid in ids:
vals.append(self.updates[aid].get(col, None))
final_data[col] = vals
self.population.batch_update_by_ids(ids, final_data)