73 lines
2.4 KiB
Python
73 lines
2.4 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
|
|
from sqlalchemy import text
|
|
|
|
from app.modules.rag.persistence.retrieval_statement_builder import RetrievalStatementBuilder
|
|
from app.modules.shared.db import get_engine
|
|
|
|
|
|
class RagQueryRepository:
|
|
def __init__(self) -> None:
|
|
self._builder = RetrievalStatementBuilder()
|
|
|
|
def retrieve(
|
|
self,
|
|
rag_session_id: str,
|
|
query_embedding: list[float],
|
|
*,
|
|
query_text: str = "",
|
|
limit: int = 5,
|
|
layers: list[str] | None = None,
|
|
path_prefixes: list[str] | None = None,
|
|
exclude_path_prefixes: list[str] | None = None,
|
|
exclude_like_patterns: list[str] | None = None,
|
|
prefer_non_tests: bool = False,
|
|
) -> list[dict]:
|
|
sql, params = self._builder.build_retrieve(
|
|
rag_session_id,
|
|
query_embedding,
|
|
query_text=query_text,
|
|
limit=limit,
|
|
layers=layers,
|
|
path_prefixes=path_prefixes,
|
|
exclude_path_prefixes=exclude_path_prefixes,
|
|
exclude_like_patterns=exclude_like_patterns,
|
|
prefer_non_tests=prefer_non_tests,
|
|
)
|
|
with get_engine().connect() as conn:
|
|
rows = conn.execute(text(sql), params).mappings().fetchall()
|
|
return [self._row_to_dict(row) for row in rows]
|
|
|
|
def retrieve_lexical_code(
|
|
self,
|
|
rag_session_id: str,
|
|
*,
|
|
query_text: str,
|
|
limit: int = 5,
|
|
path_prefixes: list[str] | None = None,
|
|
exclude_path_prefixes: list[str] | None = None,
|
|
exclude_like_patterns: list[str] | None = None,
|
|
prefer_non_tests: bool = False,
|
|
) -> list[dict]:
|
|
sql, params = self._builder.build_lexical_code(
|
|
rag_session_id,
|
|
query_text=query_text,
|
|
limit=limit,
|
|
path_prefixes=path_prefixes,
|
|
exclude_path_prefixes=exclude_path_prefixes,
|
|
exclude_like_patterns=exclude_like_patterns,
|
|
prefer_non_tests=prefer_non_tests,
|
|
)
|
|
if sql is None:
|
|
return []
|
|
with get_engine().connect() as conn:
|
|
rows = conn.execute(text(sql), params).mappings().fetchall()
|
|
return [self._row_to_dict(row) for row in rows]
|
|
|
|
def _row_to_dict(self, row) -> dict:
|
|
data = dict(row)
|
|
data["metadata"] = json.loads(str(data.pop("metadata_json") or "{}"))
|
|
return data
|