119 lines
4.7 KiB
Python
119 lines
4.7 KiB
Python
from __future__ import annotations
|
|
|
|
import os
|
|
from typing import TYPE_CHECKING
|
|
|
|
from fastapi import APIRouter, Header
|
|
from fastapi.responses import StreamingResponse
|
|
|
|
from app.core.exceptions import AppError
|
|
from app.modules.chat.direct_service import CodeExplainChatService
|
|
from app.modules.chat.dialog_store import DialogSessionStore
|
|
from app.modules.chat.service import ChatOrchestrator
|
|
from app.modules.chat.task_store import TaskStore
|
|
from app.modules.shared.event_bus import EventBus
|
|
from app.modules.shared.idempotency_store import IdempotencyStore
|
|
from app.modules.shared.retry_executor import RetryExecutor
|
|
from app.schemas.chat import (
|
|
ChatMessageRequest,
|
|
DialogCreateRequest,
|
|
DialogCreateResponse,
|
|
TaskQueuedResponse,
|
|
TaskResultResponse,
|
|
)
|
|
from app.schemas.common import ModuleName
|
|
|
|
if TYPE_CHECKING:
|
|
from app.modules.chat.repository import ChatRepository
|
|
from app.modules.contracts import AgentRunner
|
|
from app.modules.rag_session.session_store import RagSessionStore
|
|
|
|
|
|
class ChatModule:
|
|
def __init__(
|
|
self,
|
|
agent_runner: AgentRunner,
|
|
event_bus: EventBus,
|
|
retry: RetryExecutor,
|
|
rag_sessions: RagSessionStore,
|
|
repository: ChatRepository,
|
|
direct_chat: CodeExplainChatService | None = None,
|
|
task_store: TaskStore | None = None,
|
|
) -> None:
|
|
self._rag_sessions = rag_sessions
|
|
self._simple_code_explain_only = os.getenv("SIMPLE_CODE_EXPLAIN_ONLY", "true").lower() in {"1", "true", "yes"}
|
|
self.tasks = task_store or TaskStore()
|
|
self.dialogs = DialogSessionStore(repository)
|
|
self.idempotency = IdempotencyStore()
|
|
self.events = event_bus
|
|
self.direct_chat = direct_chat
|
|
self.chat = ChatOrchestrator(
|
|
task_store=self.tasks,
|
|
dialogs=self.dialogs,
|
|
idempotency=self.idempotency,
|
|
runtime=agent_runner,
|
|
events=self.events,
|
|
retry=retry,
|
|
rag_session_exists=lambda rag_session_id: rag_sessions.get(rag_session_id) is not None,
|
|
message_sink=repository.add_message,
|
|
)
|
|
|
|
def public_router(self) -> APIRouter:
|
|
router = APIRouter(tags=["chat"])
|
|
|
|
@router.post("/api/chat/dialogs", response_model=DialogCreateResponse)
|
|
async def create_dialog(request: DialogCreateRequest) -> DialogCreateResponse:
|
|
if not self._rag_sessions.get(request.rag_session_id):
|
|
raise AppError("rag_session_not_found", "RAG session not found", ModuleName.RAG)
|
|
dialog = self.dialogs.create(request.rag_session_id)
|
|
return DialogCreateResponse(
|
|
dialog_session_id=dialog.dialog_session_id,
|
|
rag_session_id=dialog.rag_session_id,
|
|
)
|
|
|
|
@router.post("/api/chat/messages", response_model=TaskQueuedResponse | TaskResultResponse)
|
|
async def send_message(
|
|
request: ChatMessageRequest,
|
|
idempotency_key: str | None = Header(default=None, alias="Idempotency-Key"),
|
|
) -> TaskQueuedResponse | TaskResultResponse:
|
|
if self._simple_code_explain_only and self.direct_chat is not None:
|
|
return await self.direct_chat.handle_message(request)
|
|
task = await self.chat.enqueue_message(request, idempotency_key)
|
|
return TaskQueuedResponse(task_id=task.task_id, status=task.status.value)
|
|
|
|
@router.get("/api/tasks/{task_id}", response_model=TaskResultResponse)
|
|
async def get_task(task_id: str) -> TaskResultResponse:
|
|
task = self.tasks.get(task_id)
|
|
if not task:
|
|
raise AppError("not_found", f"Task not found: {task_id}", ModuleName.BACKEND)
|
|
return TaskResultResponse(
|
|
task_id=task.task_id,
|
|
status=task.status,
|
|
result_type=task.result_type,
|
|
answer=task.answer,
|
|
changeset=task.changeset,
|
|
error=task.error,
|
|
)
|
|
|
|
@router.get("/api/events")
|
|
async def stream_events(task_id: str) -> StreamingResponse:
|
|
queue = await self.events.subscribe(task_id)
|
|
|
|
async def event_stream():
|
|
import asyncio
|
|
|
|
heartbeat = 10
|
|
try:
|
|
while True:
|
|
try:
|
|
event = await asyncio.wait_for(queue.get(), timeout=heartbeat)
|
|
yield EventBus.as_sse(event)
|
|
except asyncio.TimeoutError:
|
|
yield ": keepalive\\n\\n"
|
|
finally:
|
|
await self.events.unsubscribe(task_id, queue)
|
|
|
|
return StreamingResponse(event_stream(), media_type="text/event-stream")
|
|
|
|
return router
|