142 lines
5.6 KiB
Python
142 lines
5.6 KiB
Python
import asyncio
|
|
from collections import defaultdict
|
|
|
|
from app.schemas.common import ErrorPayload, ModuleName
|
|
from app.schemas.indexing import IndexJobStatus
|
|
from app.modules.contracts import RagIndexer
|
|
from app.modules.rag.job_store import IndexJob, IndexJobStore
|
|
from app.modules.shared.event_bus import EventBus
|
|
from app.modules.shared.retry_executor import RetryExecutor
|
|
|
|
|
|
class IndexingOrchestrator:
|
|
def __init__(
|
|
self,
|
|
store: IndexJobStore,
|
|
rag: RagIndexer,
|
|
events: EventBus,
|
|
retry: RetryExecutor,
|
|
) -> None:
|
|
self._store = store
|
|
self._rag = rag
|
|
self._events = events
|
|
self._retry = retry
|
|
self._locks: dict[str, asyncio.Lock] = defaultdict(asyncio.Lock)
|
|
|
|
async def enqueue_snapshot(self, rag_session_id: str, files: list[dict]) -> IndexJob:
|
|
job = self._store.create(rag_session_id)
|
|
asyncio.create_task(self._process_snapshot(job.index_job_id, rag_session_id, files))
|
|
return job
|
|
|
|
async def enqueue_changes(self, rag_session_id: str, changed_files: list[dict]) -> IndexJob:
|
|
job = self._store.create(rag_session_id)
|
|
asyncio.create_task(self._process_changes(job.index_job_id, rag_session_id, changed_files))
|
|
return job
|
|
|
|
async def _process_snapshot(self, job_id: str, rag_session_id: str, files: list[dict]) -> None:
|
|
await self._run_with_project_lock(
|
|
job_id=job_id,
|
|
rag_session_id=rag_session_id,
|
|
total_files=len(files),
|
|
operation=lambda progress_cb: self._rag.index_snapshot(
|
|
rag_session_id=rag_session_id,
|
|
files=files,
|
|
progress_cb=progress_cb,
|
|
),
|
|
)
|
|
|
|
async def _process_changes(self, job_id: str, rag_session_id: str, changed_files: list[dict]) -> None:
|
|
await self._run_with_project_lock(
|
|
job_id=job_id,
|
|
rag_session_id=rag_session_id,
|
|
total_files=len(changed_files),
|
|
operation=lambda progress_cb: self._rag.index_changes(
|
|
rag_session_id=rag_session_id,
|
|
changed_files=changed_files,
|
|
progress_cb=progress_cb,
|
|
),
|
|
)
|
|
|
|
async def _run_with_project_lock(self, job_id: str, rag_session_id: str, total_files: int, operation) -> None:
|
|
lock = self._locks[rag_session_id]
|
|
async with lock:
|
|
job = self._store.get(job_id)
|
|
if not job:
|
|
return
|
|
job.status = IndexJobStatus.RUNNING
|
|
self._store.save(job)
|
|
await self._events.publish(
|
|
job_id,
|
|
"index_status",
|
|
{"index_job_id": job_id, "status": job.status.value, "total_files": total_files},
|
|
)
|
|
try:
|
|
async def progress_cb(current_file_index: int, total: int, current_file_name: str) -> None:
|
|
await self._events.publish(
|
|
job_id,
|
|
"index_progress",
|
|
{
|
|
"index_job_id": job_id,
|
|
"current_file_index": current_file_index,
|
|
"total_files": total,
|
|
"processed_files": current_file_index,
|
|
"current_file_path": current_file_name,
|
|
"current_file_name": current_file_name,
|
|
},
|
|
)
|
|
|
|
indexed, failed = await self._retry.run(lambda: operation(progress_cb))
|
|
job.status = IndexJobStatus.DONE
|
|
job.indexed_files = indexed
|
|
job.failed_files = failed
|
|
self._store.save(job)
|
|
await self._events.publish(
|
|
job_id,
|
|
"index_status",
|
|
{
|
|
"index_job_id": job_id,
|
|
"status": job.status.value,
|
|
"indexed_files": indexed,
|
|
"failed_files": failed,
|
|
"total_files": total_files,
|
|
},
|
|
)
|
|
await self._events.publish(
|
|
job_id,
|
|
"terminal",
|
|
{
|
|
"index_job_id": job_id,
|
|
"status": "done",
|
|
"indexed_files": indexed,
|
|
"failed_files": failed,
|
|
"total_files": total_files,
|
|
},
|
|
)
|
|
except (TimeoutError, ConnectionError, OSError) as exc:
|
|
job.status = IndexJobStatus.ERROR
|
|
job.error = ErrorPayload(
|
|
code="index_retry_exhausted",
|
|
desc=f"Temporary indexing failure after retries: {exc}",
|
|
module=ModuleName.RAG,
|
|
)
|
|
self._store.save(job)
|
|
await self._events.publish(
|
|
job_id,
|
|
"index_status",
|
|
{"index_job_id": job_id, "status": job.status.value, "total_files": total_files},
|
|
)
|
|
await self._events.publish(
|
|
job_id,
|
|
"terminal",
|
|
{
|
|
"index_job_id": job_id,
|
|
"status": "error",
|
|
"total_files": total_files,
|
|
"error": {
|
|
"code": job.error.code,
|
|
"desc": job.error.desc,
|
|
"module": job.error.module.value,
|
|
},
|
|
},
|
|
)
|