Files
agent/app/modules/rag/persistence/repository.py
2026-03-01 14:21:33 +03:00

83 lines
3.3 KiB
Python

from __future__ import annotations
from app.modules.rag.contracts import RagDocument
from app.modules.rag.persistence.cache_repository import RagCacheRepository
from app.modules.rag.persistence.document_repository import RagDocumentRepository
from app.modules.rag.persistence.job_repository import RagJobRepository, RagJobRow
from app.modules.rag.persistence.query_repository import RagQueryRepository
from app.modules.rag.persistence.schema_repository import RagSchemaRepository
from app.modules.rag.persistence.session_repository import RagSessionRepository
from app.modules.shared.db import get_engine
class RagRepository:
def __init__(self) -> None:
self._schema = RagSchemaRepository()
self._sessions = RagSessionRepository()
self._jobs = RagJobRepository()
self._documents = RagDocumentRepository()
self._cache = RagCacheRepository()
self._query = RagQueryRepository()
def ensure_tables(self) -> None:
self._schema.ensure_tables()
def upsert_session(self, rag_session_id: str, project_id: str) -> None:
self._sessions.upsert_session(rag_session_id, project_id)
def session_exists(self, rag_session_id: str) -> bool:
return self._sessions.session_exists(rag_session_id)
def get_session(self, rag_session_id: str) -> dict | None:
return self._sessions.get_session(rag_session_id)
def create_job(self, index_job_id: str, rag_session_id: str, status: str) -> None:
self._jobs.create_job(index_job_id, rag_session_id, status)
def update_job(self, index_job_id: str, **kwargs) -> None:
self._jobs.update_job(index_job_id, **kwargs)
def get_job(self, index_job_id: str) -> RagJobRow | None:
return self._jobs.get_job(index_job_id)
def replace_documents(self, rag_session_id: str, docs: list[RagDocument]) -> None:
with get_engine().connect() as conn:
self._documents.replace_documents(conn, rag_session_id, docs)
conn.commit()
def apply_document_changes(self, rag_session_id: str, delete_paths: list[str], docs: list[RagDocument]) -> None:
with get_engine().connect() as conn:
self._documents.apply_document_changes(conn, rag_session_id, delete_paths, docs)
conn.commit()
def get_cached_documents(self, repo_id: str, blob_sha: str) -> list[RagDocument]:
return self._cache.get_cached_documents(repo_id, blob_sha)
def cache_documents(self, repo_id: str, path: str, blob_sha: str, docs: list[RagDocument]) -> None:
self._cache.cache_documents(repo_id, path, blob_sha, docs)
def record_repo_cache(self, **kwargs) -> None:
self._cache.record_repo_cache(**kwargs)
def retrieve(
self,
rag_session_id: str,
query_embedding: list[float],
*,
query_text: str = "",
limit: int = 5,
layers: list[str] | None = None,
prefer_non_tests: bool = False,
) -> list[dict]:
return self._query.retrieve(
rag_session_id,
query_embedding,
query_text=query_text,
limit=limit,
layers=layers,
prefer_non_tests=prefer_non_tests,
)
def fallback_chunks(self, rag_session_id: str, limit: int = 5, layers: list[str] | None = None) -> list[dict]:
return self._query.fallback_chunks(rag_session_id, limit=limit, layers=layers)