from __future__ import annotations import json from sqlalchemy import text from app.modules.rag.retrieval.query_terms import extract_query_terms from app.modules.shared.db import get_engine class RagQueryRepository: def retrieve( self, rag_session_id: str, query_embedding: list[float], *, query_text: str = "", limit: int = 5, layers: list[str] | None = None, path_prefixes: list[str] | None = None, prefer_non_tests: bool = False, ) -> list[dict]: emb = "[" + ",".join(str(x) for x in query_embedding) + "]" filters = ["rag_session_id = :sid"] params: dict = {"sid": rag_session_id, "emb": emb, "lim": limit} if layers: filters.append("layer = ANY(:layers)") params["layers"] = layers if path_prefixes: or_filters = [] for idx, prefix in enumerate(path_prefixes): key = f"path_{idx}" params[key] = f"{prefix}%" or_filters.append(f"path LIKE :{key}") filters.append("(" + " OR ".join(or_filters) + ")") term_filters = [] terms = extract_query_terms(query_text) for idx, term in enumerate(terms): exact_key = f"term_exact_{idx}" prefix_key = f"term_prefix_{idx}" contains_key = f"term_contains_{idx}" params[exact_key] = term params[prefix_key] = f"{term}%" params[contains_key] = f"%{term}%" term_filters.append( "CASE " f"WHEN lower(COALESCE(qname, '')) = :{exact_key} THEN 0 " f"WHEN lower(COALESCE(symbol_id, '')) = :{exact_key} THEN 1 " f"WHEN lower(COALESCE(title, '')) = :{exact_key} THEN 2 " f"WHEN lower(COALESCE(qname, '')) LIKE :{prefix_key} THEN 3 " f"WHEN lower(COALESCE(title, '')) LIKE :{prefix_key} THEN 4 " f"WHEN lower(COALESCE(path, '')) LIKE :{contains_key} THEN 5 " f"WHEN lower(COALESCE(content, '')) LIKE :{contains_key} THEN 6 " "ELSE 100 END" ) lexical_sql = "LEAST(" + ", ".join(term_filters) + ")" if term_filters else "100" test_penalty_sql = ( "CASE " "WHEN lower(path) LIKE 'tests/%' OR lower(path) LIKE '%/tests/%' OR lower(path) LIKE 'test_%' OR lower(path) LIKE '%/test_%' " "THEN 1 ELSE 0 END" if prefer_non_tests else "0" ) layer_rank_sql = ( "CASE " "WHEN layer = 'C3_ENTRYPOINTS' THEN 0 " "WHEN layer = 'C1_SYMBOL_CATALOG' THEN 1 " "WHEN layer = 'C2_DEPENDENCY_GRAPH' THEN 2 " "WHEN layer = 'C0_SOURCE_CHUNKS' THEN 3 " "WHEN layer = 'D1_MODULE_CATALOG' THEN 0 " "WHEN layer = 'D2_FACT_INDEX' THEN 1 " "WHEN layer = 'D3_SECTION_INDEX' THEN 2 " "WHEN layer = 'D4_POLICY_INDEX' THEN 3 " "ELSE 10 END" ) sql = f""" SELECT path, content, layer, title, metadata_json, span_start, span_end, {lexical_sql} AS lexical_rank, {test_penalty_sql} AS test_penalty, {layer_rank_sql} AS layer_rank, (embedding <=> CAST(:emb AS vector)) AS distance FROM rag_chunks WHERE {' AND '.join(filters)} ORDER BY lexical_rank ASC, test_penalty ASC, layer_rank ASC, embedding <=> CAST(:emb AS vector) LIMIT :lim """ with get_engine().connect() as conn: rows = conn.execute(text(sql), params).mappings().fetchall() return [self._row_to_dict(row) for row in rows] def fallback_chunks(self, rag_session_id: str, *, limit: int = 5, layers: list[str] | None = None) -> list[dict]: filters = ["rag_session_id = :sid"] params: dict = {"sid": rag_session_id, "lim": limit} if layers: filters.append("layer = ANY(:layers)") params["layers"] = layers sql = f""" SELECT path, content, layer, title, metadata_json, span_start, span_end FROM rag_chunks WHERE {' AND '.join(filters)} ORDER BY id DESC LIMIT :lim """ with get_engine().connect() as conn: rows = conn.execute(text(sql), params).mappings().fetchall() return [self._row_to_dict(row) for row in rows] def _row_to_dict(self, row) -> dict: data = dict(row) data["metadata"] = json.loads(str(data.pop("metadata_json") or "{}")) return data