Files
agent/tests/rag/test_layered_gateway.py

79 lines
2.6 KiB
Python

from app.modules.rag.explain.layered_gateway import LayeredRetrievalGateway
class _Embedder:
def embed(self, texts: list[str]) -> list[list[float]]:
return [[0.1, 0.2]]
class _RetryingRepository:
def __init__(self) -> None:
self.calls: list[dict] = []
def retrieve(self, *args, **kwargs):
self.calls.append(kwargs)
if kwargs.get("exclude_path_prefixes"):
raise RuntimeError("syntax error at or near ')'")
return [
{
"path": "app/users/service.py",
"content": "def get_user(): pass",
"layer": "C1_SYMBOL_CATALOG",
"title": "get_user",
"metadata": {"symbol_id": "symbol-1"},
"distance": 0.1,
"span_start": 10,
"span_end": 11,
}
]
def retrieve_lexical_code(self, *args, **kwargs):
self.calls.append(kwargs)
if kwargs.get("exclude_path_prefixes"):
raise RuntimeError("broken lexical filter")
return [
{
"path": "app/users/service.py",
"content": "def get_user(): pass",
"layer": "C0_SOURCE_CHUNKS",
"title": "get_user",
"metadata": {"symbol_id": "symbol-1"},
"span_start": 10,
"span_end": 11,
}
]
class _RecordingRepository:
def __init__(self) -> None:
self.calls: list[dict] = []
def retrieve(self, *args, **kwargs):
self.calls.append(kwargs)
return []
def retrieve_lexical_code(self, *args, **kwargs):
self.calls.append(kwargs)
return []
def test_gateway_retries_without_test_filter_on_vector_failure() -> None:
gateway = LayeredRetrievalGateway(_RetryingRepository(), _Embedder())
result = gateway.retrieve_layer("rag-1", "Explain get_user", "C1_SYMBOL_CATALOG", limit=3, exclude_tests=True)
assert len(result.items) == 1
assert "layer:C1_SYMBOL_CATALOG retrieval_failed:retried_without_test_filter" in result.missing
def test_gateway_honors_debug_disable_test_filter(monkeypatch) -> None:
monkeypatch.setenv("RAG_DEBUG_DISABLE_TEST_FILTER", "true")
repository = _RecordingRepository()
gateway = LayeredRetrievalGateway(repository, _Embedder())
gateway.retrieve_layer("rag-1", "Explain get_user", "C1_SYMBOL_CATALOG", limit=3, exclude_tests=True)
assert repository.calls
assert repository.calls[0]["exclude_path_prefixes"] is None
assert repository.calls[0]["exclude_like_patterns"] is None