from __future__ import annotations import os from typing import TYPE_CHECKING from fastapi import APIRouter, Header from fastapi.responses import StreamingResponse from app.core.exceptions import AppError from app.modules.chat.direct_service import CodeExplainChatService from app.modules.chat.dialog_store import DialogSessionStore from app.modules.chat.service import ChatOrchestrator from app.modules.chat.task_store import TaskStore from app.modules.shared.event_bus import EventBus from app.modules.shared.idempotency_store import IdempotencyStore from app.modules.shared.retry_executor import RetryExecutor from app.schemas.chat import ( ChatMessageRequest, DialogCreateRequest, DialogCreateResponse, TaskQueuedResponse, TaskResultResponse, ) from app.schemas.common import ModuleName if TYPE_CHECKING: from app.modules.chat.repository import ChatRepository from app.modules.contracts import AgentRunner from app.modules.rag_session.session_store import RagSessionStore class ChatModule: def __init__( self, agent_runner: AgentRunner, event_bus: EventBus, retry: RetryExecutor, rag_sessions: RagSessionStore, repository: ChatRepository, direct_chat: CodeExplainChatService | None = None, task_store: TaskStore | None = None, ) -> None: self._rag_sessions = rag_sessions self._simple_code_explain_only = os.getenv("SIMPLE_CODE_EXPLAIN_ONLY", "true").lower() in {"1", "true", "yes"} self.tasks = task_store or TaskStore() self.dialogs = DialogSessionStore(repository) self.idempotency = IdempotencyStore() self.events = event_bus self.direct_chat = direct_chat self.chat = ChatOrchestrator( task_store=self.tasks, dialogs=self.dialogs, idempotency=self.idempotency, runtime=agent_runner, events=self.events, retry=retry, rag_session_exists=lambda rag_session_id: rag_sessions.get(rag_session_id) is not None, message_sink=repository.add_message, ) def public_router(self) -> APIRouter: router = APIRouter(tags=["chat"]) @router.post("/api/chat/dialogs", response_model=DialogCreateResponse) async def create_dialog(request: DialogCreateRequest) -> DialogCreateResponse: if not self._rag_sessions.get(request.rag_session_id): raise AppError("rag_session_not_found", "RAG session not found", ModuleName.RAG) dialog = self.dialogs.create(request.rag_session_id) return DialogCreateResponse( dialog_session_id=dialog.dialog_session_id, rag_session_id=dialog.rag_session_id, ) @router.post("/api/chat/messages", response_model=TaskQueuedResponse | TaskResultResponse) async def send_message( request: ChatMessageRequest, idempotency_key: str | None = Header(default=None, alias="Idempotency-Key"), ) -> TaskQueuedResponse | TaskResultResponse: if self._simple_code_explain_only and self.direct_chat is not None: return await self.direct_chat.handle_message(request) task = await self.chat.enqueue_message(request, idempotency_key) return TaskQueuedResponse(task_id=task.task_id, status=task.status.value) @router.get("/api/tasks/{task_id}", response_model=TaskResultResponse) async def get_task(task_id: str) -> TaskResultResponse: task = self.tasks.get(task_id) if not task: raise AppError("not_found", f"Task not found: {task_id}", ModuleName.BACKEND) return TaskResultResponse( task_id=task.task_id, status=task.status, result_type=task.result_type, answer=task.answer, changeset=task.changeset, error=task.error, ) @router.get("/api/events") async def stream_events(task_id: str) -> StreamingResponse: queue = await self.events.subscribe(task_id) async def event_stream(): import asyncio heartbeat = 10 try: while True: try: event = await asyncio.wait_for(queue.get(), timeout=heartbeat) yield EventBus.as_sse(event) except asyncio.TimeoutError: yield ": keepalive\\n\\n" finally: await self.events.unsubscribe(task_id, queue) return StreamingResponse(event_stream(), media_type="text/event-stream") return router