Новый раг
This commit is contained in:
197
app/modules/rag/services/rag_service.py
Normal file
197
app/modules/rag/services/rag_service.py
Normal file
@@ -0,0 +1,197 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import os
|
||||
from collections.abc import Awaitable, Callable
|
||||
from inspect import isawaitable
|
||||
|
||||
from app.modules.rag.contracts import RagDocument
|
||||
from app.modules.rag.indexing.code.pipeline import CodeIndexingPipeline
|
||||
from app.modules.rag.indexing.common.report import IndexReport
|
||||
from app.modules.rag.indexing.docs.pipeline import DocsIndexingPipeline
|
||||
from app.modules.rag.persistence.repository import RagRepository
|
||||
from app.modules.rag.retrieval.query_router import RagQueryRouter
|
||||
from app.modules.rag_session.embedding.gigachat_embedder import GigaChatEmbedder
|
||||
|
||||
|
||||
class RagService:
|
||||
def __init__(
|
||||
self,
|
||||
embedder: GigaChatEmbedder,
|
||||
repository: RagRepository,
|
||||
chunker=None,
|
||||
) -> None:
|
||||
self._embedder = embedder
|
||||
self._repo = repository
|
||||
self._docs = DocsIndexingPipeline()
|
||||
self._code = CodeIndexingPipeline()
|
||||
self._queries = RagQueryRouter()
|
||||
|
||||
async def index_snapshot(
|
||||
self,
|
||||
rag_session_id: str,
|
||||
files: list[dict],
|
||||
progress_cb: Callable[[int, int, str], Awaitable[None] | None] | None = None,
|
||||
) -> tuple[int, int, int, int]:
|
||||
report = await self._index_files(rag_session_id, files, progress_cb=progress_cb)
|
||||
self._repo.replace_documents(rag_session_id, report.documents_list)
|
||||
return report.as_tuple()
|
||||
|
||||
async def index_changes(
|
||||
self,
|
||||
rag_session_id: str,
|
||||
changed_files: list[dict],
|
||||
progress_cb: Callable[[int, int, str], Awaitable[None] | None] | None = None,
|
||||
) -> tuple[int, int, int, int]:
|
||||
delete_paths: list[str] = []
|
||||
upserts: list[dict] = []
|
||||
for item in changed_files:
|
||||
if str(item.get("op")) == "delete":
|
||||
delete_paths.append(str(item.get("path", "")))
|
||||
else:
|
||||
upserts.append(item)
|
||||
report = await self._index_files(rag_session_id, upserts, progress_cb=progress_cb)
|
||||
self._repo.apply_document_changes(rag_session_id, delete_paths, report.documents_list)
|
||||
return report.as_tuple()
|
||||
|
||||
async def retrieve(self, rag_session_id: str, query: str) -> list[dict]:
|
||||
mode = self._queries.resolve_mode(query)
|
||||
layers = self._queries.layers_for_mode(mode)
|
||||
prefer_non_tests = mode == "code" and "test" not in query.lower() and "тест" not in query.lower()
|
||||
try:
|
||||
query_embedding = self._embedder.embed([query])[0]
|
||||
rows = self._repo.retrieve(
|
||||
rag_session_id,
|
||||
query_embedding,
|
||||
query_text=query,
|
||||
limit=8,
|
||||
layers=layers,
|
||||
prefer_non_tests=prefer_non_tests,
|
||||
)
|
||||
except Exception:
|
||||
rows = self._repo.fallback_chunks(rag_session_id, limit=8, layers=layers)
|
||||
if not rows and mode != "docs":
|
||||
rows = self._repo.fallback_chunks(rag_session_id, limit=8, layers=self._queries.layers_for_mode("docs"))
|
||||
return [
|
||||
{
|
||||
"source": row["path"],
|
||||
"content": row["content"],
|
||||
"layer": row.get("layer"),
|
||||
"title": row.get("title"),
|
||||
"metadata": row.get("metadata", {}),
|
||||
"score": row.get("distance"),
|
||||
}
|
||||
for row in rows
|
||||
]
|
||||
|
||||
async def _index_files(
|
||||
self,
|
||||
rag_session_id: str,
|
||||
files: list[dict],
|
||||
progress_cb: Callable[[int, int, str], Awaitable[None] | None] | None = None,
|
||||
) -> "_PipelineReport":
|
||||
total_files = len(files)
|
||||
report = _PipelineReport()
|
||||
repo_id = self._resolve_repo_id(rag_session_id)
|
||||
for index, file in enumerate(files, start=1):
|
||||
path = str(file.get("path", ""))
|
||||
try:
|
||||
blob_sha = self._blob_sha(file)
|
||||
cached = await asyncio.to_thread(self._repo.get_cached_documents, repo_id, blob_sha)
|
||||
if cached:
|
||||
report.documents_list.extend(self._with_file_metadata(cached, file, repo_id, blob_sha))
|
||||
report.cache_hit_files += 1
|
||||
else:
|
||||
built = self._build_documents(repo_id, path, file)
|
||||
embedded = await asyncio.to_thread(self._embed_documents, built, file, repo_id, blob_sha)
|
||||
report.documents_list.extend(embedded)
|
||||
await asyncio.to_thread(self._repo.cache_documents, repo_id, path, blob_sha, embedded)
|
||||
report.cache_miss_files += 1
|
||||
report.indexed_files += 1
|
||||
except Exception as exc:
|
||||
report.failed_files += 1
|
||||
report.warnings.append(f"{path}: {exc}")
|
||||
await self._notify_progress(progress_cb, index, total_files, path)
|
||||
report.documents = len(report.documents_list)
|
||||
return report
|
||||
|
||||
def _build_documents(self, repo_id: str, path: str, file: dict) -> list[RagDocument]:
|
||||
content = str(file.get("content") or "")
|
||||
commit_sha = file.get("commit_sha")
|
||||
docs: list[RagDocument] = []
|
||||
if self._docs.supports(path):
|
||||
docs.extend(self._docs.index_file(repo_id=repo_id, commit_sha=commit_sha, path=path, content=content))
|
||||
if self._code.supports(path):
|
||||
docs.extend(self._code.index_file(repo_id=repo_id, commit_sha=commit_sha, path=path, content=content))
|
||||
if not docs:
|
||||
docs.extend(self._docs.index_file(repo_id=repo_id, commit_sha=commit_sha, path=path, content=content))
|
||||
return docs
|
||||
|
||||
def _embed_documents(self, docs: list[RagDocument], file: dict, repo_id: str, blob_sha: str) -> list[RagDocument]:
|
||||
if not docs:
|
||||
return []
|
||||
batch_size = max(1, int(os.getenv("RAG_EMBED_BATCH_SIZE", "16")))
|
||||
metadata = self._document_metadata(file, repo_id, blob_sha)
|
||||
for doc in docs:
|
||||
doc.metadata.update(metadata)
|
||||
for start in range(0, len(docs), batch_size):
|
||||
batch = docs[start : start + batch_size]
|
||||
vectors = self._embedder.embed([doc.text for doc in batch])
|
||||
for doc, vector in zip(batch, vectors):
|
||||
doc.embedding = vector
|
||||
return docs
|
||||
|
||||
def _with_file_metadata(self, docs: list[RagDocument], file: dict, repo_id: str, blob_sha: str) -> list[RagDocument]:
|
||||
metadata = self._document_metadata(file, repo_id, blob_sha)
|
||||
for doc in docs:
|
||||
doc.metadata.update(metadata)
|
||||
doc.source.repo_id = repo_id
|
||||
doc.source.path = str(file.get("path", doc.source.path))
|
||||
return docs
|
||||
|
||||
def _document_metadata(self, file: dict, repo_id: str, blob_sha: str) -> dict:
|
||||
return {
|
||||
"blob_sha": blob_sha,
|
||||
"repo_id": repo_id,
|
||||
"artifact_type": file.get("artifact_type"),
|
||||
"section": file.get("section"),
|
||||
"doc_id": file.get("doc_id"),
|
||||
"doc_version": file.get("doc_version"),
|
||||
"owner": file.get("owner"),
|
||||
"system_component": file.get("system_component"),
|
||||
"last_modified": file.get("last_modified"),
|
||||
"staleness_score": file.get("staleness_score"),
|
||||
}
|
||||
|
||||
def _resolve_repo_id(self, rag_session_id: str) -> str:
|
||||
session = self._repo.get_session(rag_session_id)
|
||||
if not session:
|
||||
return rag_session_id
|
||||
return str(session.get("project_id") or rag_session_id)
|
||||
|
||||
def _blob_sha(self, file: dict) -> str:
|
||||
raw = str(file.get("content_hash") or "").strip()
|
||||
if raw:
|
||||
return raw
|
||||
content = str(file.get("content") or "")
|
||||
return hashlib.sha256(content.encode("utf-8")).hexdigest()
|
||||
|
||||
async def _notify_progress(
|
||||
self,
|
||||
progress_cb: Callable[[int, int, str], Awaitable[None] | None] | None,
|
||||
current_file_index: int,
|
||||
total_files: int,
|
||||
current_file_name: str,
|
||||
) -> None:
|
||||
if not progress_cb:
|
||||
return
|
||||
result = progress_cb(current_file_index, total_files, current_file_name)
|
||||
if isawaitable(result):
|
||||
await result
|
||||
|
||||
|
||||
class _PipelineReport(IndexReport):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.documents_list: list[RagDocument] = []
|
||||
Reference in New Issue
Block a user