341 lines
14 KiB
Python
341 lines
14 KiB
Python
from __future__ import annotations
|
|
|
|
import asyncio
|
|
from dataclasses import asdict, dataclass
|
|
from pathlib import Path
|
|
|
|
from app.core.agent.processes.v2 import V2IntentRouter, V2Process
|
|
from app.core.agent.utils.llm import AgentLlmService, PromptLoader
|
|
from app.core.rag.persistence import RagRepository
|
|
from app.core.rag.retrieval.session_retriever import RagSessionRetriever
|
|
from app.core.shared.gigachat.client import GigaChatClient
|
|
from app.core.shared.gigachat.settings import GigaChatSettings
|
|
from app.core.shared.gigachat.token_provider import GigaChatTokenProvider
|
|
from app.infra.observability.module_trace import RequestTraceContext
|
|
from app.core.agent.utils.process_v2.anchor_signals import route_anchor_summary
|
|
from app.core.agent.utils.process_v2.evidence.assembler import DocsEvidenceAssembler
|
|
from app.core.agent.utils.process_v2.evidence.gate import DocsEvidenceGate
|
|
from app.core.agent.utils.process_v2.models import V2Intent
|
|
from app.core.agent.utils.process_v2.plan_resolver import V2RetrievalPolicyResolver
|
|
from app.core.agent.utils.process_v2.rag_retrieval import DocsMetadataLookupIndex, V2RagRetrievalAdapter
|
|
from tests.pipeline_setup_v3.core.models import ExecutionPayload, V3Case
|
|
from tests.pipeline_setup_v3.shared.rag_indexer import DeterministicEmbedder
|
|
from tests.pipeline_setup_v4.executors.process_v2_router_executor import _KeywordLlm
|
|
|
|
|
|
class V2ProcessAdapter:
|
|
def __init__(self, *, workflow_llm_enabled: bool = True) -> None:
|
|
self._workflow_llm_enabled = workflow_llm_enabled
|
|
self._llm = _build_v2_llm()
|
|
self._router = V2IntentRouter(llm=_KeywordLlm(), enable_llm_disambiguation=True)
|
|
self._policy = V2RetrievalPolicyResolver()
|
|
retriever = RagSessionRetriever(repository=RagRepository(), embedder=DeterministicEmbedder())
|
|
self._retrieval = V2RagRetrievalAdapter(retriever)
|
|
self._process = V2Process(
|
|
llm=self._llm,
|
|
policy_resolver=self._policy,
|
|
rag_adapter=self._retrieval,
|
|
evidence_assembler=DocsEvidenceAssembler(),
|
|
evidence_gate=DocsEvidenceGate(),
|
|
router=self._router,
|
|
workflow_llm_enabled=workflow_llm_enabled,
|
|
)
|
|
|
|
def execute(self, case: V3Case, rag_session_id: str | None) -> ExecutionPayload:
|
|
return asyncio.run(self._execute_async(case, rag_session_id))
|
|
|
|
async def _execute_async(self, case: V3Case, rag_session_id: str | None) -> ExecutionPayload:
|
|
runtime = _RuntimeStub(query=case.query)
|
|
route = self._router.route(case.query)
|
|
_log_pipeline_step(
|
|
runtime,
|
|
"router_resolved",
|
|
{
|
|
"domain": route.routing_domain,
|
|
"intent": route.intent,
|
|
"subintent": route.subintent,
|
|
"confidence": route.confidence,
|
|
},
|
|
)
|
|
_log_pipeline_step(
|
|
runtime,
|
|
"anchors_extracted",
|
|
{
|
|
"signal_types": route_anchor_summary(route)["signal_types"],
|
|
"endpoint_paths": route.anchors.endpoint_paths,
|
|
"target_doc_hints": route.anchors.target_doc_hints,
|
|
"matched_aliases": route.anchors.matched_aliases,
|
|
},
|
|
)
|
|
_log_pipeline_step(
|
|
runtime,
|
|
"alias_resolution",
|
|
{
|
|
"resolved_aliases": route.anchors.matched_aliases,
|
|
"target_doc_hints": route.anchors.target_doc_hints,
|
|
},
|
|
)
|
|
if case.mode == "router_only":
|
|
return ExecutionPayload(
|
|
actual=_actual_from_v2(route),
|
|
details=_details(case.query, route=route, pipeline_steps=_build_pipeline_steps(runtime.logs)),
|
|
)
|
|
if case.mode == "full_chain":
|
|
return await self._execute_full_chain(case, rag_session_id, route)
|
|
plan = self._policy.resolve(route)
|
|
_log_pipeline_step(
|
|
runtime,
|
|
"retrieval_profile_selected",
|
|
{"profile": plan.profile, "layers": plan.layers, "filters": plan.filters},
|
|
)
|
|
semantic_rows = await self._retrieve_rows(route, rag_session_id, plan)
|
|
seeded_rows = await self._seed_candidates_from_target_hints(route, rag_session_id, plan)
|
|
metadata_rows = self._metadata_lookup_candidates([*seeded_rows, *semantic_rows], route)
|
|
rows = self._merge_candidate_rows(seeded_rows, metadata_rows, semantic_rows)
|
|
_log_pipeline_step(
|
|
runtime,
|
|
"candidate_generation",
|
|
{
|
|
"resolved_aliases": route.anchors.matched_aliases,
|
|
"target_doc_hints": route.anchors.target_doc_hints,
|
|
"candidate_docs_before_ranking": [self._trace_row(row) for row in rows[:8]],
|
|
"sources": {
|
|
"seeded": [self._trace_row(row) for row in seeded_rows[:5]],
|
|
"metadata_lookup": [self._trace_row(row) for row in metadata_rows[:5]],
|
|
"semantic": [self._trace_row(row) for row in semantic_rows[:5]],
|
|
},
|
|
},
|
|
)
|
|
_log_pipeline_step(
|
|
runtime,
|
|
"retrieval_executed",
|
|
{
|
|
"query": case.query,
|
|
"profile": plan.profile,
|
|
"row_count": len(rows),
|
|
"target_doc_hints": route.anchors.target_doc_hints,
|
|
"top_results": [self._trace_row(row) for row in rows[:5]],
|
|
},
|
|
)
|
|
if case.mode == "router_rag":
|
|
return ExecutionPayload(
|
|
actual=_actual_from_v2(route, rows=rows, plan=plan, answer_mode="partial"),
|
|
details=_details(case.query, route=route, plan=plan, rows=rows, pipeline_steps=_build_pipeline_steps(runtime.logs)),
|
|
)
|
|
raise ValueError(f"Unsupported process_v2 adapter mode: {case.mode}")
|
|
|
|
async def _retrieve_rows(self, route, rag_session_id: str | None, plan) -> list[dict]:
|
|
if not rag_session_id:
|
|
if route.intent == V2Intent.GENERAL_QA:
|
|
return []
|
|
raise ValueError("process_v2 cases with DOCS intent require rag_session_id")
|
|
return await self._retrieval.fetch_rows(rag_session_id, route.normalized_query, plan)
|
|
|
|
async def _seed_candidates_from_target_hints(self, route, rag_session_id: str | None, plan) -> list[dict]:
|
|
if not rag_session_id or not route.anchors.target_doc_hints:
|
|
return []
|
|
return await self._retrieval.fetch_exact_paths(rag_session_id, paths=route.anchors.target_doc_hints, layers=plan.layers)
|
|
|
|
def _metadata_lookup_candidates(self, rows: list[dict], route) -> list[dict]:
|
|
return DocsMetadataLookupIndex(rows).lookup(route)
|
|
|
|
def _merge_candidate_rows(self, *groups: list[dict]) -> list[dict]:
|
|
merged: list[dict] = []
|
|
seen: set[tuple[str, str, str]] = set()
|
|
for rows in groups:
|
|
for row in rows:
|
|
key = (
|
|
str(row.get("path") or ""),
|
|
str(row.get("layer") or ""),
|
|
str(dict(row.get("metadata") or {}).get("section_path") or ""),
|
|
)
|
|
if key in seen:
|
|
continue
|
|
seen.add(key)
|
|
merged.append(row)
|
|
return merged
|
|
|
|
async def _execute_full_chain(self, case: V3Case, rag_session_id: str | None, route) -> ExecutionPayload:
|
|
runtime = _RuntimeStub(query=case.query, rag_session_id=rag_session_id)
|
|
result = await self._process.run(runtime)
|
|
retrieval_plan = _event_payload(runtime.logs, "process.v2.retrieval_policy", "retrieval_plan_resolved")
|
|
rows = list(_event_payload(runtime.logs, "process.v2.rag_retrieval", "rag_rows_fetched").get("rows") or [])
|
|
answer_generated = _event_payload(runtime.logs, "process.v2.pipeline", "answer_generated")
|
|
return ExecutionPayload(
|
|
actual={
|
|
"domain": route.routing_domain,
|
|
"intent": route.intent,
|
|
"sub_intent": route.subintent,
|
|
"rag_count": len(rows),
|
|
"llm_answer": result.answer,
|
|
"answer_mode": str(answer_generated.get("answer_mode") or ""),
|
|
"path_scope": tuple(),
|
|
"symbol_candidates": tuple(),
|
|
"entity_candidates": tuple(_entity_candidates(rows)),
|
|
"doc_scope": tuple(_doc_scope(rows)),
|
|
"layers": tuple(retrieval_plan.get("layers") or []),
|
|
"filters": dict(retrieval_plan.get("filters") or {}),
|
|
},
|
|
details={
|
|
"query": case.query,
|
|
"router_result": asdict(route),
|
|
"retrieval_plan": retrieval_plan,
|
|
"rows": rows,
|
|
"answer": result.answer,
|
|
"logs": runtime.logs,
|
|
"pipeline_steps": _build_pipeline_steps(runtime.logs),
|
|
},
|
|
)
|
|
|
|
def _trace_row(self, row: dict) -> dict[str, object]:
|
|
metadata = row.get("metadata") or {}
|
|
content = str(row.get("content") or "").strip()
|
|
return {
|
|
"layer": str(row.get("layer") or ""),
|
|
"path": str(row.get("path") or ""),
|
|
"title": str(row.get("title") or ""),
|
|
"document_id": str(metadata.get("document_id") or metadata.get("doc_id") or row.get("document_id") or ""),
|
|
"entity_name": str(metadata.get("entity_name") or ""),
|
|
"summary_text": str(metadata.get("summary_text") or "")[:400],
|
|
"section_path": str(metadata.get("section_path") or ""),
|
|
"metadata_domain": str(metadata.get("domain") or ""),
|
|
"metadata_subdomain": str(metadata.get("subdomain") or ""),
|
|
"content_preview": content[:400],
|
|
}
|
|
|
|
|
|
@dataclass(slots=True)
|
|
class _RequestStub:
|
|
request_id: str
|
|
message: str
|
|
|
|
|
|
@dataclass(slots=True)
|
|
class _SessionStub:
|
|
active_rag_session_id: str | None = None
|
|
|
|
|
|
class _PublisherStub:
|
|
async def publish_status(self, request_id: str, source: str, message: str, payload: dict | None = None) -> None:
|
|
return None
|
|
|
|
|
|
class _TraceLoggerStub:
|
|
def __init__(self, store: list[dict]) -> None:
|
|
self._store = store
|
|
|
|
def log_module(self, request_id: str, module: str, title: str, payload: dict | None = None) -> None:
|
|
self._store.append(
|
|
{"request_id": request_id, "module": module, "event": title, "payload": dict(payload or {})}
|
|
)
|
|
|
|
|
|
class _RuntimeStub:
|
|
def __init__(self, *, query: str, rag_session_id: str | None = None) -> None:
|
|
self.logs: list[dict] = []
|
|
self.request = _RequestStub(request_id="pipeline_setup_v3", message=query)
|
|
self.session = _SessionStub(active_rag_session_id=rag_session_id)
|
|
self.publisher = _PublisherStub()
|
|
self.trace = RequestTraceContext(request_id=self.request.request_id, logger=_TraceLoggerStub(self.logs))
|
|
|
|
|
|
def _build_client() -> GigaChatClient:
|
|
settings = GigaChatSettings.from_env()
|
|
return GigaChatClient(settings, GigaChatTokenProvider(settings))
|
|
|
|
|
|
def _build_v2_llm() -> AgentLlmService:
|
|
prompt_paths = [
|
|
Path(__file__).resolve().parents[3]
|
|
/ "src/app/core/agent/processes/v2/workflows/doc_explain_summary/steps/prompts/prompts.yml",
|
|
Path(__file__).resolve().parents[3]
|
|
/ "src/app/core/agent/processes/v2/workflows/general_qa_summary/steps/prompts/prompts.yml",
|
|
Path(__file__).resolve().parents[3] / "src/app/core/agent/processes/v2/intent_router/routers/prompts.yml",
|
|
]
|
|
return AgentLlmService(client=_build_client(), prompts=PromptLoader(prompt_paths))
|
|
|
|
|
|
def _actual_from_v2(route, *, rows: list[dict] | None = None, plan=None, answer: str = "", answer_mode: str = "partial") -> dict:
|
|
return {
|
|
"domain": route.routing_domain,
|
|
"intent": route.intent,
|
|
"sub_intent": route.subintent,
|
|
"rag_count": len(rows or []),
|
|
"llm_answer": answer,
|
|
"answer_mode": answer_mode,
|
|
"path_scope": tuple(),
|
|
"symbol_candidates": tuple(),
|
|
"entity_candidates": tuple(_entity_candidates(rows or [])),
|
|
"doc_scope": tuple(_doc_scope(rows or [])),
|
|
"layers": tuple(getattr(plan, "layers", []) or []),
|
|
"filters": dict(getattr(plan, "filters", {}) or {}),
|
|
}
|
|
|
|
|
|
def _details(query: str, **payload) -> dict:
|
|
details = {"query": query}
|
|
for key, value in payload.items():
|
|
if key == "route":
|
|
details["router_result"] = asdict(value)
|
|
elif key == "plan":
|
|
details["retrieval_plan"] = asdict(value)
|
|
else:
|
|
details[key] = value
|
|
return details
|
|
|
|
|
|
def _doc_scope(rows: list[dict]) -> list[str]:
|
|
values: list[str] = []
|
|
for row in rows:
|
|
metadata = dict(row.get("metadata") or {})
|
|
for candidate in (
|
|
row.get("document_id"),
|
|
metadata.get("document_id"),
|
|
metadata.get("doc_id"),
|
|
row.get("path"),
|
|
):
|
|
value = str(candidate or "").strip()
|
|
if value and value not in values:
|
|
values.append(value)
|
|
return values
|
|
|
|
|
|
def _entity_candidates(rows: list[dict]) -> list[str]:
|
|
values: list[str] = []
|
|
for row in rows:
|
|
metadata = dict(row.get("metadata") or {})
|
|
value = str(row.get("entity_name") or metadata.get("entity_name") or row.get("title") or "").strip()
|
|
if value and value not in values and str(row.get("layer") or "") == "D3_ENTITY_CATALOG":
|
|
values.append(value)
|
|
return values
|
|
|
|
|
|
def _build_pipeline_steps(logs: list[dict]) -> list[dict]:
|
|
steps: list[dict] = []
|
|
for item in logs:
|
|
if item.get("module") != "process.v2.pipeline":
|
|
continue
|
|
steps.append({"step": item.get("event"), "output": item.get("payload") or {}})
|
|
return steps
|
|
|
|
|
|
def _event_payload(logs: list[dict], module: str, event: str) -> dict[str, object]:
|
|
for item in logs:
|
|
if item.get("module") == module and item.get("event") == event:
|
|
payload = item.get("payload") or {}
|
|
if isinstance(payload, dict):
|
|
return dict(payload)
|
|
return {}
|
|
return {}
|
|
|
|
|
|
def _log_pipeline_step(runtime: _RuntimeStub, step: str, payload: dict[str, object]) -> None:
|
|
runtime.logs.append(
|
|
{
|
|
"request_id": runtime.request.request_id,
|
|
"module": "process.v2.pipeline",
|
|
"event": step,
|
|
"payload": payload,
|
|
}
|
|
)
|