Фиксация изменений

This commit is contained in:
2026-03-05 11:03:17 +03:00
parent 1ef0b4d68c
commit 417b8b6f72
261 changed files with 8215 additions and 332 deletions

View File

@@ -2,6 +2,7 @@ from __future__ import annotations
import asyncio
import hashlib
import logging
import os
from collections.abc import Awaitable, Callable
from inspect import isawaitable
@@ -11,9 +12,10 @@ from app.modules.rag.indexing.code.pipeline import CodeIndexingPipeline
from app.modules.rag.indexing.common.report import IndexReport
from app.modules.rag.indexing.docs.pipeline import DocsIndexingPipeline
from app.modules.rag.persistence.repository import RagRepository
from app.modules.rag.retrieval.query_router import RagQueryRouter
from app.modules.rag_session.embedding.gigachat_embedder import GigaChatEmbedder
LOGGER = logging.getLogger(__name__)
class RagService:
def __init__(
@@ -26,7 +28,6 @@ class RagService:
self._repo = repository
self._docs = DocsIndexingPipeline()
self._code = CodeIndexingPipeline()
self._queries = RagQueryRouter()
async def index_snapshot(
self,
@@ -55,36 +56,6 @@ class RagService:
self._repo.apply_document_changes(rag_session_id, delete_paths, report.documents_list)
return report.as_tuple()
async def retrieve(self, rag_session_id: str, query: str) -> list[dict]:
mode = self._queries.resolve_mode(query)
layers = self._queries.layers_for_mode(mode)
prefer_non_tests = mode == "code" and "test" not in query.lower() and "тест" not in query.lower()
try:
query_embedding = self._embedder.embed([query])[0]
rows = self._repo.retrieve(
rag_session_id,
query_embedding,
query_text=query,
limit=8,
layers=layers,
prefer_non_tests=prefer_non_tests,
)
except Exception:
rows = self._repo.fallback_chunks(rag_session_id, limit=8, layers=layers)
if not rows and mode != "docs":
rows = self._repo.fallback_chunks(rag_session_id, limit=8, layers=self._queries.layers_for_mode("docs"))
return [
{
"source": row["path"],
"content": row["content"],
"layer": row.get("layer"),
"title": row.get("title"),
"metadata": row.get("metadata", {}),
"score": row.get("distance"),
}
for row in rows
]
async def _index_files(
self,
rag_session_id: str,
@@ -99,15 +70,28 @@ class RagService:
try:
blob_sha = self._blob_sha(file)
cached = await asyncio.to_thread(self._repo.get_cached_documents, repo_id, blob_sha)
pipelines = self._resolve_pipeline_names(path)
if cached:
report.documents_list.extend(self._with_file_metadata(cached, file, repo_id, blob_sha))
report.cache_hit_files += 1
LOGGER.warning(
"rag ingest file: rag_session_id=%s path=%s processing=cache pipeline=%s",
rag_session_id,
path,
",".join(pipelines),
)
else:
built = self._build_documents(repo_id, path, file)
embedded = await asyncio.to_thread(self._embed_documents, built, file, repo_id, blob_sha)
report.documents_list.extend(embedded)
await asyncio.to_thread(self._repo.cache_documents, repo_id, path, blob_sha, embedded)
report.cache_miss_files += 1
LOGGER.warning(
"rag ingest file: rag_session_id=%s path=%s processing=embed pipeline=%s",
rag_session_id,
path,
",".join(pipelines),
)
report.indexed_files += 1
except Exception as exc:
report.failed_files += 1
@@ -128,6 +112,16 @@ class RagService:
docs.extend(self._docs.index_file(repo_id=repo_id, commit_sha=commit_sha, path=path, content=content))
return docs
def _resolve_pipeline_names(self, path: str) -> list[str]:
names: list[str] = []
if self._docs.supports(path):
names.append("DOCS")
if self._code.supports(path):
names.append("CODE")
if not names:
names.append("DOCS")
return names
def _embed_documents(self, docs: list[RagDocument], file: dict, repo_id: str, blob_sha: str) -> list[RagDocument]:
if not docs:
return []
@@ -190,7 +184,6 @@ class RagService:
if isawaitable(result):
await result
class _PipelineReport(IndexReport):
def __init__(self) -> None:
super().__init__()