from __future__ import annotations from app.modules.rag.contracts import RagDocument from app.modules.rag.persistence.cache_repository import RagCacheRepository from app.modules.rag.persistence.document_repository import RagDocumentRepository from app.modules.rag.persistence.job_repository import RagJobRepository, RagJobRow from app.modules.rag.persistence.query_repository import RagQueryRepository from app.modules.rag.persistence.schema_repository import RagSchemaRepository from app.modules.rag.persistence.session_repository import RagSessionRepository from app.modules.shared.db import get_engine class RagRepository: def __init__(self) -> None: self._schema = RagSchemaRepository() self._sessions = RagSessionRepository() self._jobs = RagJobRepository() self._documents = RagDocumentRepository() self._cache = RagCacheRepository() self._query = RagQueryRepository() def ensure_tables(self) -> None: self._schema.ensure_tables() def upsert_session(self, rag_session_id: str, project_id: str) -> None: self._sessions.upsert_session(rag_session_id, project_id) def session_exists(self, rag_session_id: str) -> bool: return self._sessions.session_exists(rag_session_id) def get_session(self, rag_session_id: str) -> dict | None: return self._sessions.get_session(rag_session_id) def create_job(self, index_job_id: str, rag_session_id: str, status: str) -> None: self._jobs.create_job(index_job_id, rag_session_id, status) def update_job(self, index_job_id: str, **kwargs) -> None: self._jobs.update_job(index_job_id, **kwargs) def get_job(self, index_job_id: str) -> RagJobRow | None: return self._jobs.get_job(index_job_id) def replace_documents(self, rag_session_id: str, docs: list[RagDocument]) -> None: with get_engine().connect() as conn: self._documents.replace_documents(conn, rag_session_id, docs) conn.commit() def apply_document_changes(self, rag_session_id: str, delete_paths: list[str], docs: list[RagDocument]) -> None: with get_engine().connect() as conn: self._documents.apply_document_changes(conn, rag_session_id, delete_paths, docs) conn.commit() def get_cached_documents(self, repo_id: str, blob_sha: str) -> list[RagDocument]: return self._cache.get_cached_documents(repo_id, blob_sha) def cache_documents(self, repo_id: str, path: str, blob_sha: str, docs: list[RagDocument]) -> None: self._cache.cache_documents(repo_id, path, blob_sha, docs) def record_repo_cache(self, **kwargs) -> None: self._cache.record_repo_cache(**kwargs) def retrieve( self, rag_session_id: str, query_embedding: list[float], *, query_text: str = "", limit: int = 5, layers: list[str] | None = None, prefer_non_tests: bool = False, ) -> list[dict]: return self._query.retrieve( rag_session_id, query_embedding, query_text=query_text, limit=limit, layers=layers, prefer_non_tests=prefer_non_tests, ) def fallback_chunks(self, rag_session_id: str, limit: int = 5, layers: list[str] | None = None) -> list[dict]: return self._query.fallback_chunks(rag_session_id, limit=limit, layers=layers)