37 lines
1.5 KiB
Python
37 lines
1.5 KiB
Python
from __future__ import annotations
|
|
|
|
from typing import TYPE_CHECKING
|
|
|
|
from app.core.exceptions import AppError
|
|
from app.schemas.chat import ChatMessageRequest
|
|
from app.schemas.common import ModuleName
|
|
|
|
if TYPE_CHECKING:
|
|
from app.modules.chat.dialog_store import DialogSessionStore
|
|
|
|
|
|
class ChatSessionResolver:
|
|
def __init__(self, dialogs: DialogSessionStore, rag_session_exists) -> None:
|
|
self._dialogs = dialogs
|
|
self._rag_session_exists = rag_session_exists
|
|
|
|
def resolve(self, request: ChatMessageRequest) -> tuple[str, str]:
|
|
if request.dialog_session_id and request.rag_session_id:
|
|
dialog = self._dialogs.get(request.dialog_session_id)
|
|
if not dialog:
|
|
raise AppError("dialog_not_found", "Dialog session not found", ModuleName.BACKEND)
|
|
if dialog.rag_session_id != request.rag_session_id:
|
|
raise AppError("dialog_rag_mismatch", "Dialog session does not belong to rag session", ModuleName.BACKEND)
|
|
return request.dialog_session_id, request.rag_session_id
|
|
|
|
if request.session_id and request.project_id:
|
|
if not self._rag_session_exists(request.project_id):
|
|
raise AppError("rag_session_not_found", "RAG session not found", ModuleName.RAG)
|
|
return request.session_id, request.project_id
|
|
|
|
raise AppError(
|
|
"missing_sessions",
|
|
"dialog_session_id and rag_session_id are required",
|
|
ModuleName.BACKEND,
|
|
)
|