Source code for ambr.sequences

"""AgentList and view-based subset types.

The full population view lives at ``model.agents``. Filtered and scatter
views are produced by ``agents.where(...)`` / ``agents[mask]`` /
``agents.at[ids]``. All views route column reads/writes through
``model.agents_df`` so the DataFrame is the single source of truth.
"""

from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union

import numpy as np
import polars as pl

from .agent import Agent
from .model import Model


# Names that must live on the Python instance, not as DataFrame columns.
_INTERNAL_ATTRS = frozenset({
    "model",
    "agent_type",
    "_agent_objects",
    "_agents_by_id",
    "_ids",
    "_parent",
})


def _normalize_delta(name: str, value: Any, n: int) -> np.ndarray:
    """Coerce a scatter_add / _write_column value into a length-``n`` ndarray."""
    if isinstance(value, pl.Series):
        if value.len() != n:
            raise ValueError(
                f"length {value.len()} for {name!r} does not match view length {n}"
            )
        return value.to_numpy()
    if isinstance(value, np.ndarray):
        if len(value) != n:
            raise ValueError(
                f"length {len(value)} for {name!r} does not match view length {n}"
            )
        return value
    if isinstance(value, list):
        if len(value) != n:
            raise ValueError(
                f"length {len(value)} for {name!r} does not match view length {n}"
            )
        return np.asarray(value)
    return np.full(n, value)


