Files
agent/app/modules/rag/services/rag_service.py
2026-03-01 14:21:33 +03:00

198 lines
8.1 KiB
Python

from __future__ import annotations
import asyncio
import hashlib
import os
from collections.abc import Awaitable, Callable
from inspect import isawaitable
from app.modules.rag.contracts import RagDocument
from app.modules.rag.indexing.code.pipeline import CodeIndexingPipeline
from app.modules.rag.indexing.common.report import IndexReport
from app.modules.rag.indexing.docs.pipeline import DocsIndexingPipeline
from app.modules.rag.persistence.repository import RagRepository
from app.modules.rag.retrieval.query_router import RagQueryRouter
from app.modules.rag_session.embedding.gigachat_embedder import GigaChatEmbedder
class RagService:
def __init__(
self,
embedder: GigaChatEmbedder,
repository: RagRepository,
chunker=None,
) -> None:
self._embedder = embedder
self._repo = repository
self._docs = DocsIndexingPipeline()
self._code = CodeIndexingPipeline()
self._queries = RagQueryRouter()
async def index_snapshot(
self,
rag_session_id: str,
files: list[dict],
progress_cb: Callable[[int, int, str], Awaitable[None] | None] | None = None,
) -> tuple[int, int, int, int]:
report = await self._index_files(rag_session_id, files, progress_cb=progress_cb)
self._repo.replace_documents(rag_session_id, report.documents_list)
return report.as_tuple()
async def index_changes(
self,
rag_session_id: str,
changed_files: list[dict],
progress_cb: Callable[[int, int, str], Awaitable[None] | None] | None = None,
) -> tuple[int, int, int, int]:
delete_paths: list[str] = []
upserts: list[dict] = []
for item in changed_files:
if str(item.get("op")) == "delete":
delete_paths.append(str(item.get("path", "")))
else:
upserts.append(item)
report = await self._index_files(rag_session_id, upserts, progress_cb=progress_cb)
self._repo.apply_document_changes(rag_session_id, delete_paths, report.documents_list)
return report.as_tuple()
async def retrieve(self, rag_session_id: str, query: str) -> list[dict]:
mode = self._queries.resolve_mode(query)
layers = self._queries.layers_for_mode(mode)
prefer_non_tests = mode == "code" and "test" not in query.lower() and "тест" not in query.lower()
try:
query_embedding = self._embedder.embed([query])[0]
rows = self._repo.retrieve(
rag_session_id,
query_embedding,
query_text=query,
limit=8,
layers=layers,
prefer_non_tests=prefer_non_tests,
)
except Exception:
rows = self._repo.fallback_chunks(rag_session_id, limit=8, layers=layers)
if not rows and mode != "docs":
rows = self._repo.fallback_chunks(rag_session_id, limit=8, layers=self._queries.layers_for_mode("docs"))
return [
{
"source": row["path"],
"content": row["content"],
"layer": row.get("layer"),
"title": row.get("title"),
"metadata": row.get("metadata", {}),
"score": row.get("distance"),
}
for row in rows
]
async def _index_files(
self,
rag_session_id: str,
files: list[dict],
progress_cb: Callable[[int, int, str], Awaitable[None] | None] | None = None,
) -> "_PipelineReport":
total_files = len(files)
report = _PipelineReport()
repo_id = self._resolve_repo_id(rag_session_id)
for index, file in enumerate(files, start=1):
path = str(file.get("path", ""))
try:
blob_sha = self._blob_sha(file)
cached = await asyncio.to_thread(self._repo.get_cached_documents, repo_id, blob_sha)
if cached:
report.documents_list.extend(self._with_file_metadata(cached, file, repo_id, blob_sha))
report.cache_hit_files += 1
else:
built = self._build_documents(repo_id, path, file)
embedded = await asyncio.to_thread(self._embed_documents, built, file, repo_id, blob_sha)
report.documents_list.extend(embedded)
await asyncio.to_thread(self._repo.cache_documents, repo_id, path, blob_sha, embedded)
report.cache_miss_files += 1
report.indexed_files += 1
except Exception as exc:
report.failed_files += 1
report.warnings.append(f"{path}: {exc}")
await self._notify_progress(progress_cb, index, total_files, path)
report.documents = len(report.documents_list)
return report
def _build_documents(self, repo_id: str, path: str, file: dict) -> list[RagDocument]:
content = str(file.get("content") or "")
commit_sha = file.get("commit_sha")
docs: list[RagDocument] = []
if self._docs.supports(path):
docs.extend(self._docs.index_file(repo_id=repo_id, commit_sha=commit_sha, path=path, content=content))
if self._code.supports(path):
docs.extend(self._code.index_file(repo_id=repo_id, commit_sha=commit_sha, path=path, content=content))
if not docs:
docs.extend(self._docs.index_file(repo_id=repo_id, commit_sha=commit_sha, path=path, content=content))
return docs
def _embed_documents(self, docs: list[RagDocument], file: dict, repo_id: str, blob_sha: str) -> list[RagDocument]:
if not docs:
return []
batch_size = max(1, int(os.getenv("RAG_EMBED_BATCH_SIZE", "16")))
metadata = self._document_metadata(file, repo_id, blob_sha)
for doc in docs:
doc.metadata.update(metadata)
for start in range(0, len(docs), batch_size):
batch = docs[start : start + batch_size]
vectors = self._embedder.embed([doc.text for doc in batch])
for doc, vector in zip(batch, vectors):
doc.embedding = vector
return docs
def _with_file_metadata(self, docs: list[RagDocument], file: dict, repo_id: str, blob_sha: str) -> list[RagDocument]:
metadata = self._document_metadata(file, repo_id, blob_sha)
for doc in docs:
doc.metadata.update(metadata)
doc.source.repo_id = repo_id
doc.source.path = str(file.get("path", doc.source.path))
return docs
def _document_metadata(self, file: dict, repo_id: str, blob_sha: str) -> dict:
return {
"blob_sha": blob_sha,
"repo_id": repo_id,
"artifact_type": file.get("artifact_type"),
"section": file.get("section"),
"doc_id": file.get("doc_id"),
"doc_version": file.get("doc_version"),
"owner": file.get("owner"),
"system_component": file.get("system_component"),
"last_modified": file.get("last_modified"),
"staleness_score": file.get("staleness_score"),
}
def _resolve_repo_id(self, rag_session_id: str) -> str:
session = self._repo.get_session(rag_session_id)
if not session:
return rag_session_id
return str(session.get("project_id") or rag_session_id)
def _blob_sha(self, file: dict) -> str:
raw = str(file.get("content_hash") or "").strip()
if raw:
return raw
content = str(file.get("content") or "")
return hashlib.sha256(content.encode("utf-8")).hexdigest()
async def _notify_progress(
self,
progress_cb: Callable[[int, int, str], Awaitable[None] | None] | None,
current_file_index: int,
total_files: int,
current_file_name: str,
) -> None:
if not progress_cb:
return
result = progress_cb(current_file_index, total_files, current_file_name)
if isawaitable(result):
await result
class _PipelineReport(IndexReport):
def __init__(self) -> None:
super().__init__()
self.documents_list: list[RagDocument] = []