from __future__ import annotations from app.modules.rag.retrieval.query_terms import extract_query_terms _LIKE_ESCAPE_SQL = " ESCAPE E'\\\\'" class RetrievalStatementBuilder: def build_retrieve( self, rag_session_id: str, query_embedding: list[float], *, query_text: str = "", limit: int = 5, layers: list[str] | None = None, path_prefixes: list[str] | None = None, exclude_path_prefixes: list[str] | None = None, exclude_like_patterns: list[str] | None = None, prefer_non_tests: bool = False, ) -> tuple[str, dict]: emb = "[" + ",".join(str(x) for x in query_embedding) + "]" filters = ["rag_session_id = :sid"] params: dict = {"sid": rag_session_id, "emb": emb, "lim": limit} self._append_prefix_group(filters, params, "path", path_prefixes) self._append_prefix_group(filters, params, "exclude_prefix", exclude_path_prefixes, negate=True) self._append_like_group(filters, params, "exclude_like", exclude_like_patterns, negate=True) if layers: filters.append("layer = ANY(:layers)") params["layers"] = layers lexical_sql = self._lexical_rank_sql(query_text, params) test_penalty_sql = self._test_penalty_sql( prefer_non_tests, params, base_key="penalty", path_prefixes=exclude_path_prefixes, like_patterns=exclude_like_patterns, ) layer_rank_sql = ( "CASE " "WHEN layer = 'C3_ENTRYPOINTS' THEN 0 " "WHEN layer = 'C1_SYMBOL_CATALOG' THEN 1 " "WHEN layer = 'C2_DEPENDENCY_GRAPH' THEN 2 " "WHEN layer = 'C0_SOURCE_CHUNKS' THEN 3 " "WHEN layer = 'D1_MODULE_CATALOG' THEN 0 " "WHEN layer = 'D2_FACT_INDEX' THEN 1 " "WHEN layer = 'D3_SECTION_INDEX' THEN 2 " "WHEN layer = 'D4_POLICY_INDEX' THEN 3 " "ELSE 10 END" ) sql = f""" SELECT path, content, layer, title, metadata_json, span_start, span_end, {lexical_sql} AS lexical_rank, {test_penalty_sql} AS test_penalty, {layer_rank_sql} AS layer_rank, (embedding <=> CAST(:emb AS vector)) AS distance FROM rag_chunks WHERE {' AND '.join(filters)} ORDER BY lexical_rank ASC, test_penalty ASC, layer_rank ASC, embedding <=> CAST(:emb AS vector) LIMIT :lim """ return sql, params def build_lexical_code( self, rag_session_id: str, *, query_text: str, limit: int = 5, path_prefixes: list[str] | None = None, exclude_path_prefixes: list[str] | None = None, exclude_like_patterns: list[str] | None = None, prefer_non_tests: bool = False, ) -> tuple[str | None, dict]: terms = extract_query_terms(query_text) if not terms: return None, {} filters = ["rag_session_id = :sid", "layer = 'C0_SOURCE_CHUNKS'"] params: dict = {"sid": rag_session_id, "lim": limit} self._append_prefix_group(filters, params, "path", path_prefixes) self._append_prefix_group(filters, params, "exclude_prefix", exclude_path_prefixes, negate=True) self._append_like_group(filters, params, "exclude_like", exclude_like_patterns, negate=True) lexical_filters: list[str] = [] lexical_ranks: list[str] = [] for idx, term in enumerate(terms): exact_key = f"lex_exact_{idx}" prefix_key = f"lex_prefix_{idx}" contains_key = f"lex_contains_{idx}" params[exact_key] = term params[prefix_key] = f"{term}%" params[contains_key] = f"%{term}%" lexical_filters.append( f"(lower(COALESCE(qname, '')) = :{exact_key} " f"OR lower(COALESCE(title, '')) = :{exact_key} " f"OR lower(COALESCE(path, '')) LIKE :{contains_key} " f"OR lower(COALESCE(title, '')) LIKE :{prefix_key} " f"OR lower(COALESCE(content, '')) LIKE :{contains_key})" ) lexical_ranks.append( "CASE " f"WHEN lower(COALESCE(qname, '')) = :{exact_key} THEN 0 " f"WHEN lower(COALESCE(title, '')) = :{exact_key} THEN 1 " f"WHEN lower(COALESCE(title, '')) LIKE :{prefix_key} THEN 2 " f"WHEN lower(COALESCE(path, '')) LIKE :{contains_key} THEN 3 " f"WHEN lower(COALESCE(content, '')) LIKE :{contains_key} THEN 4 " "ELSE 100 END" ) filters.append("(" + " OR ".join(lexical_filters) + ")") lexical_sql = "LEAST(" + ", ".join(lexical_ranks) + ")" test_penalty_sql = self._test_penalty_sql( prefer_non_tests, params, base_key="lex_penalty", path_prefixes=exclude_path_prefixes, like_patterns=exclude_like_patterns, ) sql = f""" SELECT path, content, layer, title, metadata_json, span_start, span_end, {lexical_sql} AS lexical_rank, {test_penalty_sql} AS test_penalty FROM rag_chunks WHERE {' AND '.join(filters)} ORDER BY lexical_rank ASC, test_penalty ASC, path ASC, span_start ASC LIMIT :lim """ return sql, params def _lexical_rank_sql(self, query_text: str, params: dict) -> str: term_filters: list[str] = [] for idx, term in enumerate(extract_query_terms(query_text)): exact_key = f"term_exact_{idx}" prefix_key = f"term_prefix_{idx}" contains_key = f"term_contains_{idx}" params[exact_key] = term params[prefix_key] = f"{term}%" params[contains_key] = f"%{term}%" term_filters.append( "CASE " f"WHEN lower(COALESCE(qname, '')) = :{exact_key} THEN 0 " f"WHEN lower(COALESCE(symbol_id, '')) = :{exact_key} THEN 1 " f"WHEN lower(COALESCE(title, '')) = :{exact_key} THEN 2 " f"WHEN lower(COALESCE(qname, '')) LIKE :{prefix_key} THEN 3 " f"WHEN lower(COALESCE(title, '')) LIKE :{prefix_key} THEN 4 " f"WHEN lower(COALESCE(path, '')) LIKE :{contains_key} THEN 5 " f"WHEN lower(COALESCE(content, '')) LIKE :{contains_key} THEN 6 " "ELSE 100 END" ) return "LEAST(" + ", ".join(term_filters) + ")" if term_filters else "100" def _append_prefix_group(self, filters: list[str], params: dict, base_key: str, prefixes: list[str] | None, *, negate: bool = False) -> None: if not prefixes: return items: list[str] = [] for idx, prefix in enumerate(prefixes): key = f"{base_key}_{idx}" params[key] = self._escape_like_value(prefix) + "%" items.append(f"path LIKE :{key}{_LIKE_ESCAPE_SQL}") self._append_group(filters, items, negate=negate) def _append_like_group(self, filters: list[str], params: dict, base_key: str, patterns: list[str] | None, *, negate: bool = False) -> None: if not patterns: return items: list[str] = [] for idx, pattern in enumerate(patterns): key = f"{base_key}_{idx}" params[key] = pattern items.append(f"lower(path) LIKE :{key}{_LIKE_ESCAPE_SQL}") self._append_group(filters, items, negate=negate) def _append_group(self, filters: list[str], parts: list[str], *, negate: bool) -> None: if not parts: return joined = " OR ".join(parts) filters.append(f"NOT ({joined})" if negate else f"({joined})") def _test_penalty_sql( self, enabled: bool, params: dict, *, base_key: str, path_prefixes: list[str] | None, like_patterns: list[str] | None, ) -> str: if not enabled: return "0" parts: list[str] = [] for idx, prefix in enumerate(path_prefixes or []): key = f"{base_key}_prefix_{idx}" params[key] = self._escape_like_value(prefix) + "%" parts.append(f"lower(path) LIKE :{key}{_LIKE_ESCAPE_SQL}") for idx, pattern in enumerate(like_patterns or []): key = f"{base_key}_like_{idx}" params[key] = pattern parts.append(f"lower(path) LIKE :{key}{_LIKE_ESCAPE_SQL}") if not parts: return "0" return "CASE WHEN " + " OR ".join(parts) + " THEN 1 ELSE 0 END" def _escape_like_value(self, value: str) -> str: return value.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_")