class _BaseView:
    """Attribute/assignment protocol shared by every view type."""

    # Subclasses override.
    def _ids_series(self) -> pl.Series:
        raise NotImplementedError

    # --- attribute protocol -------------------------------------------------

    def __getattr__(self, name: str):
        if name.startswith("_") or name in _INTERNAL_ATTRS:
            raise AttributeError(
                f"{type(self).__name__!r} object has no attribute {name!r}"
            )
        model = self.__dict__.get("model")
        if model is None:
            raise AttributeError(name)
        df = model.agents_df
        if name not in df.columns:
            # Backward-compat fallback: if the name is a callable method
            # on the underlying agents, return a wrapper that dispatches
            # to each agent and collects results into a numpy array.
            root = self._root()
            agents = getattr(root, '_agent_objects', None)
            if agents:
                first = agents[0] if agents else None
                method = getattr(first, name, None) if first is not None else None
                if callable(method):
                    import numpy as np
                    def _dispatch(*args, **kwargs):
                        results = [getattr(a, name)(*args, **kwargs) for a in self]
                        try:
                            return np.array(results)
                        except (ValueError, TypeError):
                            return results
                    return _dispatch
            raise AttributeError(
                f"{type(self).__name__!r} has no column {name!r}; "
                f"available columns: {df.columns}"
            )
        ids = self._ids_series()
        if ids.len() == df.height and ids.equals(df["id"]):
            return df[name]
        # Align the returned Series with the view's id order so that
        # duplicate-id scatter views still return a length-matched column.
        # When name is "id", skip the join to avoid duplicate column error.
        if name == "id":
            return ids
        ids_df = pl.DataFrame([ids.rename("id")])
        return ids_df.join(df.select("id", name), on="id", how="left")[name]

    def __setattr__(self, name: str, value: Any) -> None:
        if name.startswith("_") or name in _INTERNAL_ATTRS:
            object.__setattr__(self, name, value)
            return
        self._write_column(name, value)

    # --- columnar writes ----------------------------------------------------

    def _write_column(self, name: str, value: Any) -> None:
        """Assign ``value`` to column ``name`` over this view's agents.

        Accepts scalars, ``pl.Series`` / ``np.ndarray`` / list matching
        ``len(view)``, and ``pl.Expr`` evaluated over the view's rows.
        """
        model = self.__dict__["model"]
        model._flush_pending_writes()

        ids = self._ids_series()
        n = ids.len()
        df = model.agents_df

        if isinstance(value, pl.Expr):
            sub = df.filter(pl.col("id").is_in(ids.to_list())).select(
                pl.col("id"), value.alias("__new__")
            )
            if name not in df.columns:
                df = df.with_columns(pl.Series(name, [None] * df.height, strict=False))
            joined = df.join(sub, on="id", how="left").with_columns(
                pl.when(pl.col("__new__").is_not_null())
                .then(pl.col("__new__"))
                .otherwise(pl.col(name))
                .alias(name)
            ).drop("__new__")
            model.population.data = joined
            return

        values = value if isinstance(value, pl.Series) else pl.Series(
            name,
            [value] * n if not isinstance(value, (list, np.ndarray)) else value,
            strict=False,
        )
        if values.len() != n:
            raise ValueError(
                f"Cannot assign Series of length {values.len()} to view of length {n}"
            )

        # Whole-population fast path: one with_columns, no join.
        if n == df.height and ids.equals(df["id"]):
            model.population.data = df.with_columns(values.alias(name))
            return

        if name not in df.columns:
            df = df.with_columns(pl.Series(name, [None] * df.height, strict=False))
        update_df = pl.DataFrame([ids.rename("id"), values.rename(name)])
        model.population.data = df.update(update_df, on="id", how="left")

    @property
    def ids(self) -> pl.Series:
        return self._ids_series()

    # --- filtering ----------------------------------------------------------

    def where(self, predicate) -> "FilteredAgentList":
        """Return a view of agents matching a boolean Series or Polars expression."""
        model = self.__dict__["model"]
        base_ids = self._ids_series()
        if isinstance(predicate, pl.Series):
            if predicate.dtype != pl.Boolean:
                raise TypeError("Series predicate must be boolean")
            if predicate.len() != base_ids.len():
                raise ValueError(
                    f"Boolean mask length {predicate.len()} does not match "
                    f"view length {base_ids.len()}"
                )
            new_ids = base_ids.filter(predicate)
        elif isinstance(predicate, pl.Expr):
            df = model.agents_df
            sub = df.filter(pl.col("id").is_in(base_ids.to_list())).filter(predicate)
            new_ids = sub["id"]
        else:
            raise TypeError("predicate must be a polars Series (boolean) or Expr")
        return FilteredAgentList(model, new_ids, parent=self._root())

    def select(self, selection) -> "_BaseView":
        """AgentPy-compatible filter.

        Accepts bool masks (list/ndarray/Series), Polars expressions,
        and id lists. Returns a FilteredAgentList or ScatterAgentList.
        """
        model = self.__dict__["model"]
        root = self._root()
        if isinstance(selection, pl.Expr):
            return self.where(selection)
        if isinstance(selection, pl.Series):
            if selection.dtype == pl.Boolean:
                return self.where(selection)
            return FilteredAgentList(model, selection.rename("id"), parent=root)
        if isinstance(selection, (list, np.ndarray)):
            arr = np.asarray(selection)
            if arr.dtype == bool:
                ids = self._ids_series()
                if len(arr) != ids.len():
                    raise ValueError(
                        f"Boolean mask length ({len(arr)}) does not match "
                        f"view length ({ids.len()})"
                    )
                picked = ids.filter(pl.Series("mask", arr))
                return FilteredAgentList(model, picked, parent=root)
            # List of ids
            return FilteredAgentList(model, pl.Series("id", arr.tolist()), parent=root)
        raise TypeError(f"select() unsupported type: {type(selection)}")

    def _root(self) -> "AgentList":
        return self.__dict__.get("_parent") or self  # type: ignore[return-value]

    # --- length / iteration -------------------------------------------------

    def __len__(self) -> int:
        return self._ids_series().len()

    def __iter__(self):
        lookup = getattr(self._root(), "_agents_by_id", None) or {}
        for aid in self._ids_series().to_list():
            agent = lookup.get(aid)
            if agent is not None:
                yield agent

    @property
    def at(self) -> "_AtIndexer":
        return _AtIndexer(self)

    # --- method dispatch ----------------------------------------------------

    def call(self, method_name: str, *args, **kwargs):
        """Invoke ``method_name`` on each Python Agent in this view."""
        results = []
        for agent in self:
            method = getattr(agent, method_name, None)
            if callable(method):
                results.append(method(*args, **kwargs))
        try:
            return np.array(results)
        except ValueError:
            return results

    def apply(self, func: Callable[[Agent], Any]) -> pl.Series:
        return pl.Series([func(a) for a in self])

    # --- legacy aliases -----------------------------------------------------

    def record(self, name: str, value: Any) -> None:
        """Legacy alias for ``view.<name> = value``."""
        self._write_column(name, value)

    def update_data(self, data: Dict[str, Any]) -> None:
        """Legacy alias for multi-column write over this view's agents."""
        for k, v in data.items():
            self._write_column(k, v)

    # --- scatter-add --------------------------------------------------------

    def scatter_add(self, **increments: Any) -> None:
        """Accumulate per-id deltas into columns, summing across duplicate ids.

        ``view.at[[1, 1, 3]].scatter_add(wealth=1)`` gives agent ``1`` a +2
        and agent ``3`` a +1. Accepts the same value shapes as column
        assignment.
        """
        if not increments:
            return
        model = self.__dict__["model"]
        model._flush_pending_writes()

        ids = self._ids_series()
        n = ids.len()
        df = model.agents_df
        if n == 0:
            missing = [c for c in increments if c not in df.columns]
            if missing:
                model.population.data = df.with_columns(
                    [pl.Series(c, [None] * df.height, strict=False) for c in missing]
                )
            return

        delta_np: Dict[str, np.ndarray] = {
            col: _normalize_delta(col, val, n) for col, val in increments.items()
        }

        positions = _resolve_positions(model, df, ids.to_numpy())

        new_columns: List[pl.Series] = []
        for col_name, delta in delta_np.items():
            if col_name in df.columns:
                base = df[col_name].to_numpy()
                # np.add.at won't upcast the output dtype, so expand first
                # when we're adding e.g. a float delta to an int column.
                result_dtype = np.result_type(base.dtype, delta.dtype)
                base = base.astype(result_dtype, copy=True) if base.dtype != result_dtype else base.copy()
            else:
                base = np.zeros(df.height, dtype=delta.dtype)
            np.add.at(base, positions, delta)
            new_columns.append(pl.Series(col_name, base, strict=False))

        model.population.data = df.with_columns(new_columns)


