Files
agent/tests/pipeline_setup_v3/runtime/agent_runtime_adapter.py
2026-03-12 23:33:51 +03:00

188 lines
8.2 KiB
Python

from __future__ import annotations
import math
from app.modules.agent.runtime import (
AgentRuntimeExecutor,
RuntimeRepoContextFactory,
RuntimeRetrievalAdapter,
)
from app.modules.agent.llm import AgentLlmService
from app.modules.agent.llm.prompt_loader import PromptLoader
from app.modules.agent.runtime.steps.context import build_retrieval_request, build_retrieval_result
from app.modules.agent.intent_router_v2 import ConversationState, IntentRouterV2
from app.modules.shared.gigachat.client import GigaChatClient
from app.modules.shared.gigachat.settings import GigaChatSettings
from app.modules.shared.gigachat.token_provider import GigaChatTokenProvider
from tests.pipeline_setup_v3.core.models import ExecutionPayload, V3Case
class AgentRuntimeAdapter:
def __init__(self) -> None:
self._router = IntentRouterV2()
self._repo_context_factory = RuntimeRepoContextFactory()
self._retrieval = RuntimeRetrievalAdapter()
self._executor: AgentRuntimeExecutor | None = None
def execute(self, case: V3Case, rag_session_id: str | None) -> ExecutionPayload:
if case.mode == "router_only":
return self._run_router_only(case)
if not rag_session_id:
raise ValueError(f"Case '{case.case_id}' requires rag_session_id or repo_path")
if case.mode == "router_rag":
return self._run_router_rag(case, rag_session_id)
if case.mode == "full_chain":
return self._run_full_chain(case, rag_session_id)
raise ValueError(f"Unsupported mode: {case.mode}")
def _run_router_only(self, case: V3Case) -> ExecutionPayload:
route = self._route(case.query)
actual = self._actual_from_route(route)
details = {
"query": case.query,
"router_result": route.model_dump(mode="json"),
"rag_rows": [],
}
return ExecutionPayload(actual=actual, details=details)
def _run_router_rag(self, case: V3Case, rag_session_id: str) -> ExecutionPayload:
route = self._route(case.query)
request = build_retrieval_request(route, rag_session_id)
raw_rows = self._retrieve_rows(rag_session_id, request)
symbol_resolution = self._resolve_symbol(route.symbol_resolution.model_dump(), raw_rows)
retrieval_result = build_retrieval_result(raw_rows, self._retrieval.consume_retrieval_report(), symbol_resolution)
actual = self._actual_from_route(
route,
rag_rows=raw_rows,
answer_mode="partial",
path_scope=list(request.path_scope or []),
layers=list(request.requested_layers or []),
symbol_candidates=list(request.symbol_candidates or []),
)
details = {
"query": case.query,
"router_result": route.model_dump(mode="json"),
"retrieval_request": request.model_dump(mode="json"),
"retrieval_result": retrieval_result.model_dump(mode="json"),
"rag_rows": raw_rows,
"symbol_resolution": symbol_resolution,
}
return ExecutionPayload(actual=actual, details=details)
def _run_full_chain(self, case: V3Case, rag_session_id: str) -> ExecutionPayload:
result = self._executor_instance().execute(user_query=case.query, rag_session_id=rag_session_id)
route = result.router_result
request = result.retrieval_request
actual = self._actual_from_route(
route,
rag_rows=list(result.retrieval_result.raw_rows) if result.retrieval_result else [],
llm_answer=result.final_answer,
answer_mode=_answer_status(result.answer_mode, result.llm_used),
path_scope=list(request.path_scope or []) if request else [],
layers=list(request.requested_layers or []) if request else [],
symbol_candidates=list(route.query_plan.symbol_candidates or []) if route else [],
)
details = {
"query": case.query,
"router_result": route.model_dump(mode="json") if route else {},
"retrieval_request": request.model_dump(mode="json") if request else {},
"retrieval_result": result.retrieval_result.model_dump(mode="json") if result.retrieval_result else {},
"diagnostics": result.diagnostics.model_dump(mode="json"),
"rag_rows": list(result.retrieval_result.raw_rows) if result.retrieval_result else [],
"validation": result.validation.model_dump(mode="json"),
"token_usage": _token_usage(result),
"steps": list(result.runtime_trace),
}
return ExecutionPayload(actual=actual, details=details)
def _route(self, query: str):
return self._router.route(query, ConversationState(), self._repo_context_factory.build())
def _retrieve_rows(self, rag_session_id: str, request) -> list[dict]:
if request.sub_intent == "OPEN_FILE" and request.path_scope:
return self._retrieval.retrieve_exact_files(
rag_session_id,
paths=list(request.path_scope),
layers=["C0_SOURCE_CHUNKS"],
limit=200,
query=request.query,
ranking_profile=str(getattr(request.retrieval_spec, "rerank_profile", "") or ""),
)
return self._retrieval.retrieve_with_plan(
rag_session_id,
request.query,
request.retrieval_spec,
request.retrieval_constraints,
query_plan=request.query_plan,
)
def _resolve_symbol(self, initial: dict, rag_rows: list[dict]) -> dict:
if str(initial.get("status") or "") != "pending":
return initial
candidates = [str(item).strip() for item in initial.get("alternatives", []) if str(item).strip()]
found = [
str(row.get("title") or "").strip()
for row in rag_rows
if str(row.get("layer") or "") == "C1_SYMBOL_CATALOG" and str(row.get("title") or "").strip()
]
exact = next((item for item in found if item in candidates), None)
if exact:
return {"status": "resolved", "resolved_symbol": exact, "alternatives": found[:5], "confidence": 0.99}
if found:
return {"status": "ambiguous", "resolved_symbol": None, "alternatives": found[:5], "confidence": 0.55}
return {"status": "not_found", "resolved_symbol": None, "alternatives": [], "confidence": 0.0}
def _actual_from_route(
self,
route,
*,
rag_rows: list[dict] | None = None,
llm_answer: str = "",
answer_mode: str = "partial",
path_scope: list[str] | None = None,
layers: list[str] | None = None,
symbol_candidates: list[str] | None = None,
) -> dict:
route_dump = route.model_dump(mode="json") if route else {}
return {
"intent": route_dump.get("intent"),
"sub_intent": dict(route_dump.get("query_plan") or {}).get("sub_intent"),
"graph_id": route_dump.get("graph_id"),
"conversation_mode": route_dump.get("conversation_mode"),
"rag_count": len(rag_rows or []),
"llm_answer": llm_answer,
"answer_mode": answer_mode,
"path_scope": tuple(path_scope or []),
"symbol_candidates": tuple(symbol_candidates or []),
"layers": tuple(layers or []),
}
def _executor_instance(self) -> AgentRuntimeExecutor:
if self._executor is None:
self._executor = AgentRuntimeExecutor(_build_llm())
return self._executor
def _build_llm() -> AgentLlmService:
settings = GigaChatSettings.from_env()
client = GigaChatClient(settings, GigaChatTokenProvider(settings))
return AgentLlmService(client=client, prompts=PromptLoader())
def _answer_status(answer_mode: str, llm_used: bool) -> str:
if answer_mode == "normal" and llm_used:
return "answered"
return answer_mode
def _token_usage(result) -> dict:
draft = result.draft_answer
if draft is None:
return {}
system_prompt = PromptLoader().load(draft.prompt_name)
tokens_in_estimate = max(1, int(math.ceil((len(system_prompt or "") + len(draft.prompt_payload or "")) / 4)))
return {
"prompt_name": draft.prompt_name,
"tokens_in_estimate": tokens_in_estimate,
}