27 lines
1.1 KiB
Python
27 lines
1.1 KiB
Python
"""Smoke-тест стандартного retrieval API: один embed и вызов repository."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
from unittest.mock import MagicMock
|
|
|
|
from app.core.rag.embedding.gigachat_embedder import GigaChatEmbedder
|
|
from app.core.rag.retrieval.session_retriever import RagSessionRetriever, RetrievalPlan
|
|
|
|
|
|
def test_rag_session_retriever_calls_repository() -> None:
|
|
embedder = MagicMock(spec=GigaChatEmbedder)
|
|
embedder.embed = MagicMock(return_value=[[0.1, 0.2]])
|
|
repo = MagicMock()
|
|
repo.retrieve = MagicMock(return_value=[{"path": "a.md", "layer": "D0_DOC_CHUNKS"}])
|
|
retriever = RagSessionRetriever(repository=repo, embedder=embedder)
|
|
plan = RetrievalPlan(profile="test", layers=["D0_DOC_CHUNKS", "D1_DOCUMENT_CATALOG"], limit=5)
|
|
rows = asyncio.run(retriever.retrieve("sid-1", "hello", plan))
|
|
assert len(rows) == 1
|
|
assert embedder.embed.called
|
|
assert repo.retrieve.called
|
|
call_kw = repo.retrieve.call_args
|
|
assert call_kw[0][0] == "sid-1"
|
|
assert call_kw[1]["layers"] == plan.layers
|
|
assert call_kw[1]["limit"] == 5
|