def _resolve_positions(model: Model, df: pl.DataFrame, ids_np: np.ndarray) -> np.ndarray:
    """Map view ids to row positions in ``df``, caching the lookup by id version."""
    df_ids_np = df["id"].to_numpy()
    # Contiguous [0, N) fast path (the common case after add_agents).
    if (
        df_ids_np.dtype.kind in ("i", "u")
        and df_ids_np.size == df.height
        and df_ids_np.size > 0
        and df_ids_np[0] == 0
        and df_ids_np[-1] == df_ids_np.size - 1
        and np.array_equal(df_ids_np, np.arange(df_ids_np.size))
    ):
        return np.asarray(ids_np, dtype=np.int64)

    # General case: reuse the id→position hash table across scatter_add calls
    # within the same step, invalidating whenever the id column changes.
    cached: Optional[Tuple[int, Dict[int, int]]] = getattr(model, "_id_pos_cache", None)
    version = getattr(model, "_id_version", 0)
    if cached is None or cached[0] != version:
        id_to_pos = {int(v): i for i, v in enumerate(df_ids_np)}
        model._id_pos_cache = (version, id_to_pos)
    else:
        id_to_pos = cached[1]
    return np.fromiter(
        (id_to_pos[int(v)] for v in ids_np),
        dtype=np.int64,
        count=int(ids_np.size),
    )


class _AtIndexer:
    """``view.at[ids]`` -> ScatterAgentList keyed by those ids."""

    def __init__(self, view: "_BaseView"):
        object.__setattr__(self, "_view", view)

    def __getitem__(self, key) -> "ScatterAgentList":
        view: _BaseView = self._view
        model = view.__dict__["model"]
        if isinstance(key, (int, np.integer)):
            ids = pl.Series("id", [int(key)])
        elif isinstance(key, pl.Series):
            ids = key.rename("id")
        elif isinstance(key, (list, np.ndarray)):
            ids = pl.Series("id", list(key))
        else:
            raise TypeError(
                f"at[...] accepts int, list, ndarray, or Series (got {type(key).__name__})"
            )
        return ScatterAgentList(model, ids, parent=view._root())


