106 lines
3.8 KiB
Python
106 lines
3.8 KiB
Python
from app.modules.rag.explain.models import CodeLocation, LayeredRetrievalItem
|
|
from app.modules.rag.explain.retriever_v2 import CodeExplainRetrieverV2
|
|
|
|
|
|
class _FakeGateway:
|
|
def retrieve_layer(
|
|
self,
|
|
rag_session_id: str,
|
|
query: str,
|
|
layer: str,
|
|
*,
|
|
limit: int,
|
|
path_prefixes: list[str] | None = None,
|
|
exclude_tests: bool = True,
|
|
prefer_non_tests: bool = False,
|
|
include_spans: bool = False,
|
|
):
|
|
if layer == "C3_ENTRYPOINTS":
|
|
return __import__("types").SimpleNamespace(
|
|
items=[
|
|
LayeredRetrievalItem(
|
|
source="app/api/users.py",
|
|
content="GET /users/{id}",
|
|
layer=layer,
|
|
title="GET /users/{id}",
|
|
metadata={"entry_type": "http", "handler_symbol_id": "handler-1"},
|
|
location=CodeLocation(path="app/api/users.py", start_line=10, end_line=10),
|
|
)
|
|
],
|
|
missing=[],
|
|
)
|
|
if layer == "C1_SYMBOL_CATALOG":
|
|
return __import__("types").SimpleNamespace(
|
|
items=[
|
|
LayeredRetrievalItem(
|
|
source="app/api/users.py",
|
|
content="def get_user_handler",
|
|
layer=layer,
|
|
title="get_user_handler",
|
|
metadata={"symbol_id": "handler-1"},
|
|
location=CodeLocation(path="app/api/users.py", start_line=10, end_line=18),
|
|
)
|
|
],
|
|
missing=[],
|
|
)
|
|
raise AssertionError(layer)
|
|
|
|
def retrieve_lexical_code(
|
|
self,
|
|
rag_session_id: str,
|
|
query: str,
|
|
*,
|
|
limit: int,
|
|
path_prefixes: list[str] | None = None,
|
|
exclude_tests: bool = True,
|
|
include_spans: bool = False,
|
|
):
|
|
return __import__("types").SimpleNamespace(items=[], missing=[])
|
|
|
|
|
|
class _FakeGraphRepository:
|
|
def get_symbols_by_ids(self, rag_session_id: str, symbol_ids: list[str]):
|
|
return [
|
|
LayeredRetrievalItem(
|
|
source="app/api/users.py",
|
|
content="def get_user_handler",
|
|
layer="C1_SYMBOL_CATALOG",
|
|
title="get_user_handler",
|
|
metadata={"symbol_id": "handler-1"},
|
|
location=CodeLocation(path="app/api/users.py", start_line=10, end_line=18),
|
|
)
|
|
]
|
|
|
|
def get_out_edges(self, rag_session_id: str, src_symbol_ids: list[str], edge_types: list[str], limit_per_src: int):
|
|
return []
|
|
|
|
def resolve_symbol_by_ref(self, rag_session_id: str, dst_ref: str, package_hint: str | None = None):
|
|
return None
|
|
|
|
def get_chunks_by_symbol_ids(self, rag_session_id: str, symbol_ids: list[str], prefer_chunk_type: str = "symbol_block"):
|
|
return [
|
|
LayeredRetrievalItem(
|
|
source="app/api/users.py",
|
|
content="async def get_user_handler(user_id: str):\n return await service.get_user(user_id)",
|
|
layer="C0_SOURCE_CHUNKS",
|
|
title="get_user_handler",
|
|
metadata={"symbol_id": "handler-1"},
|
|
location=CodeLocation(path="app/api/users.py", start_line=10, end_line=18),
|
|
)
|
|
]
|
|
|
|
|
|
def test_retriever_v2_builds_pack_with_trace_and_excerpts() -> None:
|
|
retriever = CodeExplainRetrieverV2(
|
|
gateway=_FakeGateway(),
|
|
graph_repository=_FakeGraphRepository(),
|
|
)
|
|
|
|
pack = retriever.build_pack("rag-1", "Explain endpoint get_user")
|
|
|
|
assert len(pack.selected_entrypoints) == 1
|
|
assert len(pack.seed_symbols) == 1
|
|
assert len(pack.trace_paths) == 1
|
|
assert len(pack.code_excerpts) == 1
|
|
assert pack.code_excerpts[0].path == "app/api/users.py"
|