from fastapi import APIRouter from fastapi.responses import JSONResponse, StreamingResponse from app.core.exceptions import AppError from app.modules.rag_session.embedding.gigachat_embedder import GigaChatEmbedder from app.modules.rag_session.indexing_service import IndexingOrchestrator from app.modules.rag_session.job_store import IndexJobStore from app.modules.rag_session.repository import RagRepository from app.modules.rag_session.retrieval.chunker import TextChunker from app.modules.rag_session.session_store import RagSessionStore from app.modules.rag_session.service import RagService from app.modules.shared.event_bus import EventBus from app.modules.shared.gigachat.client import GigaChatClient from app.modules.shared.gigachat.settings import GigaChatSettings from app.modules.shared.gigachat.token_provider import GigaChatTokenProvider from app.modules.shared.retry_executor import RetryExecutor from app.schemas.common import ModuleName from app.schemas.indexing import ( IndexChangesRequest, IndexJobQueuedResponse, IndexJobResponse, IndexSnapshotRequest, ) from app.schemas.rag_sessions import ( RagSessionChangesRequest, RagSessionCreateRequest, RagSessionCreateResponse, RagSessionJobResponse, ) class RagModule: def __init__(self, event_bus: EventBus, retry: RetryExecutor, repository: RagRepository) -> None: self._events = event_bus self.repository = repository settings = GigaChatSettings.from_env() token_provider = GigaChatTokenProvider(settings) client = GigaChatClient(settings, token_provider) embedder = GigaChatEmbedder(client) self.embedder = embedder self.rag = RagService(embedder=embedder, repository=repository, chunker=TextChunker()) self.sessions = RagSessionStore(repository) self.jobs = IndexJobStore(repository) self.indexing = IndexingOrchestrator( store=self.jobs, rag=self.rag, events=event_bus, retry=retry, ) def public_router(self) -> APIRouter: router = APIRouter(tags=["rag"]) @router.post("/api/rag/sessions", response_model=RagSessionCreateResponse) async def create_rag_session(request: RagSessionCreateRequest) -> RagSessionCreateResponse: session = self.sessions.create(request.project_id) job = await self.indexing.enqueue_snapshot( rag_session_id=session.rag_session_id, files=[x.model_dump() for x in request.files], ) return RagSessionCreateResponse( rag_session_id=session.rag_session_id, index_job_id=job.index_job_id, status=job.status, ) @router.post("/api/rag/sessions/{rag_session_id}/changes", response_model=IndexJobQueuedResponse) async def rag_session_changes( rag_session_id: str, request: RagSessionChangesRequest, ) -> IndexJobQueuedResponse: session = self.sessions.get(rag_session_id) if not session: raise AppError("not_found", f"RAG session not found: {rag_session_id}", ModuleName.RAG) job = await self.indexing.enqueue_changes( rag_session_id=rag_session_id, changed_files=[x.model_dump() for x in request.changed_files], ) return IndexJobQueuedResponse(index_job_id=job.index_job_id, status=job.status.value) @router.get("/api/rag/sessions/{rag_session_id}/jobs/{index_job_id}", response_model=RagSessionJobResponse) async def rag_session_job(rag_session_id: str, index_job_id: str) -> RagSessionJobResponse: job = self.jobs.get(index_job_id) if not job or job.rag_session_id != rag_session_id: raise AppError("not_found", f"Index job not found: {index_job_id}", ModuleName.RAG) return RagSessionJobResponse( rag_session_id=rag_session_id, index_job_id=job.index_job_id, status=job.status, indexed_files=job.indexed_files, failed_files=job.failed_files, cache_hit_files=job.cache_hit_files, cache_miss_files=job.cache_miss_files, error=job.error.model_dump(mode="json") if job.error else None, ) @router.get("/api/rag/sessions/{rag_session_id}/jobs/{index_job_id}/events") async def rag_session_job_events(rag_session_id: str, index_job_id: str) -> StreamingResponse: job = self.jobs.get(index_job_id) if not job or job.rag_session_id != rag_session_id: raise AppError("not_found", f"Index job not found: {index_job_id}", ModuleName.RAG) queue = await self._events.subscribe(index_job_id, replay=True) async def event_stream(): import asyncio heartbeat = 10 try: while True: try: event = await asyncio.wait_for(queue.get(), timeout=heartbeat) yield EventBus.as_sse(event) if event.name == "terminal": break except asyncio.TimeoutError: yield ": keepalive\n\n" finally: await self._events.unsubscribe(index_job_id, queue) return StreamingResponse( event_stream(), media_type="text/event-stream", headers={ "Cache-Control": "no-cache, no-transform", "Connection": "keep-alive", "X-Accel-Buffering": "no", }, ) # Legacy compatibility endpoints. legacy = APIRouter(prefix="/api/index", tags=["index"]) @legacy.post("/snapshot", response_model=IndexJobQueuedResponse) async def index_snapshot(request: IndexSnapshotRequest) -> IndexJobQueuedResponse: session = self.sessions.put( rag_session_id=request.project_id, project_id=request.project_id, ) job = await self.indexing.enqueue_snapshot( rag_session_id=session.rag_session_id, files=[x.model_dump() for x in request.files], ) return IndexJobQueuedResponse(index_job_id=job.index_job_id, status=job.status.value) @legacy.post("/changes", response_model=IndexJobQueuedResponse) async def index_changes(request: IndexChangesRequest) -> IndexJobQueuedResponse: rag_session_id = request.project_id if not self.sessions.get(rag_session_id): self.sessions.put(rag_session_id=rag_session_id, project_id=rag_session_id) job = await self.indexing.enqueue_changes( rag_session_id=rag_session_id, changed_files=[x.model_dump() for x in request.changed_files], ) return IndexJobQueuedResponse(index_job_id=job.index_job_id, status=job.status.value) @legacy.get("/jobs/{index_job_id}", response_model=IndexJobResponse) async def get_index_job(index_job_id: str) -> IndexJobResponse: job = self.jobs.get(index_job_id) if not job: raise AppError("not_found", f"Index job not found: {index_job_id}", ModuleName.RAG) return IndexJobResponse( index_job_id=job.index_job_id, status=job.status, indexed_files=job.indexed_files, failed_files=job.failed_files, cache_hit_files=job.cache_hit_files, cache_miss_files=job.cache_miss_files, error=job.error, ) @legacy.get("/jobs/{index_job_id}/events") async def get_index_job_events(index_job_id: str) -> StreamingResponse: job = self.jobs.get(index_job_id) if not job: raise AppError("not_found", f"Index job not found: {index_job_id}", ModuleName.RAG) queue = await self._events.subscribe(index_job_id, replay=True) async def event_stream(): import asyncio heartbeat = 10 try: while True: try: event = await asyncio.wait_for(queue.get(), timeout=heartbeat) yield EventBus.as_sse(event) if event.name == "terminal": break except asyncio.TimeoutError: yield ": keepalive\n\n" finally: await self._events.unsubscribe(index_job_id, queue) return StreamingResponse( event_stream(), media_type="text/event-stream", headers={ "Cache-Control": "no-cache, no-transform", "Connection": "keep-alive", "X-Accel-Buffering": "no", }, ) router.include_router(legacy) return router def internal_router(self) -> APIRouter: router = APIRouter(prefix="/internal/rag", tags=["internal-rag"]) @router.post("/index/snapshot") async def index_snapshot(request: IndexSnapshotRequest) -> dict: rag_session_id = request.project_id if not self.sessions.get(rag_session_id): self.sessions.put(rag_session_id=rag_session_id, project_id=rag_session_id) indexed, failed, cache_hits, cache_misses = await self.rag.index_snapshot( rag_session_id=rag_session_id, files=[x.model_dump() for x in request.files], ) return { "indexed_files": indexed, "failed_files": failed, "cache_hit_files": cache_hits, "cache_miss_files": cache_misses, } @router.post("/index/changes") async def index_changes(request: IndexChangesRequest) -> dict: rag_session_id = request.project_id indexed, failed, cache_hits, cache_misses = await self.rag.index_changes( rag_session_id=rag_session_id, changed_files=[x.model_dump() for x in request.changed_files], ) return { "indexed_files": indexed, "failed_files": failed, "cache_hit_files": cache_hits, "cache_miss_files": cache_misses, } @router.get("/index/jobs/{index_job_id}") async def get_job(index_job_id: str) -> dict: job = self.jobs.get(index_job_id) if not job: return {"status": "not_found"} return { "index_job_id": job.index_job_id, "status": job.status.value, "indexed_files": job.indexed_files, "failed_files": job.failed_files, "cache_hit_files": job.cache_hit_files, "cache_miss_files": job.cache_miss_files, "error": job.error.model_dump(mode="json") if job.error else None, } @router.post("/retrieve") async def retrieve() -> JSONResponse: return JSONResponse( status_code=410, content={ "error": "deprecated", "message": "POST /internal/rag/retrieve is deprecated.", }, ) return router