294 lines
12 KiB
Python
294 lines
12 KiB
Python
import asyncio
|
||
import logging
|
||
|
||
from app.core.exceptions import AppError
|
||
from app.modules.contracts import AgentRunner
|
||
from app.schemas.chat import ChatMessageRequest, TaskResultType, TaskStatus
|
||
from app.schemas.common import ErrorPayload, ModuleName
|
||
from app.modules.chat.dialog_store import DialogSessionStore
|
||
from app.modules.chat.task_store import TaskState, TaskStore
|
||
from app.modules.shared.event_bus import EventBus
|
||
from app.modules.shared.idempotency_store import IdempotencyStore
|
||
from app.modules.shared.retry_executor import RetryExecutor
|
||
|
||
LOGGER = logging.getLogger(__name__)
|
||
|
||
|
||
def _truncate_for_log(text: str, max_chars: int = 1200) -> str:
|
||
value = (text or "").replace("\n", "\\n").strip()
|
||
if len(value) <= max_chars:
|
||
return value
|
||
return value[:max_chars].rstrip() + "...[truncated]"
|
||
|
||
|
||
class ChatOrchestrator:
|
||
def __init__(
|
||
self,
|
||
task_store: TaskStore,
|
||
dialogs: DialogSessionStore,
|
||
idempotency: IdempotencyStore,
|
||
runtime: AgentRunner,
|
||
events: EventBus,
|
||
retry: RetryExecutor,
|
||
rag_session_exists,
|
||
message_sink,
|
||
) -> None:
|
||
self._task_store = task_store
|
||
self._dialogs = dialogs
|
||
self._idempotency = idempotency
|
||
self._runtime = runtime
|
||
self._events = events
|
||
self._retry = retry
|
||
self._rag_session_exists = rag_session_exists
|
||
self._message_sink = message_sink
|
||
|
||
async def enqueue_message(
|
||
self,
|
||
request: ChatMessageRequest,
|
||
idempotency_key: str | None,
|
||
) -> TaskState:
|
||
if idempotency_key:
|
||
existing = self._idempotency.get_task_id(idempotency_key)
|
||
if existing:
|
||
task = self._task_store.get(existing)
|
||
if task:
|
||
LOGGER.warning(
|
||
"enqueue_message reused task by idempotency key: task_id=%s mode=%s",
|
||
task.task_id,
|
||
request.mode.value,
|
||
)
|
||
return task
|
||
|
||
task = self._task_store.create()
|
||
if idempotency_key:
|
||
self._idempotency.put(idempotency_key, task.task_id)
|
||
asyncio.create_task(self._process_task(task.task_id, request))
|
||
LOGGER.warning(
|
||
"enqueue_message created task: task_id=%s mode=%s",
|
||
task.task_id,
|
||
request.mode.value,
|
||
)
|
||
return task
|
||
|
||
async def _process_task(self, task_id: str, request: ChatMessageRequest) -> None:
|
||
task = self._task_store.get(task_id)
|
||
if not task:
|
||
return
|
||
task.status = TaskStatus.RUNNING
|
||
self._task_store.save(task)
|
||
await self._events.publish(task_id, "task_status", {"task_id": task_id, "status": task.status.value})
|
||
await self._publish_progress(task_id, "task.start", "Запрос принят, начинаю обработку.", progress=5)
|
||
|
||
heartbeat_stop = asyncio.Event()
|
||
heartbeat_task = asyncio.create_task(self._run_heartbeat(task_id, heartbeat_stop))
|
||
|
||
try:
|
||
await self._publish_progress(task_id, "task.sessions", "Проверяю сессии диалога и проекта.", progress=10)
|
||
dialog_session_id, rag_session_id = self._resolve_sessions(request)
|
||
LOGGER.warning(
|
||
"incoming chat request: task_id=%s dialog_session_id=%s rag_session_id=%s mode=%s attachments=%s files=%s message=%s",
|
||
task_id,
|
||
dialog_session_id,
|
||
rag_session_id,
|
||
request.mode.value,
|
||
len(request.attachments),
|
||
len(request.files),
|
||
_truncate_for_log(request.message),
|
||
)
|
||
await self._publish_progress(task_id, "task.sessions.done", "Сессии проверены, запускаю агента.", progress=15)
|
||
loop = asyncio.get_running_loop()
|
||
|
||
def progress_cb(stage: str, message: str, kind: str = "task_progress", meta: dict | None = None):
|
||
asyncio.run_coroutine_threadsafe(
|
||
self._events.publish(
|
||
task_id,
|
||
kind,
|
||
{
|
||
"task_id": task_id,
|
||
"stage": stage,
|
||
"message": message,
|
||
"meta": meta or {},
|
||
},
|
||
),
|
||
loop,
|
||
)
|
||
|
||
async def op():
|
||
self._message_sink(dialog_session_id, "user", request.message, task_id=task_id)
|
||
await self._publish_progress(task_id, "task.agent.run", "Агент анализирует запрос и готовит ответ.", progress=20)
|
||
return await self._runtime.run(
|
||
task_id=task_id,
|
||
dialog_session_id=dialog_session_id,
|
||
rag_session_id=rag_session_id,
|
||
mode=request.mode.value,
|
||
message=request.message,
|
||
attachments=[a.model_dump(mode="json") for a in request.attachments],
|
||
files=[f.model_dump(mode="json") for f in request.files],
|
||
progress_cb=progress_cb,
|
||
)
|
||
|
||
result = await self._retry.run(op)
|
||
await self._publish_progress(task_id, "task.finalize", "Сохраняю финальный результат.", progress=95)
|
||
task.status = TaskStatus.DONE
|
||
task.result_type = TaskResultType(result.result_type)
|
||
task.answer = result.answer
|
||
task.changeset = result.changeset
|
||
if task.result_type == TaskResultType.ANSWER and task.answer:
|
||
self._message_sink(dialog_session_id, "assistant", task.answer, task_id=task_id)
|
||
elif task.result_type == TaskResultType.CHANGESET:
|
||
self._message_sink(
|
||
dialog_session_id,
|
||
"assistant",
|
||
f"changeset:{len(task.changeset)}",
|
||
task_id=task_id,
|
||
payload={
|
||
"result_type": TaskResultType.CHANGESET.value,
|
||
"changeset": [item.model_dump(mode="json") for item in task.changeset],
|
||
},
|
||
)
|
||
self._task_store.save(task)
|
||
await self._events.publish(
|
||
task_id,
|
||
"task_result",
|
||
{
|
||
"task_id": task_id,
|
||
"status": task.status.value,
|
||
"result_type": task.result_type.value,
|
||
"answer": task.answer,
|
||
"changeset": [item.model_dump(mode="json") for item in task.changeset],
|
||
"meta": getattr(result, "meta", {}) or {},
|
||
},
|
||
)
|
||
await self._publish_progress(task_id, "task.done", "Обработка завершена.", progress=100)
|
||
LOGGER.warning(
|
||
"_process_task completed: task_id=%s status=%s result_type=%s changeset_items=%s",
|
||
task_id,
|
||
task.status.value,
|
||
task.result_type.value if task.result_type else "",
|
||
len(task.changeset),
|
||
)
|
||
except (AppError, TimeoutError, ConnectionError, OSError) as exc:
|
||
task.status = TaskStatus.ERROR
|
||
if isinstance(exc, AppError):
|
||
payload = ErrorPayload(code=exc.code, desc=exc.desc, module=exc.module)
|
||
else:
|
||
payload = ErrorPayload(
|
||
code="retry_exhausted",
|
||
desc="Temporary failure after retries. Please retry request.",
|
||
module=ModuleName.BACKEND,
|
||
)
|
||
task.error = payload
|
||
self._task_store.save(task)
|
||
await self._publish_progress(task_id, "task.error", "Не удалось завершить обработку запроса.", kind="task_thinking")
|
||
await self._events.publish(task_id, "task_error", payload.model_dump(mode="json"))
|
||
LOGGER.warning(
|
||
"_process_task handled error: task_id=%s code=%s module=%s desc=%s",
|
||
task_id,
|
||
payload.code,
|
||
payload.module.value,
|
||
payload.desc,
|
||
)
|
||
except Exception:
|
||
task.status = TaskStatus.ERROR
|
||
payload = ErrorPayload(
|
||
code="agent_runtime_error",
|
||
desc="Agent execution failed unexpectedly. Please retry request.",
|
||
module=ModuleName.AGENT,
|
||
)
|
||
task.error = payload
|
||
self._task_store.save(task)
|
||
await self._publish_progress(
|
||
task_id,
|
||
"task.error",
|
||
"Во время выполнения возникла внутренняя ошибка.",
|
||
kind="task_thinking",
|
||
)
|
||
await self._events.publish(task_id, "task_error", payload.model_dump(mode="json"))
|
||
LOGGER.exception(
|
||
"_process_task unexpected error: task_id=%s code=%s",
|
||
task_id,
|
||
payload.code,
|
||
)
|
||
finally:
|
||
heartbeat_stop.set()
|
||
await heartbeat_task
|
||
|
||
async def _publish_progress(
|
||
self,
|
||
task_id: str,
|
||
stage: str,
|
||
message: str,
|
||
*,
|
||
progress: int | None = None,
|
||
kind: str = "task_progress",
|
||
meta: dict | None = None,
|
||
) -> None:
|
||
payload = {
|
||
"task_id": task_id,
|
||
"stage": stage,
|
||
"message": message,
|
||
"meta": meta or {},
|
||
}
|
||
if progress is not None:
|
||
payload["progress"] = max(0, min(100, int(progress)))
|
||
await self._events.publish(task_id, kind, payload)
|
||
LOGGER.warning(
|
||
"_publish_progress emitted: task_id=%s kind=%s stage=%s progress=%s",
|
||
task_id,
|
||
kind,
|
||
stage,
|
||
payload.get("progress"),
|
||
)
|
||
|
||
async def _run_heartbeat(self, task_id: str, stop_event: asyncio.Event) -> None:
|
||
messages = (
|
||
"Собираю данные по проекту.",
|
||
"Анализирую контекст и формирую структуру ответа.",
|
||
"Проверяю согласованность промежуточного результата.",
|
||
)
|
||
index = 0
|
||
while not stop_event.is_set():
|
||
try:
|
||
await asyncio.wait_for(stop_event.wait(), timeout=5.0)
|
||
except asyncio.TimeoutError:
|
||
await self._publish_progress(
|
||
task_id,
|
||
"task.heartbeat",
|
||
messages[index % len(messages)],
|
||
kind="task_thinking",
|
||
meta={"heartbeat": True},
|
||
)
|
||
index += 1
|
||
LOGGER.warning("_run_heartbeat stopped: task_id=%s ticks=%s", task_id, index)
|
||
|
||
def _resolve_sessions(self, request: ChatMessageRequest) -> tuple[str, str]:
|
||
# Legacy compatibility: old session_id/project_id flow.
|
||
if request.dialog_session_id and request.rag_session_id:
|
||
dialog = self._dialogs.get(request.dialog_session_id)
|
||
if not dialog:
|
||
raise AppError("dialog_not_found", "Dialog session not found", ModuleName.BACKEND)
|
||
if dialog.rag_session_id != request.rag_session_id:
|
||
raise AppError("dialog_rag_mismatch", "Dialog session does not belong to rag session", ModuleName.BACKEND)
|
||
LOGGER.warning(
|
||
"_resolve_sessions resolved by dialog_session_id: dialog_session_id=%s rag_session_id=%s",
|
||
request.dialog_session_id,
|
||
request.rag_session_id,
|
||
)
|
||
return request.dialog_session_id, request.rag_session_id
|
||
|
||
if request.session_id and request.project_id:
|
||
if not self._rag_session_exists(request.project_id):
|
||
raise AppError("rag_session_not_found", "RAG session not found", ModuleName.RAG)
|
||
LOGGER.warning(
|
||
"_resolve_sessions resolved by legacy session/project: session_id=%s project_id=%s",
|
||
request.session_id,
|
||
request.project_id,
|
||
)
|
||
return request.session_id, request.project_id
|
||
|
||
raise AppError(
|
||
"missing_sessions",
|
||
"dialog_session_id and rag_session_id are required",
|
||
ModuleName.BACKEND,
|
||
)
|