import asyncio import os from collections.abc import Awaitable, Callable from inspect import isawaitable from app.modules.rag.embedding.gigachat_embedder import GigaChatEmbedder from app.modules.rag.repository import RagRepository from app.modules.rag.retrieval.chunker import TextChunker class RagService: def __init__( self, embedder: GigaChatEmbedder, repository: RagRepository, chunker: TextChunker | None = None, ) -> None: self._embedder = embedder self._repo = repository self._chunker = chunker or TextChunker() async def index_snapshot( self, rag_session_id: str, files: list[dict], progress_cb: Callable[[int, int, str], Awaitable[None] | None] | None = None, ) -> tuple[int, int]: total_files = len(files) indexed_files = 0 failed_files = 0 all_chunks: list[dict] = [] for index, file in enumerate(files, start=1): path = str(file.get("path", "")) try: chunks = self._build_chunks_for_file(file) embedded_chunks = await asyncio.to_thread(self._embed_chunks, chunks) all_chunks.extend(embedded_chunks) indexed_files += 1 except Exception: failed_files += 1 await self._notify_progress(progress_cb, index, total_files, path) await asyncio.to_thread(self._repo.replace_chunks, rag_session_id, all_chunks) return indexed_files, failed_files async def index_changes( self, rag_session_id: str, changed_files: list[dict], progress_cb: Callable[[int, int, str], Awaitable[None] | None] | None = None, ) -> tuple[int, int]: total_files = len(changed_files) indexed_files = 0 failed_files = 0 delete_paths: list[str] = [] upsert_chunks: list[dict] = [] for index, file in enumerate(changed_files, start=1): path = str(file.get("path", "")) op = str(file.get("op", "")) try: if op == "delete": delete_paths.append(path) indexed_files += 1 await self._notify_progress(progress_cb, index, total_files, path) continue if op == "upsert" and file.get("content") is not None: chunks = self._build_chunks_for_file(file) embedded_chunks = await asyncio.to_thread(self._embed_chunks, chunks) upsert_chunks.extend(embedded_chunks) indexed_files += 1 await self._notify_progress(progress_cb, index, total_files, path) continue failed_files += 1 except Exception: failed_files += 1 await self._notify_progress(progress_cb, index, total_files, path) await asyncio.to_thread( self._repo.apply_changes, rag_session_id, delete_paths, upsert_chunks, ) return indexed_files, failed_files async def retrieve(self, rag_session_id: str, query: str) -> list[dict]: try: query_embedding = self._embedder.embed([query])[0] rows = self._repo.retrieve(rag_session_id, query_embedding, limit=5) except Exception: rows = self._repo.fallback_chunks(rag_session_id, limit=5) return [{"source": row["path"], "content": row["content"]} for row in rows] def _build_chunks_for_file(self, file: dict) -> list[tuple[str, int, str]]: path = str(file.get("path", "")) content = str(file.get("content", "")) output: list[tuple[str, int, str]] = [] for idx, chunk in enumerate(self._chunker.chunk(content)): output.append((path, idx, chunk)) return output def _embed_chunks(self, raw_chunks: list[tuple[str, int, str]]) -> list[dict]: if not raw_chunks: return [] batch_size = max(1, int(os.getenv("RAG_EMBED_BATCH_SIZE", "16"))) indexed: list[dict] = [] for i in range(0, len(raw_chunks), batch_size): batch = raw_chunks[i : i + batch_size] texts = [x[2] for x in batch] vectors = self._embedder.embed(texts) for (path, chunk_index, content), vector in zip(batch, vectors): indexed.append( { "path": path, "chunk_index": chunk_index, "content": content, "embedding": vector, } ) return indexed async def _notify_progress( self, progress_cb: Callable[[int, int, str], Awaitable[None] | None] | None, current_file_index: int, total_files: int, current_file_name: str, ) -> None: if not progress_cb: return result = progress_cb(current_file_index, total_files, current_file_name) if isawaitable(result): await result