Files
agent/app/modules/chat/module.py
2026-02-25 14:47:19 +03:00

105 lines
4.0 KiB
Python

from fastapi import APIRouter, Header
from fastapi.responses import StreamingResponse
from app.core.exceptions import AppError
from app.modules.chat.dialog_store import DialogSessionStore
from app.modules.chat.repository import ChatRepository
from app.modules.chat.service import ChatOrchestrator
from app.modules.chat.task_store import TaskStore
from app.modules.contracts import AgentRunner
from app.modules.rag.session_store import RagSessionStore
from app.modules.shared.event_bus import EventBus
from app.modules.shared.idempotency_store import IdempotencyStore
from app.modules.shared.retry_executor import RetryExecutor
from app.schemas.chat import (
ChatMessageRequest,
DialogCreateRequest,
DialogCreateResponse,
TaskQueuedResponse,
TaskResultResponse,
)
from app.schemas.common import ModuleName
class ChatModule:
def __init__(
self,
agent_runner: AgentRunner,
event_bus: EventBus,
retry: RetryExecutor,
rag_sessions: RagSessionStore,
repository: ChatRepository,
) -> None:
self._rag_sessions = rag_sessions
self.tasks = TaskStore()
self.dialogs = DialogSessionStore(repository)
self.idempotency = IdempotencyStore()
self.events = event_bus
self.chat = ChatOrchestrator(
task_store=self.tasks,
dialogs=self.dialogs,
idempotency=self.idempotency,
runtime=agent_runner,
events=self.events,
retry=retry,
rag_session_exists=lambda rag_session_id: rag_sessions.get(rag_session_id) is not None,
message_sink=repository.add_message,
)
def public_router(self) -> APIRouter:
router = APIRouter(tags=["chat"])
@router.post("/api/chat/dialogs", response_model=DialogCreateResponse)
async def create_dialog(request: DialogCreateRequest) -> DialogCreateResponse:
if not self._rag_sessions.get(request.rag_session_id):
raise AppError("rag_session_not_found", "RAG session not found", ModuleName.RAG)
dialog = self.dialogs.create(request.rag_session_id)
return DialogCreateResponse(
dialog_session_id=dialog.dialog_session_id,
rag_session_id=dialog.rag_session_id,
)
@router.post("/api/chat/messages", response_model=TaskQueuedResponse)
async def send_message(
request: ChatMessageRequest,
idempotency_key: str | None = Header(default=None, alias="Idempotency-Key"),
) -> TaskQueuedResponse:
task = await self.chat.enqueue_message(request, idempotency_key)
return TaskQueuedResponse(task_id=task.task_id, status=task.status.value)
@router.get("/api/tasks/{task_id}", response_model=TaskResultResponse)
async def get_task(task_id: str) -> TaskResultResponse:
task = self.tasks.get(task_id)
if not task:
raise AppError("not_found", f"Task not found: {task_id}", ModuleName.BACKEND)
return TaskResultResponse(
task_id=task.task_id,
status=task.status,
result_type=task.result_type,
answer=task.answer,
changeset=task.changeset,
error=task.error,
)
@router.get("/api/events")
async def stream_events(task_id: str) -> StreamingResponse:
queue = await self.events.subscribe(task_id)
async def event_stream():
import asyncio
heartbeat = 10
try:
while True:
try:
event = await asyncio.wait_for(queue.get(), timeout=heartbeat)
yield EventBus.as_sse(event)
except asyncio.TimeoutError:
yield ": keepalive\\n\\n"
finally:
await self.events.unsubscribe(task_id, queue)
return StreamingResponse(event_stream(), media_type="text/event-stream")
return router