Files
agent/tests/unit_tests/agent/test_v2_rag_adapter.py
T

82 lines
2.6 KiB
Python

from __future__ import annotations
import asyncio
from app.core.agent.processes.v2.retrieval.v2_rag_adapter import V2RagRetrievalAdapter
from app.core.rag.retrieval.session_retriever import RetrievalPlan
class FakeRetriever:
def __init__(self) -> None:
self.calls: list[tuple[str, object]] = []
async def retrieve(self, _rag_session_id: str, _query_text: str, _plan: RetrievalPlan) -> list[dict]:
self.calls.append(("semantic", None))
return [
{
"path": "docs/api/health-endpoint.md",
"layer": "D1_DOCUMENT_CATALOG",
"metadata": {},
},
{
"path": "docs/api/secondary.md",
"layer": "D0_DOC_CHUNKS",
"metadata": {},
},
]
async def retrieve_exact_files(self, _rag_session_id: str, *, paths: list[str], layers=None, limit: int = 200) -> list[dict]:
del layers, limit
self.calls.append(("exact", list(paths)))
if "docs/api/health-endpoint.md" in paths:
return [
{
"path": "docs/api/health-endpoint.md",
"layer": "D1_DOCUMENT_CATALOG",
"metadata": {},
}
]
return []
async def retrieve_chunks_by_path_substrings(
self,
_rag_session_id: str,
*,
path_needles: list[str],
layers=None,
limit: int = 200,
) -> list[dict]:
del layers, limit
self.calls.append(("substring", list(path_needles)))
return []
def test_v2_rag_adapter_seeds_exact_rows_from_plan_hints() -> None:
adapter = V2RagRetrievalAdapter(FakeRetriever())
plan = RetrievalPlan(
profile="docs_summary_api_endpoint",
layers=["D1_DOCUMENT_CATALOG", "D2_FACT_INDEX", "D0_DOC_CHUNKS"],
limit=8,
filters={"target_doc_hints": ["docs/api/health-endpoint.md"]},
)
rows = asyncio.run(adapter.fetch_rows("rag-1", "explain /health", plan))
assert rows[0]["path"] == "docs/api/health-endpoint.md"
assert len(rows) == 2
def test_v2_rag_adapter_uses_substring_fallback_for_missing_hint() -> None:
retriever = FakeRetriever()
adapter = V2RagRetrievalAdapter(retriever)
plan = RetrievalPlan(
profile="file_lookup",
layers=["D1_DOCUMENT_CATALOG", "D3_ENTITY_CATALOG"],
limit=12,
filters={"target_doc_hints": ["docs/api/missing-health-endpoint.md"]},
)
asyncio.run(adapter.fetch_rows("rag-1", "find file", plan))
assert ("substring", ["missing-health-endpoint.md"]) in retriever.calls