Files
agent/tests/docs_qa_eval/fixture_adapter.py
T
2026-03-27 15:51:10 +03:00

109 lines
5.4 KiB
Python

from __future__ import annotations
class InMemoryDocsRetrievalAdapter:
def __init__(self, rows: list[dict]) -> None:
self._rows = list(rows)
self._report: dict = {}
def retrieve_with_plan(self, rag_session_id: str, query: str, retrieval_spec, retrieval_constraints=None, *, query_plan=None) -> list[dict]:
planned_layers = [str(item.layer_id) for item in retrieval_spec.layer_queries]
query_sub_intent = str(getattr(query_plan, "sub_intent", "") or "")
relation_rows = [row for row in self._rows if str(row.get("layer") or "") == "D5_RELATION_GRAPH"]
relation_hits = len(relation_rows)
result: list[dict] = []
executed_layers: list[str] = []
layer_diagnostics: dict[str, dict] = {}
fallback_used = False
fallback_reason = None
support_paths = [
str(row.get("path") or "").strip()
for row in self._rows
if str(row.get("layer") or "") in {"D1_DOCUMENT_CATALOG", "D2_FACT_INDEX", "D3_ENTITY_CATALOG", "D4_WORKFLOW_INDEX", "D5_RELATION_GRAPH"}
and str(row.get("path") or "").strip()
]
self._report = {
"planned_layers": planned_layers,
"executed_layers": executed_layers,
"retrieval_mode_by_layer": {},
"fallback": {"used": False},
"layer_diagnostics": layer_diagnostics,
}
for item in retrieval_spec.layer_queries:
layer_id = str(item.layer_id)
if query_sub_intent == "RELATED_DOCS_EXPLAIN" and layer_id in {"D1_DOCUMENT_CATALOG", "D0_DOC_CHUNKS"} and relation_hits >= 2:
layer_diagnostics[layer_id] = {"hits": 0, "top_ids": [], "skipped": True, "reason": "relation_primary_sufficient"}
continue
executed_layers.append(layer_id)
layer_rows = [row for row in self._rows if str(row.get("layer") or "") == layer_id]
result.extend(layer_rows)
layer_diagnostics[layer_id] = self._layer_diagnostics(layer_rows)
d2_empty = "D2_FACT_INDEX" in planned_layers and int(dict(layer_diagnostics.get("D2_FACT_INDEX") or {}).get("hits") or 0) == 0
d0_empty = "D0_DOC_CHUNKS" in planned_layers and int(dict(layer_diagnostics.get("D0_DOC_CHUNKS") or {}).get("hits") or 0) == 0
if support_paths and "D0_DOC_CHUNKS" in planned_layers and (d2_empty or d0_empty):
targeted = [
row for row in self._rows
if str(row.get("layer") or "") == "D0_DOC_CHUNKS" and str(row.get("path") or "").strip() in support_paths
]
merged = self._dedupe([*result, *targeted])
new_targeted = self._subtract_rows(merged, result)
result = merged
layer_diagnostics["D0_DOC_CHUNKS"] = self._layer_diagnostics(
[row for row in result if str(row.get("layer") or "") == "D0_DOC_CHUNKS"]
)
if new_targeted:
fallback_used = True
fallback_reason = "targeted_chunk_retrieval"
self._report["fallback"] = {"used": fallback_used, "reason": fallback_reason}
return self._dedupe(result)
def retrieve_exact_files(self, rag_session_id: str, *, repo_id=None, paths: list[str], layers=None, limit: int = 200, query: str = "", ranking_profile: str = "") -> list[dict]:
return []
def hydrate_resolved_symbol_sources(self, rag_session_id: str, base_query: str, rag_rows: list[dict], symbol_resolution: dict, retrieval_spec, retrieval_constraints=None) -> list[dict]:
return list(rag_rows)
def force_symbol_context_c0(self, rag_session_id: str, *, rag_rows: list[dict], symbol_resolution: dict, limit: int = 20) -> list[dict]:
return list(rag_rows)
def consume_retrieval_report(self) -> dict:
return dict(self._report)
def _layer_diagnostics(self, rows: list[dict]) -> dict:
top_ids: list[str] = []
top_sections: list[str] = []
for row in rows[:5]:
metadata = dict(row.get("metadata") or {})
candidate = metadata.get("document_id") or metadata.get("doc_id") or metadata.get("fact_id") or metadata.get("relation_id") or row.get("path")
value = str(candidate or "").strip()
if value and value not in top_ids:
top_ids.append(value)
title = str(row.get("title") or "").strip()
if title and title not in top_sections:
top_sections.append(title)
return {"hits": len(rows), "top_ids": top_ids, "top_sections": top_sections}
def _dedupe(self, rows: list[dict]) -> list[dict]:
result: list[dict] = []
seen: set[tuple[str, str, str, int | None, int | None]] = set()
for row in rows:
key = self._row_key(row)
if key in seen:
continue
seen.add(key)
result.append(row)
return result
def _subtract_rows(self, rows: list[dict], baseline: list[dict]) -> list[dict]:
baseline_keys = {self._row_key(row) for row in baseline}
return [row for row in rows if self._row_key(row) not in baseline_keys]
def _row_key(self, row: dict) -> tuple[str, str, str, int | None, int | None]:
return (
str(row.get("layer") or ""),
str(row.get("path") or ""),
str(row.get("title") or ""),
row.get("span_start"),
row.get("span_end"),
)