[docs] class AgentList(_BaseView): """Full view over a model's population. Lives at ``model.agents``.""" def __init__( self, model: Model, agents_or_n: Union[List[Agent], int] = None, agent_type: Optional[Type[Agent]] = None, ): object.__setattr__(self, "model", model) object.__setattr__(self, "_agent_objects", []) object.__setattr__(self, "_agents_by_id", {}) if agents_or_n is None: agents_or_n = [] resolved_type: Optional[Type[Agent]] = agent_type if isinstance(agents_or_n, list): objs = list(agents_or_n) if resolved_type is None and objs: resolved_type = type(objs[0]) for a in objs: self._track_agent(a) else: if agent_type is None: raise ValueError("agent_type is required when creating new agents") resolved_type = agent_type for i in range(agents_or_n): a = agent_type(model, i) a.setup() self._track_agent(a) object.__setattr__(self, "agent_type", resolved_type) # --- internal tracking --------------------------------------------------- def _track_agent(self, agent: Agent) -> None: self._agent_objects.append(agent) aid = getattr(agent, "id", None) if aid is not None: self._agents_by_id[aid] = agent def _untrack_agent(self, agent: Agent) -> None: aid = getattr(agent, "id", None) if aid is not None: self._agents_by_id.pop(aid, None) # --- view hooks ---------------------------------------------------------- def _ids_series(self) -> pl.Series: # When Agent objects are tracked, return only their IDs — not # the entire population. This keeps each AgentList isolated. if self._agent_objects: ids = [a.id for a in self._agent_objects] return pl.Series("id", ids, dtype=pl.Int64) # Fall back to the full DataFrame for view-only or vectorized usage. df = self.model.agents_df return df["id"] if "id" in df.columns else pl.Series("id", [], dtype=pl.Int64) @property def agents(self) -> List[Agent]: """Legacy accessor: list of underlying ``Agent`` objects.""" return self._agent_objects def __iter__(self): return iter(self._agent_objects) def __len__(self) -> int: # Prefer the Python-side tracking list (OOP-style models) and fall # back to agents_df.height for fully vectorized models that never # materialise Agent instances. if self._agent_objects: return len(self._agent_objects) model = self.__dict__.get("model") if model is None: return 0 try: df = model.agents_df except Exception: return 0 if not isinstance(df, pl.DataFrame) or "id" not in df.columns: return 0 return df.height def __contains__(self, agent) -> bool: return agent in self._agent_objects def __repr__(self) -> str: return f"AgentList({len(self)} agents)"
[docs] def __getitem__(self, idx): """Index by position (int/slice), id list, boolean mask, or ``pl.Expr``.""" if isinstance(idx, (int, np.integer)): return self._agent_objects[int(idx)] if isinstance(idx, slice): return self._agent_objects[idx] if isinstance(idx, pl.Expr): return self.where(idx) if isinstance(idx, pl.Series): if idx.dtype == pl.Boolean: return self.where(idx) return FilteredAgentList(self.model, idx.rename("id"), parent=self) if isinstance(idx, (list, np.ndarray)): arr = np.asarray(idx) if arr.dtype == bool: if len(arr) != len(self._agent_objects): raise ValueError( f"Boolean mask length ({len(arr)}) does not match " f"AgentList length ({len(self._agent_objects)})" ) picked = [ getattr(a, "id", None) for a, keep in zip(self._agent_objects, arr) if keep ] return FilteredAgentList( self.model, pl.Series("id", picked), parent=self ) # list of positions → pick those agents by index picked_ids = [ getattr(self._agent_objects[int(i)], "id", None) for i in arr ] return FilteredAgentList( self.model, pl.Series("id", picked_ids), parent=self ) raise TypeError(f"Invalid index type: {type(idx)}")
def __setitem__(self, idx, agent) -> None: old = self._agent_objects[idx] self._untrack_agent(old) self._agent_objects[idx] = agent aid = getattr(agent, "id", None) if aid is not None: self._agents_by_id[aid] = agent def __add__(self, other): if isinstance(other, AgentList): combined = self._agent_objects + other._agent_objects elif isinstance(other, list): combined = self._agent_objects + other else: raise TypeError(f"Cannot add {type(other)} to AgentList") return AgentList(self.model, combined, agent_type=self.agent_type) # --- list-like mutation --------------------------------------------------
[docs] def append(self, agent: Agent) -> None: self._track_agent(agent)
[docs] def extend(self, agents: List[Agent]) -> None: for a in agents: self._track_agent(a)
[docs] def remove(self, agent: Agent) -> None: self._agent_objects.remove(agent) self._untrack_agent(agent)
[docs] def clear(self) -> None: self._agent_objects.clear() self._agents_by_id.clear()
[docs] def copy(self) -> "AgentList": new_list = AgentList(self.model, list(self._agent_objects)) new_list.agent_type = self.agent_type return new_list
[docs] def index(self, agent: Agent) -> int: return self._agent_objects.index(agent)
[docs] def count(self, agent: Agent) -> int: return self._agent_objects.count(agent)
[docs] def pop(self, idx: int = -1) -> Agent: a = self._agent_objects.pop(idx) self._untrack_agent(a) return a
[docs] def insert(self, idx: int, agent: Agent) -> None: self._agent_objects.insert(idx, agent) aid = getattr(agent, "id", None) if aid is not None: self._agents_by_id[aid] = agent
[docs] def reverse(self) -> None: self._agent_objects.reverse()
[docs] def sort(self, key=None, reverse: bool = False) -> None: self._agent_objects.sort(key=key, reverse=reverse)
# --- legacy property ---------------------------------------------------- @property def agent_ids(self): return [getattr(agent, "id", i) for i, agent in enumerate(self._agent_objects)] # --- legacy APIs (kept as thin wrappers around the column protocol) ----
[docs] def get_data(self) -> pl.DataFrame: if hasattr(self.model, "agents_df"): return self.model.agents_df return pl.DataFrame()
[docs] def group_by(self, by: str) -> Dict[Any, "FilteredAgentList"]: groups: Dict[Any, FilteredAgentList] = {} if not hasattr(self.model, "agents_df"): return groups df = self.model.agents_df if by not in df.columns: return groups for group_value, group_df in df.group_by(by): groups[group_value[0] if isinstance(group_value, tuple) else group_value] = ( FilteredAgentList(self.model, group_df["id"], parent=self) ) return groups
class _SubView(_BaseView): """Base for views backed by an explicit id list.""" def __init__(self, model: Model, ids: pl.Series, parent: AgentList): object.__setattr__(self, "model", model) if ids.name != "id": ids = ids.rename("id") object.__setattr__(self, "_ids", ids) object.__setattr__(self, "_parent", parent) def _ids_series(self) -> pl.Series: return self._ids def _root(self) -> AgentList: return self._parent def __getitem__(self, idx): """Index by position (int/slice/list/ndarray) within this view.""" if isinstance(idx, (int, np.integer)): id_list = self._ids.to_list() aid = id_list[int(idx)] lookup = getattr(self._root(), "_agents_by_id", None) or {} agent = lookup.get(aid) if agent is not None: return agent raise IndexError(f"Agent id={aid} not found in AgentList") if isinstance(idx, slice): id_list = self._ids.to_list()[idx] lookup = getattr(self._root(), "_agents_by_id", None) or {} return [lookup.get(aid) for aid in id_list] if isinstance(idx, (list, np.ndarray)): arr = np.asarray(idx) id_list = self._ids.to_list() if arr.dtype == bool: return self.select(arr) # List of positions → return a FilteredAgentList view picked_ids = [id_list[int(i)] for i in arr] return FilteredAgentList( self.__dict__["model"], pl.Series("id", picked_ids), parent=self._root(), ) if isinstance(idx, pl.Expr): return self.where(idx) if isinstance(idx, pl.Series): if idx.dtype == pl.Boolean: return self.where(idx) return FilteredAgentList(self.__dict__["model"], idx.rename("id"), parent=self._root()) raise TypeError(f"{type(self).__name__} indices must be int, slice, list, or ndarray") def __repr__(self) -> str: return f"{type(self).__name__}({self._ids.len()} agents)"
[docs] class FilteredAgentList(_SubView): """Subset view produced by ``agents.where(...)`` / ``agents[mask]``."""
[docs] class ScatterAgentList(_SubView): """Id-indexed view produced by ``agents.at[ids]`` (ids may repeat)."""