297 lines
12 KiB
Python
297 lines
12 KiB
Python
from dataclasses import dataclass, field
|
|
from collections.abc import Awaitable, Callable
|
|
import inspect
|
|
import asyncio
|
|
import logging
|
|
import re
|
|
|
|
from app.modules.agent.engine.router import build_router_service
|
|
from app.modules.agent.engine.graphs.progress_registry import progress_registry
|
|
from app.modules.agent.llm import AgentLlmService
|
|
from app.modules.agent.changeset_validator import ChangeSetValidator
|
|
from app.modules.agent.confluence_service import ConfluenceService
|
|
from app.modules.agent.repository import AgentRepository
|
|
from app.modules.contracts import RagRetriever
|
|
from app.modules.shared.checkpointer import get_checkpointer
|
|
from app.schemas.changeset import ChangeItem
|
|
from app.schemas.chat import TaskResultType
|
|
from app.core.exceptions import AppError
|
|
from app.schemas.common import ModuleName
|
|
|
|
LOGGER = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class AgentResult:
|
|
result_type: TaskResultType
|
|
answer: str | None = None
|
|
changeset: list[ChangeItem] = field(default_factory=list)
|
|
meta: dict = field(default_factory=dict)
|
|
|
|
|
|
class GraphAgentRuntime:
|
|
def __init__(
|
|
self,
|
|
rag: RagRetriever,
|
|
confluence: ConfluenceService,
|
|
changeset_validator: ChangeSetValidator,
|
|
llm: AgentLlmService,
|
|
agent_repository: AgentRepository,
|
|
) -> None:
|
|
self._rag = rag
|
|
self._confluence = confluence
|
|
self._changeset_validator = changeset_validator
|
|
self._router = build_router_service(llm, agent_repository)
|
|
self._checkpointer = None
|
|
|
|
async def run(
|
|
self,
|
|
*,
|
|
task_id: str,
|
|
dialog_session_id: str,
|
|
rag_session_id: str,
|
|
mode: str,
|
|
message: str,
|
|
attachments: list[dict],
|
|
files: list[dict],
|
|
progress_cb: Callable[[str, str, str, dict | None], Awaitable[None] | None] | None = None,
|
|
) -> AgentResult:
|
|
LOGGER.warning(
|
|
"GraphAgentRuntime.run started: task_id=%s dialog_session_id=%s mode=%s",
|
|
task_id,
|
|
dialog_session_id,
|
|
mode,
|
|
)
|
|
await self._emit_progress(progress_cb, "agent.route", "Определяю тип запроса и подбираю граф.", meta={"mode": mode})
|
|
route = self._router.resolve(message, dialog_session_id, mode=mode)
|
|
await self._emit_progress(
|
|
progress_cb,
|
|
"agent.route.resolved",
|
|
"Маршрут выбран, готовлю контекст для выполнения.",
|
|
meta={"domain_id": route.domain_id, "process_id": route.process_id},
|
|
)
|
|
graph = self._resolve_graph(route.domain_id, route.process_id)
|
|
files_map = self._build_files_map(files)
|
|
|
|
await self._emit_progress(progress_cb, "agent.rag", "Собираю релевантный контекст из RAG.")
|
|
rag_ctx = await self._rag.retrieve(rag_session_id, message)
|
|
await self._emit_progress(progress_cb, "agent.attachments", "Обрабатываю дополнительные вложения.")
|
|
conf_pages = await self._fetch_confluence_pages(attachments)
|
|
state = {
|
|
"task_id": task_id,
|
|
"project_id": rag_session_id,
|
|
"message": message,
|
|
"progress_key": task_id,
|
|
"rag_context": self._format_rag(rag_ctx),
|
|
"confluence_context": self._format_confluence(conf_pages),
|
|
"files_map": files_map,
|
|
}
|
|
|
|
await self._emit_progress(progress_cb, "agent.graph", "Запускаю выполнение графа.")
|
|
if progress_cb is not None:
|
|
progress_registry.register(task_id, progress_cb)
|
|
try:
|
|
result = await asyncio.to_thread(
|
|
self._invoke_graph,
|
|
graph,
|
|
state,
|
|
dialog_session_id,
|
|
)
|
|
finally:
|
|
if progress_cb is not None:
|
|
progress_registry.unregister(task_id)
|
|
await self._emit_progress(progress_cb, "agent.graph.done", "Граф завершил обработку результата.")
|
|
answer = result.get("answer")
|
|
changeset = result.get("changeset") or []
|
|
if changeset:
|
|
await self._emit_progress(progress_cb, "agent.changeset", "Проверяю и валидирую предложенные изменения.")
|
|
changeset = self._enrich_changeset_hashes(changeset, files_map)
|
|
changeset = self._sanitize_changeset(changeset, files_map)
|
|
if not changeset:
|
|
final_answer = (answer or "").strip() or "Предложенные правки были отброшены как нерелевантные или косметические."
|
|
await self._emit_progress(progress_cb, "agent.answer", "После фильтрации правок формирую ответ без changeset.")
|
|
self._router.persist_context(
|
|
dialog_session_id,
|
|
domain_id=route.domain_id,
|
|
process_id=route.process_id,
|
|
user_message=message,
|
|
assistant_message=final_answer,
|
|
)
|
|
return AgentResult(
|
|
result_type=TaskResultType.ANSWER,
|
|
answer=final_answer,
|
|
meta={
|
|
"route": route.model_dump(),
|
|
"used_rag": True,
|
|
"used_confluence": bool(conf_pages),
|
|
"changeset_filtered_out": True,
|
|
},
|
|
)
|
|
validated = self._changeset_validator.validate(task_id, changeset)
|
|
final_answer = (answer or "").strip() or None
|
|
self._router.persist_context(
|
|
dialog_session_id,
|
|
domain_id=route.domain_id,
|
|
process_id=route.process_id,
|
|
user_message=message,
|
|
assistant_message=final_answer or f"changeset:{len(validated)}",
|
|
)
|
|
final = AgentResult(
|
|
result_type=TaskResultType.CHANGESET,
|
|
answer=final_answer,
|
|
changeset=validated,
|
|
meta={"route": route.model_dump(), "used_rag": True, "used_confluence": bool(conf_pages)},
|
|
)
|
|
LOGGER.warning(
|
|
"GraphAgentRuntime.run completed: task_id=%s route=%s/%s result_type=%s changeset_items=%s",
|
|
task_id,
|
|
route.domain_id,
|
|
route.process_id,
|
|
final.result_type.value,
|
|
len(final.changeset),
|
|
)
|
|
return final
|
|
|
|
final_answer = answer or ""
|
|
await self._emit_progress(progress_cb, "agent.answer", "Формирую финальный ответ.")
|
|
self._router.persist_context(
|
|
dialog_session_id,
|
|
domain_id=route.domain_id,
|
|
process_id=route.process_id,
|
|
user_message=message,
|
|
assistant_message=final_answer,
|
|
)
|
|
final = AgentResult(
|
|
result_type=TaskResultType.ANSWER,
|
|
answer=final_answer,
|
|
meta={"route": route.model_dump(), "used_rag": True, "used_confluence": bool(conf_pages)},
|
|
)
|
|
LOGGER.warning(
|
|
"GraphAgentRuntime.run completed: task_id=%s route=%s/%s result_type=%s answer_len=%s",
|
|
task_id,
|
|
route.domain_id,
|
|
route.process_id,
|
|
final.result_type.value,
|
|
len(final.answer or ""),
|
|
)
|
|
return final
|
|
|
|
async def _emit_progress(
|
|
self,
|
|
progress_cb: Callable[[str, str, str, dict | None], Awaitable[None] | None] | None,
|
|
stage: str,
|
|
message: str,
|
|
*,
|
|
kind: str = "task_progress",
|
|
meta: dict | None = None,
|
|
) -> None:
|
|
if progress_cb is None:
|
|
return
|
|
result = progress_cb(stage, message, kind, meta or {})
|
|
if inspect.isawaitable(result):
|
|
await result
|
|
|
|
def _resolve_graph(self, domain_id: str, process_id: str):
|
|
if self._checkpointer is None:
|
|
self._checkpointer = get_checkpointer()
|
|
factory = self._router.graph_factory(domain_id, process_id)
|
|
if factory is None:
|
|
factory = self._router.graph_factory("default", "general")
|
|
if factory is None:
|
|
raise RuntimeError("No graph factory configured")
|
|
LOGGER.warning("_resolve_graph resolved: domain_id=%s process_id=%s", domain_id, process_id)
|
|
return factory(self._checkpointer)
|
|
|
|
def _invoke_graph(self, graph, state: dict, dialog_session_id: str):
|
|
return graph.invoke(
|
|
state,
|
|
config={"configurable": {"thread_id": dialog_session_id}},
|
|
)
|
|
|
|
async def _fetch_confluence_pages(self, attachments: list[dict]) -> list[dict]:
|
|
pages: list[dict] = []
|
|
for item in attachments:
|
|
if item.get("type") == "confluence_url":
|
|
pages.append(await self._confluence.fetch_page(item["url"]))
|
|
LOGGER.warning("_fetch_confluence_pages completed: pages=%s", len(pages))
|
|
return pages
|
|
|
|
def _format_rag(self, items: list[dict]) -> str:
|
|
return "\n".join(str(x.get("content", "")) for x in items)
|
|
|
|
def _format_confluence(self, pages: list[dict]) -> str:
|
|
return "\n".join(str(x.get("content_markdown", "")) for x in pages)
|
|
|
|
def _build_files_map(self, files: list[dict]) -> dict[str, dict]:
|
|
output: dict[str, dict] = {}
|
|
for item in files:
|
|
path = str(item.get("path", "")).replace("\\", "/").strip()
|
|
if not path:
|
|
continue
|
|
output[path] = {
|
|
"path": path,
|
|
"content": str(item.get("content", "")),
|
|
"content_hash": str(item.get("content_hash", "")),
|
|
}
|
|
LOGGER.warning("_build_files_map completed: files=%s", len(output))
|
|
return output
|
|
|
|
def _lookup_file(self, files_map: dict[str, dict], path: str) -> dict | None:
|
|
normalized = (path or "").replace("\\", "/")
|
|
if normalized in files_map:
|
|
return files_map[normalized]
|
|
low = normalized.lower()
|
|
for key, value in files_map.items():
|
|
if key.lower() == low:
|
|
return value
|
|
return None
|
|
|
|
def _enrich_changeset_hashes(self, items: list[ChangeItem], files_map: dict[str, dict]) -> list[ChangeItem]:
|
|
enriched: list[ChangeItem] = []
|
|
for item in items:
|
|
if item.op.value == "update":
|
|
source = self._lookup_file(files_map, item.path)
|
|
if not source or not source.get("content_hash"):
|
|
raise AppError(
|
|
"missing_base_hash",
|
|
f"Cannot build update for {item.path}: no file hash in request context",
|
|
ModuleName.AGENT,
|
|
)
|
|
item.base_hash = str(source["content_hash"])
|
|
enriched.append(item)
|
|
LOGGER.warning("_enrich_changeset_hashes completed: items=%s", len(enriched))
|
|
return enriched
|
|
|
|
def _sanitize_changeset(self, items: list[ChangeItem], files_map: dict[str, dict]) -> list[ChangeItem]:
|
|
sanitized: list[ChangeItem] = []
|
|
dropped_noop = 0
|
|
dropped_ws = 0
|
|
for item in items:
|
|
if item.op.value != "update":
|
|
sanitized.append(item)
|
|
continue
|
|
source = self._lookup_file(files_map, item.path)
|
|
if not source:
|
|
sanitized.append(item)
|
|
continue
|
|
original = str(source.get("content", ""))
|
|
proposed = item.proposed_content or ""
|
|
if proposed == original:
|
|
dropped_noop += 1
|
|
continue
|
|
if self._collapse_whitespace(proposed) == self._collapse_whitespace(original):
|
|
dropped_ws += 1
|
|
continue
|
|
sanitized.append(item)
|
|
if dropped_noop or dropped_ws:
|
|
LOGGER.warning(
|
|
"_sanitize_changeset dropped items: noop=%s whitespace_only=%s kept=%s",
|
|
dropped_noop,
|
|
dropped_ws,
|
|
len(sanitized),
|
|
)
|
|
return sanitized
|
|
|
|
def _collapse_whitespace(self, text: str) -> str:
|
|
return re.sub(r"\s+", " ", (text or "").strip())
|