Фикс состояния
This commit is contained in:
77
tests/unit_tests/rag/asserts_intent_router.py
Normal file
77
tests/unit_tests/rag/asserts_intent_router.py
Normal file
@@ -0,0 +1,77 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
|
||||
from app.modules.rag.intent_router_v2.models import IntentRouterResult
|
||||
|
||||
|
||||
def assert_intent(out: IntentRouterResult, expected: str) -> None:
|
||||
assert out.intent == expected
|
||||
|
||||
|
||||
def assert_domains(out: IntentRouterResult, expected: list[str]) -> None:
|
||||
assert out.retrieval_spec.domains == expected
|
||||
|
||||
|
||||
def assert_has_file_path(out: IntentRouterResult, path: str) -> None:
|
||||
assert any(anchor.type == "FILE_PATH" and anchor.value == path for anchor in out.query_plan.anchors)
|
||||
|
||||
|
||||
def assert_path_scope(out: IntentRouterResult, file_path: str, dir_path: str | None = None) -> None:
|
||||
scope = list(getattr(out.retrieval_spec.filters, "path_scope", []) or [])
|
||||
assert file_path in scope
|
||||
if dir_path is not None:
|
||||
assert dir_path in scope
|
||||
|
||||
|
||||
def assert_file_only_scope(out: IntentRouterResult, file_path: str) -> None:
|
||||
scope = list(getattr(out.retrieval_spec.filters, "path_scope", []) or [])
|
||||
assert scope == [file_path]
|
||||
|
||||
|
||||
def assert_spans_valid(out: IntentRouterResult) -> None:
|
||||
raw_len = len(out.query_plan.raw)
|
||||
for anchor in out.query_plan.anchors:
|
||||
if anchor.source == "conversation_state":
|
||||
assert anchor.span is None
|
||||
continue
|
||||
assert anchor.span is not None
|
||||
assert 0 <= anchor.span.start < anchor.span.end <= raw_len
|
||||
|
||||
|
||||
def assert_test_policy(out: IntentRouterResult, expected: str) -> None:
|
||||
assert getattr(out.retrieval_spec.filters, "test_policy", None) == expected
|
||||
|
||||
|
||||
def assert_sub_intent(out: IntentRouterResult, expected: str) -> None:
|
||||
assert out.query_plan.sub_intent == expected
|
||||
|
||||
|
||||
def assert_no_symbol_keyword(out: IntentRouterResult, forbidden: set[str] | None = None) -> None:
|
||||
denied = forbidden or {"def", "class", "return", "import", "from"}
|
||||
symbols = {anchor.value.lower() for anchor in out.query_plan.anchors if anchor.type == "SYMBOL"}
|
||||
assert symbols.isdisjoint({token.lower() for token in denied})
|
||||
|
||||
|
||||
def assert_domain_layer_prefixes(out: IntentRouterResult) -> None:
|
||||
prefixes = {layer.layer_id[0] for layer in out.retrieval_spec.layer_queries if layer.layer_id}
|
||||
if out.retrieval_spec.domains == ["CODE"]:
|
||||
assert prefixes <= {"C"}
|
||||
elif out.retrieval_spec.domains == ["DOCS"]:
|
||||
assert prefixes <= {"D"}
|
||||
else:
|
||||
assert prefixes <= {"C", "D"}
|
||||
|
||||
|
||||
def assert_no_symbol_leakage_from_paths(out: IntentRouterResult) -> None:
|
||||
file_values = [anchor.value for anchor in out.query_plan.anchors if anchor.type == "FILE_PATH"]
|
||||
if not file_values:
|
||||
return
|
||||
parts: set[str] = set()
|
||||
for value in file_values:
|
||||
for token in re.split(r"[/.]+", value.lower()):
|
||||
if token:
|
||||
parts.add(token)
|
||||
for anchor in out.query_plan.anchors:
|
||||
if anchor.type == "SYMBOL":
|
||||
assert anchor.value.lower() not in parts
|
||||
48
tests/unit_tests/rag/intent_router_testkit.py
Normal file
48
tests/unit_tests/rag/intent_router_testkit.py
Normal file
@@ -0,0 +1,48 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
from app.modules.rag.contracts.enums import RagLayer
|
||||
from app.modules.rag.intent_router_v2 import ConversationState, IntentRouterV2, RepoContext
|
||||
|
||||
|
||||
def repo_context() -> RepoContext:
|
||||
return RepoContext(
|
||||
languages=["python"],
|
||||
available_domains=["CODE", "DOCS"],
|
||||
available_layers=[
|
||||
RagLayer.CODE_ENTRYPOINTS,
|
||||
RagLayer.CODE_SYMBOL_CATALOG,
|
||||
RagLayer.CODE_DEPENDENCY_GRAPH,
|
||||
RagLayer.CODE_SEMANTIC_ROLES,
|
||||
RagLayer.CODE_SOURCE_CHUNKS,
|
||||
RagLayer.DOCS_MODULE_CATALOG,
|
||||
RagLayer.DOCS_FACT_INDEX,
|
||||
RagLayer.DOCS_SECTION_INDEX,
|
||||
RagLayer.DOCS_POLICY_INDEX,
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def run_sequence(queries: list[str], *, router: IntentRouterV2 | None = None, trace_label: str = "intent-router") -> list:
|
||||
active_router = router or IntentRouterV2()
|
||||
state = ConversationState()
|
||||
results = []
|
||||
for index, query in enumerate(queries, start=1):
|
||||
result = active_router.route(query, state, repo_context())
|
||||
print_trace(index, query, result, label=trace_label)
|
||||
results.append(result)
|
||||
state = state.advance(result)
|
||||
return results
|
||||
|
||||
|
||||
def run_single(query: str, *, router: IntentRouterV2 | None = None, trace_label: str = "intent-router"):
|
||||
result = run_sequence([query], router=router, trace_label=trace_label)[0]
|
||||
return result
|
||||
|
||||
|
||||
def print_trace(index: int, query: str, result, *, label: str = "intent-router") -> None:
|
||||
print(f"[{label}][turn {index}] input: {query}")
|
||||
print()
|
||||
print(f"[{label}][turn {index}] output: {json.dumps(result.model_dump(), ensure_ascii=False)}")
|
||||
print("=" * 50)
|
||||
217
tests/unit_tests/rag/test_code_indexing_pipeline.py
Normal file
217
tests/unit_tests/rag/test_code_indexing_pipeline.py
Normal file
@@ -0,0 +1,217 @@
|
||||
from app.modules.rag.contracts.enums import RagLayer
|
||||
from app.modules.rag.indexing.code.pipeline import CodeIndexingPipeline
|
||||
|
||||
|
||||
def test_code_pipeline_builds_source_symbols_edges_and_entrypoints() -> None:
|
||||
pipeline = CodeIndexingPipeline()
|
||||
content = """
|
||||
from fastapi import APIRouter
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
class UserService:
|
||||
def get_user(self, user_id):
|
||||
return user_id
|
||||
|
||||
@router.get("/users/{user_id}")
|
||||
async def get_user(user_id: str):
|
||||
service = UserService()
|
||||
return service.get_user(user_id)
|
||||
"""
|
||||
docs = pipeline.index_file(
|
||||
repo_id="acme/proj",
|
||||
commit_sha="abc123",
|
||||
path="app/api/users.py",
|
||||
content=content,
|
||||
)
|
||||
|
||||
layers = {doc.layer for doc in docs}
|
||||
assert RagLayer.CODE_SOURCE_CHUNKS in layers
|
||||
assert RagLayer.CODE_SYMBOL_CATALOG in layers
|
||||
assert RagLayer.CODE_DEPENDENCY_GRAPH in layers
|
||||
assert RagLayer.CODE_ENTRYPOINTS in layers
|
||||
assert RagLayer.CODE_SEMANTIC_ROLES in layers
|
||||
|
||||
symbol_doc = next(doc for doc in docs if doc.layer == RagLayer.CODE_SYMBOL_CATALOG and doc.metadata["kind"] == "function")
|
||||
assert "get_user" in symbol_doc.metadata["qname"]
|
||||
|
||||
edge_doc = next(doc for doc in docs if doc.layer == RagLayer.CODE_DEPENDENCY_GRAPH)
|
||||
assert edge_doc.metadata["edge_type"] in {
|
||||
"calls",
|
||||
"imports",
|
||||
"inherits",
|
||||
"instantiates",
|
||||
"reads_attr",
|
||||
"writes_attr",
|
||||
"dataflow_slice",
|
||||
}
|
||||
|
||||
entry_doc = next(doc for doc in docs if doc.layer == RagLayer.CODE_ENTRYPOINTS)
|
||||
assert entry_doc.metadata["framework"] == "fastapi"
|
||||
assert entry_doc.metadata["http_method"] == "GET"
|
||||
assert entry_doc.metadata["route_path"] == "/users/{user_id}"
|
||||
assert entry_doc.metadata["entrypoint_kind"] == "http_route"
|
||||
assert entry_doc.metadata["handler_symbol"] == "get_user"
|
||||
assert entry_doc.metadata["summary_text"] == "GET /users/{user_id} declared in get_user"
|
||||
assert "GET /users/{user_id}" in entry_doc.text
|
||||
|
||||
|
||||
def test_code_pipeline_indexes_import_alias_as_symbol() -> None:
|
||||
pipeline = CodeIndexingPipeline()
|
||||
content = "from .v2 import ConfigManagerV2 as ConfigManager\n"
|
||||
|
||||
docs = pipeline.index_file(
|
||||
repo_id="acme/proj",
|
||||
commit_sha="abc123",
|
||||
path="src/config_manager/__init__.py",
|
||||
content=content,
|
||||
)
|
||||
|
||||
alias_doc = next(doc for doc in docs if doc.layer == RagLayer.CODE_SYMBOL_CATALOG and doc.metadata["qname"] == "ConfigManager")
|
||||
assert alias_doc.metadata["kind"] == "const"
|
||||
|
||||
|
||||
def test_code_pipeline_marks_test_documents() -> None:
|
||||
pipeline = CodeIndexingPipeline()
|
||||
content = """
|
||||
def test_user_service():
|
||||
assert True
|
||||
"""
|
||||
|
||||
docs = pipeline.index_file(
|
||||
repo_id="acme/proj",
|
||||
commit_sha="abc123",
|
||||
path="tests/test_users.py",
|
||||
content=content,
|
||||
)
|
||||
|
||||
assert docs
|
||||
assert all(doc.metadata["is_test"] is True for doc in docs)
|
||||
|
||||
|
||||
def test_code_pipeline_extracts_data_flow_edges() -> None:
|
||||
pipeline = CodeIndexingPipeline()
|
||||
content = """
|
||||
class Context:
|
||||
def __init__(self):
|
||||
self.data = {}
|
||||
|
||||
def set(self, new_context):
|
||||
self.data = new_context
|
||||
|
||||
def process():
|
||||
ctx = Context()
|
||||
value = ctx.data
|
||||
return value
|
||||
"""
|
||||
docs = pipeline.index_file(
|
||||
repo_id="acme/proj",
|
||||
commit_sha="abc123",
|
||||
path="src/context.py",
|
||||
content=content,
|
||||
)
|
||||
edges = [doc.metadata for doc in docs if doc.layer == RagLayer.CODE_DEPENDENCY_GRAPH]
|
||||
edge_pairs = {(str(item.get("edge_type") or ""), str(item.get("dst_ref") or "")) for item in edges}
|
||||
|
||||
assert ("instantiates", "Context") in edge_pairs
|
||||
assert ("writes_attr", "Context.data") in edge_pairs
|
||||
assert ("reads_attr", "ctx.data") in edge_pairs
|
||||
|
||||
|
||||
def test_code_pipeline_builds_dataflow_slice_documents() -> None:
|
||||
pipeline = CodeIndexingPipeline()
|
||||
content = """
|
||||
class Context:
|
||||
def set(self, value):
|
||||
self.data = value
|
||||
|
||||
def read_data(ctx):
|
||||
return ctx.data
|
||||
|
||||
def run():
|
||||
ctx = Context()
|
||||
Context().set({"order_id": 1})
|
||||
return read_data(ctx)
|
||||
"""
|
||||
docs = pipeline.index_file(
|
||||
repo_id="acme/proj",
|
||||
commit_sha="abc123",
|
||||
path="src/context_flow.py",
|
||||
content=content,
|
||||
)
|
||||
|
||||
slices = [
|
||||
doc
|
||||
for doc in docs
|
||||
if doc.layer == RagLayer.CODE_DEPENDENCY_GRAPH and doc.metadata.get("edge_type") == "dataflow_slice"
|
||||
]
|
||||
assert slices
|
||||
assert any("Context.data" in item.metadata.get("path_symbols", []) for item in slices)
|
||||
assert all(item.metadata.get("path_length", 0) <= 6 for item in slices)
|
||||
|
||||
|
||||
def test_code_pipeline_builds_execution_trace_documents() -> None:
|
||||
pipeline = CodeIndexingPipeline()
|
||||
content = """
|
||||
from fastapi import APIRouter
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
def parse():
|
||||
return "parsed"
|
||||
|
||||
def send_email():
|
||||
return parse()
|
||||
|
||||
@router.post("/run")
|
||||
def run_pipeline():
|
||||
return send_email()
|
||||
"""
|
||||
docs = pipeline.index_file(
|
||||
repo_id="acme/proj",
|
||||
commit_sha="abc123",
|
||||
path="src/pipeline.py",
|
||||
content=content,
|
||||
)
|
||||
|
||||
traces = [doc for doc in docs if doc.layer == RagLayer.CODE_ENTRYPOINTS and doc.metadata.get("edge_type") == "execution_trace"]
|
||||
assert traces
|
||||
assert any(item.metadata.get("path_length", 0) >= 2 for item in traces)
|
||||
assert any("run_pipeline" in item.metadata.get("path_symbols", []) for item in traces)
|
||||
|
||||
|
||||
def test_code_pipeline_builds_semantic_role_documents() -> None:
|
||||
pipeline = CodeIndexingPipeline()
|
||||
content = """
|
||||
class EmailAdapter:
|
||||
def send(self, payload):
|
||||
import requests
|
||||
return requests.post("http://localhost", json=payload)
|
||||
|
||||
class ExcelParser:
|
||||
def parse(self, rows):
|
||||
import csv
|
||||
return list(csv.reader(rows))
|
||||
|
||||
class OrderHandler:
|
||||
def handle(self, ctx, adapter):
|
||||
ctx.data = {"status": "ready"}
|
||||
value = ctx.data
|
||||
return adapter.send(value)
|
||||
"""
|
||||
docs = pipeline.index_file(
|
||||
repo_id="acme/proj",
|
||||
commit_sha="abc123",
|
||||
path="src/semantic_roles.py",
|
||||
content=content,
|
||||
)
|
||||
|
||||
roles = {
|
||||
doc.metadata.get("symbol_name"): doc.metadata.get("role")
|
||||
for doc in docs
|
||||
if doc.layer == RagLayer.CODE_SEMANTIC_ROLES
|
||||
}
|
||||
|
||||
assert roles.get("EmailAdapter") == "adapter"
|
||||
assert roles.get("ExcelParser") == "parser"
|
||||
assert roles.get("OrderHandler") == "handler"
|
||||
63
tests/unit_tests/rag/test_docs_indexing_pipeline.py
Normal file
63
tests/unit_tests/rag/test_docs_indexing_pipeline.py
Normal file
@@ -0,0 +1,63 @@
|
||||
from app.modules.rag.contracts.enums import RagLayer
|
||||
from app.modules.rag.indexing.docs.pipeline import DocsIndexingPipeline
|
||||
|
||||
|
||||
def test_docs_pipeline_builds_catalog_facts_sections_and_policy() -> None:
|
||||
pipeline = DocsIndexingPipeline()
|
||||
content = """---
|
||||
id: api.billing.create_invoice
|
||||
type: policy
|
||||
domain: billing
|
||||
links:
|
||||
calls_api:
|
||||
- api.billing.validate_invoice
|
||||
tags: [billing]
|
||||
status: active
|
||||
---
|
||||
# Create Invoice
|
||||
|
||||
## Spec Summary
|
||||
|
||||
Creates an invoice in billing.
|
||||
|
||||
## Request Contract
|
||||
|
||||
| field | type | required | validation |
|
||||
| --- | --- | --- | --- |
|
||||
| amount | decimal | yes | > 0 |
|
||||
|
||||
## Error Matrix
|
||||
|
||||
| status | error | client action |
|
||||
| --- | --- | --- |
|
||||
| 400 | invalid_amount | fix request |
|
||||
|
||||
## Rules
|
||||
|
||||
- metric: billing.invoice.created
|
||||
- rule: amount must be positive
|
||||
"""
|
||||
docs = pipeline.index_file(
|
||||
repo_id="acme/proj",
|
||||
commit_sha="abc123",
|
||||
path="docs/billing/create_invoice.md",
|
||||
content=content,
|
||||
)
|
||||
|
||||
layers = {doc.layer for doc in docs}
|
||||
assert RagLayer.DOCS_MODULE_CATALOG in layers
|
||||
assert RagLayer.DOCS_FACT_INDEX in layers
|
||||
assert RagLayer.DOCS_SECTION_INDEX in layers
|
||||
assert RagLayer.DOCS_POLICY_INDEX in layers
|
||||
|
||||
module_doc = next(doc for doc in docs if doc.layer == RagLayer.DOCS_MODULE_CATALOG)
|
||||
assert module_doc.metadata["module_id"] == "api.billing.create_invoice"
|
||||
assert module_doc.metadata["type"] == "policy"
|
||||
|
||||
fact_texts = [doc.text for doc in docs if doc.layer == RagLayer.DOCS_FACT_INDEX]
|
||||
assert any("calls_api" in text for text in fact_texts)
|
||||
assert any("has_field" in text for text in fact_texts)
|
||||
assert any("returns_error" in text for text in fact_texts)
|
||||
|
||||
section_doc = next(doc for doc in docs if doc.layer == RagLayer.DOCS_SECTION_INDEX)
|
||||
assert section_doc.metadata["section_path"]
|
||||
22
tests/unit_tests/rag/test_explain_intent_builder.py
Normal file
22
tests/unit_tests/rag/test_explain_intent_builder.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from app.modules.rag.explain.intent_builder import ExplainIntentBuilder
|
||||
|
||||
|
||||
def test_explain_intent_builder_extracts_route_symbol_and_file_hints() -> None:
|
||||
builder = ExplainIntentBuilder()
|
||||
|
||||
intent = builder.build("Explain how /users/{user_id} reaches UserService.get_user in app/api/users.py")
|
||||
|
||||
assert "/users/{user_id}" in intent.hints.endpoints
|
||||
assert "UserService.get_user" in intent.hints.symbols
|
||||
assert "app/api/users.py" in intent.hints.paths
|
||||
assert intent.expected_entry_types == ["http"]
|
||||
assert intent.include_tests is False
|
||||
assert intent.depth == "medium"
|
||||
|
||||
|
||||
def test_explain_intent_builder_enables_tests_when_user_asks_for_them() -> None:
|
||||
builder = ExplainIntentBuilder()
|
||||
|
||||
intent = builder.build("Покажи как это тестируется в pytest и какие tests покрывают UserService")
|
||||
|
||||
assert intent.include_tests is True
|
||||
126
tests/unit_tests/rag/test_intent_router_e2e_flows.py
Normal file
126
tests/unit_tests/rag/test_intent_router_e2e_flows.py
Normal file
@@ -0,0 +1,126 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from app.modules.rag.intent_router_v2 import GigaChatIntentRouterFactory
|
||||
from app.modules.shared.env_loader import load_workspace_env
|
||||
from tests.unit_tests.rag.asserts_intent_router import (
|
||||
assert_domains,
|
||||
assert_file_only_scope,
|
||||
assert_intent,
|
||||
assert_test_policy,
|
||||
)
|
||||
from tests.unit_tests.rag.intent_router_testkit import run_sequence
|
||||
|
||||
pytestmark = pytest.mark.intent_router
|
||||
|
||||
|
||||
def _live_gigachat_enabled() -> bool:
|
||||
load_workspace_env()
|
||||
return os.getenv("RUN_INTENT_ROUTER_V2_LIVE", "").strip() == "1" and bool(os.getenv("GIGACHAT_TOKEN", "").strip())
|
||||
|
||||
|
||||
def test_e2e_path_carryover_flow() -> None:
|
||||
first, second, third = run_sequence(
|
||||
[
|
||||
"Посмотри файл app/core/config.py",
|
||||
"Теперь объясни функцию load_config",
|
||||
"Почему так?",
|
||||
]
|
||||
)
|
||||
|
||||
assert_file_only_scope(first, "app/core/config.py")
|
||||
assert "app/core/config.py" in second.retrieval_spec.filters.path_scope
|
||||
assert "app/core/config.py" in third.retrieval_spec.filters.path_scope
|
||||
second_file_anchors = [anchor.value for anchor in second.query_plan.anchors if anchor.type == "FILE_PATH" and anchor.source == "conversation_state"]
|
||||
assert second_file_anchors == ["app/core/config.py"]
|
||||
assert "app/core/config.py" in second.query_plan.keyword_hints
|
||||
assert "app/core" not in second.query_plan.keyword_hints
|
||||
assert any(anchor.type == "FILE_PATH" and anchor.source == "conversation_state" and anchor.span is None for anchor in third.query_plan.anchors)
|
||||
carried_symbols = [anchor.value for anchor in third.query_plan.anchors if anchor.type == "SYMBOL" and anchor.source == "conversation_state"]
|
||||
assert carried_symbols in ([], ["load_config"])
|
||||
assert third.query_plan.sub_intent == "EXPLAIN_LOCAL"
|
||||
layer_ids = [item.layer_id for item in third.retrieval_spec.layer_queries]
|
||||
assert "C3_ENTRYPOINTS" not in layer_ids
|
||||
|
||||
|
||||
def test_e2e_docs_switch_from_code_topic() -> None:
|
||||
first, second = run_sequence(
|
||||
[
|
||||
"Объясни как работает ConfigManager",
|
||||
"А что про это сказано в документации?",
|
||||
]
|
||||
)
|
||||
|
||||
assert_intent(first, "CODE_QA")
|
||||
assert_intent(second, "DOCS_QA")
|
||||
assert second.conversation_mode == "SWITCH"
|
||||
assert_domains(second, ["DOCS"])
|
||||
carried = [
|
||||
anchor
|
||||
for anchor in second.query_plan.anchors
|
||||
if anchor.type == "SYMBOL" and anchor.value == "ConfigManager" and anchor.source == "conversation_state"
|
||||
]
|
||||
assert carried
|
||||
assert carried[0].span is None
|
||||
assert "ConfigManager" in second.query_plan.expansions
|
||||
assert "ConfigManager" in second.query_plan.keyword_hints
|
||||
|
||||
|
||||
def test_e2e_tests_toggle_flow() -> None:
|
||||
first, second = run_sequence(
|
||||
[
|
||||
"Покажи тесты для ConfigManager",
|
||||
"А теперь не про тесты, а про прод код",
|
||||
]
|
||||
)
|
||||
|
||||
assert_intent(first, "CODE_QA")
|
||||
assert_intent(second, "CODE_QA")
|
||||
assert_test_policy(first, "INCLUDE")
|
||||
assert_test_policy(second, "EXCLUDE")
|
||||
assert first.query_plan.sub_intent == "FIND_TESTS"
|
||||
assert second.query_plan.sub_intent == "EXPLAIN"
|
||||
assert "tests" in second.query_plan.negations
|
||||
assert not second.query_plan.expansions
|
||||
assert second.evidence_policy.require_flow is False
|
||||
|
||||
|
||||
def test_e2e_open_file_then_generic_next_steps_is_lightweight() -> None:
|
||||
first, second = run_sequence(
|
||||
[
|
||||
"Открой файл app/core/config.py",
|
||||
"Что дальше?",
|
||||
]
|
||||
)
|
||||
|
||||
assert_file_only_scope(first, "app/core/config.py")
|
||||
assert_file_only_scope(second, "app/core/config.py")
|
||||
assert second.query_plan.sub_intent in {"EXPLAIN_LOCAL", "NEXT_STEPS"}
|
||||
layer_ids = [item.layer_id for item in second.retrieval_spec.layer_queries]
|
||||
assert "C3_ENTRYPOINTS" not in layer_ids
|
||||
assert second.evidence_policy.require_flow is False
|
||||
assert "app/core/config.py" in second.query_plan.keyword_hints
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not _live_gigachat_enabled(),
|
||||
reason="requires RUN_INTENT_ROUTER_V2_LIVE=1 and GIGACHAT_TOKEN in environment or .env",
|
||||
)
|
||||
def test_intent_router_live_smoke_path_carryover() -> None:
|
||||
router = GigaChatIntentRouterFactory().build()
|
||||
first, second = run_sequence(
|
||||
[
|
||||
"Открой файл app/core/config.py",
|
||||
"Что дальше?",
|
||||
],
|
||||
router=router,
|
||||
trace_label="intent-router-live",
|
||||
)
|
||||
|
||||
assert_file_only_scope(first, "app/core/config.py")
|
||||
assert "app/core/config.py" in second.retrieval_spec.filters.path_scope
|
||||
assert second.query_plan.sub_intent in {"EXPLAIN_LOCAL", "NEXT_STEPS"}
|
||||
layer_ids = [item.layer_id for item in second.retrieval_spec.layer_queries]
|
||||
assert "C3_ENTRYPOINTS" not in layer_ids
|
||||
assert second.evidence_policy.require_flow is False
|
||||
120
tests/unit_tests/rag/test_intent_router_invariants.py
Normal file
120
tests/unit_tests/rag/test_intent_router_invariants.py
Normal file
@@ -0,0 +1,120 @@
|
||||
import pytest
|
||||
|
||||
from tests.unit_tests.rag.asserts_intent_router import (
|
||||
assert_domain_layer_prefixes,
|
||||
assert_domains,
|
||||
assert_file_only_scope,
|
||||
assert_has_file_path,
|
||||
assert_intent,
|
||||
assert_no_symbol_keyword,
|
||||
assert_no_symbol_leakage_from_paths,
|
||||
assert_spans_valid,
|
||||
assert_sub_intent,
|
||||
assert_test_policy,
|
||||
)
|
||||
from tests.unit_tests.rag.intent_router_testkit import run_sequence
|
||||
|
||||
pytestmark = pytest.mark.intent_router
|
||||
|
||||
|
||||
def test_invariant_code_file_path_with_canonical_key_term() -> None:
|
||||
result = run_sequence(["Уточни по файлу app/core/config.py"])[0]
|
||||
|
||||
assert_intent(result, "CODE_QA")
|
||||
assert_has_file_path(result, "app/core/config.py")
|
||||
assert_file_only_scope(result, "app/core/config.py")
|
||||
key_terms = [anchor.value for anchor in result.query_plan.anchors if anchor.type == "KEY_TERM"]
|
||||
assert "файл" in key_terms
|
||||
assert "файлу" not in key_terms
|
||||
assert_spans_valid(result)
|
||||
assert_domain_layer_prefixes(result)
|
||||
|
||||
|
||||
def test_invariant_open_file_for_specified_file_phrase_uses_narrow_layers() -> None:
|
||||
result = run_sequence(["Уточни по файлу app/core/config.py"])[0]
|
||||
|
||||
assert_intent(result, "CODE_QA")
|
||||
assert_sub_intent(result, "OPEN_FILE")
|
||||
assert_file_only_scope(result, "app/core/config.py")
|
||||
layer_ids = [item.layer_id for item in result.retrieval_spec.layer_queries]
|
||||
assert layer_ids == ["C0_SOURCE_CHUNKS"]
|
||||
assert result.evidence_policy.require_flow is False
|
||||
|
||||
|
||||
def test_invariant_inline_code_span_routes_to_code_and_extracts_symbol() -> None:
|
||||
result = run_sequence(["Уточни по коду `def build(x): return x`"])[0]
|
||||
|
||||
assert_intent(result, "CODE_QA")
|
||||
assert_spans_valid(result)
|
||||
assert_no_symbol_keyword(result)
|
||||
symbols = [anchor.value for anchor in result.query_plan.anchors if anchor.type == "SYMBOL"]
|
||||
key_terms = [anchor.value for anchor in result.query_plan.anchors if anchor.type == "KEY_TERM"]
|
||||
assert "build" in symbols
|
||||
assert "def" in key_terms
|
||||
|
||||
|
||||
def test_invariant_docs_cyrillic_path_with_quotes() -> None:
|
||||
result = run_sequence(["Что сказано в «docs/архитектура.md»?"])[0]
|
||||
|
||||
assert_intent(result, "DOCS_QA")
|
||||
assert_sub_intent(result, "EXPLAIN")
|
||||
assert_domains(result, ["DOCS"])
|
||||
assert "docs/архитектура.md" in result.query_plan.normalized
|
||||
assert_has_file_path(result, "docs/архитектура.md")
|
||||
assert any(anchor.type == "DOC_REF" for anchor in result.query_plan.anchors)
|
||||
assert result.retrieval_spec.filters.doc_kinds == []
|
||||
assert_spans_valid(result)
|
||||
assert_domain_layer_prefixes(result)
|
||||
|
||||
|
||||
def test_invariant_file_check_phrase_not_project_misc() -> None:
|
||||
result = run_sequence(["Проверь app/modules/rag/explain/intent_builder.py и объясни"])[0]
|
||||
|
||||
assert_intent(result, "CODE_QA")
|
||||
assert_domains(result, ["CODE"])
|
||||
assert_no_symbol_leakage_from_paths(result)
|
||||
assert_domain_layer_prefixes(result)
|
||||
|
||||
|
||||
def test_invariant_tests_include_routing() -> None:
|
||||
result = run_sequence(["Где тесты на ConfigManager?"])[0]
|
||||
|
||||
assert_intent(result, "CODE_QA")
|
||||
assert_test_policy(result, "INCLUDE")
|
||||
symbols = [anchor.value for anchor in result.query_plan.anchors if anchor.type == "SYMBOL"]
|
||||
key_terms = [anchor.value for anchor in result.query_plan.anchors if anchor.type == "KEY_TERM"]
|
||||
assert "ConfigManager" in symbols
|
||||
assert "тест" in key_terms
|
||||
|
||||
|
||||
def test_invariant_keyword_hints_and_expansions_for_function_identifier() -> None:
|
||||
result = run_sequence(["Теперь объясни функцию load_config"])[0]
|
||||
|
||||
assert_intent(result, "CODE_QA")
|
||||
assert "load_config" in result.query_plan.keyword_hints
|
||||
assert "функция" not in result.query_plan.keyword_hints
|
||||
assert "def" not in result.query_plan.expansions
|
||||
|
||||
|
||||
def test_invariant_open_file_sub_intent_uses_narrow_retrieval_profile() -> None:
|
||||
result = run_sequence(["Открой файл app/core/config.py"])[0]
|
||||
|
||||
assert_intent(result, "CODE_QA")
|
||||
assert_sub_intent(result, "OPEN_FILE")
|
||||
assert_file_only_scope(result, "app/core/config.py")
|
||||
layer_ids = [item.layer_id for item in result.retrieval_spec.layer_queries]
|
||||
assert "C0_SOURCE_CHUNKS" in layer_ids
|
||||
assert "C1_SYMBOL_CATALOG" not in layer_ids
|
||||
assert "C2_DEPENDENCY_GRAPH" not in layer_ids
|
||||
assert "C3_ENTRYPOINTS" not in layer_ids
|
||||
assert result.evidence_policy.require_flow is False
|
||||
|
||||
|
||||
def test_invariant_docs_question_routes_to_docs() -> None:
|
||||
result = run_sequence(["Что сказано в документации?"])[0]
|
||||
|
||||
assert_intent(result, "DOCS_QA")
|
||||
assert_domains(result, ["DOCS"])
|
||||
assert_domain_layer_prefixes(result)
|
||||
assert result.query_plan.keyword_hints
|
||||
assert any(item in result.query_plan.expansions for item in result.query_plan.keyword_hints)
|
||||
78
tests/unit_tests/rag/test_layered_gateway.py
Normal file
78
tests/unit_tests/rag/test_layered_gateway.py
Normal file
@@ -0,0 +1,78 @@
|
||||
from app.modules.rag.explain.layered_gateway import LayeredRetrievalGateway
|
||||
|
||||
|
||||
class _Embedder:
|
||||
def embed(self, texts: list[str]) -> list[list[float]]:
|
||||
return [[0.1, 0.2]]
|
||||
|
||||
|
||||
class _RetryingRepository:
|
||||
def __init__(self) -> None:
|
||||
self.calls: list[dict] = []
|
||||
|
||||
def retrieve(self, *args, **kwargs):
|
||||
self.calls.append(kwargs)
|
||||
if kwargs.get("exclude_path_prefixes"):
|
||||
raise RuntimeError("syntax error at or near ')'")
|
||||
return [
|
||||
{
|
||||
"path": "app/users/service.py",
|
||||
"content": "def get_user(): pass",
|
||||
"layer": "C1_SYMBOL_CATALOG",
|
||||
"title": "get_user",
|
||||
"metadata": {"symbol_id": "symbol-1"},
|
||||
"distance": 0.1,
|
||||
"span_start": 10,
|
||||
"span_end": 11,
|
||||
}
|
||||
]
|
||||
|
||||
def retrieve_lexical_code(self, *args, **kwargs):
|
||||
self.calls.append(kwargs)
|
||||
if kwargs.get("exclude_path_prefixes"):
|
||||
raise RuntimeError("broken lexical filter")
|
||||
return [
|
||||
{
|
||||
"path": "app/users/service.py",
|
||||
"content": "def get_user(): pass",
|
||||
"layer": "C0_SOURCE_CHUNKS",
|
||||
"title": "get_user",
|
||||
"metadata": {"symbol_id": "symbol-1"},
|
||||
"span_start": 10,
|
||||
"span_end": 11,
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
class _RecordingRepository:
|
||||
def __init__(self) -> None:
|
||||
self.calls: list[dict] = []
|
||||
|
||||
def retrieve(self, *args, **kwargs):
|
||||
self.calls.append(kwargs)
|
||||
return []
|
||||
|
||||
def retrieve_lexical_code(self, *args, **kwargs):
|
||||
self.calls.append(kwargs)
|
||||
return []
|
||||
|
||||
|
||||
def test_gateway_retries_without_test_filter_on_vector_failure() -> None:
|
||||
gateway = LayeredRetrievalGateway(_RetryingRepository(), _Embedder())
|
||||
|
||||
result = gateway.retrieve_layer("rag-1", "Explain get_user", "C1_SYMBOL_CATALOG", limit=3, exclude_tests=True)
|
||||
|
||||
assert len(result.items) == 1
|
||||
assert "layer:C1_SYMBOL_CATALOG retrieval_failed:retried_without_test_filter" in result.missing
|
||||
|
||||
|
||||
def test_gateway_honors_debug_disable_test_filter(monkeypatch) -> None:
|
||||
monkeypatch.setenv("RAG_DEBUG_DISABLE_TEST_FILTER", "true")
|
||||
repository = _RecordingRepository()
|
||||
gateway = LayeredRetrievalGateway(repository, _Embedder())
|
||||
|
||||
gateway.retrieve_layer("rag-1", "Explain get_user", "C1_SYMBOL_CATALOG", limit=3, exclude_tests=True)
|
||||
|
||||
assert repository.calls
|
||||
assert repository.calls[0]["exclude_path_prefixes"] is None
|
||||
assert repository.calls[0]["exclude_like_patterns"] is None
|
||||
44
tests/unit_tests/rag/test_path_filter.py
Normal file
44
tests/unit_tests/rag/test_path_filter.py
Normal file
@@ -0,0 +1,44 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from app.modules.rag.indexing.common.path_filter import (
|
||||
count_indexable_change_upserts,
|
||||
filter_changes_for_indexing,
|
||||
filter_snapshot_files,
|
||||
should_skip_indexing_path,
|
||||
)
|
||||
|
||||
|
||||
def test_should_skip_indexing_path_for_hidden_and_cache_paths() -> None:
|
||||
assert should_skip_indexing_path(".env") is True
|
||||
assert should_skip_indexing_path(".venv/lib/site-packages/a.py") is True
|
||||
assert should_skip_indexing_path("src/.secrets/config.py") is True
|
||||
assert should_skip_indexing_path("src/__pycache__/module.cpython-312.pyc") is True
|
||||
assert should_skip_indexing_path("src/main.py") is False
|
||||
|
||||
|
||||
def test_filter_snapshot_files_excludes_hidden_and_cache_paths() -> None:
|
||||
files = [
|
||||
{"path": ".env", "content": "A"},
|
||||
{"path": "src/__pycache__/x.py", "content": "B"},
|
||||
{"path": "src/main.py", "content": "C"},
|
||||
]
|
||||
|
||||
filtered = filter_snapshot_files(files)
|
||||
|
||||
assert [item["path"] for item in filtered] == ["src/main.py"]
|
||||
|
||||
|
||||
def test_filter_changes_for_indexing_keeps_deletes_and_filters_upserts() -> None:
|
||||
changed_files = [
|
||||
{"op": "upsert", "path": ".env", "content": "A"},
|
||||
{"op": "upsert", "path": "src/main.py", "content": "B"},
|
||||
{"op": "delete", "path": ".cache/legacy.txt"},
|
||||
]
|
||||
|
||||
filtered = filter_changes_for_indexing(changed_files)
|
||||
|
||||
assert filtered == [
|
||||
{"op": "upsert", "path": "src/main.py", "content": "B"},
|
||||
{"op": "delete", "path": ".cache/legacy.txt"},
|
||||
]
|
||||
assert count_indexable_change_upserts(filtered) == 1
|
||||
63
tests/unit_tests/rag/test_query_normalization.py
Normal file
63
tests/unit_tests/rag/test_query_normalization.py
Normal file
@@ -0,0 +1,63 @@
|
||||
import pytest
|
||||
|
||||
from app.modules.rag.intent_router_v2.analysis.normalization import QueryNormalizer
|
||||
|
||||
pytestmark = pytest.mark.intent_router
|
||||
|
||||
|
||||
def test_query_normalizer_collapses_whitespace() -> None:
|
||||
normalizer = QueryNormalizer()
|
||||
|
||||
normalized = normalizer.normalize(" Объясни как работает \n класс X ")
|
||||
|
||||
assert normalized == "Объясни как работает класс X"
|
||||
|
||||
|
||||
def test_query_normalizer_canonicalizes_quotes() -> None:
|
||||
normalizer = QueryNormalizer()
|
||||
|
||||
normalized = normalizer.normalize('Уточни «текст» и “текст”')
|
||||
|
||||
assert normalized == 'Уточни "текст" и "текст"'
|
||||
|
||||
|
||||
def test_query_normalizer_preserves_backticks_verbatim() -> None:
|
||||
normalizer = QueryNormalizer()
|
||||
|
||||
normalized = normalizer.normalize("Уточни по коду `def build(x):` ")
|
||||
|
||||
assert normalized == "Уточни по коду `def build(x):`"
|
||||
|
||||
|
||||
def test_query_normalizer_preserves_latin_and_cyrillic_file_paths() -> None:
|
||||
normalizer = QueryNormalizer()
|
||||
|
||||
normalized = normalizer.normalize("Сверь app/core/config.py и «docs/руководство.md»")
|
||||
|
||||
assert "app/core/config.py" in normalized
|
||||
assert "docs/руководство.md" in normalized
|
||||
assert "config. py" not in normalized
|
||||
assert "руководство. md" not in normalized
|
||||
|
||||
|
||||
def test_query_normalizer_punctuation_spacing_does_not_break_extensions() -> None:
|
||||
normalizer = QueryNormalizer()
|
||||
|
||||
normalized = normalizer.normalize("Проверь docs/spec.md , затем app/main.py !")
|
||||
|
||||
assert "docs/spec.md" in normalized
|
||||
assert "app/main.py" in normalized
|
||||
assert "spec. md" not in normalized
|
||||
assert "main. py" not in normalized
|
||||
|
||||
|
||||
def test_query_normalizer_idempotent_and_without_enrichment() -> None:
|
||||
normalizer = QueryNormalizer()
|
||||
raw = ' Прочитай «README.md» и docs/spec.md '
|
||||
|
||||
once = normalizer.normalize(raw)
|
||||
twice = normalizer.normalize(once)
|
||||
|
||||
assert twice == once
|
||||
assert "documentation" not in once.lower()
|
||||
assert "class" not in once.lower()
|
||||
9
tests/unit_tests/rag/test_query_terms.py
Normal file
9
tests/unit_tests/rag/test_query_terms.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from app.modules.rag.retrieval.query_terms import extract_query_terms
|
||||
|
||||
|
||||
def test_extract_query_terms_from_code_question() -> None:
|
||||
terms = extract_query_terms("Объясни по коду как можно управлять COnfigmanager?")
|
||||
|
||||
assert "configmanager" in terms
|
||||
assert "config_manager" in terms
|
||||
assert "control" in terms
|
||||
52
tests/unit_tests/rag/test_rag_service_filtering.py
Normal file
52
tests/unit_tests/rag/test_rag_service_filtering.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
|
||||
from app.modules.rag.services.rag_service import RagService
|
||||
|
||||
|
||||
class _FakeEmbedder:
|
||||
def embed(self, texts: list[str]) -> list[list[float]]:
|
||||
return [[0.0, 0.1, 0.2] for _ in texts]
|
||||
|
||||
|
||||
class _FakeRepository:
|
||||
def __init__(self) -> None:
|
||||
self.replaced_docs = []
|
||||
|
||||
def get_session(self, rag_session_id: str) -> dict:
|
||||
return {"project_id": rag_session_id}
|
||||
|
||||
def get_cached_documents(self, repo_id: str, blob_sha: str) -> list:
|
||||
return []
|
||||
|
||||
def cache_documents(self, repo_id: str, path: str, blob_sha: str, docs: list) -> None:
|
||||
return None
|
||||
|
||||
def replace_documents(self, rag_session_id: str, docs: list) -> None:
|
||||
self.replaced_docs = docs
|
||||
|
||||
|
||||
def test_rag_service_progress_uses_only_indexable_files() -> None:
|
||||
repository = _FakeRepository()
|
||||
service = RagService(embedder=_FakeEmbedder(), repository=repository)
|
||||
files = [
|
||||
{"path": ".env", "content": "SECRET=1", "content_hash": "h1"},
|
||||
{"path": "src/.hidden/config.py", "content": "A=1", "content_hash": "h2"},
|
||||
{"path": "src/__pycache__/cache.py", "content": "A=2", "content_hash": "h3"},
|
||||
{"path": "src/main.py", "content": "def main():\n return 1\n", "content_hash": "h4"},
|
||||
]
|
||||
progress: list[tuple[int, int, str]] = []
|
||||
|
||||
def progress_cb(current: int, total: int, path: str) -> None:
|
||||
progress.append((current, total, path))
|
||||
|
||||
indexed, failed, cache_hits, cache_misses = asyncio.run(
|
||||
service.index_snapshot("project-1", files, progress_cb=progress_cb)
|
||||
)
|
||||
|
||||
assert indexed == 1
|
||||
assert failed == 0
|
||||
assert cache_hits == 0
|
||||
assert cache_misses == 1
|
||||
assert progress == [(1, 1, "src/main.py")]
|
||||
65
tests/unit_tests/rag/test_retrieval_statement_builder.py
Normal file
65
tests/unit_tests/rag/test_retrieval_statement_builder.py
Normal file
@@ -0,0 +1,65 @@
|
||||
from app.modules.rag.persistence.retrieval_statement_builder import RetrievalStatementBuilder
|
||||
from app.modules.rag.retrieval.test_filter import build_test_filters, is_test_path
|
||||
|
||||
|
||||
def test_retrieve_builder_adds_test_exclusion_filters() -> None:
|
||||
builder = RetrievalStatementBuilder()
|
||||
test_filters = build_test_filters()
|
||||
|
||||
sql, params = builder.build_retrieve(
|
||||
"rag-1",
|
||||
[0.1, 0.2],
|
||||
query_text="Explain user service",
|
||||
layers=["C0_SOURCE_CHUNKS"],
|
||||
exclude_path_prefixes=test_filters.exclude_path_prefixes,
|
||||
exclude_like_patterns=test_filters.exclude_like_patterns,
|
||||
)
|
||||
|
||||
assert "NOT (" in sql
|
||||
assert "vector_dims(embedding) = vector_dims(CAST(:emb AS vector))" in sql
|
||||
assert "path LIKE :exclude_prefix_0" in sql
|
||||
assert "lower(path) LIKE :exclude_like_0" in sql
|
||||
assert "ESCAPE E'\\\\'" in sql
|
||||
assert params["exclude_prefix_0"] == "tests/%"
|
||||
assert "%.test.%" in params.values()
|
||||
assert "%\\_test.%" in params.values()
|
||||
|
||||
|
||||
def test_retrieve_builder_adds_prefer_bonus_sorting() -> None:
|
||||
builder = RetrievalStatementBuilder()
|
||||
|
||||
sql, params = builder.build_retrieve(
|
||||
"rag-1",
|
||||
[0.1, 0.2],
|
||||
query_text="find context tests",
|
||||
layers=["C1_SYMBOL_CATALOG"],
|
||||
prefer_path_prefixes=["tests/"],
|
||||
prefer_like_patterns=["%/test\\_%.py"],
|
||||
)
|
||||
|
||||
assert "AS prefer_bonus" in sql
|
||||
assert "AS structural_rank" in sql
|
||||
assert "WHEN layer = 'C4_SEMANTIC_ROLES' THEN 2" in sql
|
||||
assert "ORDER BY prefer_bonus ASC, test_penalty ASC, layer_rank ASC" in sql
|
||||
assert params["prefer_prefix_0"] == "tests/%"
|
||||
assert params["prefer_like_0"] == "%/test\\_%.py"
|
||||
|
||||
|
||||
def test_lexical_builder_omits_test_filters_when_not_requested() -> None:
|
||||
builder = RetrievalStatementBuilder()
|
||||
|
||||
sql, params = builder.build_lexical_code(
|
||||
"rag-1",
|
||||
query_text="Explain user service",
|
||||
prefer_non_tests=False,
|
||||
)
|
||||
|
||||
assert sql is not None
|
||||
assert "exclude_prefix" not in sql
|
||||
assert "exclude_like" not in sql
|
||||
assert not any(key.startswith("exclude_") for key in params)
|
||||
|
||||
|
||||
def test_test_filter_does_not_treat_contest_file_as_test() -> None:
|
||||
assert is_test_path("app/contest.py") is False
|
||||
assert is_test_path("tests/test_users.py") is True
|
||||
52
tests/unit_tests/rag/test_retriever_v2_no_fallback.py
Normal file
52
tests/unit_tests/rag/test_retriever_v2_no_fallback.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from app.modules.rag.explain import CodeExplainRetrieverV2, LayeredRetrievalGateway
|
||||
|
||||
|
||||
class _ExplodingEmbedder:
|
||||
def embed(self, texts: list[str]) -> list[list[float]]:
|
||||
raise RuntimeError("embedding unavailable")
|
||||
|
||||
|
||||
class _RepositoryWithoutFallback:
|
||||
def retrieve(self, *args, **kwargs):
|
||||
raise RuntimeError("vector retrieval unavailable")
|
||||
|
||||
def retrieve_lexical_code(
|
||||
self,
|
||||
rag_session_id: str,
|
||||
query_text: str,
|
||||
*,
|
||||
limit: int = 5,
|
||||
path_prefixes: list[str] | None = None,
|
||||
exclude_path_prefixes: list[str] | None = None,
|
||||
exclude_like_patterns: list[str] | None = None,
|
||||
prefer_non_tests: bool = False,
|
||||
):
|
||||
return []
|
||||
|
||||
|
||||
class _FakeGraphRepository:
|
||||
def get_symbols_by_ids(self, rag_session_id: str, symbol_ids: list[str]):
|
||||
return []
|
||||
|
||||
def get_chunks_by_symbol_ids(self, rag_session_id: str, symbol_ids: list[str], prefer_chunk_type: str = "symbol_block"):
|
||||
return []
|
||||
|
||||
def get_out_edges(self, rag_session_id: str, src_symbol_ids: list[str], edge_types: list[str], limit_per_src: int):
|
||||
return []
|
||||
|
||||
def resolve_symbol_by_ref(self, rag_session_id: str, dst_ref: str, package_hint: str | None = None):
|
||||
return None
|
||||
|
||||
|
||||
def test_retriever_v2_returns_pack_without_fallback_method() -> None:
|
||||
retriever = CodeExplainRetrieverV2(
|
||||
gateway=LayeredRetrievalGateway(_RepositoryWithoutFallback(), _ExplodingEmbedder()),
|
||||
graph_repository=_FakeGraphRepository(),
|
||||
)
|
||||
|
||||
pack = retriever.build_pack("rag-1", "Explain get_user")
|
||||
|
||||
assert pack.code_excerpts == []
|
||||
assert any(item.startswith("layer:C3_ENTRYPOINTS retrieval_failed") for item in pack.missing)
|
||||
assert any(item.startswith("layer:C1_SYMBOL_CATALOG retrieval_failed") for item in pack.missing)
|
||||
assert "layer:C0 empty" in pack.missing
|
||||
105
tests/unit_tests/rag/test_retriever_v2_pack.py
Normal file
105
tests/unit_tests/rag/test_retriever_v2_pack.py
Normal file
@@ -0,0 +1,105 @@
|
||||
from app.modules.rag.explain.models import CodeLocation, LayeredRetrievalItem
|
||||
from app.modules.rag.explain.retriever_v2 import CodeExplainRetrieverV2
|
||||
|
||||
|
||||
class _FakeGateway:
|
||||
def retrieve_layer(
|
||||
self,
|
||||
rag_session_id: str,
|
||||
query: str,
|
||||
layer: str,
|
||||
*,
|
||||
limit: int,
|
||||
path_prefixes: list[str] | None = None,
|
||||
exclude_tests: bool = True,
|
||||
prefer_non_tests: bool = False,
|
||||
include_spans: bool = False,
|
||||
):
|
||||
if layer == "C3_ENTRYPOINTS":
|
||||
return __import__("types").SimpleNamespace(
|
||||
items=[
|
||||
LayeredRetrievalItem(
|
||||
source="app/api/users.py",
|
||||
content="GET /users/{id}",
|
||||
layer=layer,
|
||||
title="GET /users/{id}",
|
||||
metadata={"entry_type": "http", "handler_symbol_id": "handler-1"},
|
||||
location=CodeLocation(path="app/api/users.py", start_line=10, end_line=10),
|
||||
)
|
||||
],
|
||||
missing=[],
|
||||
)
|
||||
if layer == "C1_SYMBOL_CATALOG":
|
||||
return __import__("types").SimpleNamespace(
|
||||
items=[
|
||||
LayeredRetrievalItem(
|
||||
source="app/api/users.py",
|
||||
content="def get_user_handler",
|
||||
layer=layer,
|
||||
title="get_user_handler",
|
||||
metadata={"symbol_id": "handler-1"},
|
||||
location=CodeLocation(path="app/api/users.py", start_line=10, end_line=18),
|
||||
)
|
||||
],
|
||||
missing=[],
|
||||
)
|
||||
raise AssertionError(layer)
|
||||
|
||||
def retrieve_lexical_code(
|
||||
self,
|
||||
rag_session_id: str,
|
||||
query: str,
|
||||
*,
|
||||
limit: int,
|
||||
path_prefixes: list[str] | None = None,
|
||||
exclude_tests: bool = True,
|
||||
include_spans: bool = False,
|
||||
):
|
||||
return __import__("types").SimpleNamespace(items=[], missing=[])
|
||||
|
||||
|
||||
class _FakeGraphRepository:
|
||||
def get_symbols_by_ids(self, rag_session_id: str, symbol_ids: list[str]):
|
||||
return [
|
||||
LayeredRetrievalItem(
|
||||
source="app/api/users.py",
|
||||
content="def get_user_handler",
|
||||
layer="C1_SYMBOL_CATALOG",
|
||||
title="get_user_handler",
|
||||
metadata={"symbol_id": "handler-1"},
|
||||
location=CodeLocation(path="app/api/users.py", start_line=10, end_line=18),
|
||||
)
|
||||
]
|
||||
|
||||
def get_out_edges(self, rag_session_id: str, src_symbol_ids: list[str], edge_types: list[str], limit_per_src: int):
|
||||
return []
|
||||
|
||||
def resolve_symbol_by_ref(self, rag_session_id: str, dst_ref: str, package_hint: str | None = None):
|
||||
return None
|
||||
|
||||
def get_chunks_by_symbol_ids(self, rag_session_id: str, symbol_ids: list[str], prefer_chunk_type: str = "symbol_block"):
|
||||
return [
|
||||
LayeredRetrievalItem(
|
||||
source="app/api/users.py",
|
||||
content="async def get_user_handler(user_id: str):\n return await service.get_user(user_id)",
|
||||
layer="C0_SOURCE_CHUNKS",
|
||||
title="get_user_handler",
|
||||
metadata={"symbol_id": "handler-1"},
|
||||
location=CodeLocation(path="app/api/users.py", start_line=10, end_line=18),
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def test_retriever_v2_builds_pack_with_trace_and_excerpts() -> None:
|
||||
retriever = CodeExplainRetrieverV2(
|
||||
gateway=_FakeGateway(),
|
||||
graph_repository=_FakeGraphRepository(),
|
||||
)
|
||||
|
||||
pack = retriever.build_pack("rag-1", "Explain endpoint get_user")
|
||||
|
||||
assert len(pack.selected_entrypoints) == 1
|
||||
assert len(pack.seed_symbols) == 1
|
||||
assert len(pack.trace_paths) == 1
|
||||
assert len(pack.code_excerpts) == 1
|
||||
assert pack.code_excerpts[0].path == "app/api/users.py"
|
||||
142
tests/unit_tests/rag/test_retriever_v2_production_first.py
Normal file
142
tests/unit_tests/rag/test_retriever_v2_production_first.py
Normal file
@@ -0,0 +1,142 @@
|
||||
from types import SimpleNamespace
|
||||
|
||||
from app.modules.rag.explain.models import CodeLocation, LayeredRetrievalItem
|
||||
from app.modules.rag.explain.retriever_v2 import CodeExplainRetrieverV2
|
||||
|
||||
|
||||
class _ProductionFirstGateway:
|
||||
def __init__(self) -> None:
|
||||
self.lexical_calls: list[bool] = []
|
||||
|
||||
def retrieve_layer(
|
||||
self,
|
||||
rag_session_id: str,
|
||||
query: str,
|
||||
layer: str,
|
||||
*,
|
||||
limit: int,
|
||||
path_prefixes: list[str] | None = None,
|
||||
exclude_tests: bool = True,
|
||||
prefer_non_tests: bool = False,
|
||||
include_spans: bool = False,
|
||||
):
|
||||
if layer == "C3_ENTRYPOINTS":
|
||||
return SimpleNamespace(items=[], missing=[])
|
||||
if layer == "C1_SYMBOL_CATALOG":
|
||||
return SimpleNamespace(items=[], missing=[])
|
||||
raise AssertionError(layer)
|
||||
|
||||
def retrieve_lexical_code(
|
||||
self,
|
||||
rag_session_id: str,
|
||||
query: str,
|
||||
*,
|
||||
limit: int,
|
||||
path_prefixes: list[str] | None = None,
|
||||
exclude_tests: bool = True,
|
||||
include_spans: bool = False,
|
||||
):
|
||||
self.lexical_calls.append(exclude_tests)
|
||||
if exclude_tests:
|
||||
return SimpleNamespace(
|
||||
items=[
|
||||
LayeredRetrievalItem(
|
||||
source="app/users/service.py",
|
||||
content="def get_user():\n return repo.get_user()",
|
||||
layer="C0_SOURCE_CHUNKS",
|
||||
title="get_user",
|
||||
metadata={"symbol_id": "user-service", "is_test": False},
|
||||
location=CodeLocation(path="app/users/service.py", start_line=10, end_line=11),
|
||||
),
|
||||
LayeredRetrievalItem(
|
||||
source="app/users/repository.py",
|
||||
content="def get_user_repo():\n return {}",
|
||||
layer="C0_SOURCE_CHUNKS",
|
||||
title="get_user_repo",
|
||||
metadata={"symbol_id": "user-repo", "is_test": False},
|
||||
location=CodeLocation(path="app/users/repository.py", start_line=20, end_line=21),
|
||||
),
|
||||
],
|
||||
missing=[],
|
||||
)
|
||||
return SimpleNamespace(
|
||||
items=[
|
||||
LayeredRetrievalItem(
|
||||
source="tests/test_users.py",
|
||||
content="def test_get_user():\n assert service.get_user()",
|
||||
layer="C0_SOURCE_CHUNKS",
|
||||
title="test_get_user",
|
||||
metadata={"symbol_id": "test-user", "is_test": True},
|
||||
location=CodeLocation(path="tests/test_users.py", start_line=5, end_line=6),
|
||||
)
|
||||
],
|
||||
missing=[],
|
||||
)
|
||||
|
||||
|
||||
class _TestsOnlyGateway(_ProductionFirstGateway):
|
||||
def retrieve_lexical_code(
|
||||
self,
|
||||
rag_session_id: str,
|
||||
query: str,
|
||||
*,
|
||||
limit: int,
|
||||
path_prefixes: list[str] | None = None,
|
||||
exclude_tests: bool = True,
|
||||
include_spans: bool = False,
|
||||
):
|
||||
self.lexical_calls.append(exclude_tests)
|
||||
if exclude_tests:
|
||||
return SimpleNamespace(items=[], missing=[])
|
||||
return SimpleNamespace(
|
||||
items=[
|
||||
LayeredRetrievalItem(
|
||||
source="tests/test_users.py",
|
||||
content="def test_get_user():\n assert service.get_user()",
|
||||
layer="C0_SOURCE_CHUNKS",
|
||||
title="test_get_user",
|
||||
metadata={"symbol_id": "test-user", "is_test": True},
|
||||
location=CodeLocation(path="tests/test_users.py", start_line=5, end_line=6),
|
||||
)
|
||||
],
|
||||
missing=[],
|
||||
)
|
||||
|
||||
|
||||
class _FakeGraphRepository:
|
||||
def get_symbols_by_ids(self, rag_session_id: str, symbol_ids: list[str]):
|
||||
return []
|
||||
|
||||
def get_chunks_by_symbol_ids(self, rag_session_id: str, symbol_ids: list[str], prefer_chunk_type: str = "symbol_block"):
|
||||
return []
|
||||
|
||||
def get_out_edges(self, rag_session_id: str, src_symbol_ids: list[str], edge_types: list[str], limit_per_src: int):
|
||||
return []
|
||||
|
||||
def resolve_symbol_by_ref(self, rag_session_id: str, dst_ref: str, package_hint: str | None = None):
|
||||
return None
|
||||
|
||||
|
||||
def test_retriever_prefers_prod_chunks_and_skips_test_fallback_when_enough_evidence() -> None:
|
||||
gateway = _ProductionFirstGateway()
|
||||
retriever = CodeExplainRetrieverV2(gateway=gateway, graph_repository=_FakeGraphRepository())
|
||||
|
||||
pack = retriever.build_pack("rag-1", "Explain get_user")
|
||||
|
||||
assert gateway.lexical_calls == [True]
|
||||
assert [excerpt.path for excerpt in pack.code_excerpts] == [
|
||||
"app/users/service.py",
|
||||
"app/users/repository.py",
|
||||
]
|
||||
assert all(not excerpt.focus.startswith("test:") for excerpt in pack.code_excerpts)
|
||||
|
||||
|
||||
def test_retriever_uses_test_fallback_when_production_evidence_is_missing() -> None:
|
||||
gateway = _TestsOnlyGateway()
|
||||
retriever = CodeExplainRetrieverV2(gateway=gateway, graph_repository=_FakeGraphRepository())
|
||||
|
||||
pack = retriever.build_pack("rag-1", "Explain get_user")
|
||||
|
||||
assert gateway.lexical_calls == [True, False]
|
||||
assert [excerpt.path for excerpt in pack.code_excerpts] == ["tests/test_users.py"]
|
||||
assert pack.code_excerpts[0].focus == "test:lexical"
|
||||
83
tests/unit_tests/rag/test_trace_builder.py
Normal file
83
tests/unit_tests/rag/test_trace_builder.py
Normal file
@@ -0,0 +1,83 @@
|
||||
from app.modules.rag.explain.models import CodeLocation, LayeredRetrievalItem
|
||||
from app.modules.rag.explain.trace_builder import TraceBuilder
|
||||
|
||||
|
||||
class _FakeGraphRepository:
|
||||
def get_out_edges(self, rag_session_id: str, src_symbol_ids: list[str], edge_types: list[str], limit_per_src: int):
|
||||
assert rag_session_id == "rag-1"
|
||||
assert edge_types == ["calls", "imports", "inherits", "instantiates", "reads_attr", "writes_attr"]
|
||||
if src_symbol_ids == ["handler-1"]:
|
||||
return [
|
||||
LayeredRetrievalItem(
|
||||
source="app/api/users.py",
|
||||
content="handler calls get_user",
|
||||
layer="C2_DEPENDENCY_GRAPH",
|
||||
title="handler:calls",
|
||||
metadata={
|
||||
"src_symbol_id": "handler-1",
|
||||
"dst_symbol_id": None,
|
||||
"dst_ref": "UserService.get_user",
|
||||
"resolution": "partial",
|
||||
"edge_type": "calls",
|
||||
},
|
||||
location=CodeLocation(path="app/api/users.py", start_line=12, end_line=12),
|
||||
)
|
||||
]
|
||||
return []
|
||||
|
||||
def resolve_symbol_by_ref(self, rag_session_id: str, dst_ref: str, package_hint: str | None = None):
|
||||
assert rag_session_id == "rag-1"
|
||||
assert dst_ref == "UserService.get_user"
|
||||
assert package_hint == "app.api"
|
||||
return LayeredRetrievalItem(
|
||||
source="app/services/users.py",
|
||||
content="method UserService.get_user",
|
||||
layer="C1_SYMBOL_CATALOG",
|
||||
title="UserService.get_user",
|
||||
metadata={
|
||||
"symbol_id": "service-1",
|
||||
"package_or_module": "app.api.users",
|
||||
},
|
||||
location=CodeLocation(path="app/services/users.py", start_line=4, end_line=10),
|
||||
)
|
||||
|
||||
def get_symbols_by_ids(self, rag_session_id: str, symbol_ids: list[str]):
|
||||
assert rag_session_id == "rag-1"
|
||||
if symbol_ids == ["service-1"]:
|
||||
return [
|
||||
LayeredRetrievalItem(
|
||||
source="app/services/users.py",
|
||||
content="method UserService.get_user",
|
||||
layer="C1_SYMBOL_CATALOG",
|
||||
title="UserService.get_user",
|
||||
metadata={
|
||||
"symbol_id": "service-1",
|
||||
"package_or_module": "app.api.users",
|
||||
},
|
||||
location=CodeLocation(path="app/services/users.py", start_line=4, end_line=10),
|
||||
)
|
||||
]
|
||||
return []
|
||||
|
||||
|
||||
def test_trace_builder_resolves_partial_edges_across_files() -> None:
|
||||
builder = TraceBuilder(_FakeGraphRepository())
|
||||
seeds = [
|
||||
LayeredRetrievalItem(
|
||||
source="app/api/users.py",
|
||||
content="function handler",
|
||||
layer="C1_SYMBOL_CATALOG",
|
||||
title="get_user",
|
||||
metadata={
|
||||
"symbol_id": "handler-1",
|
||||
"package_or_module": "app.api.users",
|
||||
},
|
||||
location=CodeLocation(path="app/api/users.py", start_line=10, end_line=18),
|
||||
)
|
||||
]
|
||||
|
||||
paths = builder.build_paths("rag-1", seeds, max_depth=3)
|
||||
|
||||
assert len(paths) >= 1
|
||||
assert paths[0].symbol_ids == ["handler-1", "service-1"]
|
||||
assert "resolved:UserService.get_user" in paths[0].notes
|
||||
Reference in New Issue
Block a user