from types import SimpleNamespace from app.modules.rag.explain.models import CodeLocation, LayeredRetrievalItem from app.modules.rag.explain.retriever_v2 import CodeExplainRetrieverV2 class _ProductionFirstGateway: def __init__(self) -> None: self.lexical_calls: list[bool] = [] def retrieve_layer( self, rag_session_id: str, query: str, layer: str, *, limit: int, path_prefixes: list[str] | None = None, exclude_tests: bool = True, prefer_non_tests: bool = False, include_spans: bool = False, ): if layer == "C3_ENTRYPOINTS": return SimpleNamespace(items=[], missing=[]) if layer == "C1_SYMBOL_CATALOG": return SimpleNamespace(items=[], missing=[]) raise AssertionError(layer) def retrieve_lexical_code( self, rag_session_id: str, query: str, *, limit: int, path_prefixes: list[str] | None = None, exclude_tests: bool = True, include_spans: bool = False, ): self.lexical_calls.append(exclude_tests) if exclude_tests: return SimpleNamespace( items=[ LayeredRetrievalItem( source="app/users/service.py", content="def get_user():\n return repo.get_user()", layer="C0_SOURCE_CHUNKS", title="get_user", metadata={"symbol_id": "user-service", "is_test": False}, location=CodeLocation(path="app/users/service.py", start_line=10, end_line=11), ), LayeredRetrievalItem( source="app/users/repository.py", content="def get_user_repo():\n return {}", layer="C0_SOURCE_CHUNKS", title="get_user_repo", metadata={"symbol_id": "user-repo", "is_test": False}, location=CodeLocation(path="app/users/repository.py", start_line=20, end_line=21), ), ], missing=[], ) return SimpleNamespace( items=[ LayeredRetrievalItem( source="tests/test_users.py", content="def test_get_user():\n assert service.get_user()", layer="C0_SOURCE_CHUNKS", title="test_get_user", metadata={"symbol_id": "test-user", "is_test": True}, location=CodeLocation(path="tests/test_users.py", start_line=5, end_line=6), ) ], missing=[], ) class _TestsOnlyGateway(_ProductionFirstGateway): def retrieve_lexical_code( self, rag_session_id: str, query: str, *, limit: int, path_prefixes: list[str] | None = None, exclude_tests: bool = True, include_spans: bool = False, ): self.lexical_calls.append(exclude_tests) if exclude_tests: return SimpleNamespace(items=[], missing=[]) return SimpleNamespace( items=[ LayeredRetrievalItem( source="tests/test_users.py", content="def test_get_user():\n assert service.get_user()", layer="C0_SOURCE_CHUNKS", title="test_get_user", metadata={"symbol_id": "test-user", "is_test": True}, location=CodeLocation(path="tests/test_users.py", start_line=5, end_line=6), ) ], missing=[], ) class _FakeGraphRepository: def get_symbols_by_ids(self, rag_session_id: str, symbol_ids: list[str]): return [] def get_chunks_by_symbol_ids(self, rag_session_id: str, symbol_ids: list[str], prefer_chunk_type: str = "symbol_block"): return [] def get_out_edges(self, rag_session_id: str, src_symbol_ids: list[str], edge_types: list[str], limit_per_src: int): return [] def resolve_symbol_by_ref(self, rag_session_id: str, dst_ref: str, package_hint: str | None = None): return None def test_retriever_prefers_prod_chunks_and_skips_test_fallback_when_enough_evidence() -> None: gateway = _ProductionFirstGateway() retriever = CodeExplainRetrieverV2(gateway=gateway, graph_repository=_FakeGraphRepository()) pack = retriever.build_pack("rag-1", "Explain get_user") assert gateway.lexical_calls == [True] assert [excerpt.path for excerpt in pack.code_excerpts] == [ "app/users/service.py", "app/users/repository.py", ] assert all(not excerpt.focus.startswith("test:") for excerpt in pack.code_excerpts) def test_retriever_uses_test_fallback_when_production_evidence_is_missing() -> None: gateway = _TestsOnlyGateway() retriever = CodeExplainRetrieverV2(gateway=gateway, graph_repository=_FakeGraphRepository()) pack = retriever.build_pack("rag-1", "Explain get_user") assert gateway.lexical_calls == [True, False] assert [excerpt.path for excerpt in pack.code_excerpts] == ["tests/test_users.py"] assert pack.code_excerpts[0].focus == "test:lexical"