import asyncio import hashlib import os from collections.abc import Awaitable, Callable from inspect import isawaitable from app.modules.rag_session.embedding.gigachat_embedder import GigaChatEmbedder from app.modules.rag_session.repository import RagRepository from app.modules.rag_session.retrieval.chunker import TextChunker class RagService: def __init__( self, embedder: GigaChatEmbedder, repository: RagRepository, chunker: TextChunker | None = None, ) -> None: self._embedder = embedder self._repo = repository self._chunker = chunker or TextChunker() async def index_snapshot( self, rag_session_id: str, files: list[dict], progress_cb: Callable[[int, int, str], Awaitable[None] | None] | None = None, ) -> tuple[int, int, int, int]: total_files = len(files) indexed_files = 0 failed_files = 0 cache_hit_files = 0 cache_miss_files = 0 all_chunks: list[dict] = [] repo_id = self._resolve_repo_id(rag_session_id) for index, file in enumerate(files, start=1): path = str(file.get("path", "")) try: blob_sha = self._blob_sha(file) cached = await asyncio.to_thread(self._repo.get_cached_chunks, repo_id, blob_sha) if cached: all_chunks.extend(self._build_cached_items(path, file, repo_id, blob_sha, cached)) cache_hit_files += 1 else: chunks = self._build_chunks_for_file(file) embedded_chunks = await asyncio.to_thread(self._embed_chunks, chunks, file, repo_id, blob_sha) all_chunks.extend(embedded_chunks) await asyncio.to_thread(self._repo.cache_file_chunks, repo_id, path, blob_sha, embedded_chunks) cache_miss_files += 1 indexed_files += 1 except Exception: failed_files += 1 await self._notify_progress(progress_cb, index, total_files, path) await asyncio.to_thread(self._repo.replace_chunks, rag_session_id, all_chunks) return indexed_files, failed_files, cache_hit_files, cache_miss_files async def index_changes( self, rag_session_id: str, changed_files: list[dict], progress_cb: Callable[[int, int, str], Awaitable[None] | None] | None = None, ) -> tuple[int, int, int, int]: total_files = len(changed_files) indexed_files = 0 failed_files = 0 cache_hit_files = 0 cache_miss_files = 0 delete_paths: list[str] = [] upsert_chunks: list[dict] = [] repo_id = self._resolve_repo_id(rag_session_id) for index, file in enumerate(changed_files, start=1): path = str(file.get("path", "")) op = str(file.get("op", "")) try: if op == "delete": delete_paths.append(path) indexed_files += 1 await self._notify_progress(progress_cb, index, total_files, path) continue if op == "upsert" and file.get("content") is not None: blob_sha = self._blob_sha(file) cached = await asyncio.to_thread(self._repo.get_cached_chunks, repo_id, blob_sha) if cached: upsert_chunks.extend(self._build_cached_items(path, file, repo_id, blob_sha, cached)) cache_hit_files += 1 else: chunks = self._build_chunks_for_file(file) embedded_chunks = await asyncio.to_thread(self._embed_chunks, chunks, file, repo_id, blob_sha) upsert_chunks.extend(embedded_chunks) await asyncio.to_thread(self._repo.cache_file_chunks, repo_id, path, blob_sha, embedded_chunks) cache_miss_files += 1 indexed_files += 1 await self._notify_progress(progress_cb, index, total_files, path) continue failed_files += 1 except Exception: failed_files += 1 await self._notify_progress(progress_cb, index, total_files, path) await asyncio.to_thread( self._repo.apply_changes, rag_session_id, delete_paths, upsert_chunks, ) return indexed_files, failed_files, cache_hit_files, cache_miss_files async def retrieve(self, rag_session_id: str, query: str) -> list[dict]: try: query_embedding = self._embedder.embed([query])[0] rows = self._repo.retrieve(rag_session_id, query_embedding, limit=5) except Exception: rows = self._repo.fallback_chunks(rag_session_id, limit=5) return [{"source": row["path"], "content": row["content"]} for row in rows] def _build_chunks_for_file(self, file: dict) -> list[tuple[str, int, str]]: path = str(file.get("path", "")) content = str(file.get("content", "")) output: list[tuple[str, int, str]] = [] for idx, chunk in enumerate(self._chunker.chunk(content)): output.append((path, idx, chunk)) return output def _embed_chunks(self, raw_chunks: list[tuple[str, int, str]], file: dict, repo_id: str, blob_sha: str) -> list[dict]: if not raw_chunks: return [] batch_size = max(1, int(os.getenv("RAG_EMBED_BATCH_SIZE", "16"))) metadata = self._chunk_metadata(file) indexed: list[dict] = [] for i in range(0, len(raw_chunks), batch_size): batch = raw_chunks[i : i + batch_size] texts = [x[2] for x in batch] vectors = self._embedder.embed(texts) for (path, chunk_index, content), vector in zip(batch, vectors): indexed.append( { "path": path, "chunk_index": chunk_index, "content": content, "embedding": vector, "repo_id": repo_id, "blob_sha": blob_sha, **metadata, } ) return indexed def _build_cached_items( self, path: str, file: dict, repo_id: str, blob_sha: str, cached: list[dict], ) -> list[dict]: metadata = self._chunk_metadata(file) output: list[dict] = [] for item in cached: output.append( { "path": path, "chunk_index": int(item["chunk_index"]), "content": str(item["content"]), "embedding": item.get("embedding") or [], "repo_id": repo_id, "blob_sha": blob_sha, **metadata, "section": item.get("section") or metadata.get("section"), } ) return output def _resolve_repo_id(self, rag_session_id: str) -> str: session = self._repo.get_session(rag_session_id) if not session: return rag_session_id return str(session.get("project_id") or rag_session_id) def _blob_sha(self, file: dict) -> str: raw = str(file.get("content_hash") or "").strip() if raw: return raw content = str(file.get("content") or "") return hashlib.sha256(content.encode("utf-8")).hexdigest() def _chunk_metadata(self, file: dict) -> dict: return { "artifact_type": file.get("artifact_type"), "section": file.get("section"), "doc_id": file.get("doc_id"), "doc_version": file.get("doc_version"), "owner": file.get("owner"), "system_component": file.get("system_component"), "last_modified": file.get("last_modified"), "staleness_score": file.get("staleness_score"), } async def _notify_progress( self, progress_cb: Callable[[int, int, str], Awaitable[None] | None] | None, current_file_index: int, total_files: int, current_file_name: str, ) -> None: if not progress_cb: return result = progress_cb(current_file_index, total_files, current_file_name) if isawaitable(result): await result