143 lines
5.2 KiB
Python
143 lines
5.2 KiB
Python
from types import SimpleNamespace
|
|
|
|
from app.modules.rag.explain.models import CodeLocation, LayeredRetrievalItem
|
|
from app.modules.rag.explain.retriever_v2 import CodeExplainRetrieverV2
|
|
|
|
|
|
class _ProductionFirstGateway:
|
|
def __init__(self) -> None:
|
|
self.lexical_calls: list[bool] = []
|
|
|
|
def retrieve_layer(
|
|
self,
|
|
rag_session_id: str,
|
|
query: str,
|
|
layer: str,
|
|
*,
|
|
limit: int,
|
|
path_prefixes: list[str] | None = None,
|
|
exclude_tests: bool = True,
|
|
prefer_non_tests: bool = False,
|
|
include_spans: bool = False,
|
|
):
|
|
if layer == "C3_ENTRYPOINTS":
|
|
return SimpleNamespace(items=[], missing=[])
|
|
if layer == "C1_SYMBOL_CATALOG":
|
|
return SimpleNamespace(items=[], missing=[])
|
|
raise AssertionError(layer)
|
|
|
|
def retrieve_lexical_code(
|
|
self,
|
|
rag_session_id: str,
|
|
query: str,
|
|
*,
|
|
limit: int,
|
|
path_prefixes: list[str] | None = None,
|
|
exclude_tests: bool = True,
|
|
include_spans: bool = False,
|
|
):
|
|
self.lexical_calls.append(exclude_tests)
|
|
if exclude_tests:
|
|
return SimpleNamespace(
|
|
items=[
|
|
LayeredRetrievalItem(
|
|
source="app/users/service.py",
|
|
content="def get_user():\n return repo.get_user()",
|
|
layer="C0_SOURCE_CHUNKS",
|
|
title="get_user",
|
|
metadata={"symbol_id": "user-service", "is_test": False},
|
|
location=CodeLocation(path="app/users/service.py", start_line=10, end_line=11),
|
|
),
|
|
LayeredRetrievalItem(
|
|
source="app/users/repository.py",
|
|
content="def get_user_repo():\n return {}",
|
|
layer="C0_SOURCE_CHUNKS",
|
|
title="get_user_repo",
|
|
metadata={"symbol_id": "user-repo", "is_test": False},
|
|
location=CodeLocation(path="app/users/repository.py", start_line=20, end_line=21),
|
|
),
|
|
],
|
|
missing=[],
|
|
)
|
|
return SimpleNamespace(
|
|
items=[
|
|
LayeredRetrievalItem(
|
|
source="tests/test_users.py",
|
|
content="def test_get_user():\n assert service.get_user()",
|
|
layer="C0_SOURCE_CHUNKS",
|
|
title="test_get_user",
|
|
metadata={"symbol_id": "test-user", "is_test": True},
|
|
location=CodeLocation(path="tests/test_users.py", start_line=5, end_line=6),
|
|
)
|
|
],
|
|
missing=[],
|
|
)
|
|
|
|
|
|
class _TestsOnlyGateway(_ProductionFirstGateway):
|
|
def retrieve_lexical_code(
|
|
self,
|
|
rag_session_id: str,
|
|
query: str,
|
|
*,
|
|
limit: int,
|
|
path_prefixes: list[str] | None = None,
|
|
exclude_tests: bool = True,
|
|
include_spans: bool = False,
|
|
):
|
|
self.lexical_calls.append(exclude_tests)
|
|
if exclude_tests:
|
|
return SimpleNamespace(items=[], missing=[])
|
|
return SimpleNamespace(
|
|
items=[
|
|
LayeredRetrievalItem(
|
|
source="tests/test_users.py",
|
|
content="def test_get_user():\n assert service.get_user()",
|
|
layer="C0_SOURCE_CHUNKS",
|
|
title="test_get_user",
|
|
metadata={"symbol_id": "test-user", "is_test": True},
|
|
location=CodeLocation(path="tests/test_users.py", start_line=5, end_line=6),
|
|
)
|
|
],
|
|
missing=[],
|
|
)
|
|
|
|
|
|
class _FakeGraphRepository:
|
|
def get_symbols_by_ids(self, rag_session_id: str, symbol_ids: list[str]):
|
|
return []
|
|
|
|
def get_chunks_by_symbol_ids(self, rag_session_id: str, symbol_ids: list[str], prefer_chunk_type: str = "symbol_block"):
|
|
return []
|
|
|
|
def get_out_edges(self, rag_session_id: str, src_symbol_ids: list[str], edge_types: list[str], limit_per_src: int):
|
|
return []
|
|
|
|
def resolve_symbol_by_ref(self, rag_session_id: str, dst_ref: str, package_hint: str | None = None):
|
|
return None
|
|
|
|
|
|
def test_retriever_prefers_prod_chunks_and_skips_test_fallback_when_enough_evidence() -> None:
|
|
gateway = _ProductionFirstGateway()
|
|
retriever = CodeExplainRetrieverV2(gateway=gateway, graph_repository=_FakeGraphRepository())
|
|
|
|
pack = retriever.build_pack("rag-1", "Explain get_user")
|
|
|
|
assert gateway.lexical_calls == [True]
|
|
assert [excerpt.path for excerpt in pack.code_excerpts] == [
|
|
"app/users/service.py",
|
|
"app/users/repository.py",
|
|
]
|
|
assert all(not excerpt.focus.startswith("test:") for excerpt in pack.code_excerpts)
|
|
|
|
|
|
def test_retriever_uses_test_fallback_when_production_evidence_is_missing() -> None:
|
|
gateway = _TestsOnlyGateway()
|
|
retriever = CodeExplainRetrieverV2(gateway=gateway, graph_repository=_FakeGraphRepository())
|
|
|
|
pack = retriever.build_pack("rag-1", "Explain get_user")
|
|
|
|
assert gateway.lexical_calls == [True, False]
|
|
assert [excerpt.path for excerpt in pack.code_excerpts] == ["tests/test_users.py"]
|
|
assert pack.code_excerpts[0].focus == "test:lexical"
|