From 85fcaae31b197f377320b4af4575c8eaa43d7bfa Mon Sep 17 00:00:00 2001 From: zosimovaa Date: Tue, 28 Apr 2026 14:57:09 +0300 Subject: [PATCH] =?UTF-8?q?API=20=D0=B4=D0=BB=D1=8F=20=D0=BF=D1=80=D0=BE?= =?UTF-8?q?=D1=81=D0=BC=D0=BE=D1=82=D1=80=D0=B0=20=D0=BB=D0=BE=D0=B3=D0=BE?= =?UTF-8?q?=D0=B2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pyproject.toml | 2 +- src/app_runtime/contracts/trace.py | 37 ++++ src/app_runtime/control/base.py | 14 ++ src/app_runtime/control/http_app.py | 64 ++++++- src/app_runtime/control/http_channel.py | 13 +- src/app_runtime/control/service.py | 1 + src/app_runtime/core/runtime.py | 12 ++ src/app_runtime/tracing/reader.py | 87 +++++++++ src/app_runtime/tracing/transport.py | 51 ++++-- tests/test_trace_endpoint.py | 226 ++++++++++++++++++++++++ 10 files changed, 489 insertions(+), 18 deletions(-) create mode 100644 src/app_runtime/tracing/reader.py create mode 100644 tests/test_trace_endpoint.py diff --git a/pyproject.toml b/pyproject.toml index 3b9525f..5579238 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "plba" -version = "0.2.9" +version = "0.3.0" description = "Platform runtime for business applications" readme = "README.md" requires-python = ">=3.11" diff --git a/src/app_runtime/contracts/trace.py b/src/app_runtime/contracts/trace.py index 38ff248..71964ba 100644 --- a/src/app_runtime/contracts/trace.py +++ b/src/app_runtime/contracts/trace.py @@ -58,3 +58,40 @@ class TraceTransport(Protocol): def write_message(self, record: TraceLogMessage) -> None: """Persist trace log message.""" + + +@dataclass(frozen=True) +class TraceLogRecord: + id: int + trace_id: str + event_time: datetime + step: str + status: str + level: TraceLevel + message: str + attrs_json: Any + + def as_dict(self, *, include_attrs_json: bool) -> dict[str, Any]: + payload: dict[str, Any] = { + "id": self.id, + "trace_id": self.trace_id, + "event_time": self.event_time.isoformat(), + "step": self.step, + "status": self.status, + "level": self.level, + "message": self.message, + } + if include_attrs_json: + payload["attrs_json"] = self.attrs_json + return payload + + +@dataclass(frozen=True) +class TraceLogView: + parent_id: str | None + records: tuple[TraceLogRecord, ...] = () + + +class TraceLogReader(Protocol): + def read_trace(self, trace_id: str, levels: tuple[TraceLevel, ...]) -> TraceLogView | None: + """Load trace context and filtered log records.""" diff --git a/src/app_runtime/control/base.py b/src/app_runtime/control/base.py index 9c65fbf..90dadf0 100644 --- a/src/app_runtime/control/base.py +++ b/src/app_runtime/control/base.py @@ -3,7 +3,9 @@ from __future__ import annotations from abc import ABC, abstractmethod from collections.abc import Awaitable, Callable from dataclasses import dataclass +from typing import Literal +from app_runtime.contracts.trace import TraceLevel, TraceLogView from app_runtime.core.types import HealthPayload @@ -17,6 +19,17 @@ class ControlActionRequest: ActionResult = str | dict[str, object] ActionHandler = Callable[[ControlActionRequest], Awaitable[ActionResult]] HealthHandler = Callable[[], Awaitable[HealthPayload]] +TraceResponseFormat = Literal["json", "text"] + + +@dataclass(slots=True) +class TraceQueryRequest: + levels: tuple[TraceLevel, ...] = ("ERROR", "WARNING") + include_attrs_json: bool = False + response_format: TraceResponseFormat = "text" + + +TraceLookupHandler = Callable[[str, TraceQueryRequest], Awaitable[TraceLogView]] @dataclass(slots=True) @@ -25,6 +38,7 @@ class ControlActionSet: start: ActionHandler stop: ActionHandler status: ActionHandler + trace_lookup: TraceLookupHandler | None = None class ControlChannel(ABC): diff --git a/src/app_runtime/control/http_app.py b/src/app_runtime/control/http_app.py index 5d42c6f..e7c4601 100644 --- a/src/app_runtime/control/http_app.py +++ b/src/app_runtime/control/http_app.py @@ -1,13 +1,16 @@ from __future__ import annotations +import json import logging import time from collections.abc import Awaitable, Callable +from typing import cast from fastapi import FastAPI, Request -from fastapi.responses import JSONResponse +from fastapi.responses import JSONResponse, PlainTextResponse -from app_runtime.control.base import ControlActionRequest +from app_runtime.control.base import ControlActionRequest, TraceQueryRequest +from app_runtime.contracts.trace import TraceLevel, TraceLogView from app_runtime.core.types import HealthPayload LOGGER = logging.getLogger(__name__) @@ -18,6 +21,7 @@ class HttpControlAppFactory: self, health_provider: Callable[[], Awaitable[HealthPayload]], action_provider: Callable[[str, str, ControlActionRequest], Awaitable[JSONResponse]], + trace_provider: Callable[[str, TraceQueryRequest], Awaitable[TraceLogView]] | None = None, ) -> FastAPI: app = FastAPI(title="PLBA Control API") @@ -46,6 +50,22 @@ class HttpControlAppFactory: return JSONResponse(content={"status": "error", "detail": str(exc)}, status_code=400) return await action_provider(action, client_source, action_request) + @app.get("/traces/{traceid}") + async def trace(traceid: str, request: Request): + if trace_provider is None: + return JSONResponse(content={"status": "error", "detail": "trace lookup is not configured"}, status_code=503) + try: + trace_request = self._trace_request(request) + except ValueError as exc: + return JSONResponse(content={"status": "error", "detail": str(exc)}, status_code=400) + try: + payload = await trace_provider(traceid, trace_request) + except KeyError: + return JSONResponse(content={"status": "error", "detail": f"trace not found: {traceid}"}, status_code=404) + except RuntimeError as exc: + return JSONResponse(content={"status": "error", "detail": str(exc)}, status_code=503) + return self._trace_response(payload, trace_request) + return app def _action_request(self, request: Request) -> ControlActionRequest: @@ -86,3 +106,43 @@ class HttpControlAppFactory: if value < 0: raise ValueError(f"query parameter must be >= 0: {name}={raw_value}") return value + + def _trace_request(self, request: Request) -> TraceQueryRequest: + raw_levels = request.query_params.get("levels") + raw_format = request.query_params.get("format", "text") + response_format = raw_format.strip().lower() + if response_format not in {"json", "text"}: + raise ValueError(f"unsupported trace format: {raw_format}") + return TraceQueryRequest( + levels=self._trace_levels(raw_levels), + include_attrs_json=self._bool_param(request, "attrs_json") or False, + response_format=response_format, + ) + + def _trace_levels(self, raw_levels: str | None) -> tuple[TraceLevel, ...]: + if raw_levels is None: + return ("ERROR", "WARNING") + parts = [item.strip().upper() for item in raw_levels.split(",")] + levels = tuple(item for item in parts if item) + if not levels: + raise ValueError("trace levels must not be empty") + unsupported = [level for level in levels if level not in {"DEBUG", "INFO", "WARNING", "ERROR"}] + if unsupported: + raise ValueError(f"unsupported trace levels: {', '.join(unsupported)}") + return cast(tuple[TraceLevel, ...], levels) + + def _trace_response(self, trace_view: TraceLogView, request: TraceQueryRequest) -> JSONResponse | PlainTextResponse: + if request.response_format == "json": + return JSONResponse( + content={ + "parent_id": trace_view.parent_id or "", + "messages": [record.as_dict(include_attrs_json=request.include_attrs_json) for record in trace_view.records], + } + ) + lines = [trace_view.parent_id or ""] + for record in trace_view.records: + line = record.message + if request.include_attrs_json: + line = f"{line}, {json.dumps(record.attrs_json, ensure_ascii=False, separators=(',', ':'))}" + lines.append(line) + return PlainTextResponse(content="\n".join(lines)) diff --git a/src/app_runtime/control/http_channel.py b/src/app_runtime/control/http_channel.py index d6b4808..b1b1166 100644 --- a/src/app_runtime/control/http_channel.py +++ b/src/app_runtime/control/http_channel.py @@ -4,7 +4,8 @@ import asyncio from fastapi.responses import JSONResponse -from app_runtime.control.base import ControlActionRequest, ControlActionSet, ControlChannel +from app_runtime.control.base import ControlActionRequest, ControlActionSet, ControlChannel, TraceQueryRequest +from app_runtime.contracts.trace import TraceLogView from app_runtime.control.http_app import HttpControlAppFactory from app_runtime.control.http_runner import UvicornThreadRunner @@ -18,7 +19,7 @@ class HttpControlChannel(ControlChannel): async def start(self, actions: ControlActionSet) -> None: self._actions = actions - app = self._factory.create(self._health_response, self._action_response) + app = self._factory.create(self._health_response, self._action_response, self._trace_response) await self._runner.start(app) async def stop(self) -> None: @@ -67,3 +68,11 @@ class HttpControlChannel(ControlChannel): if action != "stop" or request.wait is False or request.timeout is None: return base_timeout return max(base_timeout, float(request.timeout) + 1.0) + + async def _trace_response(self, trace_id: str, request: TraceQueryRequest) -> TraceLogView: + if self._actions is None or self._actions.trace_lookup is None: + raise RuntimeError("trace lookup is not configured") + return await asyncio.wait_for( + self._actions.trace_lookup(trace_id, request), + timeout=float(self._timeout), + ) diff --git a/src/app_runtime/control/service.py b/src/app_runtime/control/service.py index 821e6c9..26c44a0 100644 --- a/src/app_runtime/control/service.py +++ b/src/app_runtime/control/service.py @@ -43,6 +43,7 @@ class ControlPlaneService: start=runtime.start_runtime, stop=runtime.stop_runtime, status=runtime.runtime_status, + trace_lookup=runtime.trace_logs, ) for channel in self._channels: await channel.start(actions) diff --git a/src/app_runtime/core/runtime.py b/src/app_runtime/core/runtime.py index d2ef66f..28e3c4b 100644 --- a/src/app_runtime/core/runtime.py +++ b/src/app_runtime/core/runtime.py @@ -5,6 +5,8 @@ from time import monotonic, sleep from app_runtime.config.providers import FileConfigProvider from app_runtime.contracts.application import ApplicationModule from app_runtime.control.base import ControlActionRequest +from app_runtime.control.base import TraceQueryRequest +from app_runtime.contracts.trace import TraceLogView from app_runtime.control.service import ControlPlaneService from app_runtime.core.configuration import ConfigurationManager from app_runtime.core.registration import ModuleRegistry @@ -12,6 +14,7 @@ from app_runtime.core.service_container import ServiceContainer from app_runtime.core.types import HealthPayload, LifecycleState from app_runtime.health.registry import HealthRegistry from app_runtime.logging.manager import LogManager +from app_runtime.tracing.reader import build_trace_log_reader from app_runtime.tracing.service import TraceService from app_runtime.workers.supervisor import WorkerSupervisor @@ -127,6 +130,15 @@ class RuntimeManager: self._refresh_state() return self._state.value + async def trace_logs(self, trace_id: str, request: TraceQueryRequest) -> TraceLogView: + reader = build_trace_log_reader(self.traces.transport) + if reader is None: + raise RuntimeError("trace log reader is not configured") + trace_view = reader.read_trace(trace_id, request.levels) + if trace_view is None: + raise KeyError(trace_id) + return trace_view + def _register_core_services(self) -> None: if self._core_registered: return diff --git a/src/app_runtime/tracing/reader.py b/src/app_runtime/tracing/reader.py new file mode 100644 index 0000000..b6f7e77 --- /dev/null +++ b/src/app_runtime/tracing/reader.py @@ -0,0 +1,87 @@ +from __future__ import annotations + +import json +from typing import Any + +from app_runtime.contracts.trace import TraceLevel, TraceLogReader, TraceLogRecord, TraceLogView, TraceTransport +from app_runtime.tracing.transport import MySqlTraceConnectionFactory, MySqlTraceTransport + + +class MySqlTraceLogReader(TraceLogReader): + def __init__(self, connection_factory: MySqlTraceConnectionFactory) -> None: + self._connection_factory = connection_factory + + def read_trace(self, trace_id: str, levels: tuple[TraceLevel, ...]) -> TraceLogView | None: + parent_id = self._read_parent_id(trace_id) + if parent_id is None and not self._trace_exists(trace_id): + return None + records = self._read_records(trace_id, levels) + return TraceLogView(parent_id=parent_id, records=tuple(records)) + + def _trace_exists(self, trace_id: str) -> bool: + query = "SELECT 1 FROM trace_contexts WHERE trace_id = %s" + with self._connection_factory.connect() as connection: + with connection.cursor() as cursor: + cursor.execute(query, (trace_id,)) + return cursor.fetchone() is not None + + def _read_parent_id(self, trace_id: str) -> str | None: + query = "SELECT parent_id FROM trace_contexts WHERE trace_id = %s" + with self._connection_factory.connect() as connection: + with connection.cursor() as cursor: + cursor.execute(query, (trace_id,)) + row = cursor.fetchone() + if row is None: + return None + return self._string_or_none(row.get("parent_id")) + + def _read_records(self, trace_id: str, levels: tuple[TraceLevel, ...]) -> list[TraceLogRecord]: + placeholders = ", ".join(["%s"] * len(levels)) + query = f""" + SELECT id, trace_id, event_time, step, status, level, message, attrs_json + FROM trace_messages + WHERE trace_id = %s AND level IN ({placeholders}) + ORDER BY event_time ASC, id ASC + """ + params: tuple[object, ...] = (trace_id, *levels) + with self._connection_factory.connect() as connection: + with connection.cursor() as cursor: + cursor.execute(query, params) + rows = cursor.fetchall() + return [self._build_record(row) for row in rows] + + def _build_record(self, row: dict[str, Any]) -> TraceLogRecord: + return TraceLogRecord( + id=int(row["id"]), + trace_id=str(row["trace_id"]), + event_time=row["event_time"], + step=str(row["step"] or ""), + status=str(row["status"] or ""), + level=str(row["level"]), + message=str(row["message"] or ""), + attrs_json=self._load_json(row.get("attrs_json")), + ) + + def _load_json(self, raw_value: Any) -> Any: + if raw_value is None or isinstance(raw_value, (dict, list, int, float, bool)): + return raw_value + if isinstance(raw_value, (bytes, bytearray)): + raw_value = raw_value.decode("utf-8") + if isinstance(raw_value, str): + try: + return json.loads(raw_value) + except json.JSONDecodeError: + return raw_value + return raw_value + + def _string_or_none(self, value: Any) -> str | None: + if value is None: + return None + text = str(value) + return text or None + + +def build_trace_log_reader(transport: TraceTransport) -> TraceLogReader | None: + if isinstance(transport, MySqlTraceTransport): + return MySqlTraceLogReader(transport.create_connection_factory()) + return None diff --git a/src/app_runtime/tracing/transport.py b/src/app_runtime/tracing/transport.py index c3b6db7..0bde4fc 100644 --- a/src/app_runtime/tracing/transport.py +++ b/src/app_runtime/tracing/transport.py @@ -15,7 +15,7 @@ class NoOpTraceTransport(TraceTransport): del record -class MySqlTraceTransport(TraceTransport): +class MySqlTraceConnectionFactory: def __init__( self, *, @@ -31,6 +31,39 @@ class MySqlTraceTransport(TraceTransport): self._user = user self._password = password + def connect(self): # type: ignore[no-untyped-def] + import pymysql + + return pymysql.connect( + host=self._host, + port=self._port, + user=self._user, + password=self._password, + database=self._database, + charset="utf8mb4", + autocommit=True, + cursorclass=pymysql.cursors.DictCursor, + ) + + +class MySqlTraceTransport(TraceTransport): + def __init__( + self, + *, + host: str, + port: int, + database: str, + user: str, + password: str, + ) -> None: + self._connections = MySqlTraceConnectionFactory( + host=host, + port=port, + database=database, + user=user, + password=password, + ) + def write_context(self, record: TraceContextRecord) -> None: query = """ INSERT INTO trace_contexts (trace_id, parent_id, alias, type, event_time, attrs_json) @@ -69,21 +102,13 @@ class MySqlTraceTransport(TraceTransport): self._execute(query, params) def _execute(self, query: str, params: tuple[object, ...]) -> None: - import pymysql - - with pymysql.connect( - host=self._host, - port=self._port, - user=self._user, - password=self._password, - database=self._database, - charset="utf8mb4", - autocommit=True, - cursorclass=pymysql.cursors.DictCursor, - ) as connection: + with self._connections.connect() as connection: with connection.cursor() as cursor: cursor.execute(query, params) + def create_connection_factory(self) -> MySqlTraceConnectionFactory: + return self._connections + def _dumps(self, payload: dict[str, object]) -> str: return json.dumps(payload, ensure_ascii=False, default=self._json_default) diff --git a/tests/test_trace_endpoint.py b/tests/test_trace_endpoint.py new file mode 100644 index 0000000..470f1a4 --- /dev/null +++ b/tests/test_trace_endpoint.py @@ -0,0 +1,226 @@ +from __future__ import annotations + +import asyncio +from datetime import datetime, timezone + +from fastapi.responses import JSONResponse +from fastapi.testclient import TestClient + +import app_runtime.core.runtime as runtime_module +from app_runtime.control.base import ControlActionRequest, TraceQueryRequest +from app_runtime.control.http_app import HttpControlAppFactory +from app_runtime.contracts.trace import TraceLogRecord, TraceLogView +from app_runtime.core.runtime import RuntimeManager +from app_runtime.tracing.reader import MySqlTraceLogReader + + +def _trace_record( + *, + row_id: int, + level: str, + message: str, + attrs_json: object | None = None, +) -> TraceLogRecord: + return TraceLogRecord( + id=row_id, + trace_id="trace-1", + event_time=datetime(2026, 4, 28, 10, 11, 12, tzinfo=timezone.utc), + step="process", + status="failed", + level=level, # type: ignore[arg-type] + message=message, + attrs_json=attrs_json if attrs_json is not None else {}, + ) + + +def _build_client(trace_provider=None) -> TestClient: + async def health_provider(): + return {"status": "ok"} + + async def action_provider(_action: str, _client_source: str, _request: ControlActionRequest) -> JSONResponse: + return JSONResponse(content={"status": "ok"}) + + app = HttpControlAppFactory().create(health_provider, action_provider, trace_provider) + return TestClient(app) + + +def test_trace_endpoint_returns_text_with_default_levels() -> None: + captured: list[tuple[str, TraceQueryRequest]] = [] + + async def trace_provider(trace_id: str, request: TraceQueryRequest) -> TraceLogView: + captured.append((trace_id, request)) + return TraceLogView( + parent_id="root-trace", + records=( + _trace_record(row_id=1, level="ERROR", message="first error"), + _trace_record(row_id=2, level="WARNING", message="second warning"), + ), + ) + + client = _build_client(trace_provider) + try: + response = client.get("/traces/trace-1") + finally: + client.close() + + assert response.status_code == 200 + assert response.text == "root-trace\nfirst error\nsecond warning" + assert captured == [("trace-1", TraceQueryRequest(levels=("ERROR", "WARNING"), include_attrs_json=False, response_format="text"))] + + +def test_trace_endpoint_appends_attrs_json_in_text_mode() -> None: + async def trace_provider(_trace_id: str, _request: TraceQueryRequest) -> TraceLogView: + return TraceLogView( + parent_id=None, + records=( + _trace_record(row_id=1, level="ERROR", message="failure", attrs_json={"attempt": 2, "source": "crm"}), + ), + ) + + client = _build_client(trace_provider) + try: + response = client.get("/traces/trace-1?attrs_json=true") + finally: + client.close() + + assert response.status_code == 200 + assert response.text == '\nfailure, {"attempt":2,"source":"crm"}' + + +def test_trace_endpoint_returns_json_payload() -> None: + async def trace_provider(_trace_id: str, _request: TraceQueryRequest) -> TraceLogView: + return TraceLogView( + parent_id="parent-1", + records=( + _trace_record(row_id=3, level="INFO", message="done", attrs_json={"batch": 7}), + ), + ) + + client = _build_client(trace_provider) + try: + response = client.get("/traces/trace-1?format=json&attrs_json=true&levels=info") + finally: + client.close() + + assert response.status_code == 200 + assert response.json() == { + "parent_id": "parent-1", + "messages": [ + { + "id": 3, + "trace_id": "trace-1", + "event_time": "2026-04-28T10:11:12+00:00", + "step": "process", + "status": "failed", + "level": "INFO", + "message": "done", + "attrs_json": {"batch": 7}, + } + ], + } + + +def test_trace_endpoint_validates_query_params() -> None: + client = _build_client(lambda _trace_id, _request: None) + try: + invalid_level = client.get("/traces/trace-1?levels=error,fatal") + invalid_format = client.get("/traces/trace-1?format=xml") + finally: + client.close() + + assert invalid_level.status_code == 400 + assert invalid_level.json() == {"status": "error", "detail": "unsupported trace levels: FATAL"} + assert invalid_format.status_code == 400 + assert invalid_format.json() == {"status": "error", "detail": "unsupported trace format: xml"} + + +def test_runtime_trace_logs_uses_configured_reader(monkeypatch) -> None: + expected = TraceLogView(parent_id="root", records=(_trace_record(row_id=1, level="ERROR", message="boom"),)) + + class StubReader: + def read_trace(self, trace_id: str, levels: tuple[str, ...]) -> TraceLogView | None: + assert trace_id == "trace-1" + assert levels == ("ERROR",) + return expected + + monkeypatch.setattr(runtime_module, "build_trace_log_reader", lambda _transport: StubReader()) + runtime = RuntimeManager() + + result = asyncio.run(runtime.trace_logs("trace-1", TraceQueryRequest(levels=("ERROR",)))) + + assert result == expected + + +def test_mysql_trace_log_reader_maps_db_rows() -> None: + class FakeCursor: + def __init__(self) -> None: + self.executed: list[tuple[str, tuple[object, ...]]] = [] + + def execute(self, query: str, params: tuple[object, ...]) -> None: + self.executed.append((query, params)) + + def fetchone(self) -> dict[str, object] | None: + return {"parent_id": "root-77"} + + def fetchall(self) -> list[dict[str, object]]: + return [ + { + "id": 8, + "trace_id": "trace-1", + "event_time": datetime(2026, 4, 28, 10, 11, 12, tzinfo=timezone.utc), + "step": "parse", + "status": "failed", + "level": "ERROR", + "message": "broken", + "attrs_json": '{"attempt":1}', + } + ] + + def __enter__(self) -> FakeCursor: + return self + + def __exit__(self, exc_type, exc, tb) -> None: + return None + + class FakeConnection: + def __init__(self, cursor: FakeCursor) -> None: + self._cursor = cursor + + def cursor(self) -> FakeCursor: + return self._cursor + + def __enter__(self) -> FakeConnection: + return self + + def __exit__(self, exc_type, exc, tb) -> None: + return None + + class FakeConnectionFactory: + def __init__(self) -> None: + self.cursor = FakeCursor() + + def connect(self) -> FakeConnection: + return FakeConnection(self.cursor) + + factory = FakeConnectionFactory() + reader = MySqlTraceLogReader(factory) # type: ignore[arg-type] + + view = reader.read_trace("trace-1", ("ERROR", "WARNING")) + + assert view == TraceLogView( + parent_id="root-77", + records=( + TraceLogRecord( + id=8, + trace_id="trace-1", + event_time=datetime(2026, 4, 28, 10, 11, 12, tzinfo=timezone.utc), + step="parse", + status="failed", + level="ERROR", + message="broken", + attrs_json={"attempt": 1}, + ), + ), + ) + assert len(factory.cursor.executed) == 2 + assert factory.cursor.executed[1][1] == ("trace-1", "ERROR", "WARNING")