Files
agent/app/modules/chat/service.py
2026-02-25 14:47:19 +03:00

277 lines
12 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import asyncio
import logging
from app.core.exceptions import AppError
from app.modules.contracts import AgentRunner
from app.schemas.chat import ChatMessageRequest, TaskResultType, TaskStatus
from app.schemas.common import ErrorPayload, ModuleName
from app.modules.chat.dialog_store import DialogSessionStore
from app.modules.chat.task_store import TaskState, TaskStore
from app.modules.shared.event_bus import EventBus
from app.modules.shared.idempotency_store import IdempotencyStore
from app.modules.shared.retry_executor import RetryExecutor
LOGGER = logging.getLogger(__name__)
class ChatOrchestrator:
def __init__(
self,
task_store: TaskStore,
dialogs: DialogSessionStore,
idempotency: IdempotencyStore,
runtime: AgentRunner,
events: EventBus,
retry: RetryExecutor,
rag_session_exists,
message_sink,
) -> None:
self._task_store = task_store
self._dialogs = dialogs
self._idempotency = idempotency
self._runtime = runtime
self._events = events
self._retry = retry
self._rag_session_exists = rag_session_exists
self._message_sink = message_sink
async def enqueue_message(
self,
request: ChatMessageRequest,
idempotency_key: str | None,
) -> TaskState:
if idempotency_key:
existing = self._idempotency.get_task_id(idempotency_key)
if existing:
task = self._task_store.get(existing)
if task:
LOGGER.warning(
"enqueue_message reused task by idempotency key: task_id=%s mode=%s",
task.task_id,
request.mode.value,
)
return task
task = self._task_store.create()
if idempotency_key:
self._idempotency.put(idempotency_key, task.task_id)
asyncio.create_task(self._process_task(task.task_id, request))
LOGGER.warning(
"enqueue_message created task: task_id=%s mode=%s",
task.task_id,
request.mode.value,
)
return task
async def _process_task(self, task_id: str, request: ChatMessageRequest) -> None:
task = self._task_store.get(task_id)
if not task:
return
task.status = TaskStatus.RUNNING
self._task_store.save(task)
await self._events.publish(task_id, "task_status", {"task_id": task_id, "status": task.status.value})
await self._publish_progress(task_id, "task.start", "Запрос принят, начинаю обработку.", progress=5)
heartbeat_stop = asyncio.Event()
heartbeat_task = asyncio.create_task(self._run_heartbeat(task_id, heartbeat_stop))
try:
await self._publish_progress(task_id, "task.sessions", "Проверяю сессии диалога и проекта.", progress=10)
dialog_session_id, rag_session_id = self._resolve_sessions(request)
await self._publish_progress(task_id, "task.sessions.done", "Сессии проверены, запускаю агента.", progress=15)
loop = asyncio.get_running_loop()
def progress_cb(stage: str, message: str, kind: str = "task_progress", meta: dict | None = None):
asyncio.run_coroutine_threadsafe(
self._events.publish(
task_id,
kind,
{
"task_id": task_id,
"stage": stage,
"message": message,
"meta": meta or {},
},
),
loop,
)
async def op():
self._message_sink(dialog_session_id, "user", request.message, task_id=task_id)
await self._publish_progress(task_id, "task.agent.run", "Агент анализирует запрос и готовит ответ.", progress=20)
return await self._runtime.run(
task_id=task_id,
dialog_session_id=dialog_session_id,
rag_session_id=rag_session_id,
mode=request.mode.value,
message=request.message,
attachments=[a.model_dump(mode="json") for a in request.attachments],
files=[f.model_dump(mode="json") for f in request.files],
progress_cb=progress_cb,
)
result = await self._retry.run(op)
await self._publish_progress(task_id, "task.finalize", "Сохраняю финальный результат.", progress=95)
task.status = TaskStatus.DONE
task.result_type = TaskResultType(result.result_type)
task.answer = result.answer
task.changeset = result.changeset
if task.result_type == TaskResultType.ANSWER and task.answer:
self._message_sink(dialog_session_id, "assistant", task.answer, task_id=task_id)
elif task.result_type == TaskResultType.CHANGESET:
self._message_sink(
dialog_session_id,
"assistant",
f"changeset:{len(task.changeset)}",
task_id=task_id,
payload={
"result_type": TaskResultType.CHANGESET.value,
"changeset": [item.model_dump(mode="json") for item in task.changeset],
},
)
self._task_store.save(task)
await self._events.publish(
task_id,
"task_result",
{
"task_id": task_id,
"status": task.status.value,
"result_type": task.result_type.value,
"answer": task.answer,
"changeset": [item.model_dump(mode="json") for item in task.changeset],
"meta": getattr(result, "meta", {}) or {},
},
)
await self._publish_progress(task_id, "task.done", "Обработка завершена.", progress=100)
LOGGER.warning(
"_process_task completed: task_id=%s status=%s result_type=%s changeset_items=%s",
task_id,
task.status.value,
task.result_type.value if task.result_type else "",
len(task.changeset),
)
except (AppError, TimeoutError, ConnectionError, OSError) as exc:
task.status = TaskStatus.ERROR
if isinstance(exc, AppError):
payload = ErrorPayload(code=exc.code, desc=exc.desc, module=exc.module)
else:
payload = ErrorPayload(
code="retry_exhausted",
desc="Temporary failure after retries. Please retry request.",
module=ModuleName.BACKEND,
)
task.error = payload
self._task_store.save(task)
await self._publish_progress(task_id, "task.error", "Не удалось завершить обработку запроса.", kind="task_thinking")
await self._events.publish(task_id, "task_error", payload.model_dump(mode="json"))
LOGGER.warning(
"_process_task handled error: task_id=%s code=%s module=%s desc=%s",
task_id,
payload.code,
payload.module.value,
payload.desc,
)
except Exception:
task.status = TaskStatus.ERROR
payload = ErrorPayload(
code="agent_runtime_error",
desc="Agent execution failed unexpectedly. Please retry request.",
module=ModuleName.AGENT,
)
task.error = payload
self._task_store.save(task)
await self._publish_progress(
task_id,
"task.error",
"Во время выполнения возникла внутренняя ошибка.",
kind="task_thinking",
)
await self._events.publish(task_id, "task_error", payload.model_dump(mode="json"))
LOGGER.exception(
"_process_task unexpected error: task_id=%s code=%s",
task_id,
payload.code,
)
finally:
heartbeat_stop.set()
await heartbeat_task
async def _publish_progress(
self,
task_id: str,
stage: str,
message: str,
*,
progress: int | None = None,
kind: str = "task_progress",
meta: dict | None = None,
) -> None:
payload = {
"task_id": task_id,
"stage": stage,
"message": message,
"meta": meta or {},
}
if progress is not None:
payload["progress"] = max(0, min(100, int(progress)))
await self._events.publish(task_id, kind, payload)
LOGGER.warning(
"_publish_progress emitted: task_id=%s kind=%s stage=%s progress=%s",
task_id,
kind,
stage,
payload.get("progress"),
)
async def _run_heartbeat(self, task_id: str, stop_event: asyncio.Event) -> None:
messages = (
"Собираю данные по проекту.",
"Анализирую контекст и формирую структуру ответа.",
"Проверяю согласованность промежуточного результата.",
)
index = 0
while not stop_event.is_set():
try:
await asyncio.wait_for(stop_event.wait(), timeout=5.0)
except asyncio.TimeoutError:
await self._publish_progress(
task_id,
"task.heartbeat",
messages[index % len(messages)],
kind="task_thinking",
meta={"heartbeat": True},
)
index += 1
LOGGER.warning("_run_heartbeat stopped: task_id=%s ticks=%s", task_id, index)
def _resolve_sessions(self, request: ChatMessageRequest) -> tuple[str, str]:
# Legacy compatibility: old session_id/project_id flow.
if request.dialog_session_id and request.rag_session_id:
dialog = self._dialogs.get(request.dialog_session_id)
if not dialog:
raise AppError("dialog_not_found", "Dialog session not found", ModuleName.BACKEND)
if dialog.rag_session_id != request.rag_session_id:
raise AppError("dialog_rag_mismatch", "Dialog session does not belong to rag session", ModuleName.BACKEND)
LOGGER.warning(
"_resolve_sessions resolved by dialog_session_id: dialog_session_id=%s rag_session_id=%s",
request.dialog_session_id,
request.rag_session_id,
)
return request.dialog_session_id, request.rag_session_id
if request.session_id and request.project_id:
if not self._rag_session_exists(request.project_id):
raise AppError("rag_session_not_found", "RAG session not found", ModuleName.RAG)
LOGGER.warning(
"_resolve_sessions resolved by legacy session/project: session_id=%s project_id=%s",
request.session_id,
request.project_id,
)
return request.session_id, request.project_id
raise AppError(
"missing_sessions",
"dialog_session_id and rag_session_id are required",
ModuleName.BACKEND,
)