Files
plba/tests/test_application_http.py
T

210 lines
7.2 KiB
Python

from __future__ import annotations
import http.client
from dataclasses import dataclass
from pathlib import Path
from fastapi import FastAPI, File, UploadFile
from fastapi.responses import FileResponse
from fastapi.testclient import TestClient
import pytest
from app_runtime.contracts.application import ApplicationModule
from app_runtime.control.http_channel import HttpControlChannel
from app_runtime.core.registration import ModuleRegistry
from app_runtime.core.runtime import RuntimeManager
from app_runtime.http.base import ApplicationHttpChannel
from app_runtime.http.http_channel import HttpApplicationChannel
try:
import python_multipart # noqa: F401
except ImportError:
HAS_MULTIPART = False
else:
HAS_MULTIPART = True
class RecordingChannel(ApplicationHttpChannel):
def __init__(self) -> None:
self.apps: list[FastAPI] = []
self.stop_calls = 0
async def start(self, app: FastAPI) -> None:
self.apps.append(app)
async def stop(self) -> None:
self.stop_calls += 1
class PingRoutes:
def register(self, app: FastAPI, services) -> None: # type: ignore[no-untyped-def]
@app.get("/estimate/ping")
async def ping() -> dict[str, str]:
return {"status": "ok"}
@dataclass
class ServiceBackedRoutes:
download_path: Path
def register(self, app: FastAPI, services) -> None: # type: ignore[no-untyped-def]
marker = services.get("task_query_service")
@app.get("/estimate/api/tasks")
async def list_tasks() -> dict[str, object]:
return {"marker": marker["marker"]}
@app.post("/estimate/api/tasks")
async def create_task(file: UploadFile = File(...)) -> dict[str, object]:
payload = await file.read()
return {"filename": file.filename, "size": len(payload)}
@app.get("/estimate/api/tasks/result")
async def download_result() -> FileResponse:
return FileResponse(self.download_path, filename=self.download_path.name)
class MetricsRoutes:
def register(self, app: FastAPI, services) -> None: # type: ignore[no-untyped-def]
@app.get("/estimate/api/metrics")
async def metrics() -> dict[str, int]:
return {"count": 1}
class HttpModule(ApplicationModule):
def __init__(self, *registrars: object) -> None:
self._registrars = registrars
@property
def name(self) -> str:
return "http-module"
def register(self, registry: ModuleRegistry) -> None:
for registrar in self._registrars:
registry.add_http_routes(registrar)
def _application_client(channel: RecordingChannel) -> TestClient:
assert channel.apps
return TestClient(channel.apps[0])
def _http_request(port: int, path: str) -> tuple[int, bytes]:
connection = http.client.HTTPConnection("127.0.0.1", port, timeout=2)
try:
connection.request("GET", path)
response = connection.getresponse()
payload = response.read()
return response.status, payload
finally:
connection.close()
def test_runtime_starts_application_http_and_registers_routes() -> None:
runtime = RuntimeManager()
channel = RecordingChannel()
runtime.application_http.register_channel(channel)
runtime.register_module(HttpModule(PingRoutes()))
runtime.start(start_control_plane=False)
try:
assert len(channel.apps) == 1
client = _application_client(channel)
with client:
response = client.get("/estimate/ping")
assert response.status_code == 200
assert response.json() == {"status": "ok"}
assert response.headers["x-response-time-ms"].isdigit()
finally:
runtime.stop(stop_control_plane=False)
assert channel.stop_calls == 1
def test_application_routes_see_runtime_services_and_support_upload_download(tmp_path: Path) -> None:
if not HAS_MULTIPART:
pytest.skip("python-multipart is not installed in the local environment")
runtime = RuntimeManager()
runtime.services.register("task_query_service", {"marker": "from-container"})
result_path = tmp_path / "result.txt"
result_path.write_text("ready", encoding="utf-8")
channel = RecordingChannel()
runtime.application_http.register_channel(channel)
runtime.register_module(HttpModule(ServiceBackedRoutes(result_path), MetricsRoutes()))
runtime.start(start_control_plane=False)
client = _application_client(channel)
try:
with client:
list_response = client.get("/estimate/api/tasks")
assert list_response.status_code == 200
assert list_response.json() == {"marker": "from-container"}
upload_response = client.post(
"/estimate/api/tasks",
files={"file": ("input.txt", b"payload", "text/plain")},
)
assert upload_response.status_code == 200
assert upload_response.json() == {"filename": "input.txt", "size": 7}
metrics_response = client.get("/estimate/api/metrics")
assert metrics_response.status_code == 200
assert metrics_response.json() == {"count": 1}
download_response = client.get("/estimate/api/tasks/result")
assert download_response.status_code == 200
assert download_response.content == b"ready"
finally:
runtime.stop(stop_control_plane=False)
def test_application_http_stop_shuts_down_real_server() -> None:
runtime = RuntimeManager()
channel = HttpApplicationChannel(host="127.0.0.1", port=0, timeout=2)
runtime.application_http.register_channel(channel)
runtime.register_module(HttpModule(PingRoutes()))
runtime.start(start_control_plane=False)
try:
status, _ = _http_request(channel.port, "/estimate/ping")
assert status == 200
finally:
runtime.stop(stop_control_plane=False)
try:
_http_request(channel.port, "/estimate/ping")
except OSError:
pass
else:
raise AssertionError("application HTTP server is still reachable after stop")
def test_control_plane_and_application_http_work_independently() -> None:
runtime = RuntimeManager()
control_channel = HttpControlChannel(host="127.0.0.1", port=0, timeout=2)
app_channel = HttpApplicationChannel(host="127.0.0.1", port=0, timeout=2)
runtime.control_plane.register_channel(control_channel)
runtime.application_http.register_channel(app_channel)
runtime.register_module(HttpModule(PingRoutes()))
runtime.start()
try:
control_status, _ = _http_request(control_channel.port, "/health")
app_status, _ = _http_request(app_channel.port, "/estimate/ping")
assert control_status == 200
assert app_status == 200
control_missing_status, _ = _http_request(control_channel.port, "/estimate/ping")
app_missing_status, _ = _http_request(app_channel.port, "/health")
assert control_missing_status == 404
assert app_missing_status == 404
runtime.application_http.stop()
control_status, _ = _http_request(control_channel.port, "/health")
assert control_status == 200
finally:
runtime.control_plane.stop()
runtime.stop(stop_control_plane=False)