70 lines
2.2 KiB
Python
70 lines
2.2 KiB
Python
from __future__ import annotations
|
|
|
|
import asyncio
|
|
from types import SimpleNamespace
|
|
|
|
from app.core.api.application.request_service import RequestService
|
|
from app.core.api.application.request_start_gate import RequestStartGate
|
|
from app.core.api.application.stream_service import StreamService
|
|
from app.core.api.infrastructure.ids.request_id_factory import RequestIdFactory
|
|
from app.core.api.infrastructure.stores.in_memory_request_store import InMemoryRequestStore
|
|
from app.core.api.infrastructure.streaming.sse_event_channel import SseEventChannel
|
|
|
|
|
|
class FakeRuntime:
|
|
def __init__(self) -> None:
|
|
self.started = asyncio.Event()
|
|
self.calls: list[tuple[object, object]] = []
|
|
|
|
async def run(self, request, session) -> None:
|
|
self.calls.append((request, session))
|
|
self.started.set()
|
|
|
|
|
|
class FakeSessions:
|
|
def get(self, _session_id: str):
|
|
return SimpleNamespace(session_id="sess-1")
|
|
|
|
|
|
async def _wait_briefly() -> None:
|
|
await asyncio.sleep(0.05)
|
|
|
|
|
|
def test_request_service_waits_for_stream_subscriber_before_runtime_start() -> None:
|
|
gate = RequestStartGate(timeout_seconds=1.0)
|
|
runtime = FakeRuntime()
|
|
service = RequestService(
|
|
request_store=InMemoryRequestStore(),
|
|
request_ids=RequestIdFactory(),
|
|
sessions=FakeSessions(),
|
|
runtime=runtime,
|
|
start_gate=gate,
|
|
)
|
|
|
|
async def scenario() -> None:
|
|
request = await service.create("sess-1", "hello", "v2")
|
|
await _wait_briefly()
|
|
assert runtime.calls == []
|
|
gate.mark_ready(request.request_id)
|
|
await asyncio.wait_for(runtime.started.wait(), timeout=1.0)
|
|
assert len(runtime.calls) == 1
|
|
|
|
asyncio.run(scenario())
|
|
|
|
|
|
def test_stream_service_subscribe_marks_request_ready() -> None:
|
|
gate = RequestStartGate(timeout_seconds=1.0)
|
|
gate.register("req-1")
|
|
service = StreamService(
|
|
channel=SseEventChannel(),
|
|
request_exists=lambda request_id: request_id == "req-1",
|
|
start_gate=gate,
|
|
)
|
|
|
|
async def scenario() -> None:
|
|
waiter = asyncio.create_task(gate.wait_until_ready("req-1"))
|
|
await service.subscribe("req-1")
|
|
await asyncio.wait_for(waiter, timeout=1.0)
|
|
|
|
asyncio.run(scenario())
|