58 lines
2.0 KiB
Python
58 lines
2.0 KiB
Python
import asyncio
|
|
import json
|
|
import time
|
|
from collections import defaultdict
|
|
from dataclasses import dataclass
|
|
|
|
|
|
@dataclass
|
|
class Event:
|
|
name: str
|
|
payload: dict
|
|
|
|
|
|
class EventBus:
|
|
def __init__(self) -> None:
|
|
self._channels: dict[str, list[asyncio.Queue[Event]]] = defaultdict(list)
|
|
self._history: dict[str, list[Event]] = defaultdict(list)
|
|
self._lock = asyncio.Lock()
|
|
self._history_limit = 5000
|
|
|
|
async def subscribe(self, channel_id: str, replay: bool = True) -> asyncio.Queue[Event]:
|
|
queue: asyncio.Queue[Event] = asyncio.Queue()
|
|
snapshot: list[Event] = []
|
|
async with self._lock:
|
|
self._channels[channel_id].append(queue)
|
|
if replay:
|
|
snapshot = list(self._history.get(channel_id, []))
|
|
for event in snapshot:
|
|
await queue.put(event)
|
|
return queue
|
|
|
|
async def unsubscribe(self, channel_id: str, queue: asyncio.Queue[Event]) -> None:
|
|
async with self._lock:
|
|
if channel_id not in self._channels:
|
|
return
|
|
items = self._channels[channel_id]
|
|
if queue in items:
|
|
items.remove(queue)
|
|
if not items:
|
|
del self._channels[channel_id]
|
|
|
|
async def publish(self, channel_id: str, name: str, payload: dict) -> None:
|
|
event_payload = dict(payload)
|
|
event_payload.setdefault("published_at_ms", int(time.time() * 1000))
|
|
event = Event(name=name, payload=event_payload)
|
|
async with self._lock:
|
|
queues = list(self._channels.get(channel_id, []))
|
|
history = self._history[channel_id]
|
|
history.append(event)
|
|
if len(history) > self._history_limit:
|
|
del history[: len(history) - self._history_limit]
|
|
for queue in queues:
|
|
await queue.put(event)
|
|
|
|
@staticmethod
|
|
def as_sse(event: Event) -> str:
|
|
return f"event: {event.name}\ndata: {json.dumps(event.payload, ensure_ascii=False)}\n\n"
|