villsim/backend/core/storage.py

451 lines
13 KiB
Python

"""State storage abstraction for simulation state.
Provides a unified interface for storing and retrieving simulation state,
with implementations for:
- In-memory storage (default, fast but not persistent)
- Redis storage (optional, enables decoupled UI polling and persistence)
This allows the simulation loop to snapshot state without blocking,
and enables external systems (like a web UI) to poll state independently.
"""
import json
import time
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Optional, Dict
@dataclass
class StateSnapshot:
"""A point-in-time snapshot of simulation state."""
turn: int
timestamp: float
data: dict
def to_json(self) -> str:
"""Convert to JSON string."""
return json.dumps({
"turn": self.turn,
"timestamp": self.timestamp,
"data": self.data,
})
@classmethod
def from_json(cls, json_str: str) -> "StateSnapshot":
"""Create from JSON string."""
obj = json.loads(json_str)
return cls(
turn=obj["turn"],
timestamp=obj["timestamp"],
data=obj["data"],
)
class StateStore(ABC):
"""Abstract interface for state storage.
Implementations should be thread-safe for concurrent read/write.
"""
@abstractmethod
def save_state(self, key: str, snapshot: StateSnapshot) -> bool:
"""Save a state snapshot.
Args:
key: Unique key for this state (e.g., "world", "market", "agent_123")
snapshot: The state snapshot to save
Returns:
True if saved successfully
"""
pass
@abstractmethod
def get_state(self, key: str) -> Optional[StateSnapshot]:
"""Get the latest state snapshot for a key.
Args:
key: The state key to retrieve
Returns:
The snapshot if found, None otherwise
"""
pass
@abstractmethod
def get_all_states(self, prefix: str = "") -> Dict[str, StateSnapshot]:
"""Get all states matching a prefix.
Args:
prefix: Key prefix to filter by (empty = all)
Returns:
Dict mapping keys to snapshots
"""
pass
@abstractmethod
def delete_state(self, key: str) -> bool:
"""Delete a state snapshot.
Args:
key: The state key to delete
Returns:
True if deleted (or didn't exist)
"""
pass
@abstractmethod
def clear_all(self) -> None:
"""Clear all stored states."""
pass
@abstractmethod
def is_healthy(self) -> bool:
"""Check if the store is healthy and accessible."""
pass
class MemoryStateStore(StateStore):
"""In-memory state storage (default implementation).
Fast but not persistent across restarts.
Thread-safe using a simple lock.
"""
def __init__(self, max_entries: int = 1000):
"""Initialize memory store.
Args:
max_entries: Maximum number of entries to keep (LRU eviction)
"""
import threading
self._data: Dict[str, StateSnapshot] = {}
self._lock = threading.Lock()
self._max_entries = max_entries
self._access_order: list[str] = [] # For LRU tracking
def save_state(self, key: str, snapshot: StateSnapshot) -> bool:
with self._lock:
# LRU eviction if at capacity
if len(self._data) >= self._max_entries and key not in self._data:
oldest = self._access_order.pop(0) if self._access_order else None
if oldest:
self._data.pop(oldest, None)
self._data[key] = snapshot
# Update access order
if key in self._access_order:
self._access_order.remove(key)
self._access_order.append(key)
return True
def get_state(self, key: str) -> Optional[StateSnapshot]:
with self._lock:
snapshot = self._data.get(key)
if snapshot and key in self._access_order:
# Update access order for LRU
self._access_order.remove(key)
self._access_order.append(key)
return snapshot
def get_all_states(self, prefix: str = "") -> Dict[str, StateSnapshot]:
with self._lock:
if not prefix:
return dict(self._data)
return {k: v for k, v in self._data.items() if k.startswith(prefix)}
def delete_state(self, key: str) -> bool:
with self._lock:
self._data.pop(key, None)
if key in self._access_order:
self._access_order.remove(key)
return True
def clear_all(self) -> None:
with self._lock:
self._data.clear()
self._access_order.clear()
def is_healthy(self) -> bool:
return True
class RedisStateStore(StateStore):
"""Redis-backed state storage.
Enables:
- Persistent state across restarts
- Decoupled UI polling (web clients can read state independently)
- Distributed access (multiple simulation instances)
Requires redis-py: pip install redis
"""
def __init__(
self,
host: str = "localhost",
port: int = 6379,
db: int = 0,
password: Optional[str] = None,
prefix: str = "villsim:",
ttl_seconds: int = 3600, # 1 hour default TTL
):
"""Initialize Redis store.
Args:
host: Redis server host
port: Redis server port
db: Redis database number
password: Redis password (if required)
prefix: Key prefix for all keys (for namespacing)
ttl_seconds: Time-to-live for entries (0 = no expiry)
"""
self._prefix = prefix
self._ttl = ttl_seconds
self._client = None
self._connection_params = {
"host": host,
"port": port,
"db": db,
"password": password,
"decode_responses": True,
}
self._connect()
def _connect(self) -> None:
"""Establish connection to Redis."""
try:
import redis
self._client = redis.Redis(**self._connection_params)
# Test connection
self._client.ping()
except ImportError:
raise ImportError(
"Redis support requires the 'redis' package. "
"Install with: pip install redis"
)
except Exception as e:
self._client = None
raise ConnectionError(f"Failed to connect to Redis: {e}")
def _make_key(self, key: str) -> str:
"""Create full Redis key with prefix."""
return f"{self._prefix}{key}"
def save_state(self, key: str, snapshot: StateSnapshot) -> bool:
if not self._client:
return False
try:
full_key = self._make_key(key)
data = snapshot.to_json()
if self._ttl > 0:
self._client.setex(full_key, self._ttl, data)
else:
self._client.set(full_key, data)
return True
except Exception:
return False
def get_state(self, key: str) -> Optional[StateSnapshot]:
if not self._client:
return None
try:
full_key = self._make_key(key)
data = self._client.get(full_key)
if data:
return StateSnapshot.from_json(data)
return None
except Exception:
return None
def get_all_states(self, prefix: str = "") -> Dict[str, StateSnapshot]:
if not self._client:
return {}
try:
pattern = self._make_key(prefix + "*")
keys = self._client.keys(pattern)
result = {}
for full_key in keys:
# Remove our prefix to get the original key
key = full_key[len(self._prefix):]
data = self._client.get(full_key)
if data:
result[key] = StateSnapshot.from_json(data)
return result
except Exception:
return {}
def delete_state(self, key: str) -> bool:
if not self._client:
return False
try:
full_key = self._make_key(key)
self._client.delete(full_key)
return True
except Exception:
return False
def clear_all(self) -> None:
if not self._client:
return
try:
pattern = self._make_key("*")
keys = self._client.keys(pattern)
if keys:
self._client.delete(*keys)
except Exception:
pass
def is_healthy(self) -> bool:
if not self._client:
return False
try:
self._client.ping()
return True
except Exception:
return False
def publish_state_update(self, channel: str, key: str) -> None:
"""Publish a state update notification (for real-time subscribers).
This can be used for WebSocket-style updates where clients
subscribe to state changes.
"""
if not self._client:
return
try:
self._client.publish(
f"{self._prefix}updates:{channel}",
json.dumps({"key": key, "timestamp": time.time()})
)
except Exception:
pass
class StubStateStore(StateStore):
"""No-op state store for when storage is disabled.
All operations succeed but don't actually store anything.
"""
def save_state(self, key: str, snapshot: StateSnapshot) -> bool:
return True
def get_state(self, key: str) -> Optional[StateSnapshot]:
return None
def get_all_states(self, prefix: str = "") -> Dict[str, StateSnapshot]:
return {}
def delete_state(self, key: str) -> bool:
return True
def clear_all(self) -> None:
pass
def is_healthy(self) -> bool:
return True
# Global state store instance
_state_store: Optional[StateStore] = None
def get_state_store() -> StateStore:
"""Get the global state store instance.
Creates a store based on config:
- If Redis is configured and available, uses Redis
- Otherwise falls back to in-memory storage
"""
global _state_store
if _state_store is None:
_state_store = _create_state_store()
return _state_store
def _create_state_store() -> StateStore:
"""Create the appropriate state store based on config."""
from backend.config import get_config
config = get_config()
# Check for Redis config
redis_config = getattr(config, 'redis', None)
if redis_config and getattr(redis_config, 'enabled', False):
try:
store = RedisStateStore(
host=getattr(redis_config, 'host', 'localhost'),
port=getattr(redis_config, 'port', 6379),
db=getattr(redis_config, 'db', 0),
password=getattr(redis_config, 'password', None),
prefix=getattr(redis_config, 'prefix', 'villsim:'),
ttl_seconds=getattr(redis_config, 'ttl_seconds', 3600),
)
if store.is_healthy():
return store
except Exception:
# Fall through to memory store
pass
# Check if storage is disabled
perf_config = getattr(config, 'performance', None)
if perf_config and not getattr(perf_config, 'state_storage_enabled', True):
return StubStateStore()
# Default to memory store
return MemoryStateStore()
def reset_state_store() -> None:
"""Reset the global state store."""
global _state_store
if _state_store:
_state_store.clear_all()
_state_store = None
def save_simulation_state(turn: int, state_data: dict) -> bool:
"""Convenience function to save simulation state.
Args:
turn: Current simulation turn
state_data: Full state dict (world, market, agents, etc.)
Returns:
True if saved successfully
"""
store = get_state_store()
snapshot = StateSnapshot(
turn=turn,
timestamp=time.time(),
data=state_data,
)
return store.save_state("simulation:current", snapshot)
def get_simulation_state() -> Optional[StateSnapshot]:
"""Convenience function to get current simulation state."""
store = get_state_store()
return store.get_state("simulation:current")