71 lines
2.0 KiB
Python
71 lines
2.0 KiB
Python
import asyncio
|
|
|
|
from app.modules.chat.module import ChatModule
|
|
from app.modules.chat.task_store import TaskStore
|
|
from app.schemas.chat import ChatMessageRequest
|
|
from app.schemas.chat import TaskQueuedResponse
|
|
from app.modules.shared.event_bus import EventBus
|
|
from app.modules.shared.retry_executor import RetryExecutor
|
|
|
|
|
|
class _FakeRuntime:
|
|
async def run(self, **kwargs):
|
|
raise AssertionError("legacy runtime must not be called")
|
|
|
|
|
|
class _FakeDirectChat:
|
|
def __init__(self) -> None:
|
|
self.calls = 0
|
|
|
|
async def handle_message(self, request):
|
|
self.calls += 1
|
|
return TaskQueuedResponse(
|
|
task_id="task-1",
|
|
status="done",
|
|
)
|
|
|
|
|
|
class _FakeRagSessions:
|
|
def get(self, rag_session_id: str):
|
|
return {"rag_session_id": rag_session_id}
|
|
|
|
|
|
class _FakeRepository:
|
|
def create_dialog(self, dialog_session_id: str, rag_session_id: str) -> None:
|
|
return None
|
|
|
|
def get_dialog(self, dialog_session_id: str):
|
|
return None
|
|
|
|
def add_message(self, dialog_session_id: str, role: str, content: str, task_id: str | None = None, payload: dict | None = None) -> None:
|
|
return None
|
|
|
|
|
|
def test_chat_messages_endpoint_uses_direct_service(monkeypatch) -> None:
|
|
monkeypatch.setenv("SIMPLE_CODE_EXPLAIN_ONLY", "true")
|
|
direct_chat = _FakeDirectChat()
|
|
module = ChatModule(
|
|
agent_runner=_FakeRuntime(),
|
|
event_bus=EventBus(),
|
|
retry=RetryExecutor(),
|
|
rag_sessions=_FakeRagSessions(),
|
|
repository=_FakeRepository(),
|
|
direct_chat=direct_chat,
|
|
task_store=TaskStore(),
|
|
)
|
|
router = module.public_router()
|
|
endpoint = next(route.endpoint for route in router.routes if getattr(route, "path", "") == "/api/chat/messages")
|
|
response = asyncio.run(
|
|
endpoint(
|
|
ChatMessageRequest(
|
|
session_id="dialog-1",
|
|
project_id="rag-1",
|
|
message="Explain get_user",
|
|
),
|
|
None,
|
|
)
|
|
)
|
|
|
|
assert response.task_id == "task-1"
|
|
assert direct_chat.calls == 1
|