155 lines
6.5 KiB
Python
155 lines
6.5 KiB
Python
from __future__ import annotations
|
|
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
|
|
from app.modules.agent.code_qa_runtime import CodeQaRuntimeExecutor
|
|
from app.modules.agent.llm import AgentLlmService
|
|
from app.modules.agent.prompt_loader import PromptLoader
|
|
from app.modules.shared.gigachat.client import GigaChatClient
|
|
from app.modules.shared.gigachat.settings import GigaChatSettings
|
|
from app.modules.shared.gigachat.token_provider import GigaChatTokenProvider
|
|
from tests.pipeline_setup.suite_02_pipeline.pipeline_intent_rag.helpers.env_bootstrap import PipelineEnvLoader
|
|
from tests.pipeline_setup.suite_02_pipeline.pipeline_intent_rag.helpers.models import PhraseCase, PipelineResult
|
|
from tests.pipeline_setup.suite_02_pipeline.pipeline_intent_rag.helpers.pipeline_runner import (
|
|
IntentRouterRagPipelineRunner,
|
|
RouterOnlyRunner,
|
|
)
|
|
from tests.pipeline_setup_v2.core.models import ExecutionPayload, V2Case
|
|
|
|
|
|
class RuntimeAdapter:
|
|
def __init__(self) -> None:
|
|
self._test_root = (
|
|
Path(__file__).resolve().parents[2]
|
|
/ "pipeline_setup"
|
|
/ "suite_02_pipeline"
|
|
/ "pipeline_intent_rag"
|
|
).resolve()
|
|
PipelineEnvLoader(self._test_root).load()
|
|
self._started_at = datetime.now()
|
|
self._rag_adapter = None
|
|
self._executor: CodeQaRuntimeExecutor | None = None
|
|
|
|
def execute(self, case: V2Case, rag_session_id: str | None) -> ExecutionPayload:
|
|
phrase = PhraseCase(
|
|
case_id=case.case_id,
|
|
text=case.query,
|
|
rag_session_id=rag_session_id,
|
|
expected_intent=case.expectations.router.intent,
|
|
expect_non_empty_rag=bool(case.expectations.retrieval.non_empty),
|
|
)
|
|
if case.mode == "router_only":
|
|
result = RouterOnlyRunner(started_at=self._started_at).run_case(phrase)
|
|
elif case.mode == "router_rag":
|
|
result = IntentRouterRagPipelineRunner(
|
|
started_at=self._started_at,
|
|
rag_adapter=self._rag_components(),
|
|
session_resolver=_StaticSessionResolver(),
|
|
).run_case(phrase)
|
|
else:
|
|
result = self._run_full_chain(phrase)
|
|
record = result.to_record()
|
|
return ExecutionPayload(actual=_actual_from_runtime_record(record), details=record)
|
|
|
|
def _run_full_chain(self, phrase: PhraseCase) -> PipelineResult:
|
|
result = self._executor_instance().execute(user_query=phrase.text, rag_session_id=phrase.rag_session_id or "")
|
|
route = dict(result.diagnostics.router_result)
|
|
retrieval = result.diagnostics.retrieval_request
|
|
diagnostics = {
|
|
"router_plan": {
|
|
"sub_intent": route.get("sub_intent"),
|
|
"graph_id": route.get("graph_id"),
|
|
"path_scope": list(retrieval.get("path_scope") or []),
|
|
"layers": list(retrieval.get("requested_layers") or []),
|
|
"symbol_candidates": list(result.router_result.query_plan.symbol_candidates or []) if result.router_result else [],
|
|
},
|
|
"timings_ms": dict(result.diagnostics.timings_ms or {}),
|
|
"answer_policy": {
|
|
"short_circuit": not result.llm_used,
|
|
"answer_mode": _answer_status(result.answer_mode, result.llm_used),
|
|
"failure_reason": ",".join(result.validation.reasons),
|
|
},
|
|
}
|
|
return PipelineResult(
|
|
case=phrase,
|
|
mode="full_chain",
|
|
run_started_at=self._started_at,
|
|
rag_session_id=phrase.rag_session_id,
|
|
intent=str(route.get("intent") or ""),
|
|
graph_id=str(route.get("graph_id") or ""),
|
|
conversation_mode=str(route.get("conversation_mode") or ""),
|
|
query=str(retrieval.get("query") or phrase.text),
|
|
rag_rows=list(result.retrieval_result.raw_rows) if result.retrieval_result else [],
|
|
symbol_resolution=result.router_result.symbol_resolution.model_dump() if result.router_result else {},
|
|
llm_answer=result.final_answer,
|
|
diagnostics=diagnostics,
|
|
)
|
|
|
|
def _rag_components(self):
|
|
if self._rag_adapter is None:
|
|
from app.modules.rag.persistence.repository import RagRepository
|
|
from tests.pipeline_setup.suite_02_pipeline.pipeline_intent_rag.helpers.rag_db_adapter import (
|
|
RagDbAdapter,
|
|
SessionEmbeddingDimensions,
|
|
)
|
|
|
|
self._rag_adapter = RagDbAdapter(
|
|
repository=RagRepository(),
|
|
dim_resolver=SessionEmbeddingDimensions(),
|
|
)
|
|
return self._rag_adapter
|
|
|
|
def _executor_instance(self) -> CodeQaRuntimeExecutor:
|
|
if self._executor is None:
|
|
self._executor = CodeQaRuntimeExecutor(_build_llm())
|
|
return self._executor
|
|
|
|
|
|
class _StaticSessionResolver:
|
|
def resolve(self, case: PhraseCase) -> str:
|
|
if not case.rag_session_id:
|
|
raise ValueError(f"Case '{case.case_id}' requires rag_session_id or repo_path")
|
|
return case.rag_session_id
|
|
|
|
|
|
def _actual_from_runtime_record(record: dict) -> dict:
|
|
diagnostics = dict(record.get("diagnostics") or {})
|
|
router_plan = dict(diagnostics.get("router_plan") or {})
|
|
summary = dict(record.get("summary") or {})
|
|
router_summary = dict(summary.get("router") or {})
|
|
return {
|
|
"intent": record.get("actual_intent"),
|
|
"sub_intent": router_summary.get("sub_intent") or router_plan.get("sub_intent"),
|
|
"graph_id": record.get("graph_id"),
|
|
"conversation_mode": record.get("conversation_mode"),
|
|
"rag_count": int(record.get("rag_count") or 0),
|
|
"llm_answer": str(record.get("llm_answer") or "").strip(),
|
|
"answer_mode": dict(summary.get("llm") or {}).get("answer_status"),
|
|
"path_scope": tuple(router_plan.get("path_scope") or []),
|
|
"symbol_candidates": tuple(router_plan.get("symbol_candidates") or []),
|
|
"layers": tuple(router_plan.get("layers") or []),
|
|
}
|
|
|
|
|
|
def _ms(started: float) -> int:
|
|
return 0
|
|
|
|
|
|
def _sub_intent_from_result(result: PipelineResult) -> str | None:
|
|
router_plan = dict(result.diagnostics.get("router_plan") or {})
|
|
value = str(router_plan.get("sub_intent") or "").strip()
|
|
return value or None
|
|
|
|
|
|
def _answer_status(answer_mode: str, llm_used: bool) -> str:
|
|
if answer_mode == "normal" and llm_used:
|
|
return "answered"
|
|
return answer_mode
|
|
|
|
|
|
def _build_llm() -> AgentLlmService:
|
|
settings = GigaChatSettings.from_env()
|
|
client = GigaChatClient(settings, GigaChatTokenProvider(settings))
|
|
return AgentLlmService(client=client, prompts=PromptLoader())
|