from __future__ import annotations class InMemoryDocsRetrievalAdapter: def __init__(self, rows: list[dict], materialized_rows: list[dict] | None = None) -> None: self._rows = list(rows) self._materialized_rows = list(materialized_rows or []) self._report: dict = {} def retrieve_with_plan(self, rag_session_id: str, query: str, retrieval_spec, retrieval_constraints=None, *, query_plan=None, trace=None) -> list[dict]: planned_layers = [str(item.layer_id) for item in retrieval_spec.layer_queries] query_sub_intent = str(getattr(query_plan, "sub_intent", "") or "") relation_rows = [row for row in self._rows if str(row.get("layer") or "") == "D5_RELATION_GRAPH"] relation_hits = len(relation_rows) result: list[dict] = [] executed_layers: list[str] = [] layer_diagnostics: dict[str, dict] = {} fallback_used = False fallback_reason = None support_paths = [ str(row.get("path") or "").strip() for row in self._rows if str(row.get("layer") or "") in {"D1_DOCUMENT_CATALOG", "D2_FACT_INDEX", "D3_ENTITY_CATALOG", "D4_WORKFLOW_INDEX", "D5_RELATION_GRAPH"} and str(row.get("path") or "").strip() ] self._report = { "planned_layers": planned_layers, "executed_layers": executed_layers, "retrieval_mode_by_layer": {}, "fallback": {"used": False}, "layer_diagnostics": layer_diagnostics, } for item in retrieval_spec.layer_queries: layer_id = str(item.layer_id) if query_sub_intent == "RELATED_DOCS_EXPLAIN" and layer_id in {"D1_DOCUMENT_CATALOG", "D0_DOC_CHUNKS"} and relation_hits >= 2: layer_diagnostics[layer_id] = {"hits": 0, "top_ids": [], "skipped": True, "reason": "relation_primary_sufficient"} continue executed_layers.append(layer_id) layer_rows = [row for row in self._rows if str(row.get("layer") or "") == layer_id] result.extend(layer_rows) layer_diagnostics[layer_id] = self._layer_diagnostics(layer_rows) d2_empty = "D2_FACT_INDEX" in planned_layers and int(dict(layer_diagnostics.get("D2_FACT_INDEX") or {}).get("hits") or 0) == 0 d0_empty = "D0_DOC_CHUNKS" in planned_layers and int(dict(layer_diagnostics.get("D0_DOC_CHUNKS") or {}).get("hits") or 0) == 0 if support_paths and "D0_DOC_CHUNKS" in planned_layers and (d2_empty or d0_empty): targeted = [ row for row in self._rows if str(row.get("layer") or "") == "D0_DOC_CHUNKS" and str(row.get("path") or "").strip() in support_paths ] merged = self._dedupe([*result, *targeted]) new_targeted = self._subtract_rows(merged, result) result = merged layer_diagnostics["D0_DOC_CHUNKS"] = self._layer_diagnostics( [row for row in result if str(row.get("layer") or "") == "D0_DOC_CHUNKS"] ) if new_targeted: fallback_used = True fallback_reason = "targeted_chunk_retrieval" self._report["fallback"] = {"used": fallback_used, "reason": fallback_reason} return self._dedupe(result) def retrieve_exact_files(self, rag_session_id: str, *, repo_id=None, paths: list[str], layers=None, limit: int = 200, query: str = "", ranking_profile: str = "") -> list[dict]: return [] def materialize_document_ids(self, rag_session_id: str, *, document_ids: list[str], layers=None, limit: int = 200) -> list[dict]: allowed_ids = {str(item).strip() for item in document_ids if str(item).strip()} if not allowed_ids: return [] allowed_layers = {str(item).strip() for item in list(layers or []) if str(item).strip()} result = [] for row in [*self._rows, *self._materialized_rows]: if allowed_layers and str(row.get("layer") or "") not in allowed_layers: continue metadata = dict(row.get("metadata") or {}) candidates = { str(metadata.get("document_id") or "").strip(), str(metadata.get("doc_id") or "").strip(), str(metadata.get("subject_id") or "").strip(), } if candidates & allowed_ids: result.append(row) return self._dedupe(result)[: max(1, int(limit))] def hydrate_resolved_symbol_sources(self, rag_session_id: str, base_query: str, rag_rows: list[dict], symbol_resolution: dict, retrieval_spec, retrieval_constraints=None) -> list[dict]: return list(rag_rows) def force_symbol_context_c0(self, rag_session_id: str, *, rag_rows: list[dict], symbol_resolution: dict, limit: int = 20) -> list[dict]: return list(rag_rows) def consume_retrieval_report(self) -> dict: return dict(self._report) def _layer_diagnostics(self, rows: list[dict]) -> dict: top_ids: list[str] = [] top_sections: list[str] = [] for row in rows[:5]: metadata = dict(row.get("metadata") or {}) candidate = metadata.get("document_id") or metadata.get("doc_id") or metadata.get("fact_id") or metadata.get("relation_id") or row.get("path") value = str(candidate or "").strip() if value and value not in top_ids: top_ids.append(value) title = str(row.get("title") or "").strip() if title and title not in top_sections: top_sections.append(title) return {"hits": len(rows), "top_ids": top_ids, "top_sections": top_sections} def _dedupe(self, rows: list[dict]) -> list[dict]: result: list[dict] = [] seen: set[tuple[str, str, str, int | None, int | None]] = set() for row in rows: key = self._row_key(row) if key in seen: continue seen.add(key) result.append(row) return result def _subtract_rows(self, rows: list[dict], baseline: list[dict]) -> list[dict]: baseline_keys = {self._row_key(row) for row in baseline} return [row for row in rows if self._row_key(row) not in baseline_keys] def _row_key(self, row: dict) -> tuple[str, str, str, int | None, int | None]: return ( str(row.get("layer") or ""), str(row.get("path") or ""), str(row.get("title") or ""), row.get("span_start"), row.get("span_end"), )