Files
agent/tests/unit_tests/rag/test_retriever_v2_pack.py
2026-03-12 16:55:23 +03:00

106 lines
3.8 KiB
Python

from app.modules.rag.explain.models import CodeLocation, LayeredRetrievalItem
from app.modules.rag.explain.retriever_v2 import CodeExplainRetrieverV2
class _FakeGateway:
def retrieve_layer(
self,
rag_session_id: str,
query: str,
layer: str,
*,
limit: int,
path_prefixes: list[str] | None = None,
exclude_tests: bool = True,
prefer_non_tests: bool = False,
include_spans: bool = False,
):
if layer == "C3_ENTRYPOINTS":
return __import__("types").SimpleNamespace(
items=[
LayeredRetrievalItem(
source="app/api/users.py",
content="GET /users/{id}",
layer=layer,
title="GET /users/{id}",
metadata={"entry_type": "http", "handler_symbol_id": "handler-1"},
location=CodeLocation(path="app/api/users.py", start_line=10, end_line=10),
)
],
missing=[],
)
if layer == "C1_SYMBOL_CATALOG":
return __import__("types").SimpleNamespace(
items=[
LayeredRetrievalItem(
source="app/api/users.py",
content="def get_user_handler",
layer=layer,
title="get_user_handler",
metadata={"symbol_id": "handler-1"},
location=CodeLocation(path="app/api/users.py", start_line=10, end_line=18),
)
],
missing=[],
)
raise AssertionError(layer)
def retrieve_lexical_code(
self,
rag_session_id: str,
query: str,
*,
limit: int,
path_prefixes: list[str] | None = None,
exclude_tests: bool = True,
include_spans: bool = False,
):
return __import__("types").SimpleNamespace(items=[], missing=[])
class _FakeGraphRepository:
def get_symbols_by_ids(self, rag_session_id: str, symbol_ids: list[str]):
return [
LayeredRetrievalItem(
source="app/api/users.py",
content="def get_user_handler",
layer="C1_SYMBOL_CATALOG",
title="get_user_handler",
metadata={"symbol_id": "handler-1"},
location=CodeLocation(path="app/api/users.py", start_line=10, end_line=18),
)
]
def get_out_edges(self, rag_session_id: str, src_symbol_ids: list[str], edge_types: list[str], limit_per_src: int):
return []
def resolve_symbol_by_ref(self, rag_session_id: str, dst_ref: str, package_hint: str | None = None):
return None
def get_chunks_by_symbol_ids(self, rag_session_id: str, symbol_ids: list[str], prefer_chunk_type: str = "symbol_block"):
return [
LayeredRetrievalItem(
source="app/api/users.py",
content="async def get_user_handler(user_id: str):\n return await service.get_user(user_id)",
layer="C0_SOURCE_CHUNKS",
title="get_user_handler",
metadata={"symbol_id": "handler-1"},
location=CodeLocation(path="app/api/users.py", start_line=10, end_line=18),
)
]
def test_retriever_v2_builds_pack_with_trace_and_excerpts() -> None:
retriever = CodeExplainRetrieverV2(
gateway=_FakeGateway(),
graph_repository=_FakeGraphRepository(),
)
pack = retriever.build_pack("rag-1", "Explain endpoint get_user")
assert len(pack.selected_entrypoints) == 1
assert len(pack.seed_symbols) == 1
assert len(pack.trace_paths) == 1
assert len(pack.code_excerpts) == 1
assert pack.code_excerpts[0].path == "app/api/users.py"