Files
agent/app/modules/rag/persistence/retrieval_statement_builder.py

202 lines
8.7 KiB
Python

from __future__ import annotations
from app.modules.rag.retrieval.query_terms import extract_query_terms
_LIKE_ESCAPE_SQL = " ESCAPE E'\\\\'"
class RetrievalStatementBuilder:
def build_retrieve(
self,
rag_session_id: str,
query_embedding: list[float],
*,
query_text: str = "",
limit: int = 5,
layers: list[str] | None = None,
path_prefixes: list[str] | None = None,
exclude_path_prefixes: list[str] | None = None,
exclude_like_patterns: list[str] | None = None,
prefer_non_tests: bool = False,
) -> tuple[str, dict]:
emb = "[" + ",".join(str(x) for x in query_embedding) + "]"
filters = ["rag_session_id = :sid"]
params: dict = {"sid": rag_session_id, "emb": emb, "lim": limit}
self._append_prefix_group(filters, params, "path", path_prefixes)
self._append_prefix_group(filters, params, "exclude_prefix", exclude_path_prefixes, negate=True)
self._append_like_group(filters, params, "exclude_like", exclude_like_patterns, negate=True)
if layers:
filters.append("layer = ANY(:layers)")
params["layers"] = layers
lexical_sql = self._lexical_rank_sql(query_text, params)
test_penalty_sql = self._test_penalty_sql(
prefer_non_tests,
params,
base_key="penalty",
path_prefixes=exclude_path_prefixes,
like_patterns=exclude_like_patterns,
)
layer_rank_sql = (
"CASE "
"WHEN layer = 'C3_ENTRYPOINTS' THEN 0 "
"WHEN layer = 'C1_SYMBOL_CATALOG' THEN 1 "
"WHEN layer = 'C2_DEPENDENCY_GRAPH' THEN 2 "
"WHEN layer = 'C0_SOURCE_CHUNKS' THEN 3 "
"WHEN layer = 'D1_MODULE_CATALOG' THEN 0 "
"WHEN layer = 'D2_FACT_INDEX' THEN 1 "
"WHEN layer = 'D3_SECTION_INDEX' THEN 2 "
"WHEN layer = 'D4_POLICY_INDEX' THEN 3 "
"ELSE 10 END"
)
sql = f"""
SELECT path, content, layer, title, metadata_json, span_start, span_end,
{lexical_sql} AS lexical_rank,
{test_penalty_sql} AS test_penalty,
{layer_rank_sql} AS layer_rank,
(embedding <=> CAST(:emb AS vector)) AS distance
FROM rag_chunks
WHERE {' AND '.join(filters)}
ORDER BY lexical_rank ASC, test_penalty ASC, layer_rank ASC, embedding <=> CAST(:emb AS vector)
LIMIT :lim
"""
return sql, params
def build_lexical_code(
self,
rag_session_id: str,
*,
query_text: str,
limit: int = 5,
path_prefixes: list[str] | None = None,
exclude_path_prefixes: list[str] | None = None,
exclude_like_patterns: list[str] | None = None,
prefer_non_tests: bool = False,
) -> tuple[str | None, dict]:
terms = extract_query_terms(query_text)
if not terms:
return None, {}
filters = ["rag_session_id = :sid", "layer = 'C0_SOURCE_CHUNKS'"]
params: dict = {"sid": rag_session_id, "lim": limit}
self._append_prefix_group(filters, params, "path", path_prefixes)
self._append_prefix_group(filters, params, "exclude_prefix", exclude_path_prefixes, negate=True)
self._append_like_group(filters, params, "exclude_like", exclude_like_patterns, negate=True)
lexical_filters: list[str] = []
lexical_ranks: list[str] = []
for idx, term in enumerate(terms):
exact_key = f"lex_exact_{idx}"
prefix_key = f"lex_prefix_{idx}"
contains_key = f"lex_contains_{idx}"
params[exact_key] = term
params[prefix_key] = f"{term}%"
params[contains_key] = f"%{term}%"
lexical_filters.append(
f"(lower(COALESCE(qname, '')) = :{exact_key} "
f"OR lower(COALESCE(title, '')) = :{exact_key} "
f"OR lower(COALESCE(path, '')) LIKE :{contains_key} "
f"OR lower(COALESCE(title, '')) LIKE :{prefix_key} "
f"OR lower(COALESCE(content, '')) LIKE :{contains_key})"
)
lexical_ranks.append(
"CASE "
f"WHEN lower(COALESCE(qname, '')) = :{exact_key} THEN 0 "
f"WHEN lower(COALESCE(title, '')) = :{exact_key} THEN 1 "
f"WHEN lower(COALESCE(title, '')) LIKE :{prefix_key} THEN 2 "
f"WHEN lower(COALESCE(path, '')) LIKE :{contains_key} THEN 3 "
f"WHEN lower(COALESCE(content, '')) LIKE :{contains_key} THEN 4 "
"ELSE 100 END"
)
filters.append("(" + " OR ".join(lexical_filters) + ")")
lexical_sql = "LEAST(" + ", ".join(lexical_ranks) + ")"
test_penalty_sql = self._test_penalty_sql(
prefer_non_tests,
params,
base_key="lex_penalty",
path_prefixes=exclude_path_prefixes,
like_patterns=exclude_like_patterns,
)
sql = f"""
SELECT path, content, layer, title, metadata_json, span_start, span_end,
{lexical_sql} AS lexical_rank,
{test_penalty_sql} AS test_penalty
FROM rag_chunks
WHERE {' AND '.join(filters)}
ORDER BY lexical_rank ASC, test_penalty ASC, path ASC, span_start ASC
LIMIT :lim
"""
return sql, params
def _lexical_rank_sql(self, query_text: str, params: dict) -> str:
term_filters: list[str] = []
for idx, term in enumerate(extract_query_terms(query_text)):
exact_key = f"term_exact_{idx}"
prefix_key = f"term_prefix_{idx}"
contains_key = f"term_contains_{idx}"
params[exact_key] = term
params[prefix_key] = f"{term}%"
params[contains_key] = f"%{term}%"
term_filters.append(
"CASE "
f"WHEN lower(COALESCE(qname, '')) = :{exact_key} THEN 0 "
f"WHEN lower(COALESCE(symbol_id, '')) = :{exact_key} THEN 1 "
f"WHEN lower(COALESCE(title, '')) = :{exact_key} THEN 2 "
f"WHEN lower(COALESCE(qname, '')) LIKE :{prefix_key} THEN 3 "
f"WHEN lower(COALESCE(title, '')) LIKE :{prefix_key} THEN 4 "
f"WHEN lower(COALESCE(path, '')) LIKE :{contains_key} THEN 5 "
f"WHEN lower(COALESCE(content, '')) LIKE :{contains_key} THEN 6 "
"ELSE 100 END"
)
return "LEAST(" + ", ".join(term_filters) + ")" if term_filters else "100"
def _append_prefix_group(self, filters: list[str], params: dict, base_key: str, prefixes: list[str] | None, *, negate: bool = False) -> None:
if not prefixes:
return
items: list[str] = []
for idx, prefix in enumerate(prefixes):
key = f"{base_key}_{idx}"
params[key] = self._escape_like_value(prefix) + "%"
items.append(f"path LIKE :{key}{_LIKE_ESCAPE_SQL}")
self._append_group(filters, items, negate=negate)
def _append_like_group(self, filters: list[str], params: dict, base_key: str, patterns: list[str] | None, *, negate: bool = False) -> None:
if not patterns:
return
items: list[str] = []
for idx, pattern in enumerate(patterns):
key = f"{base_key}_{idx}"
params[key] = pattern
items.append(f"lower(path) LIKE :{key}{_LIKE_ESCAPE_SQL}")
self._append_group(filters, items, negate=negate)
def _append_group(self, filters: list[str], parts: list[str], *, negate: bool) -> None:
if not parts:
return
joined = " OR ".join(parts)
filters.append(f"NOT ({joined})" if negate else f"({joined})")
def _test_penalty_sql(
self,
enabled: bool,
params: dict,
*,
base_key: str,
path_prefixes: list[str] | None,
like_patterns: list[str] | None,
) -> str:
if not enabled:
return "0"
parts: list[str] = []
for idx, prefix in enumerate(path_prefixes or []):
key = f"{base_key}_prefix_{idx}"
params[key] = self._escape_like_value(prefix) + "%"
parts.append(f"lower(path) LIKE :{key}{_LIKE_ESCAPE_SQL}")
for idx, pattern in enumerate(like_patterns or []):
key = f"{base_key}_like_{idx}"
params[key] = pattern
parts.append(f"lower(path) LIKE :{key}{_LIKE_ESCAPE_SQL}")
if not parts:
return "0"
return "CASE WHEN " + " OR ".join(parts) + " THEN 1 ELSE 0 END"
def _escape_like_value(self, value: str) -> str:
return value.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_")