Files
agent/app/modules/agent/service.py
T
2026-02-25 14:47:19 +03:00

297 lines
12 KiB
Python

from dataclasses import dataclass, field
from collections.abc import Awaitable, Callable
import inspect
import asyncio
import logging
import re
from app.modules.agent.engine.router import build_router_service
from app.modules.agent.engine.graphs.progress_registry import progress_registry
from app.modules.agent.llm import AgentLlmService
from app.modules.agent.changeset_validator import ChangeSetValidator
from app.modules.agent.confluence_service import ConfluenceService
from app.modules.agent.repository import AgentRepository
from app.modules.contracts import RagRetriever
from app.modules.shared.checkpointer import get_checkpointer
from app.schemas.changeset import ChangeItem
from app.schemas.chat import TaskResultType
from app.core.exceptions import AppError
from app.schemas.common import ModuleName
LOGGER = logging.getLogger(__name__)
@dataclass
class AgentResult:
result_type: TaskResultType
answer: str | None = None
changeset: list[ChangeItem] = field(default_factory=list)
meta: dict = field(default_factory=dict)
class GraphAgentRuntime:
def __init__(
self,
rag: RagRetriever,
confluence: ConfluenceService,
changeset_validator: ChangeSetValidator,
llm: AgentLlmService,
agent_repository: AgentRepository,
) -> None:
self._rag = rag
self._confluence = confluence
self._changeset_validator = changeset_validator
self._router = build_router_service(llm, agent_repository)
self._checkpointer = None
async def run(
self,
*,
task_id: str,
dialog_session_id: str,
rag_session_id: str,
mode: str,
message: str,
attachments: list[dict],
files: list[dict],
progress_cb: Callable[[str, str, str, dict | None], Awaitable[None] | None] | None = None,
) -> AgentResult:
LOGGER.warning(
"GraphAgentRuntime.run started: task_id=%s dialog_session_id=%s mode=%s",
task_id,
dialog_session_id,
mode,
)
await self._emit_progress(progress_cb, "agent.route", "Определяю тип запроса и подбираю граф.", meta={"mode": mode})
route = self._router.resolve(message, dialog_session_id, mode=mode)
await self._emit_progress(
progress_cb,
"agent.route.resolved",
"Маршрут выбран, готовлю контекст для выполнения.",
meta={"domain_id": route.domain_id, "process_id": route.process_id},
)
graph = self._resolve_graph(route.domain_id, route.process_id)
files_map = self._build_files_map(files)
await self._emit_progress(progress_cb, "agent.rag", "Собираю релевантный контекст из RAG.")
rag_ctx = await self._rag.retrieve(rag_session_id, message)
await self._emit_progress(progress_cb, "agent.attachments", "Обрабатываю дополнительные вложения.")
conf_pages = await self._fetch_confluence_pages(attachments)
state = {
"task_id": task_id,
"project_id": rag_session_id,
"message": message,
"progress_key": task_id,
"rag_context": self._format_rag(rag_ctx),
"confluence_context": self._format_confluence(conf_pages),
"files_map": files_map,
}
await self._emit_progress(progress_cb, "agent.graph", "Запускаю выполнение графа.")
if progress_cb is not None:
progress_registry.register(task_id, progress_cb)
try:
result = await asyncio.to_thread(
self._invoke_graph,
graph,
state,
dialog_session_id,
)
finally:
if progress_cb is not None:
progress_registry.unregister(task_id)
await self._emit_progress(progress_cb, "agent.graph.done", "Граф завершил обработку результата.")
answer = result.get("answer")
changeset = result.get("changeset") or []
if changeset:
await self._emit_progress(progress_cb, "agent.changeset", "Проверяю и валидирую предложенные изменения.")
changeset = self._enrich_changeset_hashes(changeset, files_map)
changeset = self._sanitize_changeset(changeset, files_map)
if not changeset:
final_answer = (answer or "").strip() or "Предложенные правки были отброшены как нерелевантные или косметические."
await self._emit_progress(progress_cb, "agent.answer", "После фильтрации правок формирую ответ без changeset.")
self._router.persist_context(
dialog_session_id,
domain_id=route.domain_id,
process_id=route.process_id,
user_message=message,
assistant_message=final_answer,
)
return AgentResult(
result_type=TaskResultType.ANSWER,
answer=final_answer,
meta={
"route": route.model_dump(),
"used_rag": True,
"used_confluence": bool(conf_pages),
"changeset_filtered_out": True,
},
)
validated = self._changeset_validator.validate(task_id, changeset)
final_answer = (answer or "").strip() or None
self._router.persist_context(
dialog_session_id,
domain_id=route.domain_id,
process_id=route.process_id,
user_message=message,
assistant_message=final_answer or f"changeset:{len(validated)}",
)
final = AgentResult(
result_type=TaskResultType.CHANGESET,
answer=final_answer,
changeset=validated,
meta={"route": route.model_dump(), "used_rag": True, "used_confluence": bool(conf_pages)},
)
LOGGER.warning(
"GraphAgentRuntime.run completed: task_id=%s route=%s/%s result_type=%s changeset_items=%s",
task_id,
route.domain_id,
route.process_id,
final.result_type.value,
len(final.changeset),
)
return final
final_answer = answer or ""
await self._emit_progress(progress_cb, "agent.answer", "Формирую финальный ответ.")
self._router.persist_context(
dialog_session_id,
domain_id=route.domain_id,
process_id=route.process_id,
user_message=message,
assistant_message=final_answer,
)
final = AgentResult(
result_type=TaskResultType.ANSWER,
answer=final_answer,
meta={"route": route.model_dump(), "used_rag": True, "used_confluence": bool(conf_pages)},
)
LOGGER.warning(
"GraphAgentRuntime.run completed: task_id=%s route=%s/%s result_type=%s answer_len=%s",
task_id,
route.domain_id,
route.process_id,
final.result_type.value,
len(final.answer or ""),
)
return final
async def _emit_progress(
self,
progress_cb: Callable[[str, str, str, dict | None], Awaitable[None] | None] | None,
stage: str,
message: str,
*,
kind: str = "task_progress",
meta: dict | None = None,
) -> None:
if progress_cb is None:
return
result = progress_cb(stage, message, kind, meta or {})
if inspect.isawaitable(result):
await result
def _resolve_graph(self, domain_id: str, process_id: str):
if self._checkpointer is None:
self._checkpointer = get_checkpointer()
factory = self._router.graph_factory(domain_id, process_id)
if factory is None:
factory = self._router.graph_factory("default", "general")
if factory is None:
raise RuntimeError("No graph factory configured")
LOGGER.warning("_resolve_graph resolved: domain_id=%s process_id=%s", domain_id, process_id)
return factory(self._checkpointer)
def _invoke_graph(self, graph, state: dict, dialog_session_id: str):
return graph.invoke(
state,
config={"configurable": {"thread_id": dialog_session_id}},
)
async def _fetch_confluence_pages(self, attachments: list[dict]) -> list[dict]:
pages: list[dict] = []
for item in attachments:
if item.get("type") == "confluence_url":
pages.append(await self._confluence.fetch_page(item["url"]))
LOGGER.warning("_fetch_confluence_pages completed: pages=%s", len(pages))
return pages
def _format_rag(self, items: list[dict]) -> str:
return "\n".join(str(x.get("content", "")) for x in items)
def _format_confluence(self, pages: list[dict]) -> str:
return "\n".join(str(x.get("content_markdown", "")) for x in pages)
def _build_files_map(self, files: list[dict]) -> dict[str, dict]:
output: dict[str, dict] = {}
for item in files:
path = str(item.get("path", "")).replace("\\", "/").strip()
if not path:
continue
output[path] = {
"path": path,
"content": str(item.get("content", "")),
"content_hash": str(item.get("content_hash", "")),
}
LOGGER.warning("_build_files_map completed: files=%s", len(output))
return output
def _lookup_file(self, files_map: dict[str, dict], path: str) -> dict | None:
normalized = (path or "").replace("\\", "/")
if normalized in files_map:
return files_map[normalized]
low = normalized.lower()
for key, value in files_map.items():
if key.lower() == low:
return value
return None
def _enrich_changeset_hashes(self, items: list[ChangeItem], files_map: dict[str, dict]) -> list[ChangeItem]:
enriched: list[ChangeItem] = []
for item in items:
if item.op.value == "update":
source = self._lookup_file(files_map, item.path)
if not source or not source.get("content_hash"):
raise AppError(
"missing_base_hash",
f"Cannot build update for {item.path}: no file hash in request context",
ModuleName.AGENT,
)
item.base_hash = str(source["content_hash"])
enriched.append(item)
LOGGER.warning("_enrich_changeset_hashes completed: items=%s", len(enriched))
return enriched
def _sanitize_changeset(self, items: list[ChangeItem], files_map: dict[str, dict]) -> list[ChangeItem]:
sanitized: list[ChangeItem] = []
dropped_noop = 0
dropped_ws = 0
for item in items:
if item.op.value != "update":
sanitized.append(item)
continue
source = self._lookup_file(files_map, item.path)
if not source:
sanitized.append(item)
continue
original = str(source.get("content", ""))
proposed = item.proposed_content or ""
if proposed == original:
dropped_noop += 1
continue
if self._collapse_whitespace(proposed) == self._collapse_whitespace(original):
dropped_ws += 1
continue
sanitized.append(item)
if dropped_noop or dropped_ws:
LOGGER.warning(
"_sanitize_changeset dropped items: noop=%s whitespace_only=%s kept=%s",
dropped_noop,
dropped_ws,
len(sanitized),
)
return sanitized
def _collapse_whitespace(self, text: str) -> str:
return re.sub(r"\s+", " ", (text or "").strip())