Files
agent/tests/unit_tests/api/test_request_start_gate.py

70 lines
2.2 KiB
Python

from __future__ import annotations
import asyncio
from types import SimpleNamespace
from app.core.api.application.request_service import RequestService
from app.core.api.application.request_start_gate import RequestStartGate
from app.core.api.application.stream_service import StreamService
from app.core.api.infrastructure.ids.request_id_factory import RequestIdFactory
from app.core.api.infrastructure.stores.in_memory_request_store import InMemoryRequestStore
from app.core.api.infrastructure.streaming.sse_event_channel import SseEventChannel
class FakeRuntime:
def __init__(self) -> None:
self.started = asyncio.Event()
self.calls: list[tuple[object, object]] = []
async def run(self, request, session) -> None:
self.calls.append((request, session))
self.started.set()
class FakeSessions:
def get(self, _session_id: str):
return SimpleNamespace(session_id="sess-1")
async def _wait_briefly() -> None:
await asyncio.sleep(0.05)
def test_request_service_waits_for_stream_subscriber_before_runtime_start() -> None:
gate = RequestStartGate(timeout_seconds=1.0)
runtime = FakeRuntime()
service = RequestService(
request_store=InMemoryRequestStore(),
request_ids=RequestIdFactory(),
sessions=FakeSessions(),
runtime=runtime,
start_gate=gate,
)
async def scenario() -> None:
request = await service.create("sess-1", "hello", "v2")
await _wait_briefly()
assert runtime.calls == []
gate.mark_ready(request.request_id)
await asyncio.wait_for(runtime.started.wait(), timeout=1.0)
assert len(runtime.calls) == 1
asyncio.run(scenario())
def test_stream_service_subscribe_marks_request_ready() -> None:
gate = RequestStartGate(timeout_seconds=1.0)
gate.register("req-1")
service = StreamService(
channel=SseEventChannel(),
request_exists=lambda request_id: request_id == "req-1",
start_gate=gate,
)
async def scenario() -> None:
waiter = asyncio.create_task(gate.wait_until_ready("req-1"))
await service.subscribe("req-1")
await asyncio.wait_for(waiter, timeout=1.0)
asyncio.run(scenario())