import asyncio from collections import defaultdict from app.schemas.common import ErrorPayload, ModuleName from app.schemas.indexing import IndexJobStatus from app.modules.contracts import RagIndexer from app.modules.rag_session.job_store import IndexJob, IndexJobStore from app.modules.shared.event_bus import EventBus from app.modules.shared.retry_executor import RetryExecutor class IndexingOrchestrator: def __init__( self, store: IndexJobStore, rag: RagIndexer, events: EventBus, retry: RetryExecutor, ) -> None: self._store = store self._rag = rag self._events = events self._retry = retry self._locks: dict[str, asyncio.Lock] = defaultdict(asyncio.Lock) async def enqueue_snapshot(self, rag_session_id: str, files: list[dict]) -> IndexJob: job = self._store.create(rag_session_id) asyncio.create_task(self._process_snapshot(job.index_job_id, rag_session_id, files)) return job async def enqueue_changes(self, rag_session_id: str, changed_files: list[dict]) -> IndexJob: job = self._store.create(rag_session_id) asyncio.create_task(self._process_changes(job.index_job_id, rag_session_id, changed_files)) return job async def _process_snapshot(self, job_id: str, rag_session_id: str, files: list[dict]) -> None: await self._run_with_project_lock( job_id=job_id, rag_session_id=rag_session_id, total_files=len(files), operation=lambda progress_cb: self._rag.index_snapshot( rag_session_id=rag_session_id, files=files, progress_cb=progress_cb, ), ) async def _process_changes(self, job_id: str, rag_session_id: str, changed_files: list[dict]) -> None: await self._run_with_project_lock( job_id=job_id, rag_session_id=rag_session_id, total_files=len(changed_files), operation=lambda progress_cb: self._rag.index_changes( rag_session_id=rag_session_id, changed_files=changed_files, progress_cb=progress_cb, ), ) async def _run_with_project_lock(self, job_id: str, rag_session_id: str, total_files: int, operation) -> None: lock = self._locks[rag_session_id] async with lock: job = self._store.get(job_id) if not job: return job.status = IndexJobStatus.RUNNING self._store.save(job) await self._events.publish( job_id, "index_status", {"index_job_id": job_id, "status": job.status.value, "total_files": total_files}, ) try: async def progress_cb(current_file_index: int, total: int, current_file_name: str) -> None: await self._events.publish( job_id, "index_progress", { "index_job_id": job_id, "current_file_index": current_file_index, "total_files": total, "processed_files": current_file_index, "current_file_path": current_file_name, "current_file_name": current_file_name, }, ) indexed, failed, cache_hits, cache_misses = await self._retry.run(lambda: operation(progress_cb)) job.status = IndexJobStatus.DONE job.indexed_files = indexed job.failed_files = failed job.cache_hit_files = cache_hits job.cache_miss_files = cache_misses self._store.save(job) await self._events.publish( job_id, "index_status", { "index_job_id": job_id, "status": job.status.value, "indexed_files": indexed, "failed_files": failed, "cache_hit_files": cache_hits, "cache_miss_files": cache_misses, "total_files": total_files, }, ) await self._events.publish( job_id, "terminal", { "index_job_id": job_id, "status": "done", "indexed_files": indexed, "failed_files": failed, "cache_hit_files": cache_hits, "cache_miss_files": cache_misses, "total_files": total_files, }, ) except (TimeoutError, ConnectionError, OSError) as exc: job.status = IndexJobStatus.ERROR job.error = ErrorPayload( code="index_retry_exhausted", desc=f"Temporary indexing failure after retries: {exc}", module=ModuleName.RAG, ) self._store.save(job) await self._events.publish( job_id, "index_status", {"index_job_id": job_id, "status": job.status.value, "total_files": total_files}, ) await self._events.publish( job_id, "terminal", { "index_job_id": job_id, "status": "error", "total_files": total_files, "error": { "code": job.error.code, "desc": job.error.desc, "module": job.error.module.value, }, }, )