451 lines
13 KiB
Python
451 lines
13 KiB
Python
"""State storage abstraction for simulation state.
|
|
|
|
Provides a unified interface for storing and retrieving simulation state,
|
|
with implementations for:
|
|
- In-memory storage (default, fast but not persistent)
|
|
- Redis storage (optional, enables decoupled UI polling and persistence)
|
|
|
|
This allows the simulation loop to snapshot state without blocking,
|
|
and enables external systems (like a web UI) to poll state independently.
|
|
"""
|
|
|
|
import json
|
|
import time
|
|
from abc import ABC, abstractmethod
|
|
from dataclasses import dataclass
|
|
from typing import Optional, Dict
|
|
|
|
|
|
@dataclass
|
|
class StateSnapshot:
|
|
"""A point-in-time snapshot of simulation state."""
|
|
turn: int
|
|
timestamp: float
|
|
data: dict
|
|
|
|
def to_json(self) -> str:
|
|
"""Convert to JSON string."""
|
|
return json.dumps({
|
|
"turn": self.turn,
|
|
"timestamp": self.timestamp,
|
|
"data": self.data,
|
|
})
|
|
|
|
@classmethod
|
|
def from_json(cls, json_str: str) -> "StateSnapshot":
|
|
"""Create from JSON string."""
|
|
obj = json.loads(json_str)
|
|
return cls(
|
|
turn=obj["turn"],
|
|
timestamp=obj["timestamp"],
|
|
data=obj["data"],
|
|
)
|
|
|
|
|
|
class StateStore(ABC):
|
|
"""Abstract interface for state storage.
|
|
|
|
Implementations should be thread-safe for concurrent read/write.
|
|
"""
|
|
|
|
@abstractmethod
|
|
def save_state(self, key: str, snapshot: StateSnapshot) -> bool:
|
|
"""Save a state snapshot.
|
|
|
|
Args:
|
|
key: Unique key for this state (e.g., "world", "market", "agent_123")
|
|
snapshot: The state snapshot to save
|
|
|
|
Returns:
|
|
True if saved successfully
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def get_state(self, key: str) -> Optional[StateSnapshot]:
|
|
"""Get the latest state snapshot for a key.
|
|
|
|
Args:
|
|
key: The state key to retrieve
|
|
|
|
Returns:
|
|
The snapshot if found, None otherwise
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def get_all_states(self, prefix: str = "") -> Dict[str, StateSnapshot]:
|
|
"""Get all states matching a prefix.
|
|
|
|
Args:
|
|
prefix: Key prefix to filter by (empty = all)
|
|
|
|
Returns:
|
|
Dict mapping keys to snapshots
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def delete_state(self, key: str) -> bool:
|
|
"""Delete a state snapshot.
|
|
|
|
Args:
|
|
key: The state key to delete
|
|
|
|
Returns:
|
|
True if deleted (or didn't exist)
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def clear_all(self) -> None:
|
|
"""Clear all stored states."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def is_healthy(self) -> bool:
|
|
"""Check if the store is healthy and accessible."""
|
|
pass
|
|
|
|
|
|
class MemoryStateStore(StateStore):
|
|
"""In-memory state storage (default implementation).
|
|
|
|
Fast but not persistent across restarts.
|
|
Thread-safe using a simple lock.
|
|
"""
|
|
|
|
def __init__(self, max_entries: int = 1000):
|
|
"""Initialize memory store.
|
|
|
|
Args:
|
|
max_entries: Maximum number of entries to keep (LRU eviction)
|
|
"""
|
|
import threading
|
|
self._data: Dict[str, StateSnapshot] = {}
|
|
self._lock = threading.Lock()
|
|
self._max_entries = max_entries
|
|
self._access_order: list[str] = [] # For LRU tracking
|
|
|
|
def save_state(self, key: str, snapshot: StateSnapshot) -> bool:
|
|
with self._lock:
|
|
# LRU eviction if at capacity
|
|
if len(self._data) >= self._max_entries and key not in self._data:
|
|
oldest = self._access_order.pop(0) if self._access_order else None
|
|
if oldest:
|
|
self._data.pop(oldest, None)
|
|
|
|
self._data[key] = snapshot
|
|
|
|
# Update access order
|
|
if key in self._access_order:
|
|
self._access_order.remove(key)
|
|
self._access_order.append(key)
|
|
|
|
return True
|
|
|
|
def get_state(self, key: str) -> Optional[StateSnapshot]:
|
|
with self._lock:
|
|
snapshot = self._data.get(key)
|
|
if snapshot and key in self._access_order:
|
|
# Update access order for LRU
|
|
self._access_order.remove(key)
|
|
self._access_order.append(key)
|
|
return snapshot
|
|
|
|
def get_all_states(self, prefix: str = "") -> Dict[str, StateSnapshot]:
|
|
with self._lock:
|
|
if not prefix:
|
|
return dict(self._data)
|
|
return {k: v for k, v in self._data.items() if k.startswith(prefix)}
|
|
|
|
def delete_state(self, key: str) -> bool:
|
|
with self._lock:
|
|
self._data.pop(key, None)
|
|
if key in self._access_order:
|
|
self._access_order.remove(key)
|
|
return True
|
|
|
|
def clear_all(self) -> None:
|
|
with self._lock:
|
|
self._data.clear()
|
|
self._access_order.clear()
|
|
|
|
def is_healthy(self) -> bool:
|
|
return True
|
|
|
|
|
|
class RedisStateStore(StateStore):
|
|
"""Redis-backed state storage.
|
|
|
|
Enables:
|
|
- Persistent state across restarts
|
|
- Decoupled UI polling (web clients can read state independently)
|
|
- Distributed access (multiple simulation instances)
|
|
|
|
Requires redis-py: pip install redis
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
host: str = "localhost",
|
|
port: int = 6379,
|
|
db: int = 0,
|
|
password: Optional[str] = None,
|
|
prefix: str = "villsim:",
|
|
ttl_seconds: int = 3600, # 1 hour default TTL
|
|
):
|
|
"""Initialize Redis store.
|
|
|
|
Args:
|
|
host: Redis server host
|
|
port: Redis server port
|
|
db: Redis database number
|
|
password: Redis password (if required)
|
|
prefix: Key prefix for all keys (for namespacing)
|
|
ttl_seconds: Time-to-live for entries (0 = no expiry)
|
|
"""
|
|
self._prefix = prefix
|
|
self._ttl = ttl_seconds
|
|
self._client = None
|
|
self._connection_params = {
|
|
"host": host,
|
|
"port": port,
|
|
"db": db,
|
|
"password": password,
|
|
"decode_responses": True,
|
|
}
|
|
self._connect()
|
|
|
|
def _connect(self) -> None:
|
|
"""Establish connection to Redis."""
|
|
try:
|
|
import redis
|
|
self._client = redis.Redis(**self._connection_params)
|
|
# Test connection
|
|
self._client.ping()
|
|
except ImportError:
|
|
raise ImportError(
|
|
"Redis support requires the 'redis' package. "
|
|
"Install with: pip install redis"
|
|
)
|
|
except Exception as e:
|
|
self._client = None
|
|
raise ConnectionError(f"Failed to connect to Redis: {e}")
|
|
|
|
def _make_key(self, key: str) -> str:
|
|
"""Create full Redis key with prefix."""
|
|
return f"{self._prefix}{key}"
|
|
|
|
def save_state(self, key: str, snapshot: StateSnapshot) -> bool:
|
|
if not self._client:
|
|
return False
|
|
|
|
try:
|
|
full_key = self._make_key(key)
|
|
data = snapshot.to_json()
|
|
|
|
if self._ttl > 0:
|
|
self._client.setex(full_key, self._ttl, data)
|
|
else:
|
|
self._client.set(full_key, data)
|
|
|
|
return True
|
|
except Exception:
|
|
return False
|
|
|
|
def get_state(self, key: str) -> Optional[StateSnapshot]:
|
|
if not self._client:
|
|
return None
|
|
|
|
try:
|
|
full_key = self._make_key(key)
|
|
data = self._client.get(full_key)
|
|
|
|
if data:
|
|
return StateSnapshot.from_json(data)
|
|
return None
|
|
except Exception:
|
|
return None
|
|
|
|
def get_all_states(self, prefix: str = "") -> Dict[str, StateSnapshot]:
|
|
if not self._client:
|
|
return {}
|
|
|
|
try:
|
|
pattern = self._make_key(prefix + "*")
|
|
keys = self._client.keys(pattern)
|
|
|
|
result = {}
|
|
for full_key in keys:
|
|
# Remove our prefix to get the original key
|
|
key = full_key[len(self._prefix):]
|
|
data = self._client.get(full_key)
|
|
if data:
|
|
result[key] = StateSnapshot.from_json(data)
|
|
|
|
return result
|
|
except Exception:
|
|
return {}
|
|
|
|
def delete_state(self, key: str) -> bool:
|
|
if not self._client:
|
|
return False
|
|
|
|
try:
|
|
full_key = self._make_key(key)
|
|
self._client.delete(full_key)
|
|
return True
|
|
except Exception:
|
|
return False
|
|
|
|
def clear_all(self) -> None:
|
|
if not self._client:
|
|
return
|
|
|
|
try:
|
|
pattern = self._make_key("*")
|
|
keys = self._client.keys(pattern)
|
|
if keys:
|
|
self._client.delete(*keys)
|
|
except Exception:
|
|
pass
|
|
|
|
def is_healthy(self) -> bool:
|
|
if not self._client:
|
|
return False
|
|
|
|
try:
|
|
self._client.ping()
|
|
return True
|
|
except Exception:
|
|
return False
|
|
|
|
def publish_state_update(self, channel: str, key: str) -> None:
|
|
"""Publish a state update notification (for real-time subscribers).
|
|
|
|
This can be used for WebSocket-style updates where clients
|
|
subscribe to state changes.
|
|
"""
|
|
if not self._client:
|
|
return
|
|
|
|
try:
|
|
self._client.publish(
|
|
f"{self._prefix}updates:{channel}",
|
|
json.dumps({"key": key, "timestamp": time.time()})
|
|
)
|
|
except Exception:
|
|
pass
|
|
|
|
|
|
class StubStateStore(StateStore):
|
|
"""No-op state store for when storage is disabled.
|
|
|
|
All operations succeed but don't actually store anything.
|
|
"""
|
|
|
|
def save_state(self, key: str, snapshot: StateSnapshot) -> bool:
|
|
return True
|
|
|
|
def get_state(self, key: str) -> Optional[StateSnapshot]:
|
|
return None
|
|
|
|
def get_all_states(self, prefix: str = "") -> Dict[str, StateSnapshot]:
|
|
return {}
|
|
|
|
def delete_state(self, key: str) -> bool:
|
|
return True
|
|
|
|
def clear_all(self) -> None:
|
|
pass
|
|
|
|
def is_healthy(self) -> bool:
|
|
return True
|
|
|
|
|
|
# Global state store instance
|
|
_state_store: Optional[StateStore] = None
|
|
|
|
|
|
def get_state_store() -> StateStore:
|
|
"""Get the global state store instance.
|
|
|
|
Creates a store based on config:
|
|
- If Redis is configured and available, uses Redis
|
|
- Otherwise falls back to in-memory storage
|
|
"""
|
|
global _state_store
|
|
|
|
if _state_store is None:
|
|
_state_store = _create_state_store()
|
|
|
|
return _state_store
|
|
|
|
|
|
def _create_state_store() -> StateStore:
|
|
"""Create the appropriate state store based on config."""
|
|
from backend.config import get_config
|
|
|
|
config = get_config()
|
|
|
|
# Check for Redis config
|
|
redis_config = getattr(config, 'redis', None)
|
|
|
|
if redis_config and getattr(redis_config, 'enabled', False):
|
|
try:
|
|
store = RedisStateStore(
|
|
host=getattr(redis_config, 'host', 'localhost'),
|
|
port=getattr(redis_config, 'port', 6379),
|
|
db=getattr(redis_config, 'db', 0),
|
|
password=getattr(redis_config, 'password', None),
|
|
prefix=getattr(redis_config, 'prefix', 'villsim:'),
|
|
ttl_seconds=getattr(redis_config, 'ttl_seconds', 3600),
|
|
)
|
|
if store.is_healthy():
|
|
return store
|
|
except Exception:
|
|
# Fall through to memory store
|
|
pass
|
|
|
|
# Check if storage is disabled
|
|
perf_config = getattr(config, 'performance', None)
|
|
if perf_config and not getattr(perf_config, 'state_storage_enabled', True):
|
|
return StubStateStore()
|
|
|
|
# Default to memory store
|
|
return MemoryStateStore()
|
|
|
|
|
|
def reset_state_store() -> None:
|
|
"""Reset the global state store."""
|
|
global _state_store
|
|
if _state_store:
|
|
_state_store.clear_all()
|
|
_state_store = None
|
|
|
|
|
|
def save_simulation_state(turn: int, state_data: dict) -> bool:
|
|
"""Convenience function to save simulation state.
|
|
|
|
Args:
|
|
turn: Current simulation turn
|
|
state_data: Full state dict (world, market, agents, etc.)
|
|
|
|
Returns:
|
|
True if saved successfully
|
|
"""
|
|
store = get_state_store()
|
|
snapshot = StateSnapshot(
|
|
turn=turn,
|
|
timestamp=time.time(),
|
|
data=state_data,
|
|
)
|
|
return store.save_state("simulation:current", snapshot)
|
|
|
|
|
|
def get_simulation_state() -> Optional[StateSnapshot]:
|
|
"""Convenience function to get current simulation state."""
|
|
store = get_state_store()
|
|
return store.get_state("simulation:current")
|