264 lines
12 KiB
Python
264 lines
12 KiB
Python
from fastapi import APIRouter
|
|
from fastapi.responses import StreamingResponse
|
|
|
|
from app.core.exceptions import AppError
|
|
from app.modules.rag_session.embedding.gigachat_embedder import GigaChatEmbedder
|
|
from app.modules.rag_session.indexing_service import IndexingOrchestrator
|
|
from app.modules.rag_session.job_store import IndexJobStore
|
|
from app.modules.rag_session.repository import RagRepository
|
|
from app.modules.rag_session.retrieval.chunker import TextChunker
|
|
from app.modules.rag_session.session_store import RagSessionStore
|
|
from app.modules.rag_session.service import RagService
|
|
from app.modules.shared.event_bus import EventBus
|
|
from app.modules.shared.gigachat.client import GigaChatClient
|
|
from app.modules.shared.gigachat.settings import GigaChatSettings
|
|
from app.modules.shared.gigachat.token_provider import GigaChatTokenProvider
|
|
from app.modules.shared.retry_executor import RetryExecutor
|
|
from app.schemas.common import ModuleName
|
|
from app.schemas.indexing import (
|
|
IndexChangesRequest,
|
|
IndexJobQueuedResponse,
|
|
IndexJobResponse,
|
|
IndexSnapshotRequest,
|
|
)
|
|
from app.schemas.rag_sessions import (
|
|
RagSessionChangesRequest,
|
|
RagSessionCreateRequest,
|
|
RagSessionCreateResponse,
|
|
RagSessionJobResponse,
|
|
)
|
|
|
|
|
|
class RagModule:
|
|
def __init__(self, event_bus: EventBus, retry: RetryExecutor, repository: RagRepository) -> None:
|
|
self._events = event_bus
|
|
self.repository = repository
|
|
settings = GigaChatSettings.from_env()
|
|
token_provider = GigaChatTokenProvider(settings)
|
|
client = GigaChatClient(settings, token_provider)
|
|
embedder = GigaChatEmbedder(client)
|
|
self.rag = RagService(embedder=embedder, repository=repository, chunker=TextChunker())
|
|
self.sessions = RagSessionStore(repository)
|
|
self.jobs = IndexJobStore(repository)
|
|
self.indexing = IndexingOrchestrator(
|
|
store=self.jobs,
|
|
rag=self.rag,
|
|
events=event_bus,
|
|
retry=retry,
|
|
)
|
|
|
|
def public_router(self) -> APIRouter:
|
|
router = APIRouter(tags=["rag"])
|
|
|
|
@router.post("/api/rag/sessions", response_model=RagSessionCreateResponse)
|
|
async def create_rag_session(request: RagSessionCreateRequest) -> RagSessionCreateResponse:
|
|
session = self.sessions.create(request.project_id)
|
|
job = await self.indexing.enqueue_snapshot(
|
|
rag_session_id=session.rag_session_id,
|
|
files=[x.model_dump() for x in request.files],
|
|
)
|
|
return RagSessionCreateResponse(
|
|
rag_session_id=session.rag_session_id,
|
|
index_job_id=job.index_job_id,
|
|
status=job.status,
|
|
)
|
|
|
|
@router.post("/api/rag/sessions/{rag_session_id}/changes", response_model=IndexJobQueuedResponse)
|
|
async def rag_session_changes(
|
|
rag_session_id: str,
|
|
request: RagSessionChangesRequest,
|
|
) -> IndexJobQueuedResponse:
|
|
session = self.sessions.get(rag_session_id)
|
|
if not session:
|
|
raise AppError("not_found", f"RAG session not found: {rag_session_id}", ModuleName.RAG)
|
|
job = await self.indexing.enqueue_changes(
|
|
rag_session_id=rag_session_id,
|
|
changed_files=[x.model_dump() for x in request.changed_files],
|
|
)
|
|
return IndexJobQueuedResponse(index_job_id=job.index_job_id, status=job.status.value)
|
|
|
|
@router.get("/api/rag/sessions/{rag_session_id}/jobs/{index_job_id}", response_model=RagSessionJobResponse)
|
|
async def rag_session_job(rag_session_id: str, index_job_id: str) -> RagSessionJobResponse:
|
|
job = self.jobs.get(index_job_id)
|
|
if not job or job.rag_session_id != rag_session_id:
|
|
raise AppError("not_found", f"Index job not found: {index_job_id}", ModuleName.RAG)
|
|
return RagSessionJobResponse(
|
|
rag_session_id=rag_session_id,
|
|
index_job_id=job.index_job_id,
|
|
status=job.status,
|
|
indexed_files=job.indexed_files,
|
|
failed_files=job.failed_files,
|
|
cache_hit_files=job.cache_hit_files,
|
|
cache_miss_files=job.cache_miss_files,
|
|
error=job.error.model_dump(mode="json") if job.error else None,
|
|
)
|
|
|
|
@router.get("/api/rag/sessions/{rag_session_id}/jobs/{index_job_id}/events")
|
|
async def rag_session_job_events(rag_session_id: str, index_job_id: str) -> StreamingResponse:
|
|
job = self.jobs.get(index_job_id)
|
|
if not job or job.rag_session_id != rag_session_id:
|
|
raise AppError("not_found", f"Index job not found: {index_job_id}", ModuleName.RAG)
|
|
queue = await self._events.subscribe(index_job_id, replay=True)
|
|
|
|
async def event_stream():
|
|
import asyncio
|
|
|
|
heartbeat = 10
|
|
try:
|
|
while True:
|
|
try:
|
|
event = await asyncio.wait_for(queue.get(), timeout=heartbeat)
|
|
yield EventBus.as_sse(event)
|
|
if event.name == "terminal":
|
|
break
|
|
except asyncio.TimeoutError:
|
|
yield ": keepalive\n\n"
|
|
finally:
|
|
await self._events.unsubscribe(index_job_id, queue)
|
|
|
|
return StreamingResponse(
|
|
event_stream(),
|
|
media_type="text/event-stream",
|
|
headers={
|
|
"Cache-Control": "no-cache, no-transform",
|
|
"Connection": "keep-alive",
|
|
"X-Accel-Buffering": "no",
|
|
},
|
|
)
|
|
|
|
# Legacy compatibility endpoints.
|
|
legacy = APIRouter(prefix="/api/index", tags=["index"])
|
|
|
|
@legacy.post("/snapshot", response_model=IndexJobQueuedResponse)
|
|
async def index_snapshot(request: IndexSnapshotRequest) -> IndexJobQueuedResponse:
|
|
session = self.sessions.put(
|
|
rag_session_id=request.project_id,
|
|
project_id=request.project_id,
|
|
)
|
|
job = await self.indexing.enqueue_snapshot(
|
|
rag_session_id=session.rag_session_id,
|
|
files=[x.model_dump() for x in request.files],
|
|
)
|
|
return IndexJobQueuedResponse(index_job_id=job.index_job_id, status=job.status.value)
|
|
|
|
@legacy.post("/changes", response_model=IndexJobQueuedResponse)
|
|
async def index_changes(request: IndexChangesRequest) -> IndexJobQueuedResponse:
|
|
rag_session_id = request.project_id
|
|
if not self.sessions.get(rag_session_id):
|
|
self.sessions.put(rag_session_id=rag_session_id, project_id=rag_session_id)
|
|
job = await self.indexing.enqueue_changes(
|
|
rag_session_id=rag_session_id,
|
|
changed_files=[x.model_dump() for x in request.changed_files],
|
|
)
|
|
return IndexJobQueuedResponse(index_job_id=job.index_job_id, status=job.status.value)
|
|
|
|
@legacy.get("/jobs/{index_job_id}", response_model=IndexJobResponse)
|
|
async def get_index_job(index_job_id: str) -> IndexJobResponse:
|
|
job = self.jobs.get(index_job_id)
|
|
if not job:
|
|
raise AppError("not_found", f"Index job not found: {index_job_id}", ModuleName.RAG)
|
|
return IndexJobResponse(
|
|
index_job_id=job.index_job_id,
|
|
status=job.status,
|
|
indexed_files=job.indexed_files,
|
|
failed_files=job.failed_files,
|
|
cache_hit_files=job.cache_hit_files,
|
|
cache_miss_files=job.cache_miss_files,
|
|
error=job.error,
|
|
)
|
|
|
|
@legacy.get("/jobs/{index_job_id}/events")
|
|
async def get_index_job_events(index_job_id: str) -> StreamingResponse:
|
|
job = self.jobs.get(index_job_id)
|
|
if not job:
|
|
raise AppError("not_found", f"Index job not found: {index_job_id}", ModuleName.RAG)
|
|
queue = await self._events.subscribe(index_job_id, replay=True)
|
|
|
|
async def event_stream():
|
|
import asyncio
|
|
|
|
heartbeat = 10
|
|
try:
|
|
while True:
|
|
try:
|
|
event = await asyncio.wait_for(queue.get(), timeout=heartbeat)
|
|
yield EventBus.as_sse(event)
|
|
if event.name == "terminal":
|
|
break
|
|
except asyncio.TimeoutError:
|
|
yield ": keepalive\n\n"
|
|
finally:
|
|
await self._events.unsubscribe(index_job_id, queue)
|
|
|
|
return StreamingResponse(
|
|
event_stream(),
|
|
media_type="text/event-stream",
|
|
headers={
|
|
"Cache-Control": "no-cache, no-transform",
|
|
"Connection": "keep-alive",
|
|
"X-Accel-Buffering": "no",
|
|
},
|
|
)
|
|
|
|
router.include_router(legacy)
|
|
return router
|
|
|
|
def internal_router(self) -> APIRouter:
|
|
router = APIRouter(prefix="/internal/rag", tags=["internal-rag"])
|
|
|
|
@router.post("/index/snapshot")
|
|
async def index_snapshot(request: IndexSnapshotRequest) -> dict:
|
|
rag_session_id = request.project_id
|
|
if not self.sessions.get(rag_session_id):
|
|
self.sessions.put(rag_session_id=rag_session_id, project_id=rag_session_id)
|
|
indexed, failed, cache_hits, cache_misses = await self.rag.index_snapshot(
|
|
rag_session_id=rag_session_id,
|
|
files=[x.model_dump() for x in request.files],
|
|
)
|
|
return {
|
|
"indexed_files": indexed,
|
|
"failed_files": failed,
|
|
"cache_hit_files": cache_hits,
|
|
"cache_miss_files": cache_misses,
|
|
}
|
|
|
|
@router.post("/index/changes")
|
|
async def index_changes(request: IndexChangesRequest) -> dict:
|
|
rag_session_id = request.project_id
|
|
indexed, failed, cache_hits, cache_misses = await self.rag.index_changes(
|
|
rag_session_id=rag_session_id,
|
|
changed_files=[x.model_dump() for x in request.changed_files],
|
|
)
|
|
return {
|
|
"indexed_files": indexed,
|
|
"failed_files": failed,
|
|
"cache_hit_files": cache_hits,
|
|
"cache_miss_files": cache_misses,
|
|
}
|
|
|
|
@router.get("/index/jobs/{index_job_id}")
|
|
async def get_job(index_job_id: str) -> dict:
|
|
job = self.jobs.get(index_job_id)
|
|
if not job:
|
|
return {"status": "not_found"}
|
|
return {
|
|
"index_job_id": job.index_job_id,
|
|
"status": job.status.value,
|
|
"indexed_files": job.indexed_files,
|
|
"failed_files": job.failed_files,
|
|
"cache_hit_files": job.cache_hit_files,
|
|
"cache_miss_files": job.cache_miss_files,
|
|
"error": job.error.model_dump(mode="json") if job.error else None,
|
|
}
|
|
|
|
@router.post("/retrieve")
|
|
async def retrieve(payload: dict) -> dict:
|
|
rag_session_id = payload.get("rag_session_id") or payload.get("project_id", "")
|
|
ctx = await self.rag.retrieve(
|
|
rag_session_id=rag_session_id,
|
|
query=payload.get("query", ""),
|
|
)
|
|
return {"items": ctx}
|
|
|
|
return router
|