Фиксация изменений
This commit is contained in:
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
181
tests/agent/engine/router/test_router_service_intent_policy.py
Normal file
181
tests/agent/engine/router/test_router_service_intent_policy.py
Normal file
@@ -0,0 +1,181 @@
|
||||
import sys
|
||||
import types
|
||||
|
||||
sqlalchemy = types.ModuleType("sqlalchemy")
|
||||
sqlalchemy.text = lambda value: value
|
||||
sqlalchemy.create_engine = lambda *args, **kwargs: object()
|
||||
sys.modules.setdefault("sqlalchemy", sqlalchemy)
|
||||
|
||||
sqlalchemy_engine = types.ModuleType("sqlalchemy.engine")
|
||||
sqlalchemy_engine.Engine = object
|
||||
sys.modules.setdefault("sqlalchemy.engine", sqlalchemy_engine)
|
||||
|
||||
sqlalchemy_orm = types.ModuleType("sqlalchemy.orm")
|
||||
sqlalchemy_orm.sessionmaker = lambda *args, **kwargs: object()
|
||||
sys.modules.setdefault("sqlalchemy.orm", sqlalchemy_orm)
|
||||
|
||||
sqlalchemy_pool = types.ModuleType("sqlalchemy.pool")
|
||||
sqlalchemy_pool.NullPool = object
|
||||
sys.modules.setdefault("sqlalchemy.pool", sqlalchemy_pool)
|
||||
|
||||
from app.modules.agent.engine.router.router_service import RouterService
|
||||
from app.modules.agent.engine.router.schemas import RouteDecision, RouterContext
|
||||
|
||||
|
||||
class _FakeRegistry:
|
||||
def is_valid(self, domain_id: str, process_id: str) -> bool:
|
||||
return (domain_id, process_id) in {
|
||||
("default", "general"),
|
||||
("project", "qa"),
|
||||
("project", "edits"),
|
||||
("docs", "generation"),
|
||||
}
|
||||
|
||||
def get_factory(self, domain_id: str, process_id: str):
|
||||
return object()
|
||||
|
||||
|
||||
class _FakeClassifier:
|
||||
def __init__(self, decision: RouteDecision | None = None, forced: RouteDecision | None = None) -> None:
|
||||
self._decision = decision or RouteDecision(domain_id="project", process_id="qa", confidence=0.95, reason="new_intent")
|
||||
self._forced = forced
|
||||
self.calls = 0
|
||||
|
||||
def from_mode(self, mode: str) -> RouteDecision | None:
|
||||
return self._forced if mode != "auto" else None
|
||||
|
||||
def classify_new_intent(self, user_message: str, context: RouterContext) -> RouteDecision:
|
||||
self.calls += 1
|
||||
return self._decision
|
||||
|
||||
|
||||
class _FakeContextStore:
|
||||
def __init__(self, context: RouterContext) -> None:
|
||||
self._context = context
|
||||
self.updated: list[dict] = []
|
||||
|
||||
def get(self, conversation_key: str) -> RouterContext:
|
||||
return self._context
|
||||
|
||||
def update(self, conversation_key: str, **kwargs) -> None:
|
||||
self.updated.append({"conversation_key": conversation_key, **kwargs})
|
||||
|
||||
|
||||
class _FakeSwitchDetector:
|
||||
def __init__(self, should_switch: bool) -> None:
|
||||
self._should_switch = should_switch
|
||||
|
||||
def should_switch(self, user_message: str, context: RouterContext) -> bool:
|
||||
return self._should_switch
|
||||
|
||||
|
||||
def test_router_service_classifies_first_message() -> None:
|
||||
service = RouterService(
|
||||
registry=_FakeRegistry(),
|
||||
classifier=_FakeClassifier(),
|
||||
context_store=_FakeContextStore(RouterContext()),
|
||||
switch_detector=_FakeSwitchDetector(False),
|
||||
)
|
||||
|
||||
route = service.resolve("Объясни как работает endpoint", "dialog-1")
|
||||
|
||||
assert route.domain_id == "project"
|
||||
assert route.process_id == "qa"
|
||||
assert route.decision_type == "start"
|
||||
|
||||
|
||||
def test_router_service_keeps_current_intent_for_follow_up() -> None:
|
||||
context = RouterContext(
|
||||
active_intent={"domain_id": "project", "process_id": "qa"},
|
||||
last_routing={"domain_id": "project", "process_id": "qa"},
|
||||
dialog_started=True,
|
||||
turn_index=1,
|
||||
)
|
||||
classifier = _FakeClassifier(
|
||||
decision=RouteDecision(domain_id="docs", process_id="generation", confidence=0.99, reason="should_not_run")
|
||||
)
|
||||
service = RouterService(
|
||||
registry=_FakeRegistry(),
|
||||
classifier=classifier,
|
||||
context_store=_FakeContextStore(context),
|
||||
switch_detector=_FakeSwitchDetector(False),
|
||||
)
|
||||
|
||||
route = service.resolve("Покажи подробнее", "dialog-1")
|
||||
|
||||
assert route.domain_id == "project"
|
||||
assert route.process_id == "qa"
|
||||
assert route.decision_type == "continue"
|
||||
assert classifier.calls == 0
|
||||
|
||||
|
||||
def test_router_service_switches_only_on_explicit_new_intent() -> None:
|
||||
context = RouterContext(
|
||||
active_intent={"domain_id": "project", "process_id": "qa"},
|
||||
last_routing={"domain_id": "project", "process_id": "qa"},
|
||||
dialog_started=True,
|
||||
turn_index=2,
|
||||
)
|
||||
classifier = _FakeClassifier(
|
||||
decision=RouteDecision(domain_id="project", process_id="edits", confidence=0.96, reason="explicit_edit")
|
||||
)
|
||||
service = RouterService(
|
||||
registry=_FakeRegistry(),
|
||||
classifier=classifier,
|
||||
context_store=_FakeContextStore(context),
|
||||
switch_detector=_FakeSwitchDetector(True),
|
||||
)
|
||||
|
||||
route = service.resolve("Теперь измени файл README.md", "dialog-1")
|
||||
|
||||
assert route.domain_id == "project"
|
||||
assert route.process_id == "edits"
|
||||
assert route.decision_type == "switch"
|
||||
assert route.explicit_switch is True
|
||||
assert classifier.calls == 1
|
||||
|
||||
|
||||
def test_router_service_keeps_current_when_explicit_switch_is_unresolved() -> None:
|
||||
context = RouterContext(
|
||||
active_intent={"domain_id": "project", "process_id": "qa"},
|
||||
last_routing={"domain_id": "project", "process_id": "qa"},
|
||||
dialog_started=True,
|
||||
turn_index=2,
|
||||
)
|
||||
classifier = _FakeClassifier(
|
||||
decision=RouteDecision(domain_id="docs", process_id="generation", confidence=0.2, reason="low_confidence")
|
||||
)
|
||||
service = RouterService(
|
||||
registry=_FakeRegistry(),
|
||||
classifier=classifier,
|
||||
context_store=_FakeContextStore(context),
|
||||
switch_detector=_FakeSwitchDetector(True),
|
||||
)
|
||||
|
||||
route = service.resolve("Теперь сделай что-то другое", "dialog-1")
|
||||
|
||||
assert route.domain_id == "project"
|
||||
assert route.process_id == "qa"
|
||||
assert route.decision_type == "continue"
|
||||
assert route.reason == "explicit_switch_unresolved_keep_current"
|
||||
|
||||
|
||||
def test_router_service_persists_decision_type() -> None:
|
||||
store = _FakeContextStore(RouterContext())
|
||||
service = RouterService(
|
||||
registry=_FakeRegistry(),
|
||||
classifier=_FakeClassifier(),
|
||||
context_store=store,
|
||||
switch_detector=_FakeSwitchDetector(False),
|
||||
)
|
||||
|
||||
service.persist_context(
|
||||
"dialog-1",
|
||||
domain_id="project",
|
||||
process_id="qa",
|
||||
user_message="Объясни",
|
||||
assistant_message="Ответ",
|
||||
decision_type="continue",
|
||||
)
|
||||
|
||||
assert store.updated[0]["decision_type"] == "continue"
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
59
tests/agent/orchestrator/test_code_explain_actions.py
Normal file
59
tests/agent/orchestrator/test_code_explain_actions.py
Normal file
@@ -0,0 +1,59 @@
|
||||
from app.modules.agent.engine.orchestrator.actions.code_explain_actions import CodeExplainActions
|
||||
from app.modules.agent.engine.orchestrator.execution_context import ExecutionContext
|
||||
from app.modules.agent.engine.orchestrator.models import (
|
||||
ArtifactType,
|
||||
ExecutionPlan,
|
||||
OutputContract,
|
||||
RoutingMeta,
|
||||
Scenario,
|
||||
TaskConstraints,
|
||||
TaskSpec,
|
||||
)
|
||||
from app.modules.rag.explain.models import ExplainIntent, ExplainPack
|
||||
|
||||
|
||||
class _FakeRetriever:
|
||||
def build_pack(self, rag_session_id: str, user_query: str, *, file_candidates: list[dict] | None = None) -> ExplainPack:
|
||||
assert rag_session_id == "rag-1"
|
||||
assert "endpoint" in user_query
|
||||
assert file_candidates == [{"path": "app/api/users.py", "content": "..." }]
|
||||
return ExplainPack(intent=ExplainIntent(raw_query=user_query, normalized_query=user_query))
|
||||
|
||||
|
||||
def _ctx() -> ExecutionContext:
|
||||
task = TaskSpec(
|
||||
task_id="task-1",
|
||||
dialog_session_id="dialog-1",
|
||||
rag_session_id="rag-1",
|
||||
user_message="Explain endpoint get_user",
|
||||
scenario=Scenario.EXPLAIN_PART,
|
||||
routing=RoutingMeta(domain_id="project", process_id="qa", confidence=0.9, reason="test"),
|
||||
constraints=TaskConstraints(),
|
||||
output_contract=OutputContract(result_type="answer"),
|
||||
metadata={"rag_context": "", "confluence_context": "", "files_map": {}},
|
||||
)
|
||||
plan = ExecutionPlan(
|
||||
plan_id="plan-1",
|
||||
task_id="task-1",
|
||||
scenario=Scenario.EXPLAIN_PART,
|
||||
template_id="tpl",
|
||||
template_version="1",
|
||||
steps=[],
|
||||
)
|
||||
ctx = ExecutionContext(task=task, plan=plan, graph_resolver=lambda *_: None, graph_invoker=lambda *_: {})
|
||||
ctx.artifacts.put(
|
||||
key="source_bundle",
|
||||
artifact_type=ArtifactType.STRUCTURED_JSON,
|
||||
content={"file_candidates": [{"path": "app/api/users.py", "content": "..."}]},
|
||||
)
|
||||
return ctx
|
||||
|
||||
|
||||
def test_code_explain_actions_store_explain_pack() -> None:
|
||||
ctx = _ctx()
|
||||
actions = CodeExplainActions(_FakeRetriever())
|
||||
|
||||
actions.build_code_explain_pack(ctx)
|
||||
|
||||
stored = ctx.artifacts.get_content("explain_pack", {})
|
||||
assert stored["intent"]["raw_query"] == "Explain endpoint get_user"
|
||||
@@ -14,7 +14,7 @@ class DummyGraph:
|
||||
pass
|
||||
|
||||
|
||||
def _task(scenario: Scenario) -> TaskSpec:
|
||||
def _task(scenario: Scenario, *, domain_id: str = "project", process_id: str = "qa") -> TaskSpec:
|
||||
allow_writes = scenario in {Scenario.DOCS_FROM_ANALYTICS, Scenario.TARGETED_EDIT, Scenario.GHERKIN_MODEL}
|
||||
return TaskSpec(
|
||||
task_id="task-1",
|
||||
@@ -23,7 +23,7 @@ def _task(scenario: Scenario) -> TaskSpec:
|
||||
mode="auto",
|
||||
user_message="Explain this module",
|
||||
scenario=scenario,
|
||||
routing=RoutingMeta(domain_id="project", process_id="qa", confidence=0.95, reason="unit-test"),
|
||||
routing=RoutingMeta(domain_id=domain_id, process_id=process_id, confidence=0.95, reason="unit-test"),
|
||||
constraints=TaskConstraints(allow_writes=allow_writes, max_steps=16, max_retries_per_step=2, step_timeout_sec=90),
|
||||
output_contract=OutputContract(result_type="answer"),
|
||||
metadata={
|
||||
@@ -38,8 +38,8 @@ def test_orchestrator_service_returns_answer() -> None:
|
||||
service = OrchestratorService()
|
||||
|
||||
def graph_resolver(domain_id: str, process_id: str):
|
||||
assert domain_id == "project"
|
||||
assert process_id == "qa"
|
||||
assert domain_id == "default"
|
||||
assert process_id == "general"
|
||||
return DummyGraph()
|
||||
|
||||
def graph_invoker(_graph, state: dict, dialog_session_id: str):
|
||||
@@ -47,7 +47,13 @@ def test_orchestrator_service_returns_answer() -> None:
|
||||
assert dialog_session_id == "dialog-1"
|
||||
return {"answer": "It works.", "changeset": []}
|
||||
|
||||
result = asyncio.run(service.run(task=_task(Scenario.GENERAL_QA), graph_resolver=graph_resolver, graph_invoker=graph_invoker))
|
||||
result = asyncio.run(
|
||||
service.run(
|
||||
task=_task(Scenario.GENERAL_QA, domain_id="default", process_id="general"),
|
||||
graph_resolver=graph_resolver,
|
||||
graph_invoker=graph_invoker,
|
||||
)
|
||||
)
|
||||
assert result.answer == "It works."
|
||||
assert result.meta["plan"]["status"] == "completed"
|
||||
|
||||
@@ -70,3 +76,100 @@ def test_orchestrator_service_generates_changeset_for_docs_scenario() -> None:
|
||||
)
|
||||
assert result.meta["plan"]["status"] == "completed"
|
||||
assert len(result.changeset) > 0
|
||||
|
||||
|
||||
def test_orchestrator_service_uses_project_qa_reasoning_without_graph() -> None:
|
||||
service = OrchestratorService()
|
||||
requested_graphs: list[tuple[str, str]] = []
|
||||
|
||||
def graph_resolver(domain_id: str, process_id: str):
|
||||
requested_graphs.append((domain_id, process_id))
|
||||
return DummyGraph()
|
||||
|
||||
def graph_invoker(_graph, state: dict, _dialog_session_id: str):
|
||||
if "resolved_request" not in state:
|
||||
return {
|
||||
"resolved_request": {
|
||||
"original_message": state["message"],
|
||||
"normalized_message": state["message"],
|
||||
"subject_hint": "",
|
||||
"source_hint": "code",
|
||||
"russian": True,
|
||||
}
|
||||
}
|
||||
if "question_profile" not in state:
|
||||
return {
|
||||
"question_profile": {
|
||||
"domain": "code",
|
||||
"intent": "inventory",
|
||||
"terms": ["control", "channel"],
|
||||
"entities": [],
|
||||
"russian": True,
|
||||
}
|
||||
}
|
||||
if "source_bundle" not in state:
|
||||
return {
|
||||
"source_bundle": {
|
||||
"profile": state["question_profile"],
|
||||
"rag_items": [],
|
||||
"file_candidates": [
|
||||
{"path": "src/config_manager/v2/control/base.py", "content": "class ControlChannel: pass"},
|
||||
{"path": "src/config_manager/v2/control/http_channel.py", "content": "class HttpControlChannel(ControlChannel): pass # http api"},
|
||||
],
|
||||
"rag_total": 0,
|
||||
"files_total": 2,
|
||||
}
|
||||
}
|
||||
if "analysis_brief" not in state:
|
||||
return {
|
||||
"analysis_brief": {
|
||||
"subject": "management channels",
|
||||
"findings": ["В коде найдены конкретные реализации каналов управления: http channel (`src/config_manager/v2/control/http_channel.py`)."],
|
||||
"evidence": ["src/config_manager/v2/control/http_channel.py"],
|
||||
"gaps": [],
|
||||
"answer_mode": "inventory",
|
||||
}
|
||||
}
|
||||
return {
|
||||
"answer_brief": {
|
||||
"question_profile": state["question_profile"],
|
||||
"resolved_subject": "management channels",
|
||||
"key_findings": ["В коде найдены конкретные реализации каналов управления: http channel (`src/config_manager/v2/control/http_channel.py`)."],
|
||||
"supporting_evidence": ["src/config_manager/v2/control/http_channel.py"],
|
||||
"missing_evidence": [],
|
||||
"answer_mode": "inventory",
|
||||
},
|
||||
"final_answer": "## Кратко\n### Что реализовано\n- В коде найдены конкретные реализации каналов управления: http channel (`src/config_manager/v2/control/http_channel.py`).",
|
||||
}
|
||||
|
||||
task = _task(Scenario.GENERAL_QA).model_copy(
|
||||
update={
|
||||
"user_message": "Какие каналы управления уже реализованы?",
|
||||
"metadata": {
|
||||
"rag_context": "",
|
||||
"confluence_context": "",
|
||||
"files_map": {
|
||||
"src/config_manager/v2/control/base.py": {
|
||||
"content": "class ControlChannel:\n async def start(self):\n ..."
|
||||
},
|
||||
"src/config_manager/v2/control/http_channel.py": {
|
||||
"content": "class HttpControlChannel(ControlChannel):\n async def start(self):\n ...\n# http api"
|
||||
},
|
||||
},
|
||||
"rag_items": [],
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
result = asyncio.run(service.run(task=task, graph_resolver=graph_resolver, graph_invoker=graph_invoker))
|
||||
|
||||
assert "Что реализовано" in result.answer
|
||||
assert "http channel" in result.answer.lower()
|
||||
assert result.meta["plan"]["status"] == "completed"
|
||||
assert requested_graphs == [
|
||||
("project_qa", "conversation_understanding"),
|
||||
("project_qa", "question_classification"),
|
||||
("project_qa", "context_retrieval"),
|
||||
("project_qa", "context_analysis"),
|
||||
("project_qa", "answer_composition"),
|
||||
]
|
||||
|
||||
71
tests/agent/orchestrator/test_project_qa_actions.py
Normal file
71
tests/agent/orchestrator/test_project_qa_actions.py
Normal file
@@ -0,0 +1,71 @@
|
||||
from app.modules.agent.engine.orchestrator.actions.project_qa_actions import ProjectQaActions
|
||||
from app.modules.agent.engine.orchestrator.execution_context import ExecutionContext
|
||||
from app.modules.agent.engine.orchestrator.models import (
|
||||
ExecutionPlan,
|
||||
OutputContract,
|
||||
RoutingMeta,
|
||||
Scenario,
|
||||
TaskConstraints,
|
||||
TaskSpec,
|
||||
)
|
||||
|
||||
|
||||
def _ctx(message: str, rag_items: list[dict], files_map: dict[str, dict]) -> ExecutionContext:
|
||||
task = TaskSpec(
|
||||
task_id="task-1",
|
||||
dialog_session_id="dialog-1",
|
||||
rag_session_id="rag-1",
|
||||
user_message=message,
|
||||
scenario=Scenario.GENERAL_QA,
|
||||
routing=RoutingMeta(domain_id="project", process_id="qa", confidence=0.9, reason="test"),
|
||||
constraints=TaskConstraints(),
|
||||
output_contract=OutputContract(result_type="answer"),
|
||||
metadata={
|
||||
"rag_items": rag_items,
|
||||
"rag_context": "",
|
||||
"confluence_context": "",
|
||||
"files_map": files_map,
|
||||
},
|
||||
)
|
||||
plan = ExecutionPlan(
|
||||
plan_id="plan-1",
|
||||
task_id="task-1",
|
||||
scenario=Scenario.GENERAL_QA,
|
||||
template_id="tpl",
|
||||
template_version="1",
|
||||
steps=[],
|
||||
)
|
||||
return ExecutionContext(task=task, plan=plan, graph_resolver=lambda *_: None, graph_invoker=lambda *_: {})
|
||||
|
||||
|
||||
def test_project_qa_actions_build_inventory_answer_from_code_sources() -> None:
|
||||
ctx = _ctx(
|
||||
"Какие каналы управления уже реализованы?",
|
||||
[],
|
||||
{
|
||||
"src/config_manager/v2/control/base.py": {"content": "class ControlChannel:\n async def start(self):\n ..."},
|
||||
"src/config_manager/v2/core/control_bridge.py": {
|
||||
"content": "class ControlChannelBridge:\n async def on_start(self):\n ...\n async def on_status(self):\n ..."
|
||||
},
|
||||
"src/config_manager/v2/control/http_channel.py": {
|
||||
"content": "class HttpControlChannel(ControlChannel):\n async def start(self):\n ...\n# http api"
|
||||
},
|
||||
"src/config_manager/v2/control/telegram_channel.py": {
|
||||
"content": "class TelegramControlChannel(ControlChannel):\n async def start(self):\n ...\n# telegram bot"
|
||||
},
|
||||
},
|
||||
)
|
||||
actions = ProjectQaActions()
|
||||
|
||||
actions.classify_project_question(ctx)
|
||||
actions.collect_project_sources(ctx)
|
||||
actions.analyze_project_sources(ctx)
|
||||
actions.build_project_answer_brief(ctx)
|
||||
actions.compose_project_answer(ctx)
|
||||
|
||||
answer = str(ctx.artifacts.get_content("final_answer", ""))
|
||||
assert "### Что реализовано" in answer
|
||||
assert "http channel" in answer.lower()
|
||||
assert "telegram channel" in answer.lower()
|
||||
assert "### Где смотреть в проекте" in answer
|
||||
|
||||
74
tests/agent/orchestrator/test_project_qa_answer_graph_v2.py
Normal file
74
tests/agent/orchestrator/test_project_qa_answer_graph_v2.py
Normal file
@@ -0,0 +1,74 @@
|
||||
import sys
|
||||
import types
|
||||
|
||||
langgraph = types.ModuleType("langgraph")
|
||||
langgraph_graph = types.ModuleType("langgraph.graph")
|
||||
langgraph_graph.END = "END"
|
||||
langgraph_graph.START = "START"
|
||||
langgraph_graph.StateGraph = object
|
||||
sys.modules.setdefault("langgraph", langgraph)
|
||||
sys.modules.setdefault("langgraph.graph", langgraph_graph)
|
||||
|
||||
from app.modules.agent.engine.graphs.project_qa_step_graphs import ProjectQaAnswerGraphFactory
|
||||
|
||||
|
||||
class _FakeLlm:
|
||||
def __init__(self) -> None:
|
||||
self.calls: list[tuple[str, str, str | None]] = []
|
||||
|
||||
def generate(self, prompt_name: str, user_input: str, *, log_context: str | None = None) -> str:
|
||||
self.calls.append((prompt_name, user_input, log_context))
|
||||
return "## Summary\n[entrypoint_1] [excerpt_1]"
|
||||
|
||||
|
||||
def test_project_qa_answer_graph_uses_v2_prompt_when_explain_pack_present() -> None:
|
||||
llm = _FakeLlm()
|
||||
factory = ProjectQaAnswerGraphFactory(llm)
|
||||
|
||||
result = factory._compose_answer(
|
||||
{
|
||||
"message": "Explain endpoint get_user",
|
||||
"question_profile": {"russian": False},
|
||||
"analysis_brief": {"findings": [], "evidence": [], "gaps": [], "answer_mode": "summary"},
|
||||
"explain_pack": {
|
||||
"intent": {
|
||||
"raw_query": "Explain endpoint get_user",
|
||||
"normalized_query": "Explain endpoint get_user",
|
||||
"keywords": ["get_user"],
|
||||
"hints": {"paths": [], "symbols": [], "endpoints": [], "commands": []},
|
||||
"expected_entry_types": ["http"],
|
||||
"depth": "medium",
|
||||
},
|
||||
"selected_entrypoints": [],
|
||||
"seed_symbols": [],
|
||||
"trace_paths": [],
|
||||
"evidence_index": {
|
||||
"entrypoint_1": {
|
||||
"evidence_id": "entrypoint_1",
|
||||
"kind": "entrypoint",
|
||||
"summary": "/users/{id}",
|
||||
"location": {"path": "app/api/users.py", "start_line": 10, "end_line": 10},
|
||||
"supports": ["handler-1"],
|
||||
}
|
||||
},
|
||||
"code_excerpts": [
|
||||
{
|
||||
"evidence_id": "excerpt_1",
|
||||
"symbol_id": "handler-1",
|
||||
"title": "get_user",
|
||||
"path": "app/api/users.py",
|
||||
"start_line": 10,
|
||||
"end_line": 18,
|
||||
"content": "async def get_user():\n return 1",
|
||||
"focus": "overview",
|
||||
}
|
||||
],
|
||||
"missing": [],
|
||||
"conflicts": [],
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
assert result["final_answer"].startswith("## Summary")
|
||||
assert llm.calls[0][0] == "code_explain_answer_v2"
|
||||
assert '"evidence_id": "excerpt_1"' in llm.calls[0][1]
|
||||
@@ -0,0 +1,49 @@
|
||||
import sys
|
||||
import types
|
||||
|
||||
langgraph = types.ModuleType("langgraph")
|
||||
langgraph_graph = types.ModuleType("langgraph.graph")
|
||||
langgraph_graph.END = "END"
|
||||
langgraph_graph.START = "START"
|
||||
langgraph_graph.StateGraph = object
|
||||
sys.modules.setdefault("langgraph", langgraph)
|
||||
sys.modules.setdefault("langgraph.graph", langgraph_graph)
|
||||
|
||||
from app.modules.agent.engine.graphs.project_qa_step_graphs import ProjectQaRetrievalGraphFactory
|
||||
|
||||
|
||||
class _FailingRag:
|
||||
async def retrieve(self, rag_session_id: str, query: str):
|
||||
raise AssertionError("legacy rag should not be called for explain_part")
|
||||
|
||||
|
||||
def test_project_qa_retrieval_skips_legacy_rag_for_explain_part() -> None:
|
||||
factory = ProjectQaRetrievalGraphFactory(_FailingRag())
|
||||
|
||||
result = factory._retrieve_context(
|
||||
{
|
||||
"scenario": "explain_part",
|
||||
"project_id": "rag-1",
|
||||
"resolved_request": {
|
||||
"original_message": "Explain how ConfigManager works",
|
||||
"normalized_message": "Explain how ConfigManager works",
|
||||
},
|
||||
"question_profile": {
|
||||
"domain": "code",
|
||||
"intent": "explain",
|
||||
"terms": ["configmanager"],
|
||||
"entities": ["ConfigManager"],
|
||||
"russian": False,
|
||||
},
|
||||
"files_map": {
|
||||
"src/config_manager/__init__.py": {
|
||||
"content": "from .v2 import ConfigManagerV2 as ConfigManager",
|
||||
"content_hash": "hash-1",
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
bundle = result["source_bundle"]
|
||||
assert bundle["rag_items"] == []
|
||||
assert bundle["files_total"] >= 1
|
||||
@@ -36,3 +36,13 @@ def test_template_registry_has_multi_step_review_docs_edit_gherkin() -> None:
|
||||
assert len(docs_steps) >= 9
|
||||
assert len(edit_steps) >= 7
|
||||
assert len(gherkin_steps) >= 8
|
||||
|
||||
|
||||
def test_template_registry_adds_code_explain_pack_step_for_project_explain() -> None:
|
||||
registry = ScenarioTemplateRegistry()
|
||||
|
||||
steps = [step.step_id for step in registry.build(_task(Scenario.EXPLAIN_PART)).steps]
|
||||
|
||||
assert "code_explain_pack_step" in steps
|
||||
assert steps.index("code_explain_pack_step") > steps.index("context_retrieval")
|
||||
assert steps.index("code_explain_pack_step") < steps.index("context_analysis")
|
||||
|
||||
48
tests/agent/test_gigachat_client_retry.py
Normal file
48
tests/agent/test_gigachat_client_retry.py
Normal file
@@ -0,0 +1,48 @@
|
||||
import requests
|
||||
|
||||
from app.modules.shared.gigachat.client import GigaChatClient
|
||||
from app.modules.shared.gigachat.settings import GigaChatSettings
|
||||
|
||||
|
||||
class _FakeTokenProvider:
|
||||
def get_access_token(self) -> str:
|
||||
return "token"
|
||||
|
||||
|
||||
class _FakeResponse:
|
||||
def __init__(self, status_code: int, payload: dict, text: str = "") -> None:
|
||||
self.status_code = status_code
|
||||
self._payload = payload
|
||||
self.text = text
|
||||
|
||||
def json(self) -> dict:
|
||||
return self._payload
|
||||
|
||||
|
||||
def test_gigachat_client_retries_transient_http_errors(monkeypatch) -> None:
|
||||
calls = {"count": 0}
|
||||
|
||||
def fake_post(*args, **kwargs):
|
||||
calls["count"] += 1
|
||||
if calls["count"] == 1:
|
||||
return _FakeResponse(503, {}, "temporary")
|
||||
return _FakeResponse(200, {"choices": [{"message": {"content": "ok"}}]})
|
||||
|
||||
monkeypatch.setattr(requests, "post", fake_post)
|
||||
client = GigaChatClient(
|
||||
GigaChatSettings(
|
||||
auth_url="https://auth.example.test",
|
||||
api_url="https://api.example.test",
|
||||
scope="scope",
|
||||
credentials="secret",
|
||||
ssl_verify=True,
|
||||
model="model",
|
||||
embedding_model="embed",
|
||||
),
|
||||
_FakeTokenProvider(),
|
||||
)
|
||||
|
||||
result = client.complete("system", "user")
|
||||
|
||||
assert result == "ok"
|
||||
assert calls["count"] == 2
|
||||
30
tests/agent/test_llm_service_logging.py
Normal file
30
tests/agent/test_llm_service_logging.py
Normal file
@@ -0,0 +1,30 @@
|
||||
import logging
|
||||
|
||||
from app.modules.agent.llm.service import AgentLlmService
|
||||
|
||||
|
||||
class _FakeClient:
|
||||
def complete(self, *, system_prompt: str, user_prompt: str) -> str:
|
||||
assert system_prompt == "System prompt"
|
||||
assert user_prompt == "User input"
|
||||
return "LLM output"
|
||||
|
||||
|
||||
class _FakePrompts:
|
||||
def load(self, prompt_name: str) -> str:
|
||||
assert prompt_name == "general_answer"
|
||||
return "System prompt"
|
||||
|
||||
|
||||
def test_llm_service_logs_input_and_output_for_graph_context(caplog) -> None:
|
||||
service = AgentLlmService(_FakeClient(), _FakePrompts())
|
||||
|
||||
with caplog.at_level(logging.WARNING, logger="app.modules.agent.llm.service"):
|
||||
result = service.generate("general_answer", "User input", log_context="graph.default.answer")
|
||||
|
||||
assert result == "LLM output"
|
||||
messages = [record.getMessage() for record in caplog.records]
|
||||
assert any("graph llm input: context=graph.default.answer" in message for message in messages)
|
||||
assert any("graph llm output: context=graph.default.answer" in message for message in messages)
|
||||
assert any("User input" in message for message in messages)
|
||||
assert any("LLM output" in message for message in messages)
|
||||
24
tests/agent/test_logging_setup.py
Normal file
24
tests/agent/test_logging_setup.py
Normal file
@@ -0,0 +1,24 @@
|
||||
import logging
|
||||
|
||||
from app.core.logging_setup import ScrubbingFormatter
|
||||
|
||||
|
||||
def test_scrubbing_formatter_redacts_identifiers_and_adds_blank_line() -> None:
|
||||
formatter = ScrubbingFormatter("%(levelname)s:%(name)s:%(message)s")
|
||||
record = logging.LogRecord(
|
||||
name="test.logger",
|
||||
level=logging.WARNING,
|
||||
pathname=__file__,
|
||||
lineno=1,
|
||||
msg="router decision: task_id=task-1 dialog_session_id=dialog-1 graph_id=project_qa/context_retrieval",
|
||||
args=(),
|
||||
exc_info=None,
|
||||
)
|
||||
|
||||
rendered = formatter.format(record)
|
||||
|
||||
assert "task_id=<redacted>" in rendered
|
||||
assert "dialog_session_id=<redacted>" in rendered
|
||||
assert "graph_id=<redacted>" in rendered
|
||||
assert "task-1" not in rendered
|
||||
assert rendered.endswith("\n")
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
tests/chat/__pycache__/test_direct_service.cpython-312.pyc
Normal file
BIN
tests/chat/__pycache__/test_direct_service.cpython-312.pyc
Normal file
Binary file not shown.
70
tests/chat/test_chat_api_simple_code_explain.py
Normal file
70
tests/chat/test_chat_api_simple_code_explain.py
Normal file
@@ -0,0 +1,70 @@
|
||||
import asyncio
|
||||
|
||||
from app.modules.chat.module import ChatModule
|
||||
from app.modules.chat.task_store import TaskStore
|
||||
from app.schemas.chat import ChatMessageRequest
|
||||
from app.schemas.chat import TaskQueuedResponse
|
||||
from app.modules.shared.event_bus import EventBus
|
||||
from app.modules.shared.retry_executor import RetryExecutor
|
||||
|
||||
|
||||
class _FakeRuntime:
|
||||
async def run(self, **kwargs):
|
||||
raise AssertionError("legacy runtime must not be called")
|
||||
|
||||
|
||||
class _FakeDirectChat:
|
||||
def __init__(self) -> None:
|
||||
self.calls = 0
|
||||
|
||||
async def handle_message(self, request):
|
||||
self.calls += 1
|
||||
return TaskQueuedResponse(
|
||||
task_id="task-1",
|
||||
status="done",
|
||||
)
|
||||
|
||||
|
||||
class _FakeRagSessions:
|
||||
def get(self, rag_session_id: str):
|
||||
return {"rag_session_id": rag_session_id}
|
||||
|
||||
|
||||
class _FakeRepository:
|
||||
def create_dialog(self, dialog_session_id: str, rag_session_id: str) -> None:
|
||||
return None
|
||||
|
||||
def get_dialog(self, dialog_session_id: str):
|
||||
return None
|
||||
|
||||
def add_message(self, dialog_session_id: str, role: str, content: str, task_id: str | None = None, payload: dict | None = None) -> None:
|
||||
return None
|
||||
|
||||
|
||||
def test_chat_messages_endpoint_uses_direct_service(monkeypatch) -> None:
|
||||
monkeypatch.setenv("SIMPLE_CODE_EXPLAIN_ONLY", "true")
|
||||
direct_chat = _FakeDirectChat()
|
||||
module = ChatModule(
|
||||
agent_runner=_FakeRuntime(),
|
||||
event_bus=EventBus(),
|
||||
retry=RetryExecutor(),
|
||||
rag_sessions=_FakeRagSessions(),
|
||||
repository=_FakeRepository(),
|
||||
direct_chat=direct_chat,
|
||||
task_store=TaskStore(),
|
||||
)
|
||||
router = module.public_router()
|
||||
endpoint = next(route.endpoint for route in router.routes if getattr(route, "path", "") == "/api/chat/messages")
|
||||
response = asyncio.run(
|
||||
endpoint(
|
||||
ChatMessageRequest(
|
||||
session_id="dialog-1",
|
||||
project_id="rag-1",
|
||||
message="Explain get_user",
|
||||
),
|
||||
None,
|
||||
)
|
||||
)
|
||||
|
||||
assert response.task_id == "task-1"
|
||||
assert direct_chat.calls == 1
|
||||
61
tests/chat/test_direct_service.py
Normal file
61
tests/chat/test_direct_service.py
Normal file
@@ -0,0 +1,61 @@
|
||||
import asyncio
|
||||
|
||||
from app.modules.chat.direct_service import CodeExplainChatService
|
||||
from app.modules.chat.session_resolver import ChatSessionResolver
|
||||
from app.modules.chat.task_store import TaskStore
|
||||
from app.modules.rag.explain.models import ExplainIntent, ExplainPack
|
||||
from app.schemas.chat import ChatFileContext, ChatMessageRequest
|
||||
|
||||
|
||||
class _FakeRetriever:
|
||||
def build_pack(self, rag_session_id: str, user_query: str, *, file_candidates: list[dict] | None = None) -> ExplainPack:
|
||||
return ExplainPack(
|
||||
intent=ExplainIntent(raw_query=user_query, normalized_query=user_query),
|
||||
missing=["code_excerpts"],
|
||||
)
|
||||
|
||||
|
||||
class _FakeLlm:
|
||||
def __init__(self) -> None:
|
||||
self.calls = 0
|
||||
|
||||
def generate(self, prompt_name: str, user_input: str, *, log_context: str | None = None) -> str:
|
||||
self.calls += 1
|
||||
return "should not be called"
|
||||
|
||||
|
||||
class _FakeDialogs:
|
||||
def get(self, dialog_session_id: str):
|
||||
return None
|
||||
|
||||
|
||||
def test_direct_service_skips_llm_when_evidence_is_insufficient() -> None:
|
||||
messages: list[tuple[str, str, str, str | None]] = []
|
||||
llm = _FakeLlm()
|
||||
task_store = TaskStore()
|
||||
service = CodeExplainChatService(
|
||||
retriever=_FakeRetriever(),
|
||||
llm=llm,
|
||||
session_resolver=ChatSessionResolver(_FakeDialogs(), lambda rag_session_id: rag_session_id == "rag-1"),
|
||||
task_store=task_store,
|
||||
message_sink=lambda dialog_session_id, role, content, task_id=None: messages.append((dialog_session_id, role, content, task_id)),
|
||||
)
|
||||
|
||||
result = asyncio.run(
|
||||
service.handle_message(
|
||||
ChatMessageRequest(
|
||||
session_id="dialog-1",
|
||||
project_id="rag-1",
|
||||
message="Explain get_user",
|
||||
files=[ChatFileContext(path="app/api/users.py", content="", content_hash="x")],
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
task = task_store.get(result.task_id)
|
||||
assert task is not None
|
||||
assert task.answer is not None
|
||||
assert "Недостаточно опоры в коде" in task.answer
|
||||
assert result.status == "done"
|
||||
assert llm.calls == 0
|
||||
assert [item[1] for item in messages] == ["user", "assistant"]
|
||||
BIN
tests/rag/__pycache__/asserts_intent_router.cpython-312.pyc
Normal file
BIN
tests/rag/__pycache__/asserts_intent_router.cpython-312.pyc
Normal file
Binary file not shown.
BIN
tests/rag/__pycache__/intent_router_testkit.cpython-312.pyc
Normal file
BIN
tests/rag/__pycache__/intent_router_testkit.cpython-312.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
tests/rag/__pycache__/test_intent_router_v2.cpython-312.pyc
Normal file
BIN
tests/rag/__pycache__/test_intent_router_v2.cpython-312.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
tests/rag/__pycache__/test_query_normalization.cpython-312.pyc
Normal file
BIN
tests/rag/__pycache__/test_query_normalization.cpython-312.pyc
Normal file
Binary file not shown.
BIN
tests/rag/__pycache__/test_query_terms.cpython-312.pyc
Normal file
BIN
tests/rag/__pycache__/test_query_terms.cpython-312.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
tests/rag/__pycache__/test_retriever_v2_pack.cpython-312.pyc
Normal file
BIN
tests/rag/__pycache__/test_retriever_v2_pack.cpython-312.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
tests/rag/__pycache__/test_trace_builder.cpython-312.pyc
Normal file
BIN
tests/rag/__pycache__/test_trace_builder.cpython-312.pyc
Normal file
Binary file not shown.
77
tests/rag/asserts_intent_router.py
Normal file
77
tests/rag/asserts_intent_router.py
Normal file
@@ -0,0 +1,77 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
|
||||
from app.modules.rag.intent_router_v2.models import IntentRouterResult
|
||||
|
||||
|
||||
def assert_intent(out: IntentRouterResult, expected: str) -> None:
|
||||
assert out.intent == expected
|
||||
|
||||
|
||||
def assert_domains(out: IntentRouterResult, expected: list[str]) -> None:
|
||||
assert out.retrieval_spec.domains == expected
|
||||
|
||||
|
||||
def assert_has_file_path(out: IntentRouterResult, path: str) -> None:
|
||||
assert any(anchor.type == "FILE_PATH" and anchor.value == path for anchor in out.query_plan.anchors)
|
||||
|
||||
|
||||
def assert_path_scope(out: IntentRouterResult, file_path: str, dir_path: str | None = None) -> None:
|
||||
scope = list(getattr(out.retrieval_spec.filters, "path_scope", []) or [])
|
||||
assert file_path in scope
|
||||
if dir_path is not None:
|
||||
assert dir_path in scope
|
||||
|
||||
|
||||
def assert_file_only_scope(out: IntentRouterResult, file_path: str) -> None:
|
||||
scope = list(getattr(out.retrieval_spec.filters, "path_scope", []) or [])
|
||||
assert scope == [file_path]
|
||||
|
||||
|
||||
def assert_spans_valid(out: IntentRouterResult) -> None:
|
||||
raw_len = len(out.query_plan.raw)
|
||||
for anchor in out.query_plan.anchors:
|
||||
if anchor.source == "conversation_state":
|
||||
assert anchor.span is None
|
||||
continue
|
||||
assert anchor.span is not None
|
||||
assert 0 <= anchor.span.start < anchor.span.end <= raw_len
|
||||
|
||||
|
||||
def assert_test_policy(out: IntentRouterResult, expected: str) -> None:
|
||||
assert getattr(out.retrieval_spec.filters, "test_policy", None) == expected
|
||||
|
||||
|
||||
def assert_sub_intent(out: IntentRouterResult, expected: str) -> None:
|
||||
assert out.query_plan.sub_intent == expected
|
||||
|
||||
|
||||
def assert_no_symbol_keyword(out: IntentRouterResult, forbidden: set[str] | None = None) -> None:
|
||||
denied = forbidden or {"def", "class", "return", "import", "from"}
|
||||
symbols = {anchor.value.lower() for anchor in out.query_plan.anchors if anchor.type == "SYMBOL"}
|
||||
assert symbols.isdisjoint({token.lower() for token in denied})
|
||||
|
||||
|
||||
def assert_domain_layer_prefixes(out: IntentRouterResult) -> None:
|
||||
prefixes = {layer.layer_id[0] for layer in out.retrieval_spec.layer_queries if layer.layer_id}
|
||||
if out.retrieval_spec.domains == ["CODE"]:
|
||||
assert prefixes <= {"C"}
|
||||
elif out.retrieval_spec.domains == ["DOCS"]:
|
||||
assert prefixes <= {"D"}
|
||||
else:
|
||||
assert prefixes <= {"C", "D"}
|
||||
|
||||
|
||||
def assert_no_symbol_leakage_from_paths(out: IntentRouterResult) -> None:
|
||||
file_values = [anchor.value for anchor in out.query_plan.anchors if anchor.type == "FILE_PATH"]
|
||||
if not file_values:
|
||||
return
|
||||
parts: set[str] = set()
|
||||
for value in file_values:
|
||||
for token in re.split(r"[/.]+", value.lower()):
|
||||
if token:
|
||||
parts.add(token)
|
||||
for anchor in out.query_plan.anchors:
|
||||
if anchor.type == "SYMBOL":
|
||||
assert anchor.value.lower() not in parts
|
||||
45
tests/rag/intent_router_testkit.py
Normal file
45
tests/rag/intent_router_testkit.py
Normal file
@@ -0,0 +1,45 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from app.modules.rag.contracts.enums import RagLayer
|
||||
from app.modules.rag.intent_router_v2 import ConversationState, IntentRouterV2, RepoContext
|
||||
|
||||
|
||||
def repo_context() -> RepoContext:
|
||||
return RepoContext(
|
||||
languages=["python"],
|
||||
available_domains=["CODE", "DOCS"],
|
||||
available_layers=[
|
||||
RagLayer.CODE_ENTRYPOINTS,
|
||||
RagLayer.CODE_SYMBOL_CATALOG,
|
||||
RagLayer.CODE_DEPENDENCY_GRAPH,
|
||||
RagLayer.CODE_SOURCE_CHUNKS,
|
||||
RagLayer.DOCS_MODULE_CATALOG,
|
||||
RagLayer.DOCS_FACT_INDEX,
|
||||
RagLayer.DOCS_SECTION_INDEX,
|
||||
RagLayer.DOCS_POLICY_INDEX,
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def run_sequence(queries: list[str], *, router: IntentRouterV2 | None = None, trace_label: str = "intent-router") -> list:
|
||||
active_router = router or IntentRouterV2()
|
||||
state = ConversationState()
|
||||
results = []
|
||||
for index, query in enumerate(queries, start=1):
|
||||
result = active_router.route(query, state, repo_context())
|
||||
print_trace(index, query, result, label=trace_label)
|
||||
results.append(result)
|
||||
state = state.advance(result)
|
||||
return results
|
||||
|
||||
|
||||
def run_single(query: str, *, router: IntentRouterV2 | None = None, trace_label: str = "intent-router"):
|
||||
result = run_sequence([query], router=router, trace_label=trace_label)[0]
|
||||
return result
|
||||
|
||||
|
||||
def print_trace(index: int, query: str, result, *, label: str = "intent-router") -> None:
|
||||
print(f"[{label}][turn {index}] input: {query}")
|
||||
print()
|
||||
print(f"[{label}][turn {index}] output: {result.model_dump_json(ensure_ascii=False)}")
|
||||
print("=" * 50)
|
||||
@@ -55,3 +55,21 @@ def test_code_pipeline_indexes_import_alias_as_symbol() -> None:
|
||||
alias_doc = next(doc for doc in docs if doc.layer == RagLayer.CODE_SYMBOL_CATALOG and doc.metadata["qname"] == "ConfigManager")
|
||||
assert alias_doc.metadata["kind"] == "const"
|
||||
assert alias_doc.metadata["lang_payload"]["import_alias"] is True
|
||||
|
||||
|
||||
def test_code_pipeline_marks_test_documents() -> None:
|
||||
pipeline = CodeIndexingPipeline()
|
||||
content = """
|
||||
def test_user_service():
|
||||
assert True
|
||||
"""
|
||||
|
||||
docs = pipeline.index_file(
|
||||
repo_id="acme/proj",
|
||||
commit_sha="abc123",
|
||||
path="tests/test_users.py",
|
||||
content=content,
|
||||
)
|
||||
|
||||
assert docs
|
||||
assert all(doc.metadata["is_test"] is True for doc in docs)
|
||||
|
||||
22
tests/rag/test_explain_intent_builder.py
Normal file
22
tests/rag/test_explain_intent_builder.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from app.modules.rag.explain.intent_builder import ExplainIntentBuilder
|
||||
|
||||
|
||||
def test_explain_intent_builder_extracts_route_symbol_and_file_hints() -> None:
|
||||
builder = ExplainIntentBuilder()
|
||||
|
||||
intent = builder.build("Explain how /users/{user_id} reaches UserService.get_user in app/api/users.py")
|
||||
|
||||
assert "/users/{user_id}" in intent.hints.endpoints
|
||||
assert "UserService.get_user" in intent.hints.symbols
|
||||
assert "app/api/users.py" in intent.hints.paths
|
||||
assert intent.expected_entry_types == ["http"]
|
||||
assert intent.include_tests is False
|
||||
assert intent.depth == "medium"
|
||||
|
||||
|
||||
def test_explain_intent_builder_enables_tests_when_user_asks_for_them() -> None:
|
||||
builder = ExplainIntentBuilder()
|
||||
|
||||
intent = builder.build("Покажи как это тестируется в pytest и какие tests покрывают UserService")
|
||||
|
||||
assert intent.include_tests is True
|
||||
126
tests/rag/test_intent_router_e2e_flows.py
Normal file
126
tests/rag/test_intent_router_e2e_flows.py
Normal file
@@ -0,0 +1,126 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from app.modules.rag.intent_router_v2 import GigaChatIntentRouterFactory
|
||||
from app.modules.shared.env_loader import load_workspace_env
|
||||
from tests.rag.asserts_intent_router import (
|
||||
assert_domains,
|
||||
assert_file_only_scope,
|
||||
assert_intent,
|
||||
assert_test_policy,
|
||||
)
|
||||
from tests.rag.intent_router_testkit import run_sequence
|
||||
|
||||
pytestmark = pytest.mark.intent_router
|
||||
|
||||
|
||||
def _live_gigachat_enabled() -> bool:
|
||||
load_workspace_env()
|
||||
return os.getenv("RUN_INTENT_ROUTER_V2_LIVE", "").strip() == "1" and bool(os.getenv("GIGACHAT_TOKEN", "").strip())
|
||||
|
||||
|
||||
def test_e2e_path_carryover_flow() -> None:
|
||||
first, second, third = run_sequence(
|
||||
[
|
||||
"Посмотри файл app/core/config.py",
|
||||
"Теперь объясни функцию load_config",
|
||||
"Почему так?",
|
||||
]
|
||||
)
|
||||
|
||||
assert_file_only_scope(first, "app/core/config.py")
|
||||
assert "app/core/config.py" in second.retrieval_spec.filters.path_scope
|
||||
assert "app/core/config.py" in third.retrieval_spec.filters.path_scope
|
||||
second_file_anchors = [anchor.value for anchor in second.query_plan.anchors if anchor.type == "FILE_PATH" and anchor.source == "conversation_state"]
|
||||
assert second_file_anchors == ["app/core/config.py"]
|
||||
assert "app/core/config.py" in second.query_plan.keyword_hints
|
||||
assert "app/core" not in second.query_plan.keyword_hints
|
||||
assert any(anchor.type == "FILE_PATH" and anchor.source == "conversation_state" and anchor.span is None for anchor in third.query_plan.anchors)
|
||||
carried_symbols = [anchor.value for anchor in third.query_plan.anchors if anchor.type == "SYMBOL" and anchor.source == "conversation_state"]
|
||||
assert carried_symbols in ([], ["load_config"])
|
||||
assert third.query_plan.sub_intent == "EXPLAIN_LOCAL"
|
||||
layer_ids = [item.layer_id for item in third.retrieval_spec.layer_queries]
|
||||
assert "C3_ENTRYPOINTS" not in layer_ids
|
||||
|
||||
|
||||
def test_e2e_docs_switch_from_code_topic() -> None:
|
||||
first, second = run_sequence(
|
||||
[
|
||||
"Объясни как работает ConfigManager",
|
||||
"А что про это сказано в документации?",
|
||||
]
|
||||
)
|
||||
|
||||
assert_intent(first, "CODE_QA")
|
||||
assert_intent(second, "DOCS_QA")
|
||||
assert second.conversation_mode == "SWITCH"
|
||||
assert_domains(second, ["DOCS"])
|
||||
carried = [
|
||||
anchor
|
||||
for anchor in second.query_plan.anchors
|
||||
if anchor.type == "SYMBOL" and anchor.value == "ConfigManager" and anchor.source == "conversation_state"
|
||||
]
|
||||
assert carried
|
||||
assert carried[0].span is None
|
||||
assert "ConfigManager" in second.query_plan.expansions
|
||||
assert "ConfigManager" in second.query_plan.keyword_hints
|
||||
|
||||
|
||||
def test_e2e_tests_toggle_flow() -> None:
|
||||
first, second = run_sequence(
|
||||
[
|
||||
"Покажи тесты для ConfigManager",
|
||||
"А теперь не про тесты, а про прод код",
|
||||
]
|
||||
)
|
||||
|
||||
assert_intent(first, "CODE_QA")
|
||||
assert_intent(second, "CODE_QA")
|
||||
assert_test_policy(first, "INCLUDE")
|
||||
assert_test_policy(second, "EXCLUDE")
|
||||
assert first.query_plan.sub_intent == "FIND_TESTS"
|
||||
assert second.query_plan.sub_intent == "EXPLAIN"
|
||||
assert "tests" in second.query_plan.negations
|
||||
assert not second.query_plan.expansions
|
||||
assert second.evidence_policy.require_flow is False
|
||||
|
||||
|
||||
def test_e2e_open_file_then_generic_next_steps_is_lightweight() -> None:
|
||||
first, second = run_sequence(
|
||||
[
|
||||
"Открой файл app/core/config.py",
|
||||
"Что дальше?",
|
||||
]
|
||||
)
|
||||
|
||||
assert_file_only_scope(first, "app/core/config.py")
|
||||
assert_file_only_scope(second, "app/core/config.py")
|
||||
assert second.query_plan.sub_intent in {"EXPLAIN_LOCAL", "NEXT_STEPS"}
|
||||
layer_ids = [item.layer_id for item in second.retrieval_spec.layer_queries]
|
||||
assert "C3_ENTRYPOINTS" not in layer_ids
|
||||
assert second.evidence_policy.require_flow is False
|
||||
assert "app/core/config.py" in second.query_plan.keyword_hints
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not _live_gigachat_enabled(),
|
||||
reason="requires RUN_INTENT_ROUTER_V2_LIVE=1 and GIGACHAT_TOKEN in environment or .env",
|
||||
)
|
||||
def test_intent_router_live_smoke_path_carryover() -> None:
|
||||
router = GigaChatIntentRouterFactory().build()
|
||||
first, second = run_sequence(
|
||||
[
|
||||
"Открой файл app/core/config.py",
|
||||
"Что дальше?",
|
||||
],
|
||||
router=router,
|
||||
trace_label="intent-router-live",
|
||||
)
|
||||
|
||||
assert_file_only_scope(first, "app/core/config.py")
|
||||
assert "app/core/config.py" in second.retrieval_spec.filters.path_scope
|
||||
assert second.query_plan.sub_intent in {"EXPLAIN_LOCAL", "NEXT_STEPS"}
|
||||
layer_ids = [item.layer_id for item in second.retrieval_spec.layer_queries]
|
||||
assert "C3_ENTRYPOINTS" not in layer_ids
|
||||
assert second.evidence_policy.require_flow is False
|
||||
120
tests/rag/test_intent_router_invariants.py
Normal file
120
tests/rag/test_intent_router_invariants.py
Normal file
@@ -0,0 +1,120 @@
|
||||
import pytest
|
||||
|
||||
from tests.rag.asserts_intent_router import (
|
||||
assert_domain_layer_prefixes,
|
||||
assert_domains,
|
||||
assert_file_only_scope,
|
||||
assert_has_file_path,
|
||||
assert_intent,
|
||||
assert_no_symbol_keyword,
|
||||
assert_no_symbol_leakage_from_paths,
|
||||
assert_spans_valid,
|
||||
assert_sub_intent,
|
||||
assert_test_policy,
|
||||
)
|
||||
from tests.rag.intent_router_testkit import run_sequence
|
||||
|
||||
pytestmark = pytest.mark.intent_router
|
||||
|
||||
|
||||
def test_invariant_code_file_path_with_canonical_key_term() -> None:
|
||||
result = run_sequence(["Уточни по файлу app/core/config.py"])[0]
|
||||
|
||||
assert_intent(result, "CODE_QA")
|
||||
assert_has_file_path(result, "app/core/config.py")
|
||||
assert_file_only_scope(result, "app/core/config.py")
|
||||
key_terms = [anchor.value for anchor in result.query_plan.anchors if anchor.type == "KEY_TERM"]
|
||||
assert "файл" in key_terms
|
||||
assert "файлу" not in key_terms
|
||||
assert_spans_valid(result)
|
||||
assert_domain_layer_prefixes(result)
|
||||
|
||||
|
||||
def test_invariant_open_file_for_specified_file_phrase_uses_narrow_layers() -> None:
|
||||
result = run_sequence(["Уточни по файлу app/core/config.py"])[0]
|
||||
|
||||
assert_intent(result, "CODE_QA")
|
||||
assert_sub_intent(result, "OPEN_FILE")
|
||||
assert_file_only_scope(result, "app/core/config.py")
|
||||
layer_ids = [item.layer_id for item in result.retrieval_spec.layer_queries]
|
||||
assert layer_ids == ["C0_SOURCE_CHUNKS"]
|
||||
assert result.evidence_policy.require_flow is False
|
||||
|
||||
|
||||
def test_invariant_inline_code_span_routes_to_code_and_extracts_symbol() -> None:
|
||||
result = run_sequence(["Уточни по коду `def build(x): return x`"])[0]
|
||||
|
||||
assert_intent(result, "CODE_QA")
|
||||
assert_spans_valid(result)
|
||||
assert_no_symbol_keyword(result)
|
||||
symbols = [anchor.value for anchor in result.query_plan.anchors if anchor.type == "SYMBOL"]
|
||||
key_terms = [anchor.value for anchor in result.query_plan.anchors if anchor.type == "KEY_TERM"]
|
||||
assert "build" in symbols
|
||||
assert "def" in key_terms
|
||||
|
||||
|
||||
def test_invariant_docs_cyrillic_path_with_quotes() -> None:
|
||||
result = run_sequence(["Что сказано в «docs/архитектура.md»?"])[0]
|
||||
|
||||
assert_intent(result, "DOCS_QA")
|
||||
assert_sub_intent(result, "EXPLAIN")
|
||||
assert_domains(result, ["DOCS"])
|
||||
assert "docs/архитектура.md" in result.query_plan.normalized
|
||||
assert_has_file_path(result, "docs/архитектура.md")
|
||||
assert any(anchor.type == "DOC_REF" for anchor in result.query_plan.anchors)
|
||||
assert result.retrieval_spec.filters.doc_kinds == []
|
||||
assert_spans_valid(result)
|
||||
assert_domain_layer_prefixes(result)
|
||||
|
||||
|
||||
def test_invariant_file_check_phrase_not_project_misc() -> None:
|
||||
result = run_sequence(["Проверь app/modules/rag/explain/intent_builder.py и объясни"])[0]
|
||||
|
||||
assert_intent(result, "CODE_QA")
|
||||
assert_domains(result, ["CODE"])
|
||||
assert_no_symbol_leakage_from_paths(result)
|
||||
assert_domain_layer_prefixes(result)
|
||||
|
||||
|
||||
def test_invariant_tests_include_routing() -> None:
|
||||
result = run_sequence(["Где тесты на ConfigManager?"])[0]
|
||||
|
||||
assert_intent(result, "CODE_QA")
|
||||
assert_test_policy(result, "INCLUDE")
|
||||
symbols = [anchor.value for anchor in result.query_plan.anchors if anchor.type == "SYMBOL"]
|
||||
key_terms = [anchor.value for anchor in result.query_plan.anchors if anchor.type == "KEY_TERM"]
|
||||
assert "ConfigManager" in symbols
|
||||
assert "тест" in key_terms
|
||||
|
||||
|
||||
def test_invariant_keyword_hints_and_expansions_for_function_identifier() -> None:
|
||||
result = run_sequence(["Теперь объясни функцию load_config"])[0]
|
||||
|
||||
assert_intent(result, "CODE_QA")
|
||||
assert "load_config" in result.query_plan.keyword_hints
|
||||
assert "функция" not in result.query_plan.keyword_hints
|
||||
assert "def" not in result.query_plan.expansions
|
||||
|
||||
|
||||
def test_invariant_open_file_sub_intent_uses_narrow_retrieval_profile() -> None:
|
||||
result = run_sequence(["Открой файл app/core/config.py"])[0]
|
||||
|
||||
assert_intent(result, "CODE_QA")
|
||||
assert_sub_intent(result, "OPEN_FILE")
|
||||
assert_file_only_scope(result, "app/core/config.py")
|
||||
layer_ids = [item.layer_id for item in result.retrieval_spec.layer_queries]
|
||||
assert "C0_SOURCE_CHUNKS" in layer_ids
|
||||
assert "C1_SYMBOL_CATALOG" not in layer_ids
|
||||
assert "C2_DEPENDENCY_GRAPH" not in layer_ids
|
||||
assert "C3_ENTRYPOINTS" not in layer_ids
|
||||
assert result.evidence_policy.require_flow is False
|
||||
|
||||
|
||||
def test_invariant_docs_question_routes_to_docs() -> None:
|
||||
result = run_sequence(["Что сказано в документации?"])[0]
|
||||
|
||||
assert_intent(result, "DOCS_QA")
|
||||
assert_domains(result, ["DOCS"])
|
||||
assert_domain_layer_prefixes(result)
|
||||
assert result.query_plan.keyword_hints
|
||||
assert any(item in result.query_plan.expansions for item in result.query_plan.keyword_hints)
|
||||
78
tests/rag/test_layered_gateway.py
Normal file
78
tests/rag/test_layered_gateway.py
Normal file
@@ -0,0 +1,78 @@
|
||||
from app.modules.rag.explain.layered_gateway import LayeredRetrievalGateway
|
||||
|
||||
|
||||
class _Embedder:
|
||||
def embed(self, texts: list[str]) -> list[list[float]]:
|
||||
return [[0.1, 0.2]]
|
||||
|
||||
|
||||
class _RetryingRepository:
|
||||
def __init__(self) -> None:
|
||||
self.calls: list[dict] = []
|
||||
|
||||
def retrieve(self, *args, **kwargs):
|
||||
self.calls.append(kwargs)
|
||||
if kwargs.get("exclude_path_prefixes"):
|
||||
raise RuntimeError("syntax error at or near ')'")
|
||||
return [
|
||||
{
|
||||
"path": "app/users/service.py",
|
||||
"content": "def get_user(): pass",
|
||||
"layer": "C1_SYMBOL_CATALOG",
|
||||
"title": "get_user",
|
||||
"metadata": {"symbol_id": "symbol-1"},
|
||||
"distance": 0.1,
|
||||
"span_start": 10,
|
||||
"span_end": 11,
|
||||
}
|
||||
]
|
||||
|
||||
def retrieve_lexical_code(self, *args, **kwargs):
|
||||
self.calls.append(kwargs)
|
||||
if kwargs.get("exclude_path_prefixes"):
|
||||
raise RuntimeError("broken lexical filter")
|
||||
return [
|
||||
{
|
||||
"path": "app/users/service.py",
|
||||
"content": "def get_user(): pass",
|
||||
"layer": "C0_SOURCE_CHUNKS",
|
||||
"title": "get_user",
|
||||
"metadata": {"symbol_id": "symbol-1"},
|
||||
"span_start": 10,
|
||||
"span_end": 11,
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
class _RecordingRepository:
|
||||
def __init__(self) -> None:
|
||||
self.calls: list[dict] = []
|
||||
|
||||
def retrieve(self, *args, **kwargs):
|
||||
self.calls.append(kwargs)
|
||||
return []
|
||||
|
||||
def retrieve_lexical_code(self, *args, **kwargs):
|
||||
self.calls.append(kwargs)
|
||||
return []
|
||||
|
||||
|
||||
def test_gateway_retries_without_test_filter_on_vector_failure() -> None:
|
||||
gateway = LayeredRetrievalGateway(_RetryingRepository(), _Embedder())
|
||||
|
||||
result = gateway.retrieve_layer("rag-1", "Explain get_user", "C1_SYMBOL_CATALOG", limit=3, exclude_tests=True)
|
||||
|
||||
assert len(result.items) == 1
|
||||
assert "layer:C1_SYMBOL_CATALOG retrieval_failed:retried_without_test_filter" in result.missing
|
||||
|
||||
|
||||
def test_gateway_honors_debug_disable_test_filter(monkeypatch) -> None:
|
||||
monkeypatch.setenv("RAG_DEBUG_DISABLE_TEST_FILTER", "true")
|
||||
repository = _RecordingRepository()
|
||||
gateway = LayeredRetrievalGateway(repository, _Embedder())
|
||||
|
||||
gateway.retrieve_layer("rag-1", "Explain get_user", "C1_SYMBOL_CATALOG", limit=3, exclude_tests=True)
|
||||
|
||||
assert repository.calls
|
||||
assert repository.calls[0]["exclude_path_prefixes"] is None
|
||||
assert repository.calls[0]["exclude_like_patterns"] is None
|
||||
63
tests/rag/test_query_normalization.py
Normal file
63
tests/rag/test_query_normalization.py
Normal file
@@ -0,0 +1,63 @@
|
||||
import pytest
|
||||
|
||||
from app.modules.rag.intent_router_v2.normalization import QueryNormalizer
|
||||
|
||||
pytestmark = pytest.mark.intent_router
|
||||
|
||||
|
||||
def test_query_normalizer_collapses_whitespace() -> None:
|
||||
normalizer = QueryNormalizer()
|
||||
|
||||
normalized = normalizer.normalize(" Объясни как работает \n класс X ")
|
||||
|
||||
assert normalized == "Объясни как работает класс X"
|
||||
|
||||
|
||||
def test_query_normalizer_canonicalizes_quotes() -> None:
|
||||
normalizer = QueryNormalizer()
|
||||
|
||||
normalized = normalizer.normalize('Уточни «текст» и “текст”')
|
||||
|
||||
assert normalized == 'Уточни "текст" и "текст"'
|
||||
|
||||
|
||||
def test_query_normalizer_preserves_backticks_verbatim() -> None:
|
||||
normalizer = QueryNormalizer()
|
||||
|
||||
normalized = normalizer.normalize("Уточни по коду `def build(x):` ")
|
||||
|
||||
assert normalized == "Уточни по коду `def build(x):`"
|
||||
|
||||
|
||||
def test_query_normalizer_preserves_latin_and_cyrillic_file_paths() -> None:
|
||||
normalizer = QueryNormalizer()
|
||||
|
||||
normalized = normalizer.normalize("Сверь app/core/config.py и «docs/руководство.md»")
|
||||
|
||||
assert "app/core/config.py" in normalized
|
||||
assert "docs/руководство.md" in normalized
|
||||
assert "config. py" not in normalized
|
||||
assert "руководство. md" not in normalized
|
||||
|
||||
|
||||
def test_query_normalizer_punctuation_spacing_does_not_break_extensions() -> None:
|
||||
normalizer = QueryNormalizer()
|
||||
|
||||
normalized = normalizer.normalize("Проверь docs/spec.md , затем app/main.py !")
|
||||
|
||||
assert "docs/spec.md" in normalized
|
||||
assert "app/main.py" in normalized
|
||||
assert "spec. md" not in normalized
|
||||
assert "main. py" not in normalized
|
||||
|
||||
|
||||
def test_query_normalizer_idempotent_and_without_enrichment() -> None:
|
||||
normalizer = QueryNormalizer()
|
||||
raw = ' Прочитай «README.md» и docs/spec.md '
|
||||
|
||||
once = normalizer.normalize(raw)
|
||||
twice = normalizer.normalize(once)
|
||||
|
||||
assert twice == once
|
||||
assert "documentation" not in once.lower()
|
||||
assert "class" not in once.lower()
|
||||
@@ -1,12 +0,0 @@
|
||||
from app.modules.rag.contracts.enums import RetrievalMode
|
||||
from app.modules.rag.retrieval.query_router import RagQueryRouter
|
||||
|
||||
|
||||
def test_query_router_uses_docs_by_default() -> None:
|
||||
router = RagQueryRouter()
|
||||
assert router.resolve_mode("Какие есть требования по биллингу?") == RetrievalMode.DOCS
|
||||
|
||||
|
||||
def test_query_router_switches_to_code_on_explicit_code_requests() -> None:
|
||||
router = RagQueryRouter()
|
||||
assert router.resolve_mode("Объясни как работает код endpoint create invoice") == RetrievalMode.CODE
|
||||
44
tests/rag/test_retrieval_statement_builder.py
Normal file
44
tests/rag/test_retrieval_statement_builder.py
Normal file
@@ -0,0 +1,44 @@
|
||||
from app.modules.rag.persistence.retrieval_statement_builder import RetrievalStatementBuilder
|
||||
from app.modules.rag.retrieval.test_filter import build_test_filters, is_test_path
|
||||
|
||||
|
||||
def test_retrieve_builder_adds_test_exclusion_filters() -> None:
|
||||
builder = RetrievalStatementBuilder()
|
||||
test_filters = build_test_filters()
|
||||
|
||||
sql, params = builder.build_retrieve(
|
||||
"rag-1",
|
||||
[0.1, 0.2],
|
||||
query_text="Explain user service",
|
||||
layers=["C0_SOURCE_CHUNKS"],
|
||||
exclude_path_prefixes=test_filters.exclude_path_prefixes,
|
||||
exclude_like_patterns=test_filters.exclude_like_patterns,
|
||||
)
|
||||
|
||||
assert "NOT (" in sql
|
||||
assert "path LIKE :exclude_prefix_0" in sql
|
||||
assert "lower(path) LIKE :exclude_like_0" in sql
|
||||
assert "ESCAPE E'\\\\'" in sql
|
||||
assert params["exclude_prefix_0"] == "tests/%"
|
||||
assert "%.test.%" in params.values()
|
||||
assert "%\\_test.%" in params.values()
|
||||
|
||||
|
||||
def test_lexical_builder_omits_test_filters_when_not_requested() -> None:
|
||||
builder = RetrievalStatementBuilder()
|
||||
|
||||
sql, params = builder.build_lexical_code(
|
||||
"rag-1",
|
||||
query_text="Explain user service",
|
||||
prefer_non_tests=False,
|
||||
)
|
||||
|
||||
assert sql is not None
|
||||
assert "exclude_prefix" not in sql
|
||||
assert "exclude_like" not in sql
|
||||
assert not any(key.startswith("exclude_") for key in params)
|
||||
|
||||
|
||||
def test_test_filter_does_not_treat_contest_file_as_test() -> None:
|
||||
assert is_test_path("app/contest.py") is False
|
||||
assert is_test_path("tests/test_users.py") is True
|
||||
52
tests/rag/test_retriever_v2_no_fallback.py
Normal file
52
tests/rag/test_retriever_v2_no_fallback.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from app.modules.rag.explain import CodeExplainRetrieverV2, LayeredRetrievalGateway
|
||||
|
||||
|
||||
class _ExplodingEmbedder:
|
||||
def embed(self, texts: list[str]) -> list[list[float]]:
|
||||
raise RuntimeError("embedding unavailable")
|
||||
|
||||
|
||||
class _RepositoryWithoutFallback:
|
||||
def retrieve(self, *args, **kwargs):
|
||||
raise RuntimeError("vector retrieval unavailable")
|
||||
|
||||
def retrieve_lexical_code(
|
||||
self,
|
||||
rag_session_id: str,
|
||||
query_text: str,
|
||||
*,
|
||||
limit: int = 5,
|
||||
path_prefixes: list[str] | None = None,
|
||||
exclude_path_prefixes: list[str] | None = None,
|
||||
exclude_like_patterns: list[str] | None = None,
|
||||
prefer_non_tests: bool = False,
|
||||
):
|
||||
return []
|
||||
|
||||
|
||||
class _FakeGraphRepository:
|
||||
def get_symbols_by_ids(self, rag_session_id: str, symbol_ids: list[str]):
|
||||
return []
|
||||
|
||||
def get_chunks_by_symbol_ids(self, rag_session_id: str, symbol_ids: list[str], prefer_chunk_type: str = "symbol_block"):
|
||||
return []
|
||||
|
||||
def get_out_edges(self, rag_session_id: str, src_symbol_ids: list[str], edge_types: list[str], limit_per_src: int):
|
||||
return []
|
||||
|
||||
def resolve_symbol_by_ref(self, rag_session_id: str, dst_ref: str, package_hint: str | None = None):
|
||||
return None
|
||||
|
||||
|
||||
def test_retriever_v2_returns_pack_without_fallback_method() -> None:
|
||||
retriever = CodeExplainRetrieverV2(
|
||||
gateway=LayeredRetrievalGateway(_RepositoryWithoutFallback(), _ExplodingEmbedder()),
|
||||
graph_repository=_FakeGraphRepository(),
|
||||
)
|
||||
|
||||
pack = retriever.build_pack("rag-1", "Explain get_user")
|
||||
|
||||
assert pack.code_excerpts == []
|
||||
assert any(item.startswith("layer:C3_ENTRYPOINTS retrieval_failed") for item in pack.missing)
|
||||
assert any(item.startswith("layer:C1_SYMBOL_CATALOG retrieval_failed") for item in pack.missing)
|
||||
assert "layer:C0 empty" in pack.missing
|
||||
105
tests/rag/test_retriever_v2_pack.py
Normal file
105
tests/rag/test_retriever_v2_pack.py
Normal file
@@ -0,0 +1,105 @@
|
||||
from app.modules.rag.explain.models import CodeLocation, LayeredRetrievalItem
|
||||
from app.modules.rag.explain.retriever_v2 import CodeExplainRetrieverV2
|
||||
|
||||
|
||||
class _FakeGateway:
|
||||
def retrieve_layer(
|
||||
self,
|
||||
rag_session_id: str,
|
||||
query: str,
|
||||
layer: str,
|
||||
*,
|
||||
limit: int,
|
||||
path_prefixes: list[str] | None = None,
|
||||
exclude_tests: bool = True,
|
||||
prefer_non_tests: bool = False,
|
||||
include_spans: bool = False,
|
||||
):
|
||||
if layer == "C3_ENTRYPOINTS":
|
||||
return __import__("types").SimpleNamespace(
|
||||
items=[
|
||||
LayeredRetrievalItem(
|
||||
source="app/api/users.py",
|
||||
content="GET /users/{id}",
|
||||
layer=layer,
|
||||
title="GET /users/{id}",
|
||||
metadata={"entry_type": "http", "handler_symbol_id": "handler-1"},
|
||||
location=CodeLocation(path="app/api/users.py", start_line=10, end_line=10),
|
||||
)
|
||||
],
|
||||
missing=[],
|
||||
)
|
||||
if layer == "C1_SYMBOL_CATALOG":
|
||||
return __import__("types").SimpleNamespace(
|
||||
items=[
|
||||
LayeredRetrievalItem(
|
||||
source="app/api/users.py",
|
||||
content="def get_user_handler",
|
||||
layer=layer,
|
||||
title="get_user_handler",
|
||||
metadata={"symbol_id": "handler-1"},
|
||||
location=CodeLocation(path="app/api/users.py", start_line=10, end_line=18),
|
||||
)
|
||||
],
|
||||
missing=[],
|
||||
)
|
||||
raise AssertionError(layer)
|
||||
|
||||
def retrieve_lexical_code(
|
||||
self,
|
||||
rag_session_id: str,
|
||||
query: str,
|
||||
*,
|
||||
limit: int,
|
||||
path_prefixes: list[str] | None = None,
|
||||
exclude_tests: bool = True,
|
||||
include_spans: bool = False,
|
||||
):
|
||||
return __import__("types").SimpleNamespace(items=[], missing=[])
|
||||
|
||||
|
||||
class _FakeGraphRepository:
|
||||
def get_symbols_by_ids(self, rag_session_id: str, symbol_ids: list[str]):
|
||||
return [
|
||||
LayeredRetrievalItem(
|
||||
source="app/api/users.py",
|
||||
content="def get_user_handler",
|
||||
layer="C1_SYMBOL_CATALOG",
|
||||
title="get_user_handler",
|
||||
metadata={"symbol_id": "handler-1"},
|
||||
location=CodeLocation(path="app/api/users.py", start_line=10, end_line=18),
|
||||
)
|
||||
]
|
||||
|
||||
def get_out_edges(self, rag_session_id: str, src_symbol_ids: list[str], edge_types: list[str], limit_per_src: int):
|
||||
return []
|
||||
|
||||
def resolve_symbol_by_ref(self, rag_session_id: str, dst_ref: str, package_hint: str | None = None):
|
||||
return None
|
||||
|
||||
def get_chunks_by_symbol_ids(self, rag_session_id: str, symbol_ids: list[str], prefer_chunk_type: str = "symbol_block"):
|
||||
return [
|
||||
LayeredRetrievalItem(
|
||||
source="app/api/users.py",
|
||||
content="async def get_user_handler(user_id: str):\n return await service.get_user(user_id)",
|
||||
layer="C0_SOURCE_CHUNKS",
|
||||
title="get_user_handler",
|
||||
metadata={"symbol_id": "handler-1"},
|
||||
location=CodeLocation(path="app/api/users.py", start_line=10, end_line=18),
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def test_retriever_v2_builds_pack_with_trace_and_excerpts() -> None:
|
||||
retriever = CodeExplainRetrieverV2(
|
||||
gateway=_FakeGateway(),
|
||||
graph_repository=_FakeGraphRepository(),
|
||||
)
|
||||
|
||||
pack = retriever.build_pack("rag-1", "Explain endpoint get_user")
|
||||
|
||||
assert len(pack.selected_entrypoints) == 1
|
||||
assert len(pack.seed_symbols) == 1
|
||||
assert len(pack.trace_paths) == 1
|
||||
assert len(pack.code_excerpts) == 1
|
||||
assert pack.code_excerpts[0].path == "app/api/users.py"
|
||||
142
tests/rag/test_retriever_v2_production_first.py
Normal file
142
tests/rag/test_retriever_v2_production_first.py
Normal file
@@ -0,0 +1,142 @@
|
||||
from types import SimpleNamespace
|
||||
|
||||
from app.modules.rag.explain.models import CodeLocation, LayeredRetrievalItem
|
||||
from app.modules.rag.explain.retriever_v2 import CodeExplainRetrieverV2
|
||||
|
||||
|
||||
class _ProductionFirstGateway:
|
||||
def __init__(self) -> None:
|
||||
self.lexical_calls: list[bool] = []
|
||||
|
||||
def retrieve_layer(
|
||||
self,
|
||||
rag_session_id: str,
|
||||
query: str,
|
||||
layer: str,
|
||||
*,
|
||||
limit: int,
|
||||
path_prefixes: list[str] | None = None,
|
||||
exclude_tests: bool = True,
|
||||
prefer_non_tests: bool = False,
|
||||
include_spans: bool = False,
|
||||
):
|
||||
if layer == "C3_ENTRYPOINTS":
|
||||
return SimpleNamespace(items=[], missing=[])
|
||||
if layer == "C1_SYMBOL_CATALOG":
|
||||
return SimpleNamespace(items=[], missing=[])
|
||||
raise AssertionError(layer)
|
||||
|
||||
def retrieve_lexical_code(
|
||||
self,
|
||||
rag_session_id: str,
|
||||
query: str,
|
||||
*,
|
||||
limit: int,
|
||||
path_prefixes: list[str] | None = None,
|
||||
exclude_tests: bool = True,
|
||||
include_spans: bool = False,
|
||||
):
|
||||
self.lexical_calls.append(exclude_tests)
|
||||
if exclude_tests:
|
||||
return SimpleNamespace(
|
||||
items=[
|
||||
LayeredRetrievalItem(
|
||||
source="app/users/service.py",
|
||||
content="def get_user():\n return repo.get_user()",
|
||||
layer="C0_SOURCE_CHUNKS",
|
||||
title="get_user",
|
||||
metadata={"symbol_id": "user-service", "is_test": False},
|
||||
location=CodeLocation(path="app/users/service.py", start_line=10, end_line=11),
|
||||
),
|
||||
LayeredRetrievalItem(
|
||||
source="app/users/repository.py",
|
||||
content="def get_user_repo():\n return {}",
|
||||
layer="C0_SOURCE_CHUNKS",
|
||||
title="get_user_repo",
|
||||
metadata={"symbol_id": "user-repo", "is_test": False},
|
||||
location=CodeLocation(path="app/users/repository.py", start_line=20, end_line=21),
|
||||
),
|
||||
],
|
||||
missing=[],
|
||||
)
|
||||
return SimpleNamespace(
|
||||
items=[
|
||||
LayeredRetrievalItem(
|
||||
source="tests/test_users.py",
|
||||
content="def test_get_user():\n assert service.get_user()",
|
||||
layer="C0_SOURCE_CHUNKS",
|
||||
title="test_get_user",
|
||||
metadata={"symbol_id": "test-user", "is_test": True},
|
||||
location=CodeLocation(path="tests/test_users.py", start_line=5, end_line=6),
|
||||
)
|
||||
],
|
||||
missing=[],
|
||||
)
|
||||
|
||||
|
||||
class _TestsOnlyGateway(_ProductionFirstGateway):
|
||||
def retrieve_lexical_code(
|
||||
self,
|
||||
rag_session_id: str,
|
||||
query: str,
|
||||
*,
|
||||
limit: int,
|
||||
path_prefixes: list[str] | None = None,
|
||||
exclude_tests: bool = True,
|
||||
include_spans: bool = False,
|
||||
):
|
||||
self.lexical_calls.append(exclude_tests)
|
||||
if exclude_tests:
|
||||
return SimpleNamespace(items=[], missing=[])
|
||||
return SimpleNamespace(
|
||||
items=[
|
||||
LayeredRetrievalItem(
|
||||
source="tests/test_users.py",
|
||||
content="def test_get_user():\n assert service.get_user()",
|
||||
layer="C0_SOURCE_CHUNKS",
|
||||
title="test_get_user",
|
||||
metadata={"symbol_id": "test-user", "is_test": True},
|
||||
location=CodeLocation(path="tests/test_users.py", start_line=5, end_line=6),
|
||||
)
|
||||
],
|
||||
missing=[],
|
||||
)
|
||||
|
||||
|
||||
class _FakeGraphRepository:
|
||||
def get_symbols_by_ids(self, rag_session_id: str, symbol_ids: list[str]):
|
||||
return []
|
||||
|
||||
def get_chunks_by_symbol_ids(self, rag_session_id: str, symbol_ids: list[str], prefer_chunk_type: str = "symbol_block"):
|
||||
return []
|
||||
|
||||
def get_out_edges(self, rag_session_id: str, src_symbol_ids: list[str], edge_types: list[str], limit_per_src: int):
|
||||
return []
|
||||
|
||||
def resolve_symbol_by_ref(self, rag_session_id: str, dst_ref: str, package_hint: str | None = None):
|
||||
return None
|
||||
|
||||
|
||||
def test_retriever_prefers_prod_chunks_and_skips_test_fallback_when_enough_evidence() -> None:
|
||||
gateway = _ProductionFirstGateway()
|
||||
retriever = CodeExplainRetrieverV2(gateway=gateway, graph_repository=_FakeGraphRepository())
|
||||
|
||||
pack = retriever.build_pack("rag-1", "Explain get_user")
|
||||
|
||||
assert gateway.lexical_calls == [True]
|
||||
assert [excerpt.path for excerpt in pack.code_excerpts] == [
|
||||
"app/users/service.py",
|
||||
"app/users/repository.py",
|
||||
]
|
||||
assert all(not excerpt.focus.startswith("test:") for excerpt in pack.code_excerpts)
|
||||
|
||||
|
||||
def test_retriever_uses_test_fallback_when_production_evidence_is_missing() -> None:
|
||||
gateway = _TestsOnlyGateway()
|
||||
retriever = CodeExplainRetrieverV2(gateway=gateway, graph_repository=_FakeGraphRepository())
|
||||
|
||||
pack = retriever.build_pack("rag-1", "Explain get_user")
|
||||
|
||||
assert gateway.lexical_calls == [True, False]
|
||||
assert [excerpt.path for excerpt in pack.code_excerpts] == ["tests/test_users.py"]
|
||||
assert pack.code_excerpts[0].focus == "test:lexical"
|
||||
83
tests/rag/test_trace_builder.py
Normal file
83
tests/rag/test_trace_builder.py
Normal file
@@ -0,0 +1,83 @@
|
||||
from app.modules.rag.explain.models import CodeLocation, LayeredRetrievalItem
|
||||
from app.modules.rag.explain.trace_builder import TraceBuilder
|
||||
|
||||
|
||||
class _FakeGraphRepository:
|
||||
def get_out_edges(self, rag_session_id: str, src_symbol_ids: list[str], edge_types: list[str], limit_per_src: int):
|
||||
assert rag_session_id == "rag-1"
|
||||
assert edge_types == ["calls", "imports", "inherits"]
|
||||
if src_symbol_ids == ["handler-1"]:
|
||||
return [
|
||||
LayeredRetrievalItem(
|
||||
source="app/api/users.py",
|
||||
content="handler calls get_user",
|
||||
layer="C2_DEPENDENCY_GRAPH",
|
||||
title="handler:calls",
|
||||
metadata={
|
||||
"src_symbol_id": "handler-1",
|
||||
"dst_symbol_id": None,
|
||||
"dst_ref": "UserService.get_user",
|
||||
"resolution": "partial",
|
||||
"edge_type": "calls",
|
||||
},
|
||||
location=CodeLocation(path="app/api/users.py", start_line=12, end_line=12),
|
||||
)
|
||||
]
|
||||
return []
|
||||
|
||||
def resolve_symbol_by_ref(self, rag_session_id: str, dst_ref: str, package_hint: str | None = None):
|
||||
assert rag_session_id == "rag-1"
|
||||
assert dst_ref == "UserService.get_user"
|
||||
assert package_hint == "app.api"
|
||||
return LayeredRetrievalItem(
|
||||
source="app/services/users.py",
|
||||
content="method UserService.get_user",
|
||||
layer="C1_SYMBOL_CATALOG",
|
||||
title="UserService.get_user",
|
||||
metadata={
|
||||
"symbol_id": "service-1",
|
||||
"package_or_module": "app.api.users",
|
||||
},
|
||||
location=CodeLocation(path="app/services/users.py", start_line=4, end_line=10),
|
||||
)
|
||||
|
||||
def get_symbols_by_ids(self, rag_session_id: str, symbol_ids: list[str]):
|
||||
assert rag_session_id == "rag-1"
|
||||
if symbol_ids == ["service-1"]:
|
||||
return [
|
||||
LayeredRetrievalItem(
|
||||
source="app/services/users.py",
|
||||
content="method UserService.get_user",
|
||||
layer="C1_SYMBOL_CATALOG",
|
||||
title="UserService.get_user",
|
||||
metadata={
|
||||
"symbol_id": "service-1",
|
||||
"package_or_module": "app.api.users",
|
||||
},
|
||||
location=CodeLocation(path="app/services/users.py", start_line=4, end_line=10),
|
||||
)
|
||||
]
|
||||
return []
|
||||
|
||||
|
||||
def test_trace_builder_resolves_partial_edges_across_files() -> None:
|
||||
builder = TraceBuilder(_FakeGraphRepository())
|
||||
seeds = [
|
||||
LayeredRetrievalItem(
|
||||
source="app/api/users.py",
|
||||
content="function handler",
|
||||
layer="C1_SYMBOL_CATALOG",
|
||||
title="get_user",
|
||||
metadata={
|
||||
"symbol_id": "handler-1",
|
||||
"package_or_module": "app.api.users",
|
||||
},
|
||||
location=CodeLocation(path="app/api/users.py", start_line=10, end_line=18),
|
||||
)
|
||||
]
|
||||
|
||||
paths = builder.build_paths("rag-1", seeds, max_depth=3)
|
||||
|
||||
assert len(paths) >= 1
|
||||
assert paths[0].symbol_ids == ["handler-1", "service-1"]
|
||||
assert "resolved:UserService.get_user" in paths[0].notes
|
||||
Reference in New Issue
Block a user