Первый коммит
This commit is contained in:
247
app/modules/rag/module.py
Normal file
247
app/modules/rag/module.py
Normal file
@@ -0,0 +1,247 @@
|
||||
from fastapi import APIRouter
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from app.core.exceptions import AppError
|
||||
from app.modules.rag.embedding.gigachat_embedder import GigaChatEmbedder
|
||||
from app.modules.rag.indexing_service import IndexingOrchestrator
|
||||
from app.modules.rag.job_store import IndexJobStore
|
||||
from app.modules.rag.repository import RagRepository
|
||||
from app.modules.rag.retrieval.chunker import TextChunker
|
||||
from app.modules.rag.session_store import RagSessionStore
|
||||
from app.modules.rag.service import RagService
|
||||
from app.modules.shared.event_bus import EventBus
|
||||
from app.modules.shared.gigachat.client import GigaChatClient
|
||||
from app.modules.shared.gigachat.settings import GigaChatSettings
|
||||
from app.modules.shared.gigachat.token_provider import GigaChatTokenProvider
|
||||
from app.modules.shared.retry_executor import RetryExecutor
|
||||
from app.schemas.common import ModuleName
|
||||
from app.schemas.indexing import (
|
||||
IndexChangesRequest,
|
||||
IndexJobQueuedResponse,
|
||||
IndexJobResponse,
|
||||
IndexSnapshotRequest,
|
||||
)
|
||||
from app.schemas.rag_sessions import (
|
||||
RagSessionChangesRequest,
|
||||
RagSessionCreateRequest,
|
||||
RagSessionCreateResponse,
|
||||
RagSessionJobResponse,
|
||||
)
|
||||
|
||||
|
||||
class RagModule:
|
||||
def __init__(self, event_bus: EventBus, retry: RetryExecutor, repository: RagRepository) -> None:
|
||||
self._events = event_bus
|
||||
self.repository = repository
|
||||
settings = GigaChatSettings.from_env()
|
||||
token_provider = GigaChatTokenProvider(settings)
|
||||
client = GigaChatClient(settings, token_provider)
|
||||
embedder = GigaChatEmbedder(client)
|
||||
self.rag = RagService(embedder=embedder, repository=repository, chunker=TextChunker())
|
||||
self.sessions = RagSessionStore(repository)
|
||||
self.jobs = IndexJobStore(repository)
|
||||
self.indexing = IndexingOrchestrator(
|
||||
store=self.jobs,
|
||||
rag=self.rag,
|
||||
events=event_bus,
|
||||
retry=retry,
|
||||
)
|
||||
|
||||
def public_router(self) -> APIRouter:
|
||||
router = APIRouter(tags=["rag"])
|
||||
|
||||
@router.post("/api/rag/sessions", response_model=RagSessionCreateResponse)
|
||||
async def create_rag_session(request: RagSessionCreateRequest) -> RagSessionCreateResponse:
|
||||
session = self.sessions.create(request.project_id)
|
||||
job = await self.indexing.enqueue_snapshot(
|
||||
rag_session_id=session.rag_session_id,
|
||||
files=[x.model_dump() for x in request.files],
|
||||
)
|
||||
return RagSessionCreateResponse(
|
||||
rag_session_id=session.rag_session_id,
|
||||
index_job_id=job.index_job_id,
|
||||
status=job.status,
|
||||
)
|
||||
|
||||
@router.post("/api/rag/sessions/{rag_session_id}/changes", response_model=IndexJobQueuedResponse)
|
||||
async def rag_session_changes(
|
||||
rag_session_id: str,
|
||||
request: RagSessionChangesRequest,
|
||||
) -> IndexJobQueuedResponse:
|
||||
session = self.sessions.get(rag_session_id)
|
||||
if not session:
|
||||
raise AppError("not_found", f"RAG session not found: {rag_session_id}", ModuleName.RAG)
|
||||
job = await self.indexing.enqueue_changes(
|
||||
rag_session_id=rag_session_id,
|
||||
changed_files=[x.model_dump() for x in request.changed_files],
|
||||
)
|
||||
return IndexJobQueuedResponse(index_job_id=job.index_job_id, status=job.status.value)
|
||||
|
||||
@router.get("/api/rag/sessions/{rag_session_id}/jobs/{index_job_id}", response_model=RagSessionJobResponse)
|
||||
async def rag_session_job(rag_session_id: str, index_job_id: str) -> RagSessionJobResponse:
|
||||
job = self.jobs.get(index_job_id)
|
||||
if not job or job.rag_session_id != rag_session_id:
|
||||
raise AppError("not_found", f"Index job not found: {index_job_id}", ModuleName.RAG)
|
||||
return RagSessionJobResponse(
|
||||
rag_session_id=rag_session_id,
|
||||
index_job_id=job.index_job_id,
|
||||
status=job.status,
|
||||
indexed_files=job.indexed_files,
|
||||
failed_files=job.failed_files,
|
||||
error=job.error.model_dump(mode="json") if job.error else None,
|
||||
)
|
||||
|
||||
@router.get("/api/rag/sessions/{rag_session_id}/jobs/{index_job_id}/events")
|
||||
async def rag_session_job_events(rag_session_id: str, index_job_id: str) -> StreamingResponse:
|
||||
job = self.jobs.get(index_job_id)
|
||||
if not job or job.rag_session_id != rag_session_id:
|
||||
raise AppError("not_found", f"Index job not found: {index_job_id}", ModuleName.RAG)
|
||||
queue = await self._events.subscribe(index_job_id, replay=True)
|
||||
|
||||
async def event_stream():
|
||||
import asyncio
|
||||
|
||||
heartbeat = 10
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
event = await asyncio.wait_for(queue.get(), timeout=heartbeat)
|
||||
yield EventBus.as_sse(event)
|
||||
if event.name == "terminal":
|
||||
break
|
||||
except asyncio.TimeoutError:
|
||||
yield ": keepalive\n\n"
|
||||
finally:
|
||||
await self._events.unsubscribe(index_job_id, queue)
|
||||
|
||||
return StreamingResponse(
|
||||
event_stream(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache, no-transform",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
||||
# Legacy compatibility endpoints.
|
||||
legacy = APIRouter(prefix="/api/index", tags=["index"])
|
||||
|
||||
@legacy.post("/snapshot", response_model=IndexJobQueuedResponse)
|
||||
async def index_snapshot(request: IndexSnapshotRequest) -> IndexJobQueuedResponse:
|
||||
session = self.sessions.put(
|
||||
rag_session_id=request.project_id,
|
||||
project_id=request.project_id,
|
||||
)
|
||||
job = await self.indexing.enqueue_snapshot(
|
||||
rag_session_id=session.rag_session_id,
|
||||
files=[x.model_dump() for x in request.files],
|
||||
)
|
||||
return IndexJobQueuedResponse(index_job_id=job.index_job_id, status=job.status.value)
|
||||
|
||||
@legacy.post("/changes", response_model=IndexJobQueuedResponse)
|
||||
async def index_changes(request: IndexChangesRequest) -> IndexJobQueuedResponse:
|
||||
rag_session_id = request.project_id
|
||||
if not self.sessions.get(rag_session_id):
|
||||
self.sessions.put(rag_session_id=rag_session_id, project_id=rag_session_id)
|
||||
job = await self.indexing.enqueue_changes(
|
||||
rag_session_id=rag_session_id,
|
||||
changed_files=[x.model_dump() for x in request.changed_files],
|
||||
)
|
||||
return IndexJobQueuedResponse(index_job_id=job.index_job_id, status=job.status.value)
|
||||
|
||||
@legacy.get("/jobs/{index_job_id}", response_model=IndexJobResponse)
|
||||
async def get_index_job(index_job_id: str) -> IndexJobResponse:
|
||||
job = self.jobs.get(index_job_id)
|
||||
if not job:
|
||||
raise AppError("not_found", f"Index job not found: {index_job_id}", ModuleName.RAG)
|
||||
return IndexJobResponse(
|
||||
index_job_id=job.index_job_id,
|
||||
status=job.status,
|
||||
indexed_files=job.indexed_files,
|
||||
failed_files=job.failed_files,
|
||||
error=job.error,
|
||||
)
|
||||
|
||||
@legacy.get("/jobs/{index_job_id}/events")
|
||||
async def get_index_job_events(index_job_id: str) -> StreamingResponse:
|
||||
job = self.jobs.get(index_job_id)
|
||||
if not job:
|
||||
raise AppError("not_found", f"Index job not found: {index_job_id}", ModuleName.RAG)
|
||||
queue = await self._events.subscribe(index_job_id, replay=True)
|
||||
|
||||
async def event_stream():
|
||||
import asyncio
|
||||
|
||||
heartbeat = 10
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
event = await asyncio.wait_for(queue.get(), timeout=heartbeat)
|
||||
yield EventBus.as_sse(event)
|
||||
if event.name == "terminal":
|
||||
break
|
||||
except asyncio.TimeoutError:
|
||||
yield ": keepalive\n\n"
|
||||
finally:
|
||||
await self._events.unsubscribe(index_job_id, queue)
|
||||
|
||||
return StreamingResponse(
|
||||
event_stream(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache, no-transform",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
||||
router.include_router(legacy)
|
||||
return router
|
||||
|
||||
def internal_router(self) -> APIRouter:
|
||||
router = APIRouter(prefix="/internal/rag", tags=["internal-rag"])
|
||||
|
||||
@router.post("/index/snapshot")
|
||||
async def index_snapshot(request: IndexSnapshotRequest) -> dict:
|
||||
rag_session_id = request.project_id
|
||||
if not self.sessions.get(rag_session_id):
|
||||
self.sessions.put(rag_session_id=rag_session_id, project_id=rag_session_id)
|
||||
indexed, failed = await self.rag.index_snapshot(
|
||||
rag_session_id=rag_session_id,
|
||||
files=[x.model_dump() for x in request.files],
|
||||
)
|
||||
return {"indexed_files": indexed, "failed_files": failed}
|
||||
|
||||
@router.post("/index/changes")
|
||||
async def index_changes(request: IndexChangesRequest) -> dict:
|
||||
rag_session_id = request.project_id
|
||||
indexed, failed = await self.rag.index_changes(
|
||||
rag_session_id=rag_session_id,
|
||||
changed_files=[x.model_dump() for x in request.changed_files],
|
||||
)
|
||||
return {"indexed_files": indexed, "failed_files": failed}
|
||||
|
||||
@router.get("/index/jobs/{index_job_id}")
|
||||
async def get_job(index_job_id: str) -> dict:
|
||||
job = self.jobs.get(index_job_id)
|
||||
if not job:
|
||||
return {"status": "not_found"}
|
||||
return {
|
||||
"index_job_id": job.index_job_id,
|
||||
"status": job.status.value,
|
||||
"indexed_files": job.indexed_files,
|
||||
"failed_files": job.failed_files,
|
||||
"error": job.error.model_dump(mode="json") if job.error else None,
|
||||
}
|
||||
|
||||
@router.post("/retrieve")
|
||||
async def retrieve(payload: dict) -> dict:
|
||||
rag_session_id = payload.get("rag_session_id") or payload.get("project_id", "")
|
||||
ctx = await self.rag.retrieve(
|
||||
rag_session_id=rag_session_id,
|
||||
query=payload.get("query", ""),
|
||||
)
|
||||
return {"items": ctx}
|
||||
|
||||
return router
|
||||
Reference in New Issue
Block a user