"""State storage abstraction for simulation state. Provides a unified interface for storing and retrieving simulation state, with implementations for: - In-memory storage (default, fast but not persistent) - Redis storage (optional, enables decoupled UI polling and persistence) This allows the simulation loop to snapshot state without blocking, and enables external systems (like a web UI) to poll state independently. """ import json import time from abc import ABC, abstractmethod from dataclasses import dataclass from typing import Optional, Dict @dataclass class StateSnapshot: """A point-in-time snapshot of simulation state.""" turn: int timestamp: float data: dict def to_json(self) -> str: """Convert to JSON string.""" return json.dumps({ "turn": self.turn, "timestamp": self.timestamp, "data": self.data, }) @classmethod def from_json(cls, json_str: str) -> "StateSnapshot": """Create from JSON string.""" obj = json.loads(json_str) return cls( turn=obj["turn"], timestamp=obj["timestamp"], data=obj["data"], ) class StateStore(ABC): """Abstract interface for state storage. Implementations should be thread-safe for concurrent read/write. """ @abstractmethod def save_state(self, key: str, snapshot: StateSnapshot) -> bool: """Save a state snapshot. Args: key: Unique key for this state (e.g., "world", "market", "agent_123") snapshot: The state snapshot to save Returns: True if saved successfully """ pass @abstractmethod def get_state(self, key: str) -> Optional[StateSnapshot]: """Get the latest state snapshot for a key. Args: key: The state key to retrieve Returns: The snapshot if found, None otherwise """ pass @abstractmethod def get_all_states(self, prefix: str = "") -> Dict[str, StateSnapshot]: """Get all states matching a prefix. Args: prefix: Key prefix to filter by (empty = all) Returns: Dict mapping keys to snapshots """ pass @abstractmethod def delete_state(self, key: str) -> bool: """Delete a state snapshot. Args: key: The state key to delete Returns: True if deleted (or didn't exist) """ pass @abstractmethod def clear_all(self) -> None: """Clear all stored states.""" pass @abstractmethod def is_healthy(self) -> bool: """Check if the store is healthy and accessible.""" pass class MemoryStateStore(StateStore): """In-memory state storage (default implementation). Fast but not persistent across restarts. Thread-safe using a simple lock. """ def __init__(self, max_entries: int = 1000): """Initialize memory store. Args: max_entries: Maximum number of entries to keep (LRU eviction) """ import threading self._data: Dict[str, StateSnapshot] = {} self._lock = threading.Lock() self._max_entries = max_entries self._access_order: list[str] = [] # For LRU tracking def save_state(self, key: str, snapshot: StateSnapshot) -> bool: with self._lock: # LRU eviction if at capacity if len(self._data) >= self._max_entries and key not in self._data: oldest = self._access_order.pop(0) if self._access_order else None if oldest: self._data.pop(oldest, None) self._data[key] = snapshot # Update access order if key in self._access_order: self._access_order.remove(key) self._access_order.append(key) return True def get_state(self, key: str) -> Optional[StateSnapshot]: with self._lock: snapshot = self._data.get(key) if snapshot and key in self._access_order: # Update access order for LRU self._access_order.remove(key) self._access_order.append(key) return snapshot def get_all_states(self, prefix: str = "") -> Dict[str, StateSnapshot]: with self._lock: if not prefix: return dict(self._data) return {k: v for k, v in self._data.items() if k.startswith(prefix)} def delete_state(self, key: str) -> bool: with self._lock: self._data.pop(key, None) if key in self._access_order: self._access_order.remove(key) return True def clear_all(self) -> None: with self._lock: self._data.clear() self._access_order.clear() def is_healthy(self) -> bool: return True class RedisStateStore(StateStore): """Redis-backed state storage. Enables: - Persistent state across restarts - Decoupled UI polling (web clients can read state independently) - Distributed access (multiple simulation instances) Requires redis-py: pip install redis """ def __init__( self, host: str = "localhost", port: int = 6379, db: int = 0, password: Optional[str] = None, prefix: str = "villsim:", ttl_seconds: int = 3600, # 1 hour default TTL ): """Initialize Redis store. Args: host: Redis server host port: Redis server port db: Redis database number password: Redis password (if required) prefix: Key prefix for all keys (for namespacing) ttl_seconds: Time-to-live for entries (0 = no expiry) """ self._prefix = prefix self._ttl = ttl_seconds self._client = None self._connection_params = { "host": host, "port": port, "db": db, "password": password, "decode_responses": True, } self._connect() def _connect(self) -> None: """Establish connection to Redis.""" try: import redis self._client = redis.Redis(**self._connection_params) # Test connection self._client.ping() except ImportError: raise ImportError( "Redis support requires the 'redis' package. " "Install with: pip install redis" ) except Exception as e: self._client = None raise ConnectionError(f"Failed to connect to Redis: {e}") def _make_key(self, key: str) -> str: """Create full Redis key with prefix.""" return f"{self._prefix}{key}" def save_state(self, key: str, snapshot: StateSnapshot) -> bool: if not self._client: return False try: full_key = self._make_key(key) data = snapshot.to_json() if self._ttl > 0: self._client.setex(full_key, self._ttl, data) else: self._client.set(full_key, data) return True except Exception: return False def get_state(self, key: str) -> Optional[StateSnapshot]: if not self._client: return None try: full_key = self._make_key(key) data = self._client.get(full_key) if data: return StateSnapshot.from_json(data) return None except Exception: return None def get_all_states(self, prefix: str = "") -> Dict[str, StateSnapshot]: if not self._client: return {} try: pattern = self._make_key(prefix + "*") keys = self._client.keys(pattern) result = {} for full_key in keys: # Remove our prefix to get the original key key = full_key[len(self._prefix):] data = self._client.get(full_key) if data: result[key] = StateSnapshot.from_json(data) return result except Exception: return {} def delete_state(self, key: str) -> bool: if not self._client: return False try: full_key = self._make_key(key) self._client.delete(full_key) return True except Exception: return False def clear_all(self) -> None: if not self._client: return try: pattern = self._make_key("*") keys = self._client.keys(pattern) if keys: self._client.delete(*keys) except Exception: pass def is_healthy(self) -> bool: if not self._client: return False try: self._client.ping() return True except Exception: return False def publish_state_update(self, channel: str, key: str) -> None: """Publish a state update notification (for real-time subscribers). This can be used for WebSocket-style updates where clients subscribe to state changes. """ if not self._client: return try: self._client.publish( f"{self._prefix}updates:{channel}", json.dumps({"key": key, "timestamp": time.time()}) ) except Exception: pass class StubStateStore(StateStore): """No-op state store for when storage is disabled. All operations succeed but don't actually store anything. """ def save_state(self, key: str, snapshot: StateSnapshot) -> bool: return True def get_state(self, key: str) -> Optional[StateSnapshot]: return None def get_all_states(self, prefix: str = "") -> Dict[str, StateSnapshot]: return {} def delete_state(self, key: str) -> bool: return True def clear_all(self) -> None: pass def is_healthy(self) -> bool: return True # Global state store instance _state_store: Optional[StateStore] = None def get_state_store() -> StateStore: """Get the global state store instance. Creates a store based on config: - If Redis is configured and available, uses Redis - Otherwise falls back to in-memory storage """ global _state_store if _state_store is None: _state_store = _create_state_store() return _state_store def _create_state_store() -> StateStore: """Create the appropriate state store based on config.""" from backend.config import get_config config = get_config() # Check for Redis config redis_config = getattr(config, 'redis', None) if redis_config and getattr(redis_config, 'enabled', False): try: store = RedisStateStore( host=getattr(redis_config, 'host', 'localhost'), port=getattr(redis_config, 'port', 6379), db=getattr(redis_config, 'db', 0), password=getattr(redis_config, 'password', None), prefix=getattr(redis_config, 'prefix', 'villsim:'), ttl_seconds=getattr(redis_config, 'ttl_seconds', 3600), ) if store.is_healthy(): return store except Exception: # Fall through to memory store pass # Check if storage is disabled perf_config = getattr(config, 'performance', None) if perf_config and not getattr(perf_config, 'state_storage_enabled', True): return StubStateStore() # Default to memory store return MemoryStateStore() def reset_state_store() -> None: """Reset the global state store.""" global _state_store if _state_store: _state_store.clear_all() _state_store = None def save_simulation_state(turn: int, state_data: dict) -> bool: """Convenience function to save simulation state. Args: turn: Current simulation turn state_data: Full state dict (world, market, agents, etc.) Returns: True if saved successfully """ store = get_state_store() snapshot = StateSnapshot( turn=turn, timestamp=time.time(), data=state_data, ) return store.save_state("simulation:current", snapshot) def get_simulation_state() -> Optional[StateSnapshot]: """Convenience function to get current simulation state.""" store = get_state_store() return store.get_state("simulation:current")