import asyncio import json import time from collections import defaultdict from dataclasses import dataclass @dataclass class Event: name: str payload: dict class EventBus: def __init__(self) -> None: self._channels: dict[str, list[asyncio.Queue[Event]]] = defaultdict(list) self._history: dict[str, list[Event]] = defaultdict(list) self._lock = asyncio.Lock() self._history_limit = 5000 async def subscribe(self, channel_id: str, replay: bool = True) -> asyncio.Queue[Event]: queue: asyncio.Queue[Event] = asyncio.Queue() snapshot: list[Event] = [] async with self._lock: self._channels[channel_id].append(queue) if replay: snapshot = list(self._history.get(channel_id, [])) for event in snapshot: await queue.put(event) return queue async def unsubscribe(self, channel_id: str, queue: asyncio.Queue[Event]) -> None: async with self._lock: if channel_id not in self._channels: return items = self._channels[channel_id] if queue in items: items.remove(queue) if not items: del self._channels[channel_id] async def publish(self, channel_id: str, name: str, payload: dict) -> None: event_payload = dict(payload) event_payload.setdefault("published_at_ms", int(time.time() * 1000)) event = Event(name=name, payload=event_payload) async with self._lock: queues = list(self._channels.get(channel_id, [])) history = self._history[channel_id] history.append(event) if len(history) > self._history_limit: del history[: len(history) - self._history_limit] for queue in queues: await queue.put(event) @staticmethod def as_sse(event: Event) -> str: return f"event: {event.name}\ndata: {json.dumps(event.payload, ensure_ascii=False)}\n\n"