Files
agent/app/modules/rag/persistence/query_repository.py

73 lines
2.4 KiB
Python

from __future__ import annotations
import json
from sqlalchemy import text
from app.modules.rag.persistence.retrieval_statement_builder import RetrievalStatementBuilder
from app.modules.shared.db import get_engine
class RagQueryRepository:
def __init__(self) -> None:
self._builder = RetrievalStatementBuilder()
def retrieve(
self,
rag_session_id: str,
query_embedding: list[float],
*,
query_text: str = "",
limit: int = 5,
layers: list[str] | None = None,
path_prefixes: list[str] | None = None,
exclude_path_prefixes: list[str] | None = None,
exclude_like_patterns: list[str] | None = None,
prefer_non_tests: bool = False,
) -> list[dict]:
sql, params = self._builder.build_retrieve(
rag_session_id,
query_embedding,
query_text=query_text,
limit=limit,
layers=layers,
path_prefixes=path_prefixes,
exclude_path_prefixes=exclude_path_prefixes,
exclude_like_patterns=exclude_like_patterns,
prefer_non_tests=prefer_non_tests,
)
with get_engine().connect() as conn:
rows = conn.execute(text(sql), params).mappings().fetchall()
return [self._row_to_dict(row) for row in rows]
def retrieve_lexical_code(
self,
rag_session_id: str,
*,
query_text: str,
limit: int = 5,
path_prefixes: list[str] | None = None,
exclude_path_prefixes: list[str] | None = None,
exclude_like_patterns: list[str] | None = None,
prefer_non_tests: bool = False,
) -> list[dict]:
sql, params = self._builder.build_lexical_code(
rag_session_id,
query_text=query_text,
limit=limit,
path_prefixes=path_prefixes,
exclude_path_prefixes=exclude_path_prefixes,
exclude_like_patterns=exclude_like_patterns,
prefer_non_tests=prefer_non_tests,
)
if sql is None:
return []
with get_engine().connect() as conn:
rows = conn.execute(text(sql), params).mappings().fetchall()
return [self._row_to_dict(row) for row in rows]
def _row_to_dict(self, row) -> dict:
data = dict(row)
data["metadata"] = json.loads(str(data.pop("metadata_json") or "{}"))
return data