Files
agent/tests/rag/test_trace_builder.py

84 lines
3.3 KiB
Python

from app.modules.rag.explain.models import CodeLocation, LayeredRetrievalItem
from app.modules.rag.explain.trace_builder import TraceBuilder
class _FakeGraphRepository:
def get_out_edges(self, rag_session_id: str, src_symbol_ids: list[str], edge_types: list[str], limit_per_src: int):
assert rag_session_id == "rag-1"
assert edge_types == ["calls", "imports", "inherits"]
if src_symbol_ids == ["handler-1"]:
return [
LayeredRetrievalItem(
source="app/api/users.py",
content="handler calls get_user",
layer="C2_DEPENDENCY_GRAPH",
title="handler:calls",
metadata={
"src_symbol_id": "handler-1",
"dst_symbol_id": None,
"dst_ref": "UserService.get_user",
"resolution": "partial",
"edge_type": "calls",
},
location=CodeLocation(path="app/api/users.py", start_line=12, end_line=12),
)
]
return []
def resolve_symbol_by_ref(self, rag_session_id: str, dst_ref: str, package_hint: str | None = None):
assert rag_session_id == "rag-1"
assert dst_ref == "UserService.get_user"
assert package_hint == "app.api"
return LayeredRetrievalItem(
source="app/services/users.py",
content="method UserService.get_user",
layer="C1_SYMBOL_CATALOG",
title="UserService.get_user",
metadata={
"symbol_id": "service-1",
"package_or_module": "app.api.users",
},
location=CodeLocation(path="app/services/users.py", start_line=4, end_line=10),
)
def get_symbols_by_ids(self, rag_session_id: str, symbol_ids: list[str]):
assert rag_session_id == "rag-1"
if symbol_ids == ["service-1"]:
return [
LayeredRetrievalItem(
source="app/services/users.py",
content="method UserService.get_user",
layer="C1_SYMBOL_CATALOG",
title="UserService.get_user",
metadata={
"symbol_id": "service-1",
"package_or_module": "app.api.users",
},
location=CodeLocation(path="app/services/users.py", start_line=4, end_line=10),
)
]
return []
def test_trace_builder_resolves_partial_edges_across_files() -> None:
builder = TraceBuilder(_FakeGraphRepository())
seeds = [
LayeredRetrievalItem(
source="app/api/users.py",
content="function handler",
layer="C1_SYMBOL_CATALOG",
title="get_user",
metadata={
"symbol_id": "handler-1",
"package_or_module": "app.api.users",
},
location=CodeLocation(path="app/api/users.py", start_line=10, end_line=18),
)
]
paths = builder.build_paths("rag-1", seeds, max_depth=3)
assert len(paths) >= 1
assert paths[0].symbol_ids == ["handler-1", "service-1"]
assert "resolved:UserService.get_user" in paths[0].notes