ййй
This commit is contained in:
@@ -0,0 +1,108 @@
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
class InMemoryDocsRetrievalAdapter:
|
||||
def __init__(self, rows: list[dict]) -> None:
|
||||
self._rows = list(rows)
|
||||
self._report: dict = {}
|
||||
|
||||
def retrieve_with_plan(self, rag_session_id: str, query: str, retrieval_spec, retrieval_constraints=None, *, query_plan=None) -> list[dict]:
|
||||
planned_layers = [str(item.layer_id) for item in retrieval_spec.layer_queries]
|
||||
query_sub_intent = str(getattr(query_plan, "sub_intent", "") or "")
|
||||
relation_rows = [row for row in self._rows if str(row.get("layer") or "") == "D5_RELATION_GRAPH"]
|
||||
relation_hits = len(relation_rows)
|
||||
result: list[dict] = []
|
||||
executed_layers: list[str] = []
|
||||
layer_diagnostics: dict[str, dict] = {}
|
||||
fallback_used = False
|
||||
fallback_reason = None
|
||||
support_paths = [
|
||||
str(row.get("path") or "").strip()
|
||||
for row in self._rows
|
||||
if str(row.get("layer") or "") in {"D1_DOCUMENT_CATALOG", "D2_FACT_INDEX", "D3_ENTITY_CATALOG", "D4_WORKFLOW_INDEX", "D5_RELATION_GRAPH"}
|
||||
and str(row.get("path") or "").strip()
|
||||
]
|
||||
self._report = {
|
||||
"planned_layers": planned_layers,
|
||||
"executed_layers": executed_layers,
|
||||
"retrieval_mode_by_layer": {},
|
||||
"fallback": {"used": False},
|
||||
"layer_diagnostics": layer_diagnostics,
|
||||
}
|
||||
for item in retrieval_spec.layer_queries:
|
||||
layer_id = str(item.layer_id)
|
||||
if query_sub_intent == "RELATED_DOCS_EXPLAIN" and layer_id in {"D1_DOCUMENT_CATALOG", "D0_DOC_CHUNKS"} and relation_hits >= 2:
|
||||
layer_diagnostics[layer_id] = {"hits": 0, "top_ids": [], "skipped": True, "reason": "relation_primary_sufficient"}
|
||||
continue
|
||||
executed_layers.append(layer_id)
|
||||
layer_rows = [row for row in self._rows if str(row.get("layer") or "") == layer_id]
|
||||
result.extend(layer_rows)
|
||||
layer_diagnostics[layer_id] = self._layer_diagnostics(layer_rows)
|
||||
d2_empty = "D2_FACT_INDEX" in planned_layers and int(dict(layer_diagnostics.get("D2_FACT_INDEX") or {}).get("hits") or 0) == 0
|
||||
d0_empty = "D0_DOC_CHUNKS" in planned_layers and int(dict(layer_diagnostics.get("D0_DOC_CHUNKS") or {}).get("hits") or 0) == 0
|
||||
if support_paths and "D0_DOC_CHUNKS" in planned_layers and (d2_empty or d0_empty):
|
||||
targeted = [
|
||||
row for row in self._rows
|
||||
if str(row.get("layer") or "") == "D0_DOC_CHUNKS" and str(row.get("path") or "").strip() in support_paths
|
||||
]
|
||||
merged = self._dedupe([*result, *targeted])
|
||||
new_targeted = self._subtract_rows(merged, result)
|
||||
result = merged
|
||||
layer_diagnostics["D0_DOC_CHUNKS"] = self._layer_diagnostics(
|
||||
[row for row in result if str(row.get("layer") or "") == "D0_DOC_CHUNKS"]
|
||||
)
|
||||
if new_targeted:
|
||||
fallback_used = True
|
||||
fallback_reason = "targeted_chunk_retrieval"
|
||||
self._report["fallback"] = {"used": fallback_used, "reason": fallback_reason}
|
||||
return self._dedupe(result)
|
||||
|
||||
def retrieve_exact_files(self, rag_session_id: str, *, repo_id=None, paths: list[str], layers=None, limit: int = 200, query: str = "", ranking_profile: str = "") -> list[dict]:
|
||||
return []
|
||||
|
||||
def hydrate_resolved_symbol_sources(self, rag_session_id: str, base_query: str, rag_rows: list[dict], symbol_resolution: dict, retrieval_spec, retrieval_constraints=None) -> list[dict]:
|
||||
return list(rag_rows)
|
||||
|
||||
def force_symbol_context_c0(self, rag_session_id: str, *, rag_rows: list[dict], symbol_resolution: dict, limit: int = 20) -> list[dict]:
|
||||
return list(rag_rows)
|
||||
|
||||
def consume_retrieval_report(self) -> dict:
|
||||
return dict(self._report)
|
||||
|
||||
def _layer_diagnostics(self, rows: list[dict]) -> dict:
|
||||
top_ids: list[str] = []
|
||||
top_sections: list[str] = []
|
||||
for row in rows[:5]:
|
||||
metadata = dict(row.get("metadata") or {})
|
||||
candidate = metadata.get("document_id") or metadata.get("doc_id") or metadata.get("fact_id") or metadata.get("relation_id") or row.get("path")
|
||||
value = str(candidate or "").strip()
|
||||
if value and value not in top_ids:
|
||||
top_ids.append(value)
|
||||
title = str(row.get("title") or "").strip()
|
||||
if title and title not in top_sections:
|
||||
top_sections.append(title)
|
||||
return {"hits": len(rows), "top_ids": top_ids, "top_sections": top_sections}
|
||||
|
||||
def _dedupe(self, rows: list[dict]) -> list[dict]:
|
||||
result: list[dict] = []
|
||||
seen: set[tuple[str, str, str, int | None, int | None]] = set()
|
||||
for row in rows:
|
||||
key = self._row_key(row)
|
||||
if key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
result.append(row)
|
||||
return result
|
||||
|
||||
def _subtract_rows(self, rows: list[dict], baseline: list[dict]) -> list[dict]:
|
||||
baseline_keys = {self._row_key(row) for row in baseline}
|
||||
return [row for row in rows if self._row_key(row) not in baseline_keys]
|
||||
|
||||
def _row_key(self, row: dict) -> tuple[str, str, str, int | None, int | None]:
|
||||
return (
|
||||
str(row.get("layer") or ""),
|
||||
str(row.get("path") or ""),
|
||||
str(row.get("title") or ""),
|
||||
row.get("span_start"),
|
||||
row.get("span_end"),
|
||||
)
|
||||
Reference in New Issue
Block a user