from __future__ import annotations import asyncio from types import SimpleNamespace from app.core.api.application.request_service import RequestService from app.core.api.application.request_start_gate import RequestStartGate from app.core.api.application.stream_service import StreamService from app.core.api.infrastructure.ids.request_id_factory import RequestIdFactory from app.core.api.infrastructure.stores.in_memory_request_store import InMemoryRequestStore from app.core.api.infrastructure.streaming.sse_event_channel import SseEventChannel class FakeRuntime: def __init__(self) -> None: self.started = asyncio.Event() self.calls: list[tuple[object, object]] = [] async def run(self, request, session) -> None: self.calls.append((request, session)) self.started.set() class FakeSessions: def get(self, _session_id: str): return SimpleNamespace(session_id="sess-1") async def _wait_briefly() -> None: await asyncio.sleep(0.05) def test_request_service_waits_for_stream_subscriber_before_runtime_start() -> None: gate = RequestStartGate(timeout_seconds=1.0) runtime = FakeRuntime() service = RequestService( request_store=InMemoryRequestStore(), request_ids=RequestIdFactory(), sessions=FakeSessions(), runtime=runtime, start_gate=gate, ) async def scenario() -> None: request = await service.create("sess-1", "hello", "v2") await _wait_briefly() assert runtime.calls == [] gate.mark_ready(request.request_id) await asyncio.wait_for(runtime.started.wait(), timeout=1.0) assert len(runtime.calls) == 1 asyncio.run(scenario()) def test_stream_service_subscribe_marks_request_ready() -> None: gate = RequestStartGate(timeout_seconds=1.0) gate.register("req-1") service = StreamService( channel=SseEventChannel(), request_exists=lambda request_id: request_id == "req-1", start_gate=gate, ) async def scenario() -> None: waiter = asyncio.create_task(gate.wait_until_ready("req-1")) await service.subscribe("req-1") await asyncio.wait_for(waiter, timeout=1.0) asyncio.run(scenario())