from __future__ import annotations import json from sqlalchemy import text from app.modules.rag.persistence.retrieval_statement_builder import RetrievalStatementBuilder from app.modules.shared.db import get_engine class RagQueryRepository: def __init__(self) -> None: self._builder = RetrievalStatementBuilder() def retrieve( self, rag_session_id: str, query_embedding: list[float], *, query_text: str = "", limit: int = 5, layers: list[str] | None = None, path_prefixes: list[str] | None = None, exclude_path_prefixes: list[str] | None = None, exclude_like_patterns: list[str] | None = None, prefer_non_tests: bool = False, ) -> list[dict]: sql, params = self._builder.build_retrieve( rag_session_id, query_embedding, query_text=query_text, limit=limit, layers=layers, path_prefixes=path_prefixes, exclude_path_prefixes=exclude_path_prefixes, exclude_like_patterns=exclude_like_patterns, prefer_non_tests=prefer_non_tests, ) with get_engine().connect() as conn: rows = conn.execute(text(sql), params).mappings().fetchall() return [self._row_to_dict(row) for row in rows] def retrieve_lexical_code( self, rag_session_id: str, *, query_text: str, limit: int = 5, path_prefixes: list[str] | None = None, exclude_path_prefixes: list[str] | None = None, exclude_like_patterns: list[str] | None = None, prefer_non_tests: bool = False, ) -> list[dict]: sql, params = self._builder.build_lexical_code( rag_session_id, query_text=query_text, limit=limit, path_prefixes=path_prefixes, exclude_path_prefixes=exclude_path_prefixes, exclude_like_patterns=exclude_like_patterns, prefer_non_tests=prefer_non_tests, ) if sql is None: return [] with get_engine().connect() as conn: rows = conn.execute(text(sql), params).mappings().fetchall() return [self._row_to_dict(row) for row in rows] def _row_to_dict(self, row) -> dict: data = dict(row) data["metadata"] = json.loads(str(data.pop("metadata_json") or "{}")) return data