Files
agent/app/modules/rag/services/rag_service.py

191 lines
7.7 KiB
Python

from __future__ import annotations
import asyncio
import hashlib
import logging
import os
from collections.abc import Awaitable, Callable
from inspect import isawaitable
from app.modules.rag.contracts import RagDocument
from app.modules.rag.indexing.code.pipeline import CodeIndexingPipeline
from app.modules.rag.indexing.common.report import IndexReport
from app.modules.rag.indexing.docs.pipeline import DocsIndexingPipeline
from app.modules.rag.persistence.repository import RagRepository
from app.modules.rag_session.embedding.gigachat_embedder import GigaChatEmbedder
LOGGER = logging.getLogger(__name__)
class RagService:
def __init__(
self,
embedder: GigaChatEmbedder,
repository: RagRepository,
chunker=None,
) -> None:
self._embedder = embedder
self._repo = repository
self._docs = DocsIndexingPipeline()
self._code = CodeIndexingPipeline()
async def index_snapshot(
self,
rag_session_id: str,
files: list[dict],
progress_cb: Callable[[int, int, str], Awaitable[None] | None] | None = None,
) -> tuple[int, int, int, int]:
report = await self._index_files(rag_session_id, files, progress_cb=progress_cb)
self._repo.replace_documents(rag_session_id, report.documents_list)
return report.as_tuple()
async def index_changes(
self,
rag_session_id: str,
changed_files: list[dict],
progress_cb: Callable[[int, int, str], Awaitable[None] | None] | None = None,
) -> tuple[int, int, int, int]:
delete_paths: list[str] = []
upserts: list[dict] = []
for item in changed_files:
if str(item.get("op")) == "delete":
delete_paths.append(str(item.get("path", "")))
else:
upserts.append(item)
report = await self._index_files(rag_session_id, upserts, progress_cb=progress_cb)
self._repo.apply_document_changes(rag_session_id, delete_paths, report.documents_list)
return report.as_tuple()
async def _index_files(
self,
rag_session_id: str,
files: list[dict],
progress_cb: Callable[[int, int, str], Awaitable[None] | None] | None = None,
) -> "_PipelineReport":
total_files = len(files)
report = _PipelineReport()
repo_id = self._resolve_repo_id(rag_session_id)
for index, file in enumerate(files, start=1):
path = str(file.get("path", ""))
try:
blob_sha = self._blob_sha(file)
cached = await asyncio.to_thread(self._repo.get_cached_documents, repo_id, blob_sha)
pipelines = self._resolve_pipeline_names(path)
if cached:
report.documents_list.extend(self._with_file_metadata(cached, file, repo_id, blob_sha))
report.cache_hit_files += 1
LOGGER.warning(
"rag ingest file: rag_session_id=%s path=%s processing=cache pipeline=%s",
rag_session_id,
path,
",".join(pipelines),
)
else:
built = self._build_documents(repo_id, path, file)
embedded = await asyncio.to_thread(self._embed_documents, built, file, repo_id, blob_sha)
report.documents_list.extend(embedded)
await asyncio.to_thread(self._repo.cache_documents, repo_id, path, blob_sha, embedded)
report.cache_miss_files += 1
LOGGER.warning(
"rag ingest file: rag_session_id=%s path=%s processing=embed pipeline=%s",
rag_session_id,
path,
",".join(pipelines),
)
report.indexed_files += 1
except Exception as exc:
report.failed_files += 1
report.warnings.append(f"{path}: {exc}")
await self._notify_progress(progress_cb, index, total_files, path)
report.documents = len(report.documents_list)
return report
def _build_documents(self, repo_id: str, path: str, file: dict) -> list[RagDocument]:
content = str(file.get("content") or "")
commit_sha = file.get("commit_sha")
docs: list[RagDocument] = []
if self._docs.supports(path):
docs.extend(self._docs.index_file(repo_id=repo_id, commit_sha=commit_sha, path=path, content=content))
if self._code.supports(path):
docs.extend(self._code.index_file(repo_id=repo_id, commit_sha=commit_sha, path=path, content=content))
if not docs:
docs.extend(self._docs.index_file(repo_id=repo_id, commit_sha=commit_sha, path=path, content=content))
return docs
def _resolve_pipeline_names(self, path: str) -> list[str]:
names: list[str] = []
if self._docs.supports(path):
names.append("DOCS")
if self._code.supports(path):
names.append("CODE")
if not names:
names.append("DOCS")
return names
def _embed_documents(self, docs: list[RagDocument], file: dict, repo_id: str, blob_sha: str) -> list[RagDocument]:
if not docs:
return []
batch_size = max(1, int(os.getenv("RAG_EMBED_BATCH_SIZE", "16")))
metadata = self._document_metadata(file, repo_id, blob_sha)
for doc in docs:
doc.metadata.update(metadata)
for start in range(0, len(docs), batch_size):
batch = docs[start : start + batch_size]
vectors = self._embedder.embed([doc.text for doc in batch])
for doc, vector in zip(batch, vectors):
doc.embedding = vector
return docs
def _with_file_metadata(self, docs: list[RagDocument], file: dict, repo_id: str, blob_sha: str) -> list[RagDocument]:
metadata = self._document_metadata(file, repo_id, blob_sha)
for doc in docs:
doc.metadata.update(metadata)
doc.source.repo_id = repo_id
doc.source.path = str(file.get("path", doc.source.path))
return docs
def _document_metadata(self, file: dict, repo_id: str, blob_sha: str) -> dict:
return {
"blob_sha": blob_sha,
"repo_id": repo_id,
"artifact_type": file.get("artifact_type"),
"section": file.get("section"),
"doc_id": file.get("doc_id"),
"doc_version": file.get("doc_version"),
"owner": file.get("owner"),
"system_component": file.get("system_component"),
"last_modified": file.get("last_modified"),
"staleness_score": file.get("staleness_score"),
}
def _resolve_repo_id(self, rag_session_id: str) -> str:
session = self._repo.get_session(rag_session_id)
if not session:
return rag_session_id
return str(session.get("project_id") or rag_session_id)
def _blob_sha(self, file: dict) -> str:
raw = str(file.get("content_hash") or "").strip()
if raw:
return raw
content = str(file.get("content") or "")
return hashlib.sha256(content.encode("utf-8")).hexdigest()
async def _notify_progress(
self,
progress_cb: Callable[[int, int, str], Awaitable[None] | None] | None,
current_file_index: int,
total_files: int,
current_file_name: str,
) -> None:
if not progress_cb:
return
result = progress_cb(current_file_index, total_files, current_file_name)
if isawaitable(result):
await result
class _PipelineReport(IndexReport):
def __init__(self) -> None:
super().__init__()
self.documents_list: list[RagDocument] = []