Files
agent/app/modules/agent/service.py

484 lines
19 KiB
Python

from __future__ import annotations
from dataclasses import dataclass, field
from collections.abc import Awaitable, Callable
import inspect
import logging
import re
from typing import TYPE_CHECKING
from app.modules.agent.engine.orchestrator import OrchestratorService, TaskSpecBuilder
from app.modules.agent.engine.orchestrator.metrics_persister import MetricsPersister
from app.modules.agent.engine.orchestrator.models import RoutingMeta
from app.modules.agent.engine.orchestrator.step_registry import StepRegistry
from app.modules.agent.engine.router import build_router_service
from app.modules.agent.llm import AgentLlmService
from app.modules.agent.story_session_recorder import StorySessionRecorder
from app.modules.agent.changeset_validator import ChangeSetValidator
from app.modules.agent.confluence_service import ConfluenceService
from app.modules.agent.repository import AgentRepository
from app.modules.contracts import RagRetriever
from app.modules.shared.checkpointer import get_checkpointer
from app.schemas.changeset import ChangeItem
from app.schemas.chat import TaskResultType
from app.core.exceptions import AppError
from app.schemas.common import ModuleName
LOGGER = logging.getLogger(__name__)
if TYPE_CHECKING:
from app.modules.rag.explain.retriever_v2 import CodeExplainRetrieverV2
def _truncate_for_log(text: str | None, max_chars: int = 1500) -> str:
value = (text or "").replace("\n", "\\n").strip()
if len(value) <= max_chars:
return value
return value[:max_chars].rstrip() + "...[truncated]"
@dataclass
class AgentResult:
result_type: TaskResultType
answer: str | None = None
changeset: list[ChangeItem] = field(default_factory=list)
meta: dict = field(default_factory=dict)
class GraphAgentRuntime:
def __init__(
self,
rag: RagRetriever,
confluence: ConfluenceService,
changeset_validator: ChangeSetValidator,
llm: AgentLlmService,
agent_repository: AgentRepository,
story_recorder: StorySessionRecorder | None = None,
code_explain_retriever: CodeExplainRetrieverV2 | None = None,
) -> None:
self._rag = rag
self._confluence = confluence
self._changeset_validator = changeset_validator
self._router = build_router_service(llm, agent_repository, rag)
self._task_spec_builder = TaskSpecBuilder()
self._orchestrator = OrchestratorService(step_registry=StepRegistry(code_explain_retriever))
self._metrics_persister = MetricsPersister(agent_repository)
self._story_recorder = story_recorder
self._checkpointer = None
async def run(
self,
*,
task_id: str,
dialog_session_id: str,
rag_session_id: str,
mode: str,
message: str,
attachments: list[dict],
files: list[dict],
progress_cb: Callable[[str, str, str, dict | None], Awaitable[None] | None] | None = None,
) -> AgentResult:
LOGGER.info(
"GraphAgentRuntime.run started: task_id=%s dialog_session_id=%s mode=%s",
task_id,
dialog_session_id,
mode,
)
await self._emit_progress(progress_cb, "agent.route", "Определяю тип запроса и подбираю граф.", meta={"mode": mode})
route = self._router.resolve(message, dialog_session_id, mode=mode)
LOGGER.warning(
"router decision: task_id=%s dialog_session_id=%s mode=%s route=%s/%s reason=%s confidence=%s fallback_used=%s",
task_id,
dialog_session_id,
mode,
route.domain_id,
route.process_id,
route.reason,
route.confidence,
route.fallback_used,
)
await self._emit_progress(
progress_cb,
"agent.route.resolved",
"Маршрут выбран, готовлю контекст для выполнения.",
meta={"domain_id": route.domain_id, "process_id": route.process_id},
)
files_map = self._build_files_map(files)
rag_ctx: list[dict] = []
await self._emit_progress(progress_cb, "agent.attachments", "Обрабатываю дополнительные вложения.")
conf_pages = await self._fetch_confluence_pages(attachments)
route_meta = RoutingMeta(
domain_id=route.domain_id,
process_id=route.process_id,
confidence=route.confidence,
reason=route.reason,
fallback_used=route.fallback_used,
)
task_spec = self._task_spec_builder.build(
task_id=task_id,
dialog_session_id=dialog_session_id,
rag_session_id=rag_session_id,
mode=mode,
message=message,
route=route_meta,
attachments=attachments,
files=files,
rag_items=rag_ctx,
rag_context=self._format_rag(rag_ctx),
confluence_context=self._format_confluence(conf_pages),
files_map=files_map,
)
await self._emit_progress(progress_cb, "agent.orchestrator", "Строю и выполняю план оркестрации.")
orchestrator_result = await self._orchestrator.run(
task=task_spec,
graph_resolver=self._resolve_graph,
graph_invoker=self._invoke_graph,
progress_cb=progress_cb,
)
await self._emit_progress(progress_cb, "agent.orchestrator.done", "Оркестратор завершил выполнение плана.")
answer = orchestrator_result.answer
changeset = orchestrator_result.changeset or []
orchestrator_meta = orchestrator_result.meta or {}
quality_meta = self._extract_quality_meta(orchestrator_meta)
orchestrator_steps = [item.model_dump(mode="json") for item in orchestrator_result.steps]
self._record_session_story_artifacts(
dialog_session_id=dialog_session_id,
rag_session_id=rag_session_id,
scenario=str(orchestrator_meta.get("scenario", task_spec.scenario.value)),
attachments=[a.model_dump(mode="json") for a in task_spec.attachments],
answer=answer,
changeset=changeset,
)
if changeset:
await self._emit_progress(progress_cb, "agent.changeset", "Проверяю и валидирую предложенные изменения.")
changeset = self._enrich_changeset_hashes(changeset, files_map)
changeset = self._sanitize_changeset(changeset, files_map)
if not changeset:
final_answer = (answer or "").strip() or "Предложенные правки были отброшены как нерелевантные или косметические."
await self._emit_progress(progress_cb, "agent.answer", "После фильтрации правок формирую ответ без changeset.")
self._router.persist_context(
dialog_session_id,
domain_id=route.domain_id,
process_id=route.process_id,
user_message=message,
assistant_message=final_answer,
decision_type=route.decision_type,
)
LOGGER.info(
"final agent answer: task_id=%s route=%s/%s answer=%s",
task_id,
route.domain_id,
route.process_id,
_truncate_for_log(final_answer),
)
self._persist_quality_metrics(
task_id=task_id,
dialog_session_id=dialog_session_id,
rag_session_id=rag_session_id,
route=route,
scenario=str(orchestrator_meta.get("scenario", task_spec.scenario.value)),
quality=quality_meta,
)
return AgentResult(
result_type=TaskResultType.ANSWER,
answer=final_answer,
meta={
"route": route.model_dump(),
"used_rag": False,
"used_confluence": bool(conf_pages),
"changeset_filtered_out": True,
"orchestrator": orchestrator_meta,
"orchestrator_steps": orchestrator_steps,
},
)
validated = self._changeset_validator.validate(task_id, changeset)
final_answer = (answer or "").strip() or None
self._router.persist_context(
dialog_session_id,
domain_id=route.domain_id,
process_id=route.process_id,
user_message=message,
assistant_message=final_answer or f"changeset:{len(validated)}",
decision_type=route.decision_type,
)
final = AgentResult(
result_type=TaskResultType.CHANGESET,
answer=final_answer,
changeset=validated,
meta={
"route": route.model_dump(),
"used_rag": False,
"used_confluence": bool(conf_pages),
"orchestrator": orchestrator_meta,
"orchestrator_steps": orchestrator_steps,
},
)
self._persist_quality_metrics(
task_id=task_id,
dialog_session_id=dialog_session_id,
rag_session_id=rag_session_id,
route=route,
scenario=str(orchestrator_meta.get("scenario", task_spec.scenario.value)),
quality=quality_meta,
)
LOGGER.info(
"GraphAgentRuntime.run completed: task_id=%s route=%s/%s result_type=%s changeset_items=%s",
task_id,
route.domain_id,
route.process_id,
final.result_type.value,
len(final.changeset),
)
LOGGER.info(
"final agent answer: task_id=%s route=%s/%s answer=%s",
task_id,
route.domain_id,
route.process_id,
_truncate_for_log(final.answer),
)
return final
final_answer = answer or ""
await self._emit_progress(progress_cb, "agent.answer", "Формирую финальный ответ.")
self._router.persist_context(
dialog_session_id,
domain_id=route.domain_id,
process_id=route.process_id,
user_message=message,
assistant_message=final_answer,
decision_type=route.decision_type,
)
final = AgentResult(
result_type=TaskResultType.ANSWER,
answer=final_answer,
meta={
"route": route.model_dump(),
"used_rag": False,
"used_confluence": bool(conf_pages),
"orchestrator": orchestrator_meta,
"orchestrator_steps": orchestrator_steps,
},
)
self._persist_quality_metrics(
task_id=task_id,
dialog_session_id=dialog_session_id,
rag_session_id=rag_session_id,
route=route,
scenario=str(orchestrator_meta.get("scenario", task_spec.scenario.value)),
quality=quality_meta,
)
LOGGER.info(
"GraphAgentRuntime.run completed: task_id=%s route=%s/%s result_type=%s answer_len=%s",
task_id,
route.domain_id,
route.process_id,
final.result_type.value,
len(final.answer or ""),
)
LOGGER.info(
"final agent answer: task_id=%s route=%s/%s answer=%s",
task_id,
route.domain_id,
route.process_id,
_truncate_for_log(final.answer),
)
return final
def _extract_quality_meta(self, orchestrator_meta: dict) -> dict:
if not isinstance(orchestrator_meta, dict):
return {}
quality = orchestrator_meta.get("quality")
return quality if isinstance(quality, dict) else {}
def _persist_quality_metrics(
self,
*,
task_id: str,
dialog_session_id: str,
rag_session_id: str,
route,
scenario: str,
quality: dict,
) -> None:
if not quality:
return
self._metrics_persister.save(
task_id=task_id,
dialog_session_id=dialog_session_id,
rag_session_id=rag_session_id,
scenario=scenario,
domain_id=str(route.domain_id),
process_id=str(route.process_id),
quality=quality,
)
def _record_session_story_artifacts(
self,
*,
dialog_session_id: str,
rag_session_id: str,
scenario: str,
attachments: list[dict],
answer: str | None,
changeset: list[ChangeItem],
) -> None:
if self._story_recorder is None:
return
try:
self._story_recorder.record_run(
dialog_session_id=dialog_session_id,
rag_session_id=rag_session_id,
scenario=scenario,
attachments=attachments,
answer=answer,
changeset=changeset,
)
except Exception: # noqa: BLE001
LOGGER.exception("story session artifact recording failed")
async def _emit_progress(
self,
progress_cb: Callable[[str, str, str, dict | None], Awaitable[None] | None] | None,
stage: str,
message: str,
*,
kind: str = "task_progress",
meta: dict | None = None,
) -> None:
if progress_cb is None:
return
result = progress_cb(stage, message, kind, meta or {})
if inspect.isawaitable(result):
await result
def _resolve_graph(self, domain_id: str, process_id: str):
if self._checkpointer is None:
self._checkpointer = get_checkpointer()
factory = self._router.graph_factory(domain_id, process_id)
if factory is None:
factory = self._router.graph_factory("default", "general")
if factory is None:
raise RuntimeError("No graph factory configured")
LOGGER.debug("_resolve_graph resolved: domain_id=%s process_id=%s", domain_id, process_id)
return factory(self._checkpointer)
def _invoke_graph(self, graph, state: dict, dialog_session_id: str):
return graph.invoke(
state,
config={"configurable": {"thread_id": dialog_session_id}},
)
async def _fetch_confluence_pages(self, attachments: list[dict]) -> list[dict]:
pages: list[dict] = []
for item in attachments:
if item.get("type") == "confluence_url":
pages.append(await self._confluence.fetch_page(item["url"]))
LOGGER.info("_fetch_confluence_pages completed: pages=%s", len(pages))
return pages
def _format_rag(self, items: list[dict]) -> str:
blocks: list[str] = []
for item in items:
source = str(item.get("source", "") or item.get("path", "") or "")
layer = str(item.get("layer", "") or "").strip()
title = str(item.get("title", "") or "").strip()
metadata = item.get("metadata", {}) or {}
lines = []
if source:
lines.append(f"Source: {source}")
if layer:
lines.append(f"Layer: {layer}")
if title:
lines.append(f"Title: {title}")
if metadata:
hints = []
for key in ("module_id", "qname", "predicate", "entry_type", "framework", "section_path"):
value = metadata.get(key)
if value:
hints.append(f"{key}={value}")
if hints:
lines.append("Meta: " + ", ".join(hints))
content = str(item.get("content", "")).strip()
if content:
lines.append(content)
if lines:
blocks.append("\n".join(lines))
return "\n\n".join(blocks)
def _format_confluence(self, pages: list[dict]) -> str:
return "\n".join(str(x.get("content_markdown", "")) for x in pages)
def _build_files_map(self, files: list[dict]) -> dict[str, dict]:
output: dict[str, dict] = {}
for item in files:
path = str(item.get("path", "")).replace("\\", "/").strip()
if not path:
continue
output[path] = {
"path": path,
"content": str(item.get("content", "")),
"content_hash": str(item.get("content_hash", "")),
}
LOGGER.debug("_build_files_map completed: files=%s", len(output))
return output
def _lookup_file(self, files_map: dict[str, dict], path: str) -> dict | None:
normalized = (path or "").replace("\\", "/")
if normalized in files_map:
return files_map[normalized]
low = normalized.lower()
for key, value in files_map.items():
if key.lower() == low:
return value
return None
def _enrich_changeset_hashes(self, items: list[ChangeItem], files_map: dict[str, dict]) -> list[ChangeItem]:
enriched: list[ChangeItem] = []
for item in items:
if item.op.value == "update":
source = self._lookup_file(files_map, item.path)
if not source or not source.get("content_hash"):
raise AppError(
"missing_base_hash",
f"Cannot build update for {item.path}: no file hash in request context",
ModuleName.AGENT,
)
item.base_hash = str(source["content_hash"])
enriched.append(item)
LOGGER.debug("_enrich_changeset_hashes completed: items=%s", len(enriched))
return enriched
def _sanitize_changeset(self, items: list[ChangeItem], files_map: dict[str, dict]) -> list[ChangeItem]:
sanitized: list[ChangeItem] = []
dropped_noop = 0
dropped_ws = 0
for item in items:
if item.op.value != "update":
sanitized.append(item)
continue
source = self._lookup_file(files_map, item.path)
if not source:
sanitized.append(item)
continue
original = str(source.get("content", ""))
proposed = item.proposed_content or ""
if proposed == original:
dropped_noop += 1
continue
if self._collapse_whitespace(proposed) == self._collapse_whitespace(original):
dropped_ws += 1
continue
sanitized.append(item)
if dropped_noop or dropped_ws:
LOGGER.info(
"_sanitize_changeset dropped items: noop=%s whitespace_only=%s kept=%s",
dropped_noop,
dropped_ws,
len(sanitized),
)
return sanitized
def _collapse_whitespace(self, text: str) -> str:
return re.sub(r"\s+", " ", (text or "").strip())