Фикс состояния
This commit is contained in:
171
tests/pipeline_setup_v3/runtime/agent_runtime_adapter.py
Normal file
171
tests/pipeline_setup_v3/runtime/agent_runtime_adapter.py
Normal file
@@ -0,0 +1,171 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from app.modules.agent.code_qa_runtime import CodeQaRuntimeExecutor
|
||||
from app.modules.agent.code_qa_runtime.repo_context import CodeQaRepoContextFactory
|
||||
from app.modules.agent.code_qa_runtime.retrieval_adapter import CodeQaRetrievalAdapter
|
||||
from app.modules.agent.llm import AgentLlmService
|
||||
from app.modules.agent.prompt_loader import PromptLoader
|
||||
from app.modules.rag.code_qa_pipeline.retrieval_request_builder import build_retrieval_request
|
||||
from app.modules.rag.code_qa_pipeline.retrieval_result_builder import build_retrieval_result
|
||||
from app.modules.rag.intent_router_v2 import ConversationState, IntentRouterV2
|
||||
from app.modules.shared.gigachat.client import GigaChatClient
|
||||
from app.modules.shared.gigachat.settings import GigaChatSettings
|
||||
from app.modules.shared.gigachat.token_provider import GigaChatTokenProvider
|
||||
from tests.pipeline_setup_v3.core.models import ExecutionPayload, V3Case
|
||||
|
||||
|
||||
class AgentRuntimeAdapter:
|
||||
def __init__(self) -> None:
|
||||
self._router = IntentRouterV2()
|
||||
self._repo_context_factory = CodeQaRepoContextFactory()
|
||||
self._retrieval = CodeQaRetrievalAdapter()
|
||||
self._executor: CodeQaRuntimeExecutor | None = None
|
||||
|
||||
def execute(self, case: V3Case, rag_session_id: str | None) -> ExecutionPayload:
|
||||
if case.mode == "router_only":
|
||||
return self._run_router_only(case)
|
||||
if not rag_session_id:
|
||||
raise ValueError(f"Case '{case.case_id}' requires rag_session_id or repo_path")
|
||||
if case.mode == "router_rag":
|
||||
return self._run_router_rag(case, rag_session_id)
|
||||
if case.mode == "full_chain":
|
||||
return self._run_full_chain(case, rag_session_id)
|
||||
raise ValueError(f"Unsupported mode: {case.mode}")
|
||||
|
||||
def _run_router_only(self, case: V3Case) -> ExecutionPayload:
|
||||
route = self._route(case.query)
|
||||
actual = self._actual_from_route(route)
|
||||
details = {
|
||||
"query": case.query,
|
||||
"router_result": route.model_dump(mode="json"),
|
||||
"rag_rows": [],
|
||||
}
|
||||
return ExecutionPayload(actual=actual, details=details)
|
||||
|
||||
def _run_router_rag(self, case: V3Case, rag_session_id: str) -> ExecutionPayload:
|
||||
route = self._route(case.query)
|
||||
request = build_retrieval_request(route, rag_session_id)
|
||||
raw_rows = self._retrieve_rows(rag_session_id, request)
|
||||
symbol_resolution = self._resolve_symbol(route.symbol_resolution.model_dump(), raw_rows)
|
||||
retrieval_result = build_retrieval_result(raw_rows, self._retrieval.consume_retrieval_report(), symbol_resolution)
|
||||
actual = self._actual_from_route(
|
||||
route,
|
||||
rag_rows=raw_rows,
|
||||
answer_mode="partial",
|
||||
path_scope=list(request.path_scope or []),
|
||||
layers=list(request.requested_layers or []),
|
||||
symbol_candidates=list(request.symbol_candidates or []),
|
||||
)
|
||||
details = {
|
||||
"query": case.query,
|
||||
"router_result": route.model_dump(mode="json"),
|
||||
"retrieval_request": request.model_dump(mode="json"),
|
||||
"retrieval_result": retrieval_result.model_dump(mode="json"),
|
||||
"rag_rows": raw_rows,
|
||||
"symbol_resolution": symbol_resolution,
|
||||
}
|
||||
return ExecutionPayload(actual=actual, details=details)
|
||||
|
||||
def _run_full_chain(self, case: V3Case, rag_session_id: str) -> ExecutionPayload:
|
||||
result = self._executor_instance().execute(user_query=case.query, rag_session_id=rag_session_id)
|
||||
route = result.router_result
|
||||
request = result.retrieval_request
|
||||
actual = self._actual_from_route(
|
||||
route,
|
||||
rag_rows=list(result.retrieval_result.raw_rows) if result.retrieval_result else [],
|
||||
llm_answer=result.final_answer,
|
||||
answer_mode=_answer_status(result.answer_mode, result.llm_used),
|
||||
path_scope=list(request.path_scope or []) if request else [],
|
||||
layers=list(request.requested_layers or []) if request else [],
|
||||
symbol_candidates=list(route.query_plan.symbol_candidates or []) if route else [],
|
||||
)
|
||||
details = {
|
||||
"query": case.query,
|
||||
"router_result": route.model_dump(mode="json") if route else {},
|
||||
"retrieval_request": request.model_dump(mode="json") if request else {},
|
||||
"retrieval_result": result.retrieval_result.model_dump(mode="json") if result.retrieval_result else {},
|
||||
"diagnostics": result.diagnostics.model_dump(mode="json"),
|
||||
"rag_rows": list(result.retrieval_result.raw_rows) if result.retrieval_result else [],
|
||||
"validation": result.validation.model_dump(mode="json"),
|
||||
"steps": list(result.runtime_trace),
|
||||
}
|
||||
return ExecutionPayload(actual=actual, details=details)
|
||||
|
||||
def _route(self, query: str):
|
||||
return self._router.route(query, ConversationState(), self._repo_context_factory.build())
|
||||
|
||||
def _retrieve_rows(self, rag_session_id: str, request) -> list[dict]:
|
||||
if request.sub_intent == "OPEN_FILE" and request.path_scope:
|
||||
return self._retrieval.retrieve_exact_files(
|
||||
rag_session_id,
|
||||
paths=list(request.path_scope),
|
||||
layers=["C0_SOURCE_CHUNKS"],
|
||||
limit=200,
|
||||
query=request.query,
|
||||
ranking_profile=str(getattr(request.retrieval_spec, "rerank_profile", "") or ""),
|
||||
)
|
||||
return self._retrieval.retrieve_with_plan(
|
||||
rag_session_id,
|
||||
request.query,
|
||||
request.retrieval_spec,
|
||||
request.retrieval_constraints,
|
||||
query_plan=request.query_plan,
|
||||
)
|
||||
|
||||
def _resolve_symbol(self, initial: dict, rag_rows: list[dict]) -> dict:
|
||||
if str(initial.get("status") or "") != "pending":
|
||||
return initial
|
||||
candidates = [str(item).strip() for item in initial.get("alternatives", []) if str(item).strip()]
|
||||
found = [
|
||||
str(row.get("title") or "").strip()
|
||||
for row in rag_rows
|
||||
if str(row.get("layer") or "") == "C1_SYMBOL_CATALOG" and str(row.get("title") or "").strip()
|
||||
]
|
||||
exact = next((item for item in found if item in candidates), None)
|
||||
if exact:
|
||||
return {"status": "resolved", "resolved_symbol": exact, "alternatives": found[:5], "confidence": 0.99}
|
||||
if found:
|
||||
return {"status": "ambiguous", "resolved_symbol": None, "alternatives": found[:5], "confidence": 0.55}
|
||||
return {"status": "not_found", "resolved_symbol": None, "alternatives": [], "confidence": 0.0}
|
||||
|
||||
def _actual_from_route(
|
||||
self,
|
||||
route,
|
||||
*,
|
||||
rag_rows: list[dict] | None = None,
|
||||
llm_answer: str = "",
|
||||
answer_mode: str = "partial",
|
||||
path_scope: list[str] | None = None,
|
||||
layers: list[str] | None = None,
|
||||
symbol_candidates: list[str] | None = None,
|
||||
) -> dict:
|
||||
route_dump = route.model_dump(mode="json") if route else {}
|
||||
return {
|
||||
"intent": route_dump.get("intent"),
|
||||
"sub_intent": dict(route_dump.get("query_plan") or {}).get("sub_intent"),
|
||||
"graph_id": route_dump.get("graph_id"),
|
||||
"conversation_mode": route_dump.get("conversation_mode"),
|
||||
"rag_count": len(rag_rows or []),
|
||||
"llm_answer": llm_answer,
|
||||
"answer_mode": answer_mode,
|
||||
"path_scope": tuple(path_scope or []),
|
||||
"symbol_candidates": tuple(symbol_candidates or []),
|
||||
"layers": tuple(layers or []),
|
||||
}
|
||||
|
||||
def _executor_instance(self) -> CodeQaRuntimeExecutor:
|
||||
if self._executor is None:
|
||||
self._executor = CodeQaRuntimeExecutor(_build_llm())
|
||||
return self._executor
|
||||
|
||||
|
||||
def _build_llm() -> AgentLlmService:
|
||||
settings = GigaChatSettings.from_env()
|
||||
client = GigaChatClient(settings, GigaChatTokenProvider(settings))
|
||||
return AgentLlmService(client=client, prompts=PromptLoader())
|
||||
|
||||
|
||||
def _answer_status(answer_mode: str, llm_used: bool) -> str:
|
||||
if answer_mode == "normal" and llm_used:
|
||||
return "answered"
|
||||
return answer_mode
|
||||
Reference in New Issue
Block a user