Новый раг
This commit is contained in:
@@ -1,211 +1,3 @@
|
||||
import asyncio
|
||||
import hashlib
|
||||
import os
|
||||
from collections.abc import Awaitable, Callable
|
||||
from inspect import isawaitable
|
||||
from app.modules.rag.services.rag_service import RagService
|
||||
|
||||
from app.modules.rag_session.embedding.gigachat_embedder import GigaChatEmbedder
|
||||
from app.modules.rag_session.repository import RagRepository
|
||||
from app.modules.rag_session.retrieval.chunker import TextChunker
|
||||
|
||||
|
||||
class RagService:
|
||||
def __init__(
|
||||
self,
|
||||
embedder: GigaChatEmbedder,
|
||||
repository: RagRepository,
|
||||
chunker: TextChunker | None = None,
|
||||
) -> None:
|
||||
self._embedder = embedder
|
||||
self._repo = repository
|
||||
self._chunker = chunker or TextChunker()
|
||||
|
||||
async def index_snapshot(
|
||||
self,
|
||||
rag_session_id: str,
|
||||
files: list[dict],
|
||||
progress_cb: Callable[[int, int, str], Awaitable[None] | None] | None = None,
|
||||
) -> tuple[int, int, int, int]:
|
||||
total_files = len(files)
|
||||
indexed_files = 0
|
||||
failed_files = 0
|
||||
cache_hit_files = 0
|
||||
cache_miss_files = 0
|
||||
all_chunks: list[dict] = []
|
||||
repo_id = self._resolve_repo_id(rag_session_id)
|
||||
for index, file in enumerate(files, start=1):
|
||||
path = str(file.get("path", ""))
|
||||
try:
|
||||
blob_sha = self._blob_sha(file)
|
||||
cached = await asyncio.to_thread(self._repo.get_cached_chunks, repo_id, blob_sha)
|
||||
if cached:
|
||||
all_chunks.extend(self._build_cached_items(path, file, repo_id, blob_sha, cached))
|
||||
cache_hit_files += 1
|
||||
else:
|
||||
chunks = self._build_chunks_for_file(file)
|
||||
embedded_chunks = await asyncio.to_thread(self._embed_chunks, chunks, file, repo_id, blob_sha)
|
||||
all_chunks.extend(embedded_chunks)
|
||||
await asyncio.to_thread(self._repo.cache_file_chunks, repo_id, path, blob_sha, embedded_chunks)
|
||||
cache_miss_files += 1
|
||||
indexed_files += 1
|
||||
except Exception:
|
||||
failed_files += 1
|
||||
await self._notify_progress(progress_cb, index, total_files, path)
|
||||
await asyncio.to_thread(self._repo.replace_chunks, rag_session_id, all_chunks)
|
||||
return indexed_files, failed_files, cache_hit_files, cache_miss_files
|
||||
|
||||
async def index_changes(
|
||||
self,
|
||||
rag_session_id: str,
|
||||
changed_files: list[dict],
|
||||
progress_cb: Callable[[int, int, str], Awaitable[None] | None] | None = None,
|
||||
) -> tuple[int, int, int, int]:
|
||||
total_files = len(changed_files)
|
||||
indexed_files = 0
|
||||
failed_files = 0
|
||||
cache_hit_files = 0
|
||||
cache_miss_files = 0
|
||||
delete_paths: list[str] = []
|
||||
upsert_chunks: list[dict] = []
|
||||
repo_id = self._resolve_repo_id(rag_session_id)
|
||||
|
||||
for index, file in enumerate(changed_files, start=1):
|
||||
path = str(file.get("path", ""))
|
||||
op = str(file.get("op", ""))
|
||||
try:
|
||||
if op == "delete":
|
||||
delete_paths.append(path)
|
||||
indexed_files += 1
|
||||
await self._notify_progress(progress_cb, index, total_files, path)
|
||||
continue
|
||||
if op == "upsert" and file.get("content") is not None:
|
||||
blob_sha = self._blob_sha(file)
|
||||
cached = await asyncio.to_thread(self._repo.get_cached_chunks, repo_id, blob_sha)
|
||||
if cached:
|
||||
upsert_chunks.extend(self._build_cached_items(path, file, repo_id, blob_sha, cached))
|
||||
cache_hit_files += 1
|
||||
else:
|
||||
chunks = self._build_chunks_for_file(file)
|
||||
embedded_chunks = await asyncio.to_thread(self._embed_chunks, chunks, file, repo_id, blob_sha)
|
||||
upsert_chunks.extend(embedded_chunks)
|
||||
await asyncio.to_thread(self._repo.cache_file_chunks, repo_id, path, blob_sha, embedded_chunks)
|
||||
cache_miss_files += 1
|
||||
indexed_files += 1
|
||||
await self._notify_progress(progress_cb, index, total_files, path)
|
||||
continue
|
||||
failed_files += 1
|
||||
except Exception:
|
||||
failed_files += 1
|
||||
await self._notify_progress(progress_cb, index, total_files, path)
|
||||
|
||||
await asyncio.to_thread(
|
||||
self._repo.apply_changes,
|
||||
rag_session_id,
|
||||
delete_paths,
|
||||
upsert_chunks,
|
||||
)
|
||||
return indexed_files, failed_files, cache_hit_files, cache_miss_files
|
||||
|
||||
async def retrieve(self, rag_session_id: str, query: str) -> list[dict]:
|
||||
try:
|
||||
query_embedding = self._embedder.embed([query])[0]
|
||||
rows = self._repo.retrieve(rag_session_id, query_embedding, limit=5)
|
||||
except Exception:
|
||||
rows = self._repo.fallback_chunks(rag_session_id, limit=5)
|
||||
return [{"source": row["path"], "content": row["content"]} for row in rows]
|
||||
|
||||
def _build_chunks_for_file(self, file: dict) -> list[tuple[str, int, str]]:
|
||||
path = str(file.get("path", ""))
|
||||
content = str(file.get("content", ""))
|
||||
output: list[tuple[str, int, str]] = []
|
||||
for idx, chunk in enumerate(self._chunker.chunk(content)):
|
||||
output.append((path, idx, chunk))
|
||||
return output
|
||||
|
||||
def _embed_chunks(self, raw_chunks: list[tuple[str, int, str]], file: dict, repo_id: str, blob_sha: str) -> list[dict]:
|
||||
if not raw_chunks:
|
||||
return []
|
||||
batch_size = max(1, int(os.getenv("RAG_EMBED_BATCH_SIZE", "16")))
|
||||
metadata = self._chunk_metadata(file)
|
||||
|
||||
indexed: list[dict] = []
|
||||
for i in range(0, len(raw_chunks), batch_size):
|
||||
batch = raw_chunks[i : i + batch_size]
|
||||
texts = [x[2] for x in batch]
|
||||
vectors = self._embedder.embed(texts)
|
||||
for (path, chunk_index, content), vector in zip(batch, vectors):
|
||||
indexed.append(
|
||||
{
|
||||
"path": path,
|
||||
"chunk_index": chunk_index,
|
||||
"content": content,
|
||||
"embedding": vector,
|
||||
"repo_id": repo_id,
|
||||
"blob_sha": blob_sha,
|
||||
**metadata,
|
||||
}
|
||||
)
|
||||
return indexed
|
||||
|
||||
def _build_cached_items(
|
||||
self,
|
||||
path: str,
|
||||
file: dict,
|
||||
repo_id: str,
|
||||
blob_sha: str,
|
||||
cached: list[dict],
|
||||
) -> list[dict]:
|
||||
metadata = self._chunk_metadata(file)
|
||||
output: list[dict] = []
|
||||
for item in cached:
|
||||
output.append(
|
||||
{
|
||||
"path": path,
|
||||
"chunk_index": int(item["chunk_index"]),
|
||||
"content": str(item["content"]),
|
||||
"embedding": item.get("embedding") or [],
|
||||
"repo_id": repo_id,
|
||||
"blob_sha": blob_sha,
|
||||
**metadata,
|
||||
"section": item.get("section") or metadata.get("section"),
|
||||
}
|
||||
)
|
||||
return output
|
||||
|
||||
def _resolve_repo_id(self, rag_session_id: str) -> str:
|
||||
session = self._repo.get_session(rag_session_id)
|
||||
if not session:
|
||||
return rag_session_id
|
||||
return str(session.get("project_id") or rag_session_id)
|
||||
|
||||
def _blob_sha(self, file: dict) -> str:
|
||||
raw = str(file.get("content_hash") or "").strip()
|
||||
if raw:
|
||||
return raw
|
||||
content = str(file.get("content") or "")
|
||||
return hashlib.sha256(content.encode("utf-8")).hexdigest()
|
||||
|
||||
def _chunk_metadata(self, file: dict) -> dict:
|
||||
return {
|
||||
"artifact_type": file.get("artifact_type"),
|
||||
"section": file.get("section"),
|
||||
"doc_id": file.get("doc_id"),
|
||||
"doc_version": file.get("doc_version"),
|
||||
"owner": file.get("owner"),
|
||||
"system_component": file.get("system_component"),
|
||||
"last_modified": file.get("last_modified"),
|
||||
"staleness_score": file.get("staleness_score"),
|
||||
}
|
||||
|
||||
async def _notify_progress(
|
||||
self,
|
||||
progress_cb: Callable[[int, int, str], Awaitable[None] | None] | None,
|
||||
current_file_index: int,
|
||||
total_files: int,
|
||||
current_file_name: str,
|
||||
) -> None:
|
||||
if not progress_cb:
|
||||
return
|
||||
result = progress_cb(current_file_index, total_files, current_file_name)
|
||||
if isawaitable(result):
|
||||
await result
|
||||
__all__ = ["RagService"]
|
||||
|
||||
Reference in New Issue
Block a user