129 lines
6.4 KiB
Python
129 lines
6.4 KiB
Python
from __future__ import annotations
|
|
|
|
|
|
class InMemoryDocsRetrievalAdapter:
|
|
def __init__(self, rows: list[dict], materialized_rows: list[dict] | None = None) -> None:
|
|
self._rows = list(rows)
|
|
self._materialized_rows = list(materialized_rows or [])
|
|
self._report: dict = {}
|
|
|
|
def retrieve_with_plan(self, rag_session_id: str, query: str, retrieval_spec, retrieval_constraints=None, *, query_plan=None, trace=None) -> list[dict]:
|
|
planned_layers = [str(item.layer_id) for item in retrieval_spec.layer_queries]
|
|
query_sub_intent = str(getattr(query_plan, "sub_intent", "") or "")
|
|
relation_rows = [row for row in self._rows if str(row.get("layer") or "") == "D5_RELATION_GRAPH"]
|
|
relation_hits = len(relation_rows)
|
|
result: list[dict] = []
|
|
executed_layers: list[str] = []
|
|
layer_diagnostics: dict[str, dict] = {}
|
|
fallback_used = False
|
|
fallback_reason = None
|
|
support_paths = [
|
|
str(row.get("path") or "").strip()
|
|
for row in self._rows
|
|
if str(row.get("layer") or "") in {"D1_DOCUMENT_CATALOG", "D2_FACT_INDEX", "D3_ENTITY_CATALOG", "D4_WORKFLOW_INDEX", "D5_RELATION_GRAPH"}
|
|
and str(row.get("path") or "").strip()
|
|
]
|
|
self._report = {
|
|
"planned_layers": planned_layers,
|
|
"executed_layers": executed_layers,
|
|
"retrieval_mode_by_layer": {},
|
|
"fallback": {"used": False},
|
|
"layer_diagnostics": layer_diagnostics,
|
|
}
|
|
for item in retrieval_spec.layer_queries:
|
|
layer_id = str(item.layer_id)
|
|
if query_sub_intent == "RELATED_DOCS_EXPLAIN" and layer_id in {"D1_DOCUMENT_CATALOG", "D0_DOC_CHUNKS"} and relation_hits >= 2:
|
|
layer_diagnostics[layer_id] = {"hits": 0, "top_ids": [], "skipped": True, "reason": "relation_primary_sufficient"}
|
|
continue
|
|
executed_layers.append(layer_id)
|
|
layer_rows = [row for row in self._rows if str(row.get("layer") or "") == layer_id]
|
|
result.extend(layer_rows)
|
|
layer_diagnostics[layer_id] = self._layer_diagnostics(layer_rows)
|
|
d2_empty = "D2_FACT_INDEX" in planned_layers and int(dict(layer_diagnostics.get("D2_FACT_INDEX") or {}).get("hits") or 0) == 0
|
|
d0_empty = "D0_DOC_CHUNKS" in planned_layers and int(dict(layer_diagnostics.get("D0_DOC_CHUNKS") or {}).get("hits") or 0) == 0
|
|
if support_paths and "D0_DOC_CHUNKS" in planned_layers and (d2_empty or d0_empty):
|
|
targeted = [
|
|
row for row in self._rows
|
|
if str(row.get("layer") or "") == "D0_DOC_CHUNKS" and str(row.get("path") or "").strip() in support_paths
|
|
]
|
|
merged = self._dedupe([*result, *targeted])
|
|
new_targeted = self._subtract_rows(merged, result)
|
|
result = merged
|
|
layer_diagnostics["D0_DOC_CHUNKS"] = self._layer_diagnostics(
|
|
[row for row in result if str(row.get("layer") or "") == "D0_DOC_CHUNKS"]
|
|
)
|
|
if new_targeted:
|
|
fallback_used = True
|
|
fallback_reason = "targeted_chunk_retrieval"
|
|
self._report["fallback"] = {"used": fallback_used, "reason": fallback_reason}
|
|
return self._dedupe(result)
|
|
|
|
def retrieve_exact_files(self, rag_session_id: str, *, repo_id=None, paths: list[str], layers=None, limit: int = 200, query: str = "", ranking_profile: str = "") -> list[dict]:
|
|
return []
|
|
|
|
def materialize_document_ids(self, rag_session_id: str, *, document_ids: list[str], layers=None, limit: int = 200) -> list[dict]:
|
|
allowed_ids = {str(item).strip() for item in document_ids if str(item).strip()}
|
|
if not allowed_ids:
|
|
return []
|
|
allowed_layers = {str(item).strip() for item in list(layers or []) if str(item).strip()}
|
|
result = []
|
|
for row in [*self._rows, *self._materialized_rows]:
|
|
if allowed_layers and str(row.get("layer") or "") not in allowed_layers:
|
|
continue
|
|
metadata = dict(row.get("metadata") or {})
|
|
candidates = {
|
|
str(metadata.get("document_id") or "").strip(),
|
|
str(metadata.get("doc_id") or "").strip(),
|
|
str(metadata.get("subject_id") or "").strip(),
|
|
}
|
|
if candidates & allowed_ids:
|
|
result.append(row)
|
|
return self._dedupe(result)[: max(1, int(limit))]
|
|
|
|
def hydrate_resolved_symbol_sources(self, rag_session_id: str, base_query: str, rag_rows: list[dict], symbol_resolution: dict, retrieval_spec, retrieval_constraints=None) -> list[dict]:
|
|
return list(rag_rows)
|
|
|
|
def force_symbol_context_c0(self, rag_session_id: str, *, rag_rows: list[dict], symbol_resolution: dict, limit: int = 20) -> list[dict]:
|
|
return list(rag_rows)
|
|
|
|
def consume_retrieval_report(self) -> dict:
|
|
return dict(self._report)
|
|
|
|
def _layer_diagnostics(self, rows: list[dict]) -> dict:
|
|
top_ids: list[str] = []
|
|
top_sections: list[str] = []
|
|
for row in rows[:5]:
|
|
metadata = dict(row.get("metadata") or {})
|
|
candidate = metadata.get("document_id") or metadata.get("doc_id") or metadata.get("fact_id") or metadata.get("relation_id") or row.get("path")
|
|
value = str(candidate or "").strip()
|
|
if value and value not in top_ids:
|
|
top_ids.append(value)
|
|
title = str(row.get("title") or "").strip()
|
|
if title and title not in top_sections:
|
|
top_sections.append(title)
|
|
return {"hits": len(rows), "top_ids": top_ids, "top_sections": top_sections}
|
|
|
|
def _dedupe(self, rows: list[dict]) -> list[dict]:
|
|
result: list[dict] = []
|
|
seen: set[tuple[str, str, str, int | None, int | None]] = set()
|
|
for row in rows:
|
|
key = self._row_key(row)
|
|
if key in seen:
|
|
continue
|
|
seen.add(key)
|
|
result.append(row)
|
|
return result
|
|
|
|
def _subtract_rows(self, rows: list[dict], baseline: list[dict]) -> list[dict]:
|
|
baseline_keys = {self._row_key(row) for row in baseline}
|
|
return [row for row in rows if self._row_key(row) not in baseline_keys]
|
|
|
|
def _row_key(self, row: dict) -> tuple[str, str, str, int | None, int | None]:
|
|
return (
|
|
str(row.get("layer") or ""),
|
|
str(row.get("path") or ""),
|
|
str(row.get("title") or ""),
|
|
row.get("span_start"),
|
|
row.get("span_end"),
|
|
)
|