Фикс состояния
This commit is contained in:
1
tests/unit_tests/__init__.py
Normal file
1
tests/unit_tests/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Unit tests for application modules: agent, chat, rag."""
|
||||
@@ -0,0 +1,181 @@
|
||||
import sys
|
||||
import types
|
||||
|
||||
sqlalchemy = types.ModuleType("sqlalchemy")
|
||||
sqlalchemy.text = lambda value: value
|
||||
sqlalchemy.create_engine = lambda *args, **kwargs: object()
|
||||
sys.modules.setdefault("sqlalchemy", sqlalchemy)
|
||||
|
||||
sqlalchemy_engine = types.ModuleType("sqlalchemy.engine")
|
||||
sqlalchemy_engine.Engine = object
|
||||
sys.modules.setdefault("sqlalchemy.engine", sqlalchemy_engine)
|
||||
|
||||
sqlalchemy_orm = types.ModuleType("sqlalchemy.orm")
|
||||
sqlalchemy_orm.sessionmaker = lambda *args, **kwargs: object()
|
||||
sys.modules.setdefault("sqlalchemy.orm", sqlalchemy_orm)
|
||||
|
||||
sqlalchemy_pool = types.ModuleType("sqlalchemy.pool")
|
||||
sqlalchemy_pool.NullPool = object
|
||||
sys.modules.setdefault("sqlalchemy.pool", sqlalchemy_pool)
|
||||
|
||||
from app.modules.agent.engine.router.router_service import RouterService
|
||||
from app.modules.agent.engine.router.schemas import RouteDecision, RouterContext
|
||||
|
||||
|
||||
class _FakeRegistry:
|
||||
def is_valid(self, domain_id: str, process_id: str) -> bool:
|
||||
return (domain_id, process_id) in {
|
||||
("default", "general"),
|
||||
("project", "qa"),
|
||||
("project", "edits"),
|
||||
("docs", "generation"),
|
||||
}
|
||||
|
||||
def get_factory(self, domain_id: str, process_id: str):
|
||||
return object()
|
||||
|
||||
|
||||
class _FakeClassifier:
|
||||
def __init__(self, decision: RouteDecision | None = None, forced: RouteDecision | None = None) -> None:
|
||||
self._decision = decision or RouteDecision(domain_id="project", process_id="qa", confidence=0.95, reason="new_intent")
|
||||
self._forced = forced
|
||||
self.calls = 0
|
||||
|
||||
def from_mode(self, mode: str) -> RouteDecision | None:
|
||||
return self._forced if mode != "auto" else None
|
||||
|
||||
def classify_new_intent(self, user_message: str, context: RouterContext) -> RouteDecision:
|
||||
self.calls += 1
|
||||
return self._decision
|
||||
|
||||
|
||||
class _FakeContextStore:
|
||||
def __init__(self, context: RouterContext) -> None:
|
||||
self._context = context
|
||||
self.updated: list[dict] = []
|
||||
|
||||
def get(self, conversation_key: str) -> RouterContext:
|
||||
return self._context
|
||||
|
||||
def update(self, conversation_key: str, **kwargs) -> None:
|
||||
self.updated.append({"conversation_key": conversation_key, **kwargs})
|
||||
|
||||
|
||||
class _FakeSwitchDetector:
|
||||
def __init__(self, should_switch: bool) -> None:
|
||||
self._should_switch = should_switch
|
||||
|
||||
def should_switch(self, user_message: str, context: RouterContext) -> bool:
|
||||
return self._should_switch
|
||||
|
||||
|
||||
def test_router_service_classifies_first_message() -> None:
|
||||
service = RouterService(
|
||||
registry=_FakeRegistry(),
|
||||
classifier=_FakeClassifier(),
|
||||
context_store=_FakeContextStore(RouterContext()),
|
||||
switch_detector=_FakeSwitchDetector(False),
|
||||
)
|
||||
|
||||
route = service.resolve("Объясни как работает endpoint", "dialog-1")
|
||||
|
||||
assert route.domain_id == "project"
|
||||
assert route.process_id == "qa"
|
||||
assert route.decision_type == "start"
|
||||
|
||||
|
||||
def test_router_service_keeps_current_intent_for_follow_up() -> None:
|
||||
context = RouterContext(
|
||||
active_intent={"domain_id": "project", "process_id": "qa"},
|
||||
last_routing={"domain_id": "project", "process_id": "qa"},
|
||||
dialog_started=True,
|
||||
turn_index=1,
|
||||
)
|
||||
classifier = _FakeClassifier(
|
||||
decision=RouteDecision(domain_id="docs", process_id="generation", confidence=0.99, reason="should_not_run")
|
||||
)
|
||||
service = RouterService(
|
||||
registry=_FakeRegistry(),
|
||||
classifier=classifier,
|
||||
context_store=_FakeContextStore(context),
|
||||
switch_detector=_FakeSwitchDetector(False),
|
||||
)
|
||||
|
||||
route = service.resolve("Покажи подробнее", "dialog-1")
|
||||
|
||||
assert route.domain_id == "project"
|
||||
assert route.process_id == "qa"
|
||||
assert route.decision_type == "continue"
|
||||
assert classifier.calls == 0
|
||||
|
||||
|
||||
def test_router_service_switches_only_on_explicit_new_intent() -> None:
|
||||
context = RouterContext(
|
||||
active_intent={"domain_id": "project", "process_id": "qa"},
|
||||
last_routing={"domain_id": "project", "process_id": "qa"},
|
||||
dialog_started=True,
|
||||
turn_index=2,
|
||||
)
|
||||
classifier = _FakeClassifier(
|
||||
decision=RouteDecision(domain_id="project", process_id="edits", confidence=0.96, reason="explicit_edit")
|
||||
)
|
||||
service = RouterService(
|
||||
registry=_FakeRegistry(),
|
||||
classifier=classifier,
|
||||
context_store=_FakeContextStore(context),
|
||||
switch_detector=_FakeSwitchDetector(True),
|
||||
)
|
||||
|
||||
route = service.resolve("Теперь измени файл README.md", "dialog-1")
|
||||
|
||||
assert route.domain_id == "project"
|
||||
assert route.process_id == "edits"
|
||||
assert route.decision_type == "switch"
|
||||
assert route.explicit_switch is True
|
||||
assert classifier.calls == 1
|
||||
|
||||
|
||||
def test_router_service_keeps_current_when_explicit_switch_is_unresolved() -> None:
|
||||
context = RouterContext(
|
||||
active_intent={"domain_id": "project", "process_id": "qa"},
|
||||
last_routing={"domain_id": "project", "process_id": "qa"},
|
||||
dialog_started=True,
|
||||
turn_index=2,
|
||||
)
|
||||
classifier = _FakeClassifier(
|
||||
decision=RouteDecision(domain_id="docs", process_id="generation", confidence=0.2, reason="low_confidence")
|
||||
)
|
||||
service = RouterService(
|
||||
registry=_FakeRegistry(),
|
||||
classifier=classifier,
|
||||
context_store=_FakeContextStore(context),
|
||||
switch_detector=_FakeSwitchDetector(True),
|
||||
)
|
||||
|
||||
route = service.resolve("Теперь сделай что-то другое", "dialog-1")
|
||||
|
||||
assert route.domain_id == "project"
|
||||
assert route.process_id == "qa"
|
||||
assert route.decision_type == "continue"
|
||||
assert route.reason == "explicit_switch_unresolved_keep_current"
|
||||
|
||||
|
||||
def test_router_service_persists_decision_type() -> None:
|
||||
store = _FakeContextStore(RouterContext())
|
||||
service = RouterService(
|
||||
registry=_FakeRegistry(),
|
||||
classifier=_FakeClassifier(),
|
||||
context_store=store,
|
||||
switch_detector=_FakeSwitchDetector(False),
|
||||
)
|
||||
|
||||
service.persist_context(
|
||||
"dialog-1",
|
||||
domain_id="project",
|
||||
process_id="qa",
|
||||
user_message="Объясни",
|
||||
assistant_message="Ответ",
|
||||
decision_type="continue",
|
||||
)
|
||||
|
||||
assert store.updated[0]["decision_type"] == "continue"
|
||||
@@ -0,0 +1,59 @@
|
||||
from app.modules.agent.engine.orchestrator.actions.code_explain_actions import CodeExplainActions
|
||||
from app.modules.agent.engine.orchestrator.execution_context import ExecutionContext
|
||||
from app.modules.agent.engine.orchestrator.models import (
|
||||
ArtifactType,
|
||||
ExecutionPlan,
|
||||
OutputContract,
|
||||
RoutingMeta,
|
||||
Scenario,
|
||||
TaskConstraints,
|
||||
TaskSpec,
|
||||
)
|
||||
from app.modules.rag.explain.models import ExplainIntent, ExplainPack
|
||||
|
||||
|
||||
class _FakeRetriever:
|
||||
def build_pack(self, rag_session_id: str, user_query: str, *, file_candidates: list[dict] | None = None) -> ExplainPack:
|
||||
assert rag_session_id == "rag-1"
|
||||
assert "endpoint" in user_query
|
||||
assert file_candidates == [{"path": "app/api/users.py", "content": "..." }]
|
||||
return ExplainPack(intent=ExplainIntent(raw_query=user_query, normalized_query=user_query))
|
||||
|
||||
|
||||
def _ctx() -> ExecutionContext:
|
||||
task = TaskSpec(
|
||||
task_id="task-1",
|
||||
dialog_session_id="dialog-1",
|
||||
rag_session_id="rag-1",
|
||||
user_message="Explain endpoint get_user",
|
||||
scenario=Scenario.EXPLAIN_PART,
|
||||
routing=RoutingMeta(domain_id="project", process_id="qa", confidence=0.9, reason="test"),
|
||||
constraints=TaskConstraints(),
|
||||
output_contract=OutputContract(result_type="answer"),
|
||||
metadata={"rag_context": "", "confluence_context": "", "files_map": {}},
|
||||
)
|
||||
plan = ExecutionPlan(
|
||||
plan_id="plan-1",
|
||||
task_id="task-1",
|
||||
scenario=Scenario.EXPLAIN_PART,
|
||||
template_id="tpl",
|
||||
template_version="1",
|
||||
steps=[],
|
||||
)
|
||||
ctx = ExecutionContext(task=task, plan=plan, graph_resolver=lambda *_: None, graph_invoker=lambda *_: {})
|
||||
ctx.artifacts.put(
|
||||
key="source_bundle",
|
||||
artifact_type=ArtifactType.STRUCTURED_JSON,
|
||||
content={"file_candidates": [{"path": "app/api/users.py", "content": "..."}]},
|
||||
)
|
||||
return ctx
|
||||
|
||||
|
||||
def test_code_explain_actions_store_explain_pack() -> None:
|
||||
ctx = _ctx()
|
||||
actions = CodeExplainActions(_FakeRetriever())
|
||||
|
||||
actions.build_code_explain_pack(ctx)
|
||||
|
||||
stored = ctx.artifacts.get_content("explain_pack", {})
|
||||
assert stored["intent"]["raw_query"] == "Explain endpoint get_user"
|
||||
@@ -0,0 +1,61 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from app.modules.agent.engine.orchestrator.actions.edit_actions import EditActions
|
||||
from app.modules.agent.engine.orchestrator.execution_context import ExecutionContext
|
||||
from app.modules.agent.engine.orchestrator.models import (
|
||||
ExecutionPlan,
|
||||
OutputContract,
|
||||
RoutingMeta,
|
||||
Scenario,
|
||||
TaskConstraints,
|
||||
TaskSpec,
|
||||
)
|
||||
|
||||
|
||||
def _ctx() -> ExecutionContext:
|
||||
task = TaskSpec(
|
||||
task_id="task-1",
|
||||
dialog_session_id="dialog-1",
|
||||
rag_session_id="rag-1",
|
||||
mode="auto",
|
||||
user_message="Добавь в readme.md в конце строку про автора",
|
||||
scenario=Scenario.TARGETED_EDIT,
|
||||
routing=RoutingMeta(domain_id="project", process_id="edits", confidence=0.95, reason="test"),
|
||||
constraints=TaskConstraints(allow_writes=True),
|
||||
output_contract=OutputContract(result_type="changeset"),
|
||||
metadata={
|
||||
"files_map": {
|
||||
"README.md": {
|
||||
"path": "README.md",
|
||||
"content": "# Title\n",
|
||||
"content_hash": "hash123",
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
plan = ExecutionPlan(
|
||||
plan_id="plan-1",
|
||||
task_id="task-1",
|
||||
scenario=Scenario.TARGETED_EDIT,
|
||||
template_id="targeted_edit_v1",
|
||||
template_version="1.0",
|
||||
steps=[],
|
||||
)
|
||||
return ExecutionContext(task=task, plan=plan, graph_resolver=lambda *_: None, graph_invoker=lambda *_: {})
|
||||
|
||||
|
||||
def test_edit_actions_resolve_path_case_insensitive_and_keep_update() -> None:
|
||||
actions = EditActions()
|
||||
ctx = _ctx()
|
||||
|
||||
actions.resolve_target(ctx)
|
||||
actions.load_target_context(ctx)
|
||||
actions.plan_minimal_patch(ctx)
|
||||
actions.generate_patch(ctx)
|
||||
|
||||
target = ctx.artifacts.get_content("target_context", {})
|
||||
changeset = ctx.artifacts.get_content("raw_changeset", [])
|
||||
|
||||
assert target["path"] == "README.md"
|
||||
assert changeset[0]["path"] == "README.md"
|
||||
assert changeset[0]["op"] == "update"
|
||||
56
tests/unit_tests/agent/orchestrator/test_eval_suite.py
Normal file
56
tests/unit_tests/agent/orchestrator/test_eval_suite.py
Normal file
@@ -0,0 +1,56 @@
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
from app.modules.agent.engine.orchestrator.models import OutputContract, RoutingMeta, Scenario, TaskConstraints, TaskSpec
|
||||
from app.modules.agent.engine.orchestrator.service import OrchestratorService
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"scenario,expect_changeset",
|
||||
[
|
||||
(Scenario.EXPLAIN_PART, False),
|
||||
(Scenario.ANALYTICS_REVIEW, False),
|
||||
(Scenario.DOCS_FROM_ANALYTICS, True),
|
||||
(Scenario.TARGETED_EDIT, True),
|
||||
(Scenario.GHERKIN_MODEL, True),
|
||||
],
|
||||
)
|
||||
def test_eval_suite_scenarios_run(scenario: Scenario, expect_changeset: bool) -> None:
|
||||
service = OrchestratorService()
|
||||
|
||||
task = TaskSpec(
|
||||
task_id=f"task-{scenario.value}",
|
||||
dialog_session_id="dialog-1",
|
||||
rag_session_id="rag-1",
|
||||
mode="auto",
|
||||
user_message="Please process this scenario using project docs and requirements.",
|
||||
scenario=scenario,
|
||||
routing=RoutingMeta(domain_id="project", process_id="qa", confidence=0.95, reason="eval"),
|
||||
constraints=TaskConstraints(
|
||||
allow_writes=scenario in {Scenario.DOCS_FROM_ANALYTICS, Scenario.TARGETED_EDIT, Scenario.GHERKIN_MODEL},
|
||||
max_steps=20,
|
||||
max_retries_per_step=2,
|
||||
step_timeout_sec=90,
|
||||
),
|
||||
output_contract=OutputContract(result_type="answer"),
|
||||
attachments=[{"type": "http_url", "value": "https://example.com/doc"}],
|
||||
metadata={
|
||||
"rag_context": "Requirements context is available.",
|
||||
"confluence_context": "",
|
||||
"files_map": {"docs/api/increment.md": {"content": "old", "content_hash": "h1"}},
|
||||
},
|
||||
)
|
||||
|
||||
result = asyncio.run(
|
||||
service.run(
|
||||
task=task,
|
||||
graph_resolver=lambda _domain, _process: object(),
|
||||
graph_invoker=lambda _graph, _state, _dialog: {"answer": "fallback", "changeset": []},
|
||||
)
|
||||
)
|
||||
|
||||
assert result.meta["plan"]["status"] in {"completed", "partial"}
|
||||
assert bool(result.changeset) is expect_changeset
|
||||
if not expect_changeset:
|
||||
assert result.answer
|
||||
131
tests/unit_tests/agent/orchestrator/test_explain_actions.py
Normal file
131
tests/unit_tests/agent/orchestrator/test_explain_actions.py
Normal file
@@ -0,0 +1,131 @@
|
||||
from app.modules.agent.engine.orchestrator.actions.explain_actions import ExplainActions
|
||||
from app.modules.agent.engine.orchestrator.execution_context import ExecutionContext
|
||||
from app.modules.agent.engine.orchestrator.models import (
|
||||
ExecutionPlan,
|
||||
OutputContract,
|
||||
RoutingMeta,
|
||||
Scenario,
|
||||
TaskConstraints,
|
||||
TaskSpec,
|
||||
)
|
||||
|
||||
|
||||
def _ctx(rag_items: list[dict]) -> ExecutionContext:
|
||||
task = TaskSpec(
|
||||
task_id="task-1",
|
||||
dialog_session_id="dialog-1",
|
||||
rag_session_id="rag-1",
|
||||
user_message="Объясни по коду как работает task_processor",
|
||||
scenario=Scenario.EXPLAIN_PART,
|
||||
routing=RoutingMeta(domain_id="project", process_id="qa", confidence=0.9, reason="test"),
|
||||
constraints=TaskConstraints(),
|
||||
output_contract=OutputContract(result_type="answer"),
|
||||
metadata={
|
||||
"rag_items": rag_items,
|
||||
"rag_context": "",
|
||||
"confluence_context": "",
|
||||
"files_map": {},
|
||||
},
|
||||
)
|
||||
plan = ExecutionPlan(
|
||||
plan_id="plan-1",
|
||||
task_id="task-1",
|
||||
scenario=Scenario.EXPLAIN_PART,
|
||||
template_id="tpl",
|
||||
template_version="1",
|
||||
steps=[],
|
||||
)
|
||||
return ExecutionContext(task=task, plan=plan, graph_resolver=lambda *_: None, graph_invoker=lambda *_: {})
|
||||
|
||||
|
||||
def test_explain_actions_switch_to_code_profile_when_code_layers_present() -> None:
|
||||
ctx = _ctx(
|
||||
[
|
||||
{
|
||||
"source": "app/task_processor.py",
|
||||
"layer": "C1_SYMBOL_CATALOG",
|
||||
"title": "task_processor.process_task",
|
||||
"content": "function task_processor.process_task(task)",
|
||||
"metadata": {"qname": "task_processor.process_task", "kind": "function"},
|
||||
},
|
||||
{
|
||||
"source": "app/task_processor.py",
|
||||
"layer": "C2_DEPENDENCY_GRAPH",
|
||||
"title": "task_processor.process_task:calls",
|
||||
"content": "task_processor.process_task calls queue.publish",
|
||||
"metadata": {"edge_type": "calls"},
|
||||
},
|
||||
]
|
||||
)
|
||||
actions = ExplainActions()
|
||||
|
||||
actions.collect_sources(ctx)
|
||||
actions.extract_logic(ctx)
|
||||
actions.summarize(ctx)
|
||||
|
||||
sources = ctx.artifacts.get_content("sources", {})
|
||||
assert sources["source_profile"] == "code"
|
||||
answer = str(ctx.artifacts.get_content("final_answer", ""))
|
||||
assert "кодовых слоев индекса" not in answer
|
||||
assert "CodeRAG" not in answer
|
||||
assert "app/task_processor.py" in answer
|
||||
assert "requirements/docs context" not in answer
|
||||
|
||||
|
||||
def test_explain_actions_add_code_details_block() -> None:
|
||||
ctx = _ctx(
|
||||
[
|
||||
{
|
||||
"source": "src/config_manager/__init__.py",
|
||||
"layer": "C1_SYMBOL_CATALOG",
|
||||
"title": "ConfigManager",
|
||||
"content": "const ConfigManager\nConfigManager = config_manager.v2.ConfigManagerV2",
|
||||
"metadata": {
|
||||
"qname": "ConfigManager",
|
||||
"kind": "const",
|
||||
"lang_payload": {"imported_from": "v2.ConfigManagerV2", "import_alias": True},
|
||||
},
|
||||
},
|
||||
{
|
||||
"source": "src/config_manager/v2/control/base.py",
|
||||
"layer": "C1_SYMBOL_CATALOG",
|
||||
"title": "ControlChannel",
|
||||
"content": "class ControlChannel\nControlChannel(ABC)",
|
||||
"metadata": {"qname": "ControlChannel", "kind": "class"},
|
||||
},
|
||||
{
|
||||
"source": "src/config_manager/v2/core/control_bridge.py",
|
||||
"layer": "C1_SYMBOL_CATALOG",
|
||||
"title": "ControlChannelBridge",
|
||||
"content": "class ControlChannelBridge\nПредоставляет halt и status как обработчики start/stop/status",
|
||||
"metadata": {"qname": "ControlChannelBridge", "kind": "class"},
|
||||
},
|
||||
{
|
||||
"source": "src/config_manager/v2/core/control_bridge.py",
|
||||
"layer": "C2_DEPENDENCY_GRAPH",
|
||||
"title": "ControlChannelBridge.on_start:calls",
|
||||
"content": "ControlChannelBridge.on_start calls self._start_runtime",
|
||||
"metadata": {"src_qname": "ControlChannelBridge.on_start", "dst_ref": "self._start_runtime"},
|
||||
},
|
||||
{
|
||||
"source": "src/config_manager/v2/__init__.py",
|
||||
"layer": "C0_SOURCE_CHUNKS",
|
||||
"title": "src/config_manager/v2/__init__.py:1-6",
|
||||
"content": '"""Контракт: управление через API (config.yaml, секция management)."""',
|
||||
"metadata": {},
|
||||
},
|
||||
]
|
||||
)
|
||||
actions = ExplainActions()
|
||||
|
||||
actions.collect_sources(ctx)
|
||||
actions.extract_logic(ctx)
|
||||
actions.summarize(ctx)
|
||||
|
||||
answer = str(ctx.artifacts.get_content("final_answer", ""))
|
||||
assert "### Что видно по коду" in answer
|
||||
assert "ConfigManager` в проекте доступен как alias" in answer
|
||||
assert "ControlChannelBridge.on_start" in answer
|
||||
assert "### Где смотреть в проекте" in answer
|
||||
assert "В индексе нет точного символа" not in answer
|
||||
assert "отдельный интерфейс управления" in answer
|
||||
175
tests/unit_tests/agent/orchestrator/test_orchestrator_service.py
Normal file
175
tests/unit_tests/agent/orchestrator/test_orchestrator_service.py
Normal file
@@ -0,0 +1,175 @@
|
||||
import asyncio
|
||||
|
||||
from app.modules.agent.engine.orchestrator.models import (
|
||||
OutputContract,
|
||||
RoutingMeta,
|
||||
Scenario,
|
||||
TaskConstraints,
|
||||
TaskSpec,
|
||||
)
|
||||
from app.modules.agent.engine.orchestrator.service import OrchestratorService
|
||||
|
||||
|
||||
class DummyGraph:
|
||||
pass
|
||||
|
||||
|
||||
def _task(scenario: Scenario, *, domain_id: str = "project", process_id: str = "qa") -> TaskSpec:
|
||||
allow_writes = scenario in {Scenario.DOCS_FROM_ANALYTICS, Scenario.TARGETED_EDIT, Scenario.GHERKIN_MODEL}
|
||||
return TaskSpec(
|
||||
task_id="task-1",
|
||||
dialog_session_id="dialog-1",
|
||||
rag_session_id="rag-1",
|
||||
mode="auto",
|
||||
user_message="Explain this module",
|
||||
scenario=scenario,
|
||||
routing=RoutingMeta(domain_id=domain_id, process_id=process_id, confidence=0.95, reason="unit-test"),
|
||||
constraints=TaskConstraints(allow_writes=allow_writes, max_steps=16, max_retries_per_step=2, step_timeout_sec=90),
|
||||
output_contract=OutputContract(result_type="answer"),
|
||||
metadata={
|
||||
"rag_context": "RAG",
|
||||
"confluence_context": "",
|
||||
"files_map": {},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def test_orchestrator_service_returns_answer() -> None:
|
||||
service = OrchestratorService()
|
||||
|
||||
def graph_resolver(domain_id: str, process_id: str):
|
||||
assert domain_id == "default"
|
||||
assert process_id == "general"
|
||||
return DummyGraph()
|
||||
|
||||
def graph_invoker(_graph, state: dict, dialog_session_id: str):
|
||||
assert state["message"] == "Explain this module"
|
||||
assert dialog_session_id == "dialog-1"
|
||||
return {"answer": "It works.", "changeset": []}
|
||||
|
||||
result = asyncio.run(
|
||||
service.run(
|
||||
task=_task(Scenario.GENERAL_QA, domain_id="default", process_id="general"),
|
||||
graph_resolver=graph_resolver,
|
||||
graph_invoker=graph_invoker,
|
||||
)
|
||||
)
|
||||
assert result.answer == "It works."
|
||||
assert result.meta["plan"]["status"] == "completed"
|
||||
|
||||
|
||||
def test_orchestrator_service_generates_changeset_for_docs_scenario() -> None:
|
||||
service = OrchestratorService()
|
||||
|
||||
def graph_resolver(_domain_id: str, _process_id: str):
|
||||
return DummyGraph()
|
||||
|
||||
def graph_invoker(_graph, _state: dict, _dialog_session_id: str):
|
||||
return {"answer": "unused", "changeset": []}
|
||||
|
||||
result = asyncio.run(
|
||||
service.run(
|
||||
task=_task(Scenario.DOCS_FROM_ANALYTICS),
|
||||
graph_resolver=graph_resolver,
|
||||
graph_invoker=graph_invoker,
|
||||
)
|
||||
)
|
||||
assert result.meta["plan"]["status"] == "completed"
|
||||
assert len(result.changeset) > 0
|
||||
|
||||
|
||||
def test_orchestrator_service_uses_project_qa_reasoning_without_graph() -> None:
|
||||
service = OrchestratorService()
|
||||
requested_graphs: list[tuple[str, str]] = []
|
||||
|
||||
def graph_resolver(domain_id: str, process_id: str):
|
||||
requested_graphs.append((domain_id, process_id))
|
||||
return DummyGraph()
|
||||
|
||||
def graph_invoker(_graph, state: dict, _dialog_session_id: str):
|
||||
if "resolved_request" not in state:
|
||||
return {
|
||||
"resolved_request": {
|
||||
"original_message": state["message"],
|
||||
"normalized_message": state["message"],
|
||||
"subject_hint": "",
|
||||
"source_hint": "code",
|
||||
"russian": True,
|
||||
}
|
||||
}
|
||||
if "question_profile" not in state:
|
||||
return {
|
||||
"question_profile": {
|
||||
"domain": "code",
|
||||
"intent": "inventory",
|
||||
"terms": ["control", "channel"],
|
||||
"entities": [],
|
||||
"russian": True,
|
||||
}
|
||||
}
|
||||
if "source_bundle" not in state:
|
||||
return {
|
||||
"source_bundle": {
|
||||
"profile": state["question_profile"],
|
||||
"rag_items": [],
|
||||
"file_candidates": [
|
||||
{"path": "src/config_manager/v2/control/base.py", "content": "class ControlChannel: pass"},
|
||||
{"path": "src/config_manager/v2/control/http_channel.py", "content": "class HttpControlChannel(ControlChannel): pass # http api"},
|
||||
],
|
||||
"rag_total": 0,
|
||||
"files_total": 2,
|
||||
}
|
||||
}
|
||||
if "analysis_brief" not in state:
|
||||
return {
|
||||
"analysis_brief": {
|
||||
"subject": "management channels",
|
||||
"findings": ["В коде найдены конкретные реализации каналов управления: http channel (`src/config_manager/v2/control/http_channel.py`)."],
|
||||
"evidence": ["src/config_manager/v2/control/http_channel.py"],
|
||||
"gaps": [],
|
||||
"answer_mode": "inventory",
|
||||
}
|
||||
}
|
||||
return {
|
||||
"answer_brief": {
|
||||
"question_profile": state["question_profile"],
|
||||
"resolved_subject": "management channels",
|
||||
"key_findings": ["В коде найдены конкретные реализации каналов управления: http channel (`src/config_manager/v2/control/http_channel.py`)."],
|
||||
"supporting_evidence": ["src/config_manager/v2/control/http_channel.py"],
|
||||
"missing_evidence": [],
|
||||
"answer_mode": "inventory",
|
||||
},
|
||||
"final_answer": "## Кратко\n### Что реализовано\n- В коде найдены конкретные реализации каналов управления: http channel (`src/config_manager/v2/control/http_channel.py`).",
|
||||
}
|
||||
|
||||
task = _task(Scenario.GENERAL_QA).model_copy(
|
||||
update={
|
||||
"user_message": "Какие каналы управления уже реализованы?",
|
||||
"metadata": {
|
||||
"rag_context": "",
|
||||
"confluence_context": "",
|
||||
"files_map": {
|
||||
"src/config_manager/v2/control/base.py": {
|
||||
"content": "class ControlChannel:\n async def start(self):\n ..."
|
||||
},
|
||||
"src/config_manager/v2/control/http_channel.py": {
|
||||
"content": "class HttpControlChannel(ControlChannel):\n async def start(self):\n ...\n# http api"
|
||||
},
|
||||
},
|
||||
"rag_items": [],
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
result = asyncio.run(service.run(task=task, graph_resolver=graph_resolver, graph_invoker=graph_invoker))
|
||||
|
||||
assert "Что реализовано" in result.answer
|
||||
assert "http channel" in result.answer.lower()
|
||||
assert result.meta["plan"]["status"] == "completed"
|
||||
assert requested_graphs == [
|
||||
("project_qa", "conversation_understanding"),
|
||||
("project_qa", "question_classification"),
|
||||
("project_qa", "context_retrieval"),
|
||||
("project_qa", "context_analysis"),
|
||||
("project_qa", "answer_composition"),
|
||||
]
|
||||
49
tests/unit_tests/agent/orchestrator/test_plan_validator.py
Normal file
49
tests/unit_tests/agent/orchestrator/test_plan_validator.py
Normal file
@@ -0,0 +1,49 @@
|
||||
from app.modules.agent.engine.orchestrator.models import (
|
||||
ExecutionPlan,
|
||||
OutputContract,
|
||||
PlanStep,
|
||||
RetryPolicy,
|
||||
RoutingMeta,
|
||||
Scenario,
|
||||
TaskConstraints,
|
||||
TaskSpec,
|
||||
)
|
||||
from app.modules.agent.engine.orchestrator.plan_validator import PlanValidator
|
||||
|
||||
|
||||
def _task(*, allow_writes: bool) -> TaskSpec:
|
||||
return TaskSpec(
|
||||
task_id="t1",
|
||||
dialog_session_id="d1",
|
||||
rag_session_id="r1",
|
||||
mode="auto",
|
||||
user_message="hello",
|
||||
scenario=Scenario.GENERAL_QA,
|
||||
routing=RoutingMeta(domain_id="default", process_id="general", confidence=0.9, reason="test"),
|
||||
constraints=TaskConstraints(allow_writes=allow_writes, max_steps=10, max_retries_per_step=2, step_timeout_sec=60),
|
||||
output_contract=OutputContract(result_type="answer"),
|
||||
)
|
||||
|
||||
|
||||
def test_plan_validator_rejects_write_step_when_not_allowed() -> None:
|
||||
plan = ExecutionPlan(
|
||||
plan_id="p1",
|
||||
task_id="t1",
|
||||
scenario=Scenario.GENERAL_QA,
|
||||
template_id="tmp",
|
||||
template_version="1.0",
|
||||
steps=[
|
||||
PlanStep(
|
||||
step_id="s1",
|
||||
title="write",
|
||||
action_id="collect_state",
|
||||
executor="function",
|
||||
side_effect="write",
|
||||
retry=RetryPolicy(max_attempts=1),
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
errors = PlanValidator().validate(plan, _task(allow_writes=False))
|
||||
|
||||
assert "write_step_not_allowed:s1" in errors
|
||||
@@ -0,0 +1,71 @@
|
||||
from app.modules.agent.engine.orchestrator.actions.project_qa_actions import ProjectQaActions
|
||||
from app.modules.agent.engine.orchestrator.execution_context import ExecutionContext
|
||||
from app.modules.agent.engine.orchestrator.models import (
|
||||
ExecutionPlan,
|
||||
OutputContract,
|
||||
RoutingMeta,
|
||||
Scenario,
|
||||
TaskConstraints,
|
||||
TaskSpec,
|
||||
)
|
||||
|
||||
|
||||
def _ctx(message: str, rag_items: list[dict], files_map: dict[str, dict]) -> ExecutionContext:
|
||||
task = TaskSpec(
|
||||
task_id="task-1",
|
||||
dialog_session_id="dialog-1",
|
||||
rag_session_id="rag-1",
|
||||
user_message=message,
|
||||
scenario=Scenario.GENERAL_QA,
|
||||
routing=RoutingMeta(domain_id="project", process_id="qa", confidence=0.9, reason="test"),
|
||||
constraints=TaskConstraints(),
|
||||
output_contract=OutputContract(result_type="answer"),
|
||||
metadata={
|
||||
"rag_items": rag_items,
|
||||
"rag_context": "",
|
||||
"confluence_context": "",
|
||||
"files_map": files_map,
|
||||
},
|
||||
)
|
||||
plan = ExecutionPlan(
|
||||
plan_id="plan-1",
|
||||
task_id="task-1",
|
||||
scenario=Scenario.GENERAL_QA,
|
||||
template_id="tpl",
|
||||
template_version="1",
|
||||
steps=[],
|
||||
)
|
||||
return ExecutionContext(task=task, plan=plan, graph_resolver=lambda *_: None, graph_invoker=lambda *_: {})
|
||||
|
||||
|
||||
def test_project_qa_actions_build_inventory_answer_from_code_sources() -> None:
|
||||
ctx = _ctx(
|
||||
"Какие каналы управления уже реализованы?",
|
||||
[],
|
||||
{
|
||||
"src/config_manager/v2/control/base.py": {"content": "class ControlChannel:\n async def start(self):\n ..."},
|
||||
"src/config_manager/v2/core/control_bridge.py": {
|
||||
"content": "class ControlChannelBridge:\n async def on_start(self):\n ...\n async def on_status(self):\n ..."
|
||||
},
|
||||
"src/config_manager/v2/control/http_channel.py": {
|
||||
"content": "class HttpControlChannel(ControlChannel):\n async def start(self):\n ...\n# http api"
|
||||
},
|
||||
"src/config_manager/v2/control/telegram_channel.py": {
|
||||
"content": "class TelegramControlChannel(ControlChannel):\n async def start(self):\n ...\n# telegram bot"
|
||||
},
|
||||
},
|
||||
)
|
||||
actions = ProjectQaActions()
|
||||
|
||||
actions.classify_project_question(ctx)
|
||||
actions.collect_project_sources(ctx)
|
||||
actions.analyze_project_sources(ctx)
|
||||
actions.build_project_answer_brief(ctx)
|
||||
actions.compose_project_answer(ctx)
|
||||
|
||||
answer = str(ctx.artifacts.get_content("final_answer", ""))
|
||||
assert "### Что реализовано" in answer
|
||||
assert "http channel" in answer.lower()
|
||||
assert "telegram channel" in answer.lower()
|
||||
assert "### Где смотреть в проекте" in answer
|
||||
|
||||
@@ -0,0 +1,74 @@
|
||||
import sys
|
||||
import types
|
||||
|
||||
langgraph = types.ModuleType("langgraph")
|
||||
langgraph_graph = types.ModuleType("langgraph.graph")
|
||||
langgraph_graph.END = "END"
|
||||
langgraph_graph.START = "START"
|
||||
langgraph_graph.StateGraph = object
|
||||
sys.modules.setdefault("langgraph", langgraph)
|
||||
sys.modules.setdefault("langgraph.graph", langgraph_graph)
|
||||
|
||||
from app.modules.agent.engine.graphs.project_qa_step_graphs import ProjectQaAnswerGraphFactory
|
||||
|
||||
|
||||
class _FakeLlm:
|
||||
def __init__(self) -> None:
|
||||
self.calls: list[tuple[str, str, str | None]] = []
|
||||
|
||||
def generate(self, prompt_name: str, user_input: str, *, log_context: str | None = None) -> str:
|
||||
self.calls.append((prompt_name, user_input, log_context))
|
||||
return "## Summary\n[entrypoint_1] [excerpt_1]"
|
||||
|
||||
|
||||
def test_project_qa_answer_graph_uses_v2_prompt_when_explain_pack_present() -> None:
|
||||
llm = _FakeLlm()
|
||||
factory = ProjectQaAnswerGraphFactory(llm)
|
||||
|
||||
result = factory._compose_answer(
|
||||
{
|
||||
"message": "Explain endpoint get_user",
|
||||
"question_profile": {"russian": False},
|
||||
"analysis_brief": {"findings": [], "evidence": [], "gaps": [], "answer_mode": "summary"},
|
||||
"explain_pack": {
|
||||
"intent": {
|
||||
"raw_query": "Explain endpoint get_user",
|
||||
"normalized_query": "Explain endpoint get_user",
|
||||
"keywords": ["get_user"],
|
||||
"hints": {"paths": [], "symbols": [], "endpoints": [], "commands": []},
|
||||
"expected_entry_types": ["http"],
|
||||
"depth": "medium",
|
||||
},
|
||||
"selected_entrypoints": [],
|
||||
"seed_symbols": [],
|
||||
"trace_paths": [],
|
||||
"evidence_index": {
|
||||
"entrypoint_1": {
|
||||
"evidence_id": "entrypoint_1",
|
||||
"kind": "entrypoint",
|
||||
"summary": "/users/{id}",
|
||||
"location": {"path": "app/api/users.py", "start_line": 10, "end_line": 10},
|
||||
"supports": ["handler-1"],
|
||||
}
|
||||
},
|
||||
"code_excerpts": [
|
||||
{
|
||||
"evidence_id": "excerpt_1",
|
||||
"symbol_id": "handler-1",
|
||||
"title": "get_user",
|
||||
"path": "app/api/users.py",
|
||||
"start_line": 10,
|
||||
"end_line": 18,
|
||||
"content": "async def get_user():\n return 1",
|
||||
"focus": "overview",
|
||||
}
|
||||
],
|
||||
"missing": [],
|
||||
"conflicts": [],
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
assert result["final_answer"].startswith("## Summary")
|
||||
assert llm.calls[0][0] == "code_explain_answer_v2"
|
||||
assert '"evidence_id": "excerpt_1"' in llm.calls[0][1]
|
||||
@@ -0,0 +1,49 @@
|
||||
import sys
|
||||
import types
|
||||
|
||||
langgraph = types.ModuleType("langgraph")
|
||||
langgraph_graph = types.ModuleType("langgraph.graph")
|
||||
langgraph_graph.END = "END"
|
||||
langgraph_graph.START = "START"
|
||||
langgraph_graph.StateGraph = object
|
||||
sys.modules.setdefault("langgraph", langgraph)
|
||||
sys.modules.setdefault("langgraph.graph", langgraph_graph)
|
||||
|
||||
from app.modules.agent.engine.graphs.project_qa_step_graphs import ProjectQaRetrievalGraphFactory
|
||||
|
||||
|
||||
class _FailingRag:
|
||||
async def retrieve(self, rag_session_id: str, query: str):
|
||||
raise AssertionError("legacy rag should not be called for explain_part")
|
||||
|
||||
|
||||
def test_project_qa_retrieval_skips_legacy_rag_for_explain_part() -> None:
|
||||
factory = ProjectQaRetrievalGraphFactory(_FailingRag())
|
||||
|
||||
result = factory._retrieve_context(
|
||||
{
|
||||
"scenario": "explain_part",
|
||||
"project_id": "rag-1",
|
||||
"resolved_request": {
|
||||
"original_message": "Explain how ConfigManager works",
|
||||
"normalized_message": "Explain how ConfigManager works",
|
||||
},
|
||||
"question_profile": {
|
||||
"domain": "code",
|
||||
"intent": "explain",
|
||||
"terms": ["configmanager"],
|
||||
"entities": ["ConfigManager"],
|
||||
"russian": False,
|
||||
},
|
||||
"files_map": {
|
||||
"src/config_manager/__init__.py": {
|
||||
"content": "from .v2 import ConfigManagerV2 as ConfigManager",
|
||||
"content_hash": "hash-1",
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
bundle = result["source_bundle"]
|
||||
assert bundle["rag_items"] == []
|
||||
assert bundle["files_total"] >= 1
|
||||
42
tests/unit_tests/agent/orchestrator/test_quality_metrics.py
Normal file
42
tests/unit_tests/agent/orchestrator/test_quality_metrics.py
Normal file
@@ -0,0 +1,42 @@
|
||||
import asyncio
|
||||
|
||||
from app.modules.agent.engine.orchestrator.models import OutputContract, OutputSection, RoutingMeta, Scenario, TaskConstraints, TaskSpec
|
||||
from app.modules.agent.engine.orchestrator.service import OrchestratorService
|
||||
|
||||
|
||||
def test_quality_metrics_present_and_scored() -> None:
|
||||
service = OrchestratorService()
|
||||
task = TaskSpec(
|
||||
task_id="quality-1",
|
||||
dialog_session_id="dialog-1",
|
||||
rag_session_id="rag-1",
|
||||
mode="auto",
|
||||
user_message="Explain architecture",
|
||||
scenario=Scenario.EXPLAIN_PART,
|
||||
routing=RoutingMeta(domain_id="project", process_id="qa", confidence=0.9, reason="test"),
|
||||
constraints=TaskConstraints(allow_writes=False),
|
||||
output_contract=OutputContract(
|
||||
result_type="answer",
|
||||
sections=[
|
||||
OutputSection(name="sequence_diagram", format="mermaid"),
|
||||
OutputSection(name="use_cases", format="markdown"),
|
||||
OutputSection(name="summary", format="markdown"),
|
||||
],
|
||||
),
|
||||
metadata={"rag_context": "A\nB", "confluence_context": "", "files_map": {}},
|
||||
)
|
||||
|
||||
result = asyncio.run(
|
||||
service.run(
|
||||
task=task,
|
||||
graph_resolver=lambda _d, _p: object(),
|
||||
graph_invoker=lambda _g, _s, _id: {"answer": "unused", "changeset": []},
|
||||
)
|
||||
)
|
||||
|
||||
quality = result.meta.get("quality", {})
|
||||
assert quality
|
||||
assert quality.get("faithfulness", {}).get("score") is not None
|
||||
assert quality.get("coverage", {}).get("score") is not None
|
||||
assert quality.get("status") in {"ok", "needs_review", "fail"}
|
||||
assert quality.get("coverage", {}).get("covered_count", 0) >= 1
|
||||
@@ -0,0 +1,50 @@
|
||||
from app.modules.agent.engine.orchestrator.models import (
|
||||
ArtifactType,
|
||||
OutputContract,
|
||||
OutputSection,
|
||||
RoutingMeta,
|
||||
Scenario,
|
||||
TaskConstraints,
|
||||
TaskSpec,
|
||||
)
|
||||
from app.modules.agent.engine.orchestrator.quality_metrics import QualityMetricsCalculator
|
||||
from app.modules.agent.engine.orchestrator.template_registry import ScenarioTemplateRegistry
|
||||
from app.modules.agent.engine.orchestrator.execution_context import ExecutionContext
|
||||
from app.modules.agent.engine.orchestrator.models import PlanStatus
|
||||
|
||||
|
||||
def test_quality_metrics_coverage_reflects_missing_required_sections() -> None:
|
||||
task = TaskSpec(
|
||||
task_id="quality-2",
|
||||
dialog_session_id="dialog-1",
|
||||
rag_session_id="rag-1",
|
||||
mode="auto",
|
||||
user_message="Explain architecture",
|
||||
scenario=Scenario.EXPLAIN_PART,
|
||||
routing=RoutingMeta(domain_id="project", process_id="qa", confidence=0.9, reason="test"),
|
||||
constraints=TaskConstraints(allow_writes=False),
|
||||
output_contract=OutputContract(
|
||||
result_type="answer",
|
||||
sections=[
|
||||
OutputSection(name="sequence_diagram", format="mermaid"),
|
||||
OutputSection(name="use_cases", format="markdown"),
|
||||
OutputSection(name="summary", format="markdown"),
|
||||
],
|
||||
),
|
||||
metadata={"rag_context": "A", "confluence_context": "", "files_map": {}},
|
||||
)
|
||||
|
||||
plan = ScenarioTemplateRegistry().build(task)
|
||||
plan.status = PlanStatus.COMPLETED
|
||||
ctx = ExecutionContext(
|
||||
task=task,
|
||||
plan=plan,
|
||||
graph_resolver=lambda _d, _p: object(),
|
||||
graph_invoker=lambda _g, _s, _id: {},
|
||||
)
|
||||
ctx.artifacts.put(key="final_answer", artifact_type=ArtifactType.TEXT, content="Only summary text")
|
||||
|
||||
metrics = QualityMetricsCalculator().build(ctx, step_results=[])
|
||||
|
||||
assert metrics["coverage"]["score"] < 1.0
|
||||
assert "sequence_diagram" in metrics["coverage"]["missing_items"]
|
||||
@@ -0,0 +1,48 @@
|
||||
from app.modules.agent.engine.orchestrator.models import OutputContract, RoutingMeta, Scenario, TaskConstraints, TaskSpec
|
||||
from app.modules.agent.engine.orchestrator.template_registry import ScenarioTemplateRegistry
|
||||
|
||||
|
||||
def _task(scenario: Scenario) -> TaskSpec:
|
||||
return TaskSpec(
|
||||
task_id="t1",
|
||||
dialog_session_id="d1",
|
||||
rag_session_id="r1",
|
||||
mode="auto",
|
||||
user_message="run scenario",
|
||||
scenario=scenario,
|
||||
routing=RoutingMeta(domain_id="project", process_id="qa", confidence=0.9, reason="test"),
|
||||
constraints=TaskConstraints(
|
||||
allow_writes=scenario in {Scenario.DOCS_FROM_ANALYTICS, Scenario.TARGETED_EDIT, Scenario.GHERKIN_MODEL}
|
||||
),
|
||||
output_contract=OutputContract(result_type="answer"),
|
||||
metadata={"rag_context": "ctx", "confluence_context": "", "files_map": {}},
|
||||
)
|
||||
|
||||
|
||||
def test_template_registry_has_multi_step_review_docs_edit_gherkin() -> None:
|
||||
registry = ScenarioTemplateRegistry()
|
||||
|
||||
review_steps = [step.step_id for step in registry.build(_task(Scenario.ANALYTICS_REVIEW)).steps]
|
||||
docs_steps = [step.step_id for step in registry.build(_task(Scenario.DOCS_FROM_ANALYTICS)).steps]
|
||||
edit_steps = [step.step_id for step in registry.build(_task(Scenario.TARGETED_EDIT)).steps]
|
||||
gherkin_steps = [step.step_id for step in registry.build(_task(Scenario.GHERKIN_MODEL)).steps]
|
||||
|
||||
assert "structural_check" in review_steps and "compose_review_report" in review_steps
|
||||
assert "extract_change_intents" in docs_steps and "build_changeset" in docs_steps
|
||||
assert "resolve_target" in edit_steps and "finalize_changeset" in edit_steps
|
||||
assert "generate_gherkin_bundle" in gherkin_steps and "validate_coverage" in gherkin_steps
|
||||
|
||||
assert len(review_steps) >= 7
|
||||
assert len(docs_steps) >= 9
|
||||
assert len(edit_steps) >= 7
|
||||
assert len(gherkin_steps) >= 8
|
||||
|
||||
|
||||
def test_template_registry_adds_code_explain_pack_step_for_project_explain() -> None:
|
||||
registry = ScenarioTemplateRegistry()
|
||||
|
||||
steps = [step.step_id for step in registry.build(_task(Scenario.EXPLAIN_PART)).steps]
|
||||
|
||||
assert "code_explain_pack_step" in steps
|
||||
assert steps.index("code_explain_pack_step") > steps.index("context_retrieval")
|
||||
assert steps.index("code_explain_pack_step") < steps.index("context_analysis")
|
||||
48
tests/unit_tests/agent/test_gigachat_client_retry.py
Normal file
48
tests/unit_tests/agent/test_gigachat_client_retry.py
Normal file
@@ -0,0 +1,48 @@
|
||||
import requests
|
||||
|
||||
from app.modules.shared.gigachat.client import GigaChatClient
|
||||
from app.modules.shared.gigachat.settings import GigaChatSettings
|
||||
|
||||
|
||||
class _FakeTokenProvider:
|
||||
def get_access_token(self) -> str:
|
||||
return "token"
|
||||
|
||||
|
||||
class _FakeResponse:
|
||||
def __init__(self, status_code: int, payload: dict, text: str = "") -> None:
|
||||
self.status_code = status_code
|
||||
self._payload = payload
|
||||
self.text = text
|
||||
|
||||
def json(self) -> dict:
|
||||
return self._payload
|
||||
|
||||
|
||||
def test_gigachat_client_retries_transient_http_errors(monkeypatch) -> None:
|
||||
calls = {"count": 0}
|
||||
|
||||
def fake_post(*args, **kwargs):
|
||||
calls["count"] += 1
|
||||
if calls["count"] == 1:
|
||||
return _FakeResponse(503, {}, "temporary")
|
||||
return _FakeResponse(200, {"choices": [{"message": {"content": "ok"}}]})
|
||||
|
||||
monkeypatch.setattr(requests, "post", fake_post)
|
||||
client = GigaChatClient(
|
||||
GigaChatSettings(
|
||||
auth_url="https://auth.example.test",
|
||||
api_url="https://api.example.test",
|
||||
scope="scope",
|
||||
credentials="secret",
|
||||
ssl_verify=True,
|
||||
model="model",
|
||||
embedding_model="embed",
|
||||
),
|
||||
_FakeTokenProvider(),
|
||||
)
|
||||
|
||||
result = client.complete("system", "user")
|
||||
|
||||
assert result == "ok"
|
||||
assert calls["count"] == 2
|
||||
30
tests/unit_tests/agent/test_llm_service_logging.py
Normal file
30
tests/unit_tests/agent/test_llm_service_logging.py
Normal file
@@ -0,0 +1,30 @@
|
||||
import logging
|
||||
|
||||
from app.modules.agent.llm.service import AgentLlmService
|
||||
|
||||
|
||||
class _FakeClient:
|
||||
def complete(self, *, system_prompt: str, user_prompt: str) -> str:
|
||||
assert system_prompt == "System prompt"
|
||||
assert user_prompt == "User input"
|
||||
return "LLM output"
|
||||
|
||||
|
||||
class _FakePrompts:
|
||||
def load(self, prompt_name: str) -> str:
|
||||
assert prompt_name == "general_answer"
|
||||
return "System prompt"
|
||||
|
||||
|
||||
def test_llm_service_logs_input_and_output_for_graph_context(caplog) -> None:
|
||||
service = AgentLlmService(_FakeClient(), _FakePrompts())
|
||||
|
||||
with caplog.at_level(logging.WARNING, logger="app.modules.agent.llm.service"):
|
||||
result = service.generate("general_answer", "User input", log_context="graph.default.answer")
|
||||
|
||||
assert result == "LLM output"
|
||||
messages = [record.getMessage() for record in caplog.records]
|
||||
assert any("graph llm input: context=graph.default.answer" in message for message in messages)
|
||||
assert any("graph llm output: context=graph.default.answer" in message for message in messages)
|
||||
assert any("User input" in message for message in messages)
|
||||
assert any("LLM output" in message for message in messages)
|
||||
24
tests/unit_tests/agent/test_logging_setup.py
Normal file
24
tests/unit_tests/agent/test_logging_setup.py
Normal file
@@ -0,0 +1,24 @@
|
||||
import logging
|
||||
|
||||
from app.core.logging_setup import ScrubbingFormatter
|
||||
|
||||
|
||||
def test_scrubbing_formatter_redacts_identifiers_and_adds_blank_line() -> None:
|
||||
formatter = ScrubbingFormatter("%(levelname)s:%(name)s:%(message)s")
|
||||
record = logging.LogRecord(
|
||||
name="test.logger",
|
||||
level=logging.WARNING,
|
||||
pathname=__file__,
|
||||
lineno=1,
|
||||
msg="router decision: task_id=task-1 dialog_session_id=dialog-1 graph_id=project_qa/context_retrieval",
|
||||
args=(),
|
||||
exc_info=None,
|
||||
)
|
||||
|
||||
rendered = formatter.format(record)
|
||||
|
||||
assert "task_id=<redacted>" in rendered
|
||||
assert "dialog_session_id=<redacted>" in rendered
|
||||
assert "graph_id=<redacted>" in rendered
|
||||
assert "task-1" not in rendered
|
||||
assert rendered.endswith("\n")
|
||||
98
tests/unit_tests/agent/test_repo_webhook_service.py
Normal file
98
tests/unit_tests/agent/test_repo_webhook_service.py
Normal file
@@ -0,0 +1,98 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from app.modules.rag.webhook_service import RepoWebhookService
|
||||
|
||||
|
||||
class FakeStoryWriter:
|
||||
def __init__(self) -> None:
|
||||
self.calls: list[dict] = []
|
||||
|
||||
def record_story_commit(self, **kwargs) -> None:
|
||||
self.calls.append(kwargs)
|
||||
|
||||
|
||||
class FakeCacheWriter:
|
||||
def __init__(self) -> None:
|
||||
self.calls: list[dict] = []
|
||||
|
||||
def record_repo_cache(self, **kwargs) -> None:
|
||||
self.calls.append(kwargs)
|
||||
|
||||
|
||||
def test_gitea_webhook_binds_story() -> None:
|
||||
writer = FakeStoryWriter()
|
||||
cache = FakeCacheWriter()
|
||||
service = RepoWebhookService(writer, cache)
|
||||
|
||||
result = service.process(
|
||||
provider="gitea",
|
||||
payload={
|
||||
"repository": {"full_name": "acme/proj"},
|
||||
"ref": "refs/heads/feature/AAAA-1234",
|
||||
"pusher": {"username": "alice"},
|
||||
"commits": [
|
||||
{
|
||||
"id": "abc123",
|
||||
"message": "FEAT-1 update docs",
|
||||
"added": ["docs/new.md"],
|
||||
"modified": ["docs/api.md"],
|
||||
"removed": [],
|
||||
}
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
assert result["accepted"] is True
|
||||
assert result["story_bound"] is True
|
||||
assert result["story_id"] == "FEAT-1"
|
||||
assert result["cache_recorded"] is True
|
||||
assert len(writer.calls) == 1
|
||||
assert len(cache.calls) == 1
|
||||
assert writer.calls[0]["project_id"] == "acme/proj"
|
||||
|
||||
|
||||
def test_webhook_without_story_id_is_non_fatal() -> None:
|
||||
writer = FakeStoryWriter()
|
||||
cache = FakeCacheWriter()
|
||||
service = RepoWebhookService(writer, cache)
|
||||
|
||||
result = service.process(
|
||||
provider="bitbucket",
|
||||
payload={
|
||||
"repository": {"full_name": "acme/proj"},
|
||||
"push": {
|
||||
"changes": [
|
||||
{
|
||||
"new": {
|
||||
"name": "feature/no-story",
|
||||
"target": {"hash": "abc123", "message": "update docs"},
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
assert result["accepted"] is True
|
||||
assert result["story_bound"] is False
|
||||
assert result["cache_recorded"] is True
|
||||
assert len(cache.calls) == 1
|
||||
assert writer.calls == []
|
||||
|
||||
|
||||
def test_provider_autodetect_by_headers() -> None:
|
||||
writer = FakeStoryWriter()
|
||||
service = RepoWebhookService(writer)
|
||||
|
||||
result = service.process(
|
||||
headers={"X-Gitea-Event": "push"},
|
||||
payload={
|
||||
"repository": {"full_name": "acme/proj"},
|
||||
"ref": "refs/heads/feature/AAAA-1234",
|
||||
"commits": [{"id": "abc123", "message": "AAAA-1234 update"}],
|
||||
},
|
||||
)
|
||||
|
||||
assert result["accepted"] is True
|
||||
assert result["story_bound"] is True
|
||||
assert result["story_id"] == "AAAA-1234"
|
||||
48
tests/unit_tests/agent/test_story_session_recorder.py
Normal file
48
tests/unit_tests/agent/test_story_session_recorder.py
Normal file
@@ -0,0 +1,48 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from app.modules.agent.story_session_recorder import StorySessionRecorder
|
||||
from app.schemas.changeset import ChangeItem, ChangeOp
|
||||
|
||||
|
||||
class FakeStoryRepo:
|
||||
def __init__(self) -> None:
|
||||
self.calls: list[dict] = []
|
||||
|
||||
def add_session_artifact(self, **kwargs) -> None:
|
||||
self.calls.append(kwargs)
|
||||
|
||||
|
||||
def test_record_run_stores_attachment_and_changeset_artifacts() -> None:
|
||||
repo = FakeStoryRepo()
|
||||
recorder = StorySessionRecorder(repo)
|
||||
|
||||
recorder.record_run(
|
||||
dialog_session_id="dialog-1",
|
||||
rag_session_id="rag-1",
|
||||
scenario="docs_from_analytics",
|
||||
attachments=[
|
||||
{"type": "confluence_url", "value": "https://example.org/doc"},
|
||||
{"type": "file_ref", "value": "local.md"},
|
||||
],
|
||||
answer="Generated docs update summary",
|
||||
changeset=[
|
||||
ChangeItem(
|
||||
op=ChangeOp.UPDATE,
|
||||
path="docs/api.md",
|
||||
base_hash="abc",
|
||||
proposed_content="new",
|
||||
reason="sync endpoint section",
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
assert len(repo.calls) == 3
|
||||
assert repo.calls[0]["artifact_role"] == "analysis"
|
||||
assert repo.calls[0]["source_ref"] == "https://example.org/doc"
|
||||
|
||||
assert repo.calls[1]["artifact_role"] == "doc_change"
|
||||
assert repo.calls[1]["summary"] == "Generated docs update summary"
|
||||
|
||||
assert repo.calls[2]["artifact_role"] == "doc_change"
|
||||
assert repo.calls[2]["path"] == "docs/api.md"
|
||||
assert repo.calls[2]["change_type"] == "updated"
|
||||
70
tests/unit_tests/chat/test_chat_api_simple_code_explain.py
Normal file
70
tests/unit_tests/chat/test_chat_api_simple_code_explain.py
Normal file
@@ -0,0 +1,70 @@
|
||||
import asyncio
|
||||
|
||||
from app.modules.chat.module import ChatModule
|
||||
from app.modules.chat.task_store import TaskStore
|
||||
from app.schemas.chat import ChatMessageRequest
|
||||
from app.schemas.chat import TaskQueuedResponse
|
||||
from app.modules.shared.event_bus import EventBus
|
||||
from app.modules.shared.retry_executor import RetryExecutor
|
||||
|
||||
|
||||
class _FakeRuntime:
|
||||
async def run(self, **kwargs):
|
||||
raise AssertionError("legacy runtime must not be called")
|
||||
|
||||
|
||||
class _FakeDirectChat:
|
||||
def __init__(self) -> None:
|
||||
self.calls = 0
|
||||
|
||||
async def handle_message(self, request):
|
||||
self.calls += 1
|
||||
return TaskQueuedResponse(
|
||||
task_id="task-1",
|
||||
status="done",
|
||||
)
|
||||
|
||||
|
||||
class _FakeRagSessions:
|
||||
def get(self, rag_session_id: str):
|
||||
return {"rag_session_id": rag_session_id}
|
||||
|
||||
|
||||
class _FakeRepository:
|
||||
def create_dialog(self, dialog_session_id: str, rag_session_id: str) -> None:
|
||||
return None
|
||||
|
||||
def get_dialog(self, dialog_session_id: str):
|
||||
return None
|
||||
|
||||
def add_message(self, dialog_session_id: str, role: str, content: str, task_id: str | None = None, payload: dict | None = None) -> None:
|
||||
return None
|
||||
|
||||
|
||||
def test_chat_messages_endpoint_uses_direct_service(monkeypatch) -> None:
|
||||
monkeypatch.setenv("SIMPLE_CODE_EXPLAIN_ONLY", "true")
|
||||
direct_chat = _FakeDirectChat()
|
||||
module = ChatModule(
|
||||
agent_runner=_FakeRuntime(),
|
||||
event_bus=EventBus(),
|
||||
retry=RetryExecutor(),
|
||||
rag_sessions=_FakeRagSessions(),
|
||||
repository=_FakeRepository(),
|
||||
direct_chat=direct_chat,
|
||||
task_store=TaskStore(),
|
||||
)
|
||||
router = module.public_router()
|
||||
endpoint = next(route.endpoint for route in router.routes if getattr(route, "path", "") == "/api/chat/messages")
|
||||
response = asyncio.run(
|
||||
endpoint(
|
||||
ChatMessageRequest(
|
||||
session_id="dialog-1",
|
||||
project_id="rag-1",
|
||||
message="Explain get_user",
|
||||
),
|
||||
None,
|
||||
)
|
||||
)
|
||||
|
||||
assert response.task_id == "task-1"
|
||||
assert direct_chat.calls == 1
|
||||
61
tests/unit_tests/chat/test_direct_service.py
Normal file
61
tests/unit_tests/chat/test_direct_service.py
Normal file
@@ -0,0 +1,61 @@
|
||||
import asyncio
|
||||
|
||||
from app.modules.chat.direct_service import CodeExplainChatService
|
||||
from app.modules.chat.session_resolver import ChatSessionResolver
|
||||
from app.modules.chat.task_store import TaskStore
|
||||
from app.modules.rag.explain.models import ExplainIntent, ExplainPack
|
||||
from app.schemas.chat import ChatFileContext, ChatMessageRequest
|
||||
|
||||
|
||||
class _FakeRetriever:
|
||||
def build_pack(self, rag_session_id: str, user_query: str, *, file_candidates: list[dict] | None = None) -> ExplainPack:
|
||||
return ExplainPack(
|
||||
intent=ExplainIntent(raw_query=user_query, normalized_query=user_query),
|
||||
missing=["code_excerpts"],
|
||||
)
|
||||
|
||||
|
||||
class _FakeLlm:
|
||||
def __init__(self) -> None:
|
||||
self.calls = 0
|
||||
|
||||
def generate(self, prompt_name: str, user_input: str, *, log_context: str | None = None) -> str:
|
||||
self.calls += 1
|
||||
return "should not be called"
|
||||
|
||||
|
||||
class _FakeDialogs:
|
||||
def get(self, dialog_session_id: str):
|
||||
return None
|
||||
|
||||
|
||||
def test_direct_service_skips_llm_when_evidence_is_insufficient() -> None:
|
||||
messages: list[tuple[str, str, str, str | None]] = []
|
||||
llm = _FakeLlm()
|
||||
task_store = TaskStore()
|
||||
service = CodeExplainChatService(
|
||||
retriever=_FakeRetriever(),
|
||||
llm=llm,
|
||||
session_resolver=ChatSessionResolver(_FakeDialogs(), lambda rag_session_id: rag_session_id == "rag-1"),
|
||||
task_store=task_store,
|
||||
message_sink=lambda dialog_session_id, role, content, task_id=None: messages.append((dialog_session_id, role, content, task_id)),
|
||||
)
|
||||
|
||||
result = asyncio.run(
|
||||
service.handle_message(
|
||||
ChatMessageRequest(
|
||||
session_id="dialog-1",
|
||||
project_id="rag-1",
|
||||
message="Explain get_user",
|
||||
files=[ChatFileContext(path="app/api/users.py", content="", content_hash="x")],
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
task = task_store.get(result.task_id)
|
||||
assert task is not None
|
||||
assert task.answer is not None
|
||||
assert "Недостаточно опоры в коде" in task.answer
|
||||
assert result.status == "done"
|
||||
assert llm.calls == 0
|
||||
assert [item[1] for item in messages] == ["user", "assistant"]
|
||||
77
tests/unit_tests/rag/asserts_intent_router.py
Normal file
77
tests/unit_tests/rag/asserts_intent_router.py
Normal file
@@ -0,0 +1,77 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
|
||||
from app.modules.rag.intent_router_v2.models import IntentRouterResult
|
||||
|
||||
|
||||
def assert_intent(out: IntentRouterResult, expected: str) -> None:
|
||||
assert out.intent == expected
|
||||
|
||||
|
||||
def assert_domains(out: IntentRouterResult, expected: list[str]) -> None:
|
||||
assert out.retrieval_spec.domains == expected
|
||||
|
||||
|
||||
def assert_has_file_path(out: IntentRouterResult, path: str) -> None:
|
||||
assert any(anchor.type == "FILE_PATH" and anchor.value == path for anchor in out.query_plan.anchors)
|
||||
|
||||
|
||||
def assert_path_scope(out: IntentRouterResult, file_path: str, dir_path: str | None = None) -> None:
|
||||
scope = list(getattr(out.retrieval_spec.filters, "path_scope", []) or [])
|
||||
assert file_path in scope
|
||||
if dir_path is not None:
|
||||
assert dir_path in scope
|
||||
|
||||
|
||||
def assert_file_only_scope(out: IntentRouterResult, file_path: str) -> None:
|
||||
scope = list(getattr(out.retrieval_spec.filters, "path_scope", []) or [])
|
||||
assert scope == [file_path]
|
||||
|
||||
|
||||
def assert_spans_valid(out: IntentRouterResult) -> None:
|
||||
raw_len = len(out.query_plan.raw)
|
||||
for anchor in out.query_plan.anchors:
|
||||
if anchor.source == "conversation_state":
|
||||
assert anchor.span is None
|
||||
continue
|
||||
assert anchor.span is not None
|
||||
assert 0 <= anchor.span.start < anchor.span.end <= raw_len
|
||||
|
||||
|
||||
def assert_test_policy(out: IntentRouterResult, expected: str) -> None:
|
||||
assert getattr(out.retrieval_spec.filters, "test_policy", None) == expected
|
||||
|
||||
|
||||
def assert_sub_intent(out: IntentRouterResult, expected: str) -> None:
|
||||
assert out.query_plan.sub_intent == expected
|
||||
|
||||
|
||||
def assert_no_symbol_keyword(out: IntentRouterResult, forbidden: set[str] | None = None) -> None:
|
||||
denied = forbidden or {"def", "class", "return", "import", "from"}
|
||||
symbols = {anchor.value.lower() for anchor in out.query_plan.anchors if anchor.type == "SYMBOL"}
|
||||
assert symbols.isdisjoint({token.lower() for token in denied})
|
||||
|
||||
|
||||
def assert_domain_layer_prefixes(out: IntentRouterResult) -> None:
|
||||
prefixes = {layer.layer_id[0] for layer in out.retrieval_spec.layer_queries if layer.layer_id}
|
||||
if out.retrieval_spec.domains == ["CODE"]:
|
||||
assert prefixes <= {"C"}
|
||||
elif out.retrieval_spec.domains == ["DOCS"]:
|
||||
assert prefixes <= {"D"}
|
||||
else:
|
||||
assert prefixes <= {"C", "D"}
|
||||
|
||||
|
||||
def assert_no_symbol_leakage_from_paths(out: IntentRouterResult) -> None:
|
||||
file_values = [anchor.value for anchor in out.query_plan.anchors if anchor.type == "FILE_PATH"]
|
||||
if not file_values:
|
||||
return
|
||||
parts: set[str] = set()
|
||||
for value in file_values:
|
||||
for token in re.split(r"[/.]+", value.lower()):
|
||||
if token:
|
||||
parts.add(token)
|
||||
for anchor in out.query_plan.anchors:
|
||||
if anchor.type == "SYMBOL":
|
||||
assert anchor.value.lower() not in parts
|
||||
48
tests/unit_tests/rag/intent_router_testkit.py
Normal file
48
tests/unit_tests/rag/intent_router_testkit.py
Normal file
@@ -0,0 +1,48 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
from app.modules.rag.contracts.enums import RagLayer
|
||||
from app.modules.rag.intent_router_v2 import ConversationState, IntentRouterV2, RepoContext
|
||||
|
||||
|
||||
def repo_context() -> RepoContext:
|
||||
return RepoContext(
|
||||
languages=["python"],
|
||||
available_domains=["CODE", "DOCS"],
|
||||
available_layers=[
|
||||
RagLayer.CODE_ENTRYPOINTS,
|
||||
RagLayer.CODE_SYMBOL_CATALOG,
|
||||
RagLayer.CODE_DEPENDENCY_GRAPH,
|
||||
RagLayer.CODE_SEMANTIC_ROLES,
|
||||
RagLayer.CODE_SOURCE_CHUNKS,
|
||||
RagLayer.DOCS_MODULE_CATALOG,
|
||||
RagLayer.DOCS_FACT_INDEX,
|
||||
RagLayer.DOCS_SECTION_INDEX,
|
||||
RagLayer.DOCS_POLICY_INDEX,
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def run_sequence(queries: list[str], *, router: IntentRouterV2 | None = None, trace_label: str = "intent-router") -> list:
|
||||
active_router = router or IntentRouterV2()
|
||||
state = ConversationState()
|
||||
results = []
|
||||
for index, query in enumerate(queries, start=1):
|
||||
result = active_router.route(query, state, repo_context())
|
||||
print_trace(index, query, result, label=trace_label)
|
||||
results.append(result)
|
||||
state = state.advance(result)
|
||||
return results
|
||||
|
||||
|
||||
def run_single(query: str, *, router: IntentRouterV2 | None = None, trace_label: str = "intent-router"):
|
||||
result = run_sequence([query], router=router, trace_label=trace_label)[0]
|
||||
return result
|
||||
|
||||
|
||||
def print_trace(index: int, query: str, result, *, label: str = "intent-router") -> None:
|
||||
print(f"[{label}][turn {index}] input: {query}")
|
||||
print()
|
||||
print(f"[{label}][turn {index}] output: {json.dumps(result.model_dump(), ensure_ascii=False)}")
|
||||
print("=" * 50)
|
||||
217
tests/unit_tests/rag/test_code_indexing_pipeline.py
Normal file
217
tests/unit_tests/rag/test_code_indexing_pipeline.py
Normal file
@@ -0,0 +1,217 @@
|
||||
from app.modules.rag.contracts.enums import RagLayer
|
||||
from app.modules.rag.indexing.code.pipeline import CodeIndexingPipeline
|
||||
|
||||
|
||||
def test_code_pipeline_builds_source_symbols_edges_and_entrypoints() -> None:
|
||||
pipeline = CodeIndexingPipeline()
|
||||
content = """
|
||||
from fastapi import APIRouter
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
class UserService:
|
||||
def get_user(self, user_id):
|
||||
return user_id
|
||||
|
||||
@router.get("/users/{user_id}")
|
||||
async def get_user(user_id: str):
|
||||
service = UserService()
|
||||
return service.get_user(user_id)
|
||||
"""
|
||||
docs = pipeline.index_file(
|
||||
repo_id="acme/proj",
|
||||
commit_sha="abc123",
|
||||
path="app/api/users.py",
|
||||
content=content,
|
||||
)
|
||||
|
||||
layers = {doc.layer for doc in docs}
|
||||
assert RagLayer.CODE_SOURCE_CHUNKS in layers
|
||||
assert RagLayer.CODE_SYMBOL_CATALOG in layers
|
||||
assert RagLayer.CODE_DEPENDENCY_GRAPH in layers
|
||||
assert RagLayer.CODE_ENTRYPOINTS in layers
|
||||
assert RagLayer.CODE_SEMANTIC_ROLES in layers
|
||||
|
||||
symbol_doc = next(doc for doc in docs if doc.layer == RagLayer.CODE_SYMBOL_CATALOG and doc.metadata["kind"] == "function")
|
||||
assert "get_user" in symbol_doc.metadata["qname"]
|
||||
|
||||
edge_doc = next(doc for doc in docs if doc.layer == RagLayer.CODE_DEPENDENCY_GRAPH)
|
||||
assert edge_doc.metadata["edge_type"] in {
|
||||
"calls",
|
||||
"imports",
|
||||
"inherits",
|
||||
"instantiates",
|
||||
"reads_attr",
|
||||
"writes_attr",
|
||||
"dataflow_slice",
|
||||
}
|
||||
|
||||
entry_doc = next(doc for doc in docs if doc.layer == RagLayer.CODE_ENTRYPOINTS)
|
||||
assert entry_doc.metadata["framework"] == "fastapi"
|
||||
assert entry_doc.metadata["http_method"] == "GET"
|
||||
assert entry_doc.metadata["route_path"] == "/users/{user_id}"
|
||||
assert entry_doc.metadata["entrypoint_kind"] == "http_route"
|
||||
assert entry_doc.metadata["handler_symbol"] == "get_user"
|
||||
assert entry_doc.metadata["summary_text"] == "GET /users/{user_id} declared in get_user"
|
||||
assert "GET /users/{user_id}" in entry_doc.text
|
||||
|
||||
|
||||
def test_code_pipeline_indexes_import_alias_as_symbol() -> None:
|
||||
pipeline = CodeIndexingPipeline()
|
||||
content = "from .v2 import ConfigManagerV2 as ConfigManager\n"
|
||||
|
||||
docs = pipeline.index_file(
|
||||
repo_id="acme/proj",
|
||||
commit_sha="abc123",
|
||||
path="src/config_manager/__init__.py",
|
||||
content=content,
|
||||
)
|
||||
|
||||
alias_doc = next(doc for doc in docs if doc.layer == RagLayer.CODE_SYMBOL_CATALOG and doc.metadata["qname"] == "ConfigManager")
|
||||
assert alias_doc.metadata["kind"] == "const"
|
||||
|
||||
|
||||
def test_code_pipeline_marks_test_documents() -> None:
|
||||
pipeline = CodeIndexingPipeline()
|
||||
content = """
|
||||
def test_user_service():
|
||||
assert True
|
||||
"""
|
||||
|
||||
docs = pipeline.index_file(
|
||||
repo_id="acme/proj",
|
||||
commit_sha="abc123",
|
||||
path="tests/test_users.py",
|
||||
content=content,
|
||||
)
|
||||
|
||||
assert docs
|
||||
assert all(doc.metadata["is_test"] is True for doc in docs)
|
||||
|
||||
|
||||
def test_code_pipeline_extracts_data_flow_edges() -> None:
|
||||
pipeline = CodeIndexingPipeline()
|
||||
content = """
|
||||
class Context:
|
||||
def __init__(self):
|
||||
self.data = {}
|
||||
|
||||
def set(self, new_context):
|
||||
self.data = new_context
|
||||
|
||||
def process():
|
||||
ctx = Context()
|
||||
value = ctx.data
|
||||
return value
|
||||
"""
|
||||
docs = pipeline.index_file(
|
||||
repo_id="acme/proj",
|
||||
commit_sha="abc123",
|
||||
path="src/context.py",
|
||||
content=content,
|
||||
)
|
||||
edges = [doc.metadata for doc in docs if doc.layer == RagLayer.CODE_DEPENDENCY_GRAPH]
|
||||
edge_pairs = {(str(item.get("edge_type") or ""), str(item.get("dst_ref") or "")) for item in edges}
|
||||
|
||||
assert ("instantiates", "Context") in edge_pairs
|
||||
assert ("writes_attr", "Context.data") in edge_pairs
|
||||
assert ("reads_attr", "ctx.data") in edge_pairs
|
||||
|
||||
|
||||
def test_code_pipeline_builds_dataflow_slice_documents() -> None:
|
||||
pipeline = CodeIndexingPipeline()
|
||||
content = """
|
||||
class Context:
|
||||
def set(self, value):
|
||||
self.data = value
|
||||
|
||||
def read_data(ctx):
|
||||
return ctx.data
|
||||
|
||||
def run():
|
||||
ctx = Context()
|
||||
Context().set({"order_id": 1})
|
||||
return read_data(ctx)
|
||||
"""
|
||||
docs = pipeline.index_file(
|
||||
repo_id="acme/proj",
|
||||
commit_sha="abc123",
|
||||
path="src/context_flow.py",
|
||||
content=content,
|
||||
)
|
||||
|
||||
slices = [
|
||||
doc
|
||||
for doc in docs
|
||||
if doc.layer == RagLayer.CODE_DEPENDENCY_GRAPH and doc.metadata.get("edge_type") == "dataflow_slice"
|
||||
]
|
||||
assert slices
|
||||
assert any("Context.data" in item.metadata.get("path_symbols", []) for item in slices)
|
||||
assert all(item.metadata.get("path_length", 0) <= 6 for item in slices)
|
||||
|
||||
|
||||
def test_code_pipeline_builds_execution_trace_documents() -> None:
|
||||
pipeline = CodeIndexingPipeline()
|
||||
content = """
|
||||
from fastapi import APIRouter
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
def parse():
|
||||
return "parsed"
|
||||
|
||||
def send_email():
|
||||
return parse()
|
||||
|
||||
@router.post("/run")
|
||||
def run_pipeline():
|
||||
return send_email()
|
||||
"""
|
||||
docs = pipeline.index_file(
|
||||
repo_id="acme/proj",
|
||||
commit_sha="abc123",
|
||||
path="src/pipeline.py",
|
||||
content=content,
|
||||
)
|
||||
|
||||
traces = [doc for doc in docs if doc.layer == RagLayer.CODE_ENTRYPOINTS and doc.metadata.get("edge_type") == "execution_trace"]
|
||||
assert traces
|
||||
assert any(item.metadata.get("path_length", 0) >= 2 for item in traces)
|
||||
assert any("run_pipeline" in item.metadata.get("path_symbols", []) for item in traces)
|
||||
|
||||
|
||||
def test_code_pipeline_builds_semantic_role_documents() -> None:
|
||||
pipeline = CodeIndexingPipeline()
|
||||
content = """
|
||||
class EmailAdapter:
|
||||
def send(self, payload):
|
||||
import requests
|
||||
return requests.post("http://localhost", json=payload)
|
||||
|
||||
class ExcelParser:
|
||||
def parse(self, rows):
|
||||
import csv
|
||||
return list(csv.reader(rows))
|
||||
|
||||
class OrderHandler:
|
||||
def handle(self, ctx, adapter):
|
||||
ctx.data = {"status": "ready"}
|
||||
value = ctx.data
|
||||
return adapter.send(value)
|
||||
"""
|
||||
docs = pipeline.index_file(
|
||||
repo_id="acme/proj",
|
||||
commit_sha="abc123",
|
||||
path="src/semantic_roles.py",
|
||||
content=content,
|
||||
)
|
||||
|
||||
roles = {
|
||||
doc.metadata.get("symbol_name"): doc.metadata.get("role")
|
||||
for doc in docs
|
||||
if doc.layer == RagLayer.CODE_SEMANTIC_ROLES
|
||||
}
|
||||
|
||||
assert roles.get("EmailAdapter") == "adapter"
|
||||
assert roles.get("ExcelParser") == "parser"
|
||||
assert roles.get("OrderHandler") == "handler"
|
||||
63
tests/unit_tests/rag/test_docs_indexing_pipeline.py
Normal file
63
tests/unit_tests/rag/test_docs_indexing_pipeline.py
Normal file
@@ -0,0 +1,63 @@
|
||||
from app.modules.rag.contracts.enums import RagLayer
|
||||
from app.modules.rag.indexing.docs.pipeline import DocsIndexingPipeline
|
||||
|
||||
|
||||
def test_docs_pipeline_builds_catalog_facts_sections_and_policy() -> None:
|
||||
pipeline = DocsIndexingPipeline()
|
||||
content = """---
|
||||
id: api.billing.create_invoice
|
||||
type: policy
|
||||
domain: billing
|
||||
links:
|
||||
calls_api:
|
||||
- api.billing.validate_invoice
|
||||
tags: [billing]
|
||||
status: active
|
||||
---
|
||||
# Create Invoice
|
||||
|
||||
## Spec Summary
|
||||
|
||||
Creates an invoice in billing.
|
||||
|
||||
## Request Contract
|
||||
|
||||
| field | type | required | validation |
|
||||
| --- | --- | --- | --- |
|
||||
| amount | decimal | yes | > 0 |
|
||||
|
||||
## Error Matrix
|
||||
|
||||
| status | error | client action |
|
||||
| --- | --- | --- |
|
||||
| 400 | invalid_amount | fix request |
|
||||
|
||||
## Rules
|
||||
|
||||
- metric: billing.invoice.created
|
||||
- rule: amount must be positive
|
||||
"""
|
||||
docs = pipeline.index_file(
|
||||
repo_id="acme/proj",
|
||||
commit_sha="abc123",
|
||||
path="docs/billing/create_invoice.md",
|
||||
content=content,
|
||||
)
|
||||
|
||||
layers = {doc.layer for doc in docs}
|
||||
assert RagLayer.DOCS_MODULE_CATALOG in layers
|
||||
assert RagLayer.DOCS_FACT_INDEX in layers
|
||||
assert RagLayer.DOCS_SECTION_INDEX in layers
|
||||
assert RagLayer.DOCS_POLICY_INDEX in layers
|
||||
|
||||
module_doc = next(doc for doc in docs if doc.layer == RagLayer.DOCS_MODULE_CATALOG)
|
||||
assert module_doc.metadata["module_id"] == "api.billing.create_invoice"
|
||||
assert module_doc.metadata["type"] == "policy"
|
||||
|
||||
fact_texts = [doc.text for doc in docs if doc.layer == RagLayer.DOCS_FACT_INDEX]
|
||||
assert any("calls_api" in text for text in fact_texts)
|
||||
assert any("has_field" in text for text in fact_texts)
|
||||
assert any("returns_error" in text for text in fact_texts)
|
||||
|
||||
section_doc = next(doc for doc in docs if doc.layer == RagLayer.DOCS_SECTION_INDEX)
|
||||
assert section_doc.metadata["section_path"]
|
||||
22
tests/unit_tests/rag/test_explain_intent_builder.py
Normal file
22
tests/unit_tests/rag/test_explain_intent_builder.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from app.modules.rag.explain.intent_builder import ExplainIntentBuilder
|
||||
|
||||
|
||||
def test_explain_intent_builder_extracts_route_symbol_and_file_hints() -> None:
|
||||
builder = ExplainIntentBuilder()
|
||||
|
||||
intent = builder.build("Explain how /users/{user_id} reaches UserService.get_user in app/api/users.py")
|
||||
|
||||
assert "/users/{user_id}" in intent.hints.endpoints
|
||||
assert "UserService.get_user" in intent.hints.symbols
|
||||
assert "app/api/users.py" in intent.hints.paths
|
||||
assert intent.expected_entry_types == ["http"]
|
||||
assert intent.include_tests is False
|
||||
assert intent.depth == "medium"
|
||||
|
||||
|
||||
def test_explain_intent_builder_enables_tests_when_user_asks_for_them() -> None:
|
||||
builder = ExplainIntentBuilder()
|
||||
|
||||
intent = builder.build("Покажи как это тестируется в pytest и какие tests покрывают UserService")
|
||||
|
||||
assert intent.include_tests is True
|
||||
126
tests/unit_tests/rag/test_intent_router_e2e_flows.py
Normal file
126
tests/unit_tests/rag/test_intent_router_e2e_flows.py
Normal file
@@ -0,0 +1,126 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from app.modules.rag.intent_router_v2 import GigaChatIntentRouterFactory
|
||||
from app.modules.shared.env_loader import load_workspace_env
|
||||
from tests.unit_tests.rag.asserts_intent_router import (
|
||||
assert_domains,
|
||||
assert_file_only_scope,
|
||||
assert_intent,
|
||||
assert_test_policy,
|
||||
)
|
||||
from tests.unit_tests.rag.intent_router_testkit import run_sequence
|
||||
|
||||
pytestmark = pytest.mark.intent_router
|
||||
|
||||
|
||||
def _live_gigachat_enabled() -> bool:
|
||||
load_workspace_env()
|
||||
return os.getenv("RUN_INTENT_ROUTER_V2_LIVE", "").strip() == "1" and bool(os.getenv("GIGACHAT_TOKEN", "").strip())
|
||||
|
||||
|
||||
def test_e2e_path_carryover_flow() -> None:
|
||||
first, second, third = run_sequence(
|
||||
[
|
||||
"Посмотри файл app/core/config.py",
|
||||
"Теперь объясни функцию load_config",
|
||||
"Почему так?",
|
||||
]
|
||||
)
|
||||
|
||||
assert_file_only_scope(first, "app/core/config.py")
|
||||
assert "app/core/config.py" in second.retrieval_spec.filters.path_scope
|
||||
assert "app/core/config.py" in third.retrieval_spec.filters.path_scope
|
||||
second_file_anchors = [anchor.value for anchor in second.query_plan.anchors if anchor.type == "FILE_PATH" and anchor.source == "conversation_state"]
|
||||
assert second_file_anchors == ["app/core/config.py"]
|
||||
assert "app/core/config.py" in second.query_plan.keyword_hints
|
||||
assert "app/core" not in second.query_plan.keyword_hints
|
||||
assert any(anchor.type == "FILE_PATH" and anchor.source == "conversation_state" and anchor.span is None for anchor in third.query_plan.anchors)
|
||||
carried_symbols = [anchor.value for anchor in third.query_plan.anchors if anchor.type == "SYMBOL" and anchor.source == "conversation_state"]
|
||||
assert carried_symbols in ([], ["load_config"])
|
||||
assert third.query_plan.sub_intent == "EXPLAIN_LOCAL"
|
||||
layer_ids = [item.layer_id for item in third.retrieval_spec.layer_queries]
|
||||
assert "C3_ENTRYPOINTS" not in layer_ids
|
||||
|
||||
|
||||
def test_e2e_docs_switch_from_code_topic() -> None:
|
||||
first, second = run_sequence(
|
||||
[
|
||||
"Объясни как работает ConfigManager",
|
||||
"А что про это сказано в документации?",
|
||||
]
|
||||
)
|
||||
|
||||
assert_intent(first, "CODE_QA")
|
||||
assert_intent(second, "DOCS_QA")
|
||||
assert second.conversation_mode == "SWITCH"
|
||||
assert_domains(second, ["DOCS"])
|
||||
carried = [
|
||||
anchor
|
||||
for anchor in second.query_plan.anchors
|
||||
if anchor.type == "SYMBOL" and anchor.value == "ConfigManager" and anchor.source == "conversation_state"
|
||||
]
|
||||
assert carried
|
||||
assert carried[0].span is None
|
||||
assert "ConfigManager" in second.query_plan.expansions
|
||||
assert "ConfigManager" in second.query_plan.keyword_hints
|
||||
|
||||
|
||||
def test_e2e_tests_toggle_flow() -> None:
|
||||
first, second = run_sequence(
|
||||
[
|
||||
"Покажи тесты для ConfigManager",
|
||||
"А теперь не про тесты, а про прод код",
|
||||
]
|
||||
)
|
||||
|
||||
assert_intent(first, "CODE_QA")
|
||||
assert_intent(second, "CODE_QA")
|
||||
assert_test_policy(first, "INCLUDE")
|
||||
assert_test_policy(second, "EXCLUDE")
|
||||
assert first.query_plan.sub_intent == "FIND_TESTS"
|
||||
assert second.query_plan.sub_intent == "EXPLAIN"
|
||||
assert "tests" in second.query_plan.negations
|
||||
assert not second.query_plan.expansions
|
||||
assert second.evidence_policy.require_flow is False
|
||||
|
||||
|
||||
def test_e2e_open_file_then_generic_next_steps_is_lightweight() -> None:
|
||||
first, second = run_sequence(
|
||||
[
|
||||
"Открой файл app/core/config.py",
|
||||
"Что дальше?",
|
||||
]
|
||||
)
|
||||
|
||||
assert_file_only_scope(first, "app/core/config.py")
|
||||
assert_file_only_scope(second, "app/core/config.py")
|
||||
assert second.query_plan.sub_intent in {"EXPLAIN_LOCAL", "NEXT_STEPS"}
|
||||
layer_ids = [item.layer_id for item in second.retrieval_spec.layer_queries]
|
||||
assert "C3_ENTRYPOINTS" not in layer_ids
|
||||
assert second.evidence_policy.require_flow is False
|
||||
assert "app/core/config.py" in second.query_plan.keyword_hints
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not _live_gigachat_enabled(),
|
||||
reason="requires RUN_INTENT_ROUTER_V2_LIVE=1 and GIGACHAT_TOKEN in environment or .env",
|
||||
)
|
||||
def test_intent_router_live_smoke_path_carryover() -> None:
|
||||
router = GigaChatIntentRouterFactory().build()
|
||||
first, second = run_sequence(
|
||||
[
|
||||
"Открой файл app/core/config.py",
|
||||
"Что дальше?",
|
||||
],
|
||||
router=router,
|
||||
trace_label="intent-router-live",
|
||||
)
|
||||
|
||||
assert_file_only_scope(first, "app/core/config.py")
|
||||
assert "app/core/config.py" in second.retrieval_spec.filters.path_scope
|
||||
assert second.query_plan.sub_intent in {"EXPLAIN_LOCAL", "NEXT_STEPS"}
|
||||
layer_ids = [item.layer_id for item in second.retrieval_spec.layer_queries]
|
||||
assert "C3_ENTRYPOINTS" not in layer_ids
|
||||
assert second.evidence_policy.require_flow is False
|
||||
120
tests/unit_tests/rag/test_intent_router_invariants.py
Normal file
120
tests/unit_tests/rag/test_intent_router_invariants.py
Normal file
@@ -0,0 +1,120 @@
|
||||
import pytest
|
||||
|
||||
from tests.unit_tests.rag.asserts_intent_router import (
|
||||
assert_domain_layer_prefixes,
|
||||
assert_domains,
|
||||
assert_file_only_scope,
|
||||
assert_has_file_path,
|
||||
assert_intent,
|
||||
assert_no_symbol_keyword,
|
||||
assert_no_symbol_leakage_from_paths,
|
||||
assert_spans_valid,
|
||||
assert_sub_intent,
|
||||
assert_test_policy,
|
||||
)
|
||||
from tests.unit_tests.rag.intent_router_testkit import run_sequence
|
||||
|
||||
pytestmark = pytest.mark.intent_router
|
||||
|
||||
|
||||
def test_invariant_code_file_path_with_canonical_key_term() -> None:
|
||||
result = run_sequence(["Уточни по файлу app/core/config.py"])[0]
|
||||
|
||||
assert_intent(result, "CODE_QA")
|
||||
assert_has_file_path(result, "app/core/config.py")
|
||||
assert_file_only_scope(result, "app/core/config.py")
|
||||
key_terms = [anchor.value for anchor in result.query_plan.anchors if anchor.type == "KEY_TERM"]
|
||||
assert "файл" in key_terms
|
||||
assert "файлу" not in key_terms
|
||||
assert_spans_valid(result)
|
||||
assert_domain_layer_prefixes(result)
|
||||
|
||||
|
||||
def test_invariant_open_file_for_specified_file_phrase_uses_narrow_layers() -> None:
|
||||
result = run_sequence(["Уточни по файлу app/core/config.py"])[0]
|
||||
|
||||
assert_intent(result, "CODE_QA")
|
||||
assert_sub_intent(result, "OPEN_FILE")
|
||||
assert_file_only_scope(result, "app/core/config.py")
|
||||
layer_ids = [item.layer_id for item in result.retrieval_spec.layer_queries]
|
||||
assert layer_ids == ["C0_SOURCE_CHUNKS"]
|
||||
assert result.evidence_policy.require_flow is False
|
||||
|
||||
|
||||
def test_invariant_inline_code_span_routes_to_code_and_extracts_symbol() -> None:
|
||||
result = run_sequence(["Уточни по коду `def build(x): return x`"])[0]
|
||||
|
||||
assert_intent(result, "CODE_QA")
|
||||
assert_spans_valid(result)
|
||||
assert_no_symbol_keyword(result)
|
||||
symbols = [anchor.value for anchor in result.query_plan.anchors if anchor.type == "SYMBOL"]
|
||||
key_terms = [anchor.value for anchor in result.query_plan.anchors if anchor.type == "KEY_TERM"]
|
||||
assert "build" in symbols
|
||||
assert "def" in key_terms
|
||||
|
||||
|
||||
def test_invariant_docs_cyrillic_path_with_quotes() -> None:
|
||||
result = run_sequence(["Что сказано в «docs/архитектура.md»?"])[0]
|
||||
|
||||
assert_intent(result, "DOCS_QA")
|
||||
assert_sub_intent(result, "EXPLAIN")
|
||||
assert_domains(result, ["DOCS"])
|
||||
assert "docs/архитектура.md" in result.query_plan.normalized
|
||||
assert_has_file_path(result, "docs/архитектура.md")
|
||||
assert any(anchor.type == "DOC_REF" for anchor in result.query_plan.anchors)
|
||||
assert result.retrieval_spec.filters.doc_kinds == []
|
||||
assert_spans_valid(result)
|
||||
assert_domain_layer_prefixes(result)
|
||||
|
||||
|
||||
def test_invariant_file_check_phrase_not_project_misc() -> None:
|
||||
result = run_sequence(["Проверь app/modules/rag/explain/intent_builder.py и объясни"])[0]
|
||||
|
||||
assert_intent(result, "CODE_QA")
|
||||
assert_domains(result, ["CODE"])
|
||||
assert_no_symbol_leakage_from_paths(result)
|
||||
assert_domain_layer_prefixes(result)
|
||||
|
||||
|
||||
def test_invariant_tests_include_routing() -> None:
|
||||
result = run_sequence(["Где тесты на ConfigManager?"])[0]
|
||||
|
||||
assert_intent(result, "CODE_QA")
|
||||
assert_test_policy(result, "INCLUDE")
|
||||
symbols = [anchor.value for anchor in result.query_plan.anchors if anchor.type == "SYMBOL"]
|
||||
key_terms = [anchor.value for anchor in result.query_plan.anchors if anchor.type == "KEY_TERM"]
|
||||
assert "ConfigManager" in symbols
|
||||
assert "тест" in key_terms
|
||||
|
||||
|
||||
def test_invariant_keyword_hints_and_expansions_for_function_identifier() -> None:
|
||||
result = run_sequence(["Теперь объясни функцию load_config"])[0]
|
||||
|
||||
assert_intent(result, "CODE_QA")
|
||||
assert "load_config" in result.query_plan.keyword_hints
|
||||
assert "функция" not in result.query_plan.keyword_hints
|
||||
assert "def" not in result.query_plan.expansions
|
||||
|
||||
|
||||
def test_invariant_open_file_sub_intent_uses_narrow_retrieval_profile() -> None:
|
||||
result = run_sequence(["Открой файл app/core/config.py"])[0]
|
||||
|
||||
assert_intent(result, "CODE_QA")
|
||||
assert_sub_intent(result, "OPEN_FILE")
|
||||
assert_file_only_scope(result, "app/core/config.py")
|
||||
layer_ids = [item.layer_id for item in result.retrieval_spec.layer_queries]
|
||||
assert "C0_SOURCE_CHUNKS" in layer_ids
|
||||
assert "C1_SYMBOL_CATALOG" not in layer_ids
|
||||
assert "C2_DEPENDENCY_GRAPH" not in layer_ids
|
||||
assert "C3_ENTRYPOINTS" not in layer_ids
|
||||
assert result.evidence_policy.require_flow is False
|
||||
|
||||
|
||||
def test_invariant_docs_question_routes_to_docs() -> None:
|
||||
result = run_sequence(["Что сказано в документации?"])[0]
|
||||
|
||||
assert_intent(result, "DOCS_QA")
|
||||
assert_domains(result, ["DOCS"])
|
||||
assert_domain_layer_prefixes(result)
|
||||
assert result.query_plan.keyword_hints
|
||||
assert any(item in result.query_plan.expansions for item in result.query_plan.keyword_hints)
|
||||
78
tests/unit_tests/rag/test_layered_gateway.py
Normal file
78
tests/unit_tests/rag/test_layered_gateway.py
Normal file
@@ -0,0 +1,78 @@
|
||||
from app.modules.rag.explain.layered_gateway import LayeredRetrievalGateway
|
||||
|
||||
|
||||
class _Embedder:
|
||||
def embed(self, texts: list[str]) -> list[list[float]]:
|
||||
return [[0.1, 0.2]]
|
||||
|
||||
|
||||
class _RetryingRepository:
|
||||
def __init__(self) -> None:
|
||||
self.calls: list[dict] = []
|
||||
|
||||
def retrieve(self, *args, **kwargs):
|
||||
self.calls.append(kwargs)
|
||||
if kwargs.get("exclude_path_prefixes"):
|
||||
raise RuntimeError("syntax error at or near ')'")
|
||||
return [
|
||||
{
|
||||
"path": "app/users/service.py",
|
||||
"content": "def get_user(): pass",
|
||||
"layer": "C1_SYMBOL_CATALOG",
|
||||
"title": "get_user",
|
||||
"metadata": {"symbol_id": "symbol-1"},
|
||||
"distance": 0.1,
|
||||
"span_start": 10,
|
||||
"span_end": 11,
|
||||
}
|
||||
]
|
||||
|
||||
def retrieve_lexical_code(self, *args, **kwargs):
|
||||
self.calls.append(kwargs)
|
||||
if kwargs.get("exclude_path_prefixes"):
|
||||
raise RuntimeError("broken lexical filter")
|
||||
return [
|
||||
{
|
||||
"path": "app/users/service.py",
|
||||
"content": "def get_user(): pass",
|
||||
"layer": "C0_SOURCE_CHUNKS",
|
||||
"title": "get_user",
|
||||
"metadata": {"symbol_id": "symbol-1"},
|
||||
"span_start": 10,
|
||||
"span_end": 11,
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
class _RecordingRepository:
|
||||
def __init__(self) -> None:
|
||||
self.calls: list[dict] = []
|
||||
|
||||
def retrieve(self, *args, **kwargs):
|
||||
self.calls.append(kwargs)
|
||||
return []
|
||||
|
||||
def retrieve_lexical_code(self, *args, **kwargs):
|
||||
self.calls.append(kwargs)
|
||||
return []
|
||||
|
||||
|
||||
def test_gateway_retries_without_test_filter_on_vector_failure() -> None:
|
||||
gateway = LayeredRetrievalGateway(_RetryingRepository(), _Embedder())
|
||||
|
||||
result = gateway.retrieve_layer("rag-1", "Explain get_user", "C1_SYMBOL_CATALOG", limit=3, exclude_tests=True)
|
||||
|
||||
assert len(result.items) == 1
|
||||
assert "layer:C1_SYMBOL_CATALOG retrieval_failed:retried_without_test_filter" in result.missing
|
||||
|
||||
|
||||
def test_gateway_honors_debug_disable_test_filter(monkeypatch) -> None:
|
||||
monkeypatch.setenv("RAG_DEBUG_DISABLE_TEST_FILTER", "true")
|
||||
repository = _RecordingRepository()
|
||||
gateway = LayeredRetrievalGateway(repository, _Embedder())
|
||||
|
||||
gateway.retrieve_layer("rag-1", "Explain get_user", "C1_SYMBOL_CATALOG", limit=3, exclude_tests=True)
|
||||
|
||||
assert repository.calls
|
||||
assert repository.calls[0]["exclude_path_prefixes"] is None
|
||||
assert repository.calls[0]["exclude_like_patterns"] is None
|
||||
44
tests/unit_tests/rag/test_path_filter.py
Normal file
44
tests/unit_tests/rag/test_path_filter.py
Normal file
@@ -0,0 +1,44 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from app.modules.rag.indexing.common.path_filter import (
|
||||
count_indexable_change_upserts,
|
||||
filter_changes_for_indexing,
|
||||
filter_snapshot_files,
|
||||
should_skip_indexing_path,
|
||||
)
|
||||
|
||||
|
||||
def test_should_skip_indexing_path_for_hidden_and_cache_paths() -> None:
|
||||
assert should_skip_indexing_path(".env") is True
|
||||
assert should_skip_indexing_path(".venv/lib/site-packages/a.py") is True
|
||||
assert should_skip_indexing_path("src/.secrets/config.py") is True
|
||||
assert should_skip_indexing_path("src/__pycache__/module.cpython-312.pyc") is True
|
||||
assert should_skip_indexing_path("src/main.py") is False
|
||||
|
||||
|
||||
def test_filter_snapshot_files_excludes_hidden_and_cache_paths() -> None:
|
||||
files = [
|
||||
{"path": ".env", "content": "A"},
|
||||
{"path": "src/__pycache__/x.py", "content": "B"},
|
||||
{"path": "src/main.py", "content": "C"},
|
||||
]
|
||||
|
||||
filtered = filter_snapshot_files(files)
|
||||
|
||||
assert [item["path"] for item in filtered] == ["src/main.py"]
|
||||
|
||||
|
||||
def test_filter_changes_for_indexing_keeps_deletes_and_filters_upserts() -> None:
|
||||
changed_files = [
|
||||
{"op": "upsert", "path": ".env", "content": "A"},
|
||||
{"op": "upsert", "path": "src/main.py", "content": "B"},
|
||||
{"op": "delete", "path": ".cache/legacy.txt"},
|
||||
]
|
||||
|
||||
filtered = filter_changes_for_indexing(changed_files)
|
||||
|
||||
assert filtered == [
|
||||
{"op": "upsert", "path": "src/main.py", "content": "B"},
|
||||
{"op": "delete", "path": ".cache/legacy.txt"},
|
||||
]
|
||||
assert count_indexable_change_upserts(filtered) == 1
|
||||
63
tests/unit_tests/rag/test_query_normalization.py
Normal file
63
tests/unit_tests/rag/test_query_normalization.py
Normal file
@@ -0,0 +1,63 @@
|
||||
import pytest
|
||||
|
||||
from app.modules.rag.intent_router_v2.analysis.normalization import QueryNormalizer
|
||||
|
||||
pytestmark = pytest.mark.intent_router
|
||||
|
||||
|
||||
def test_query_normalizer_collapses_whitespace() -> None:
|
||||
normalizer = QueryNormalizer()
|
||||
|
||||
normalized = normalizer.normalize(" Объясни как работает \n класс X ")
|
||||
|
||||
assert normalized == "Объясни как работает класс X"
|
||||
|
||||
|
||||
def test_query_normalizer_canonicalizes_quotes() -> None:
|
||||
normalizer = QueryNormalizer()
|
||||
|
||||
normalized = normalizer.normalize('Уточни «текст» и “текст”')
|
||||
|
||||
assert normalized == 'Уточни "текст" и "текст"'
|
||||
|
||||
|
||||
def test_query_normalizer_preserves_backticks_verbatim() -> None:
|
||||
normalizer = QueryNormalizer()
|
||||
|
||||
normalized = normalizer.normalize("Уточни по коду `def build(x):` ")
|
||||
|
||||
assert normalized == "Уточни по коду `def build(x):`"
|
||||
|
||||
|
||||
def test_query_normalizer_preserves_latin_and_cyrillic_file_paths() -> None:
|
||||
normalizer = QueryNormalizer()
|
||||
|
||||
normalized = normalizer.normalize("Сверь app/core/config.py и «docs/руководство.md»")
|
||||
|
||||
assert "app/core/config.py" in normalized
|
||||
assert "docs/руководство.md" in normalized
|
||||
assert "config. py" not in normalized
|
||||
assert "руководство. md" not in normalized
|
||||
|
||||
|
||||
def test_query_normalizer_punctuation_spacing_does_not_break_extensions() -> None:
|
||||
normalizer = QueryNormalizer()
|
||||
|
||||
normalized = normalizer.normalize("Проверь docs/spec.md , затем app/main.py !")
|
||||
|
||||
assert "docs/spec.md" in normalized
|
||||
assert "app/main.py" in normalized
|
||||
assert "spec. md" not in normalized
|
||||
assert "main. py" not in normalized
|
||||
|
||||
|
||||
def test_query_normalizer_idempotent_and_without_enrichment() -> None:
|
||||
normalizer = QueryNormalizer()
|
||||
raw = ' Прочитай «README.md» и docs/spec.md '
|
||||
|
||||
once = normalizer.normalize(raw)
|
||||
twice = normalizer.normalize(once)
|
||||
|
||||
assert twice == once
|
||||
assert "documentation" not in once.lower()
|
||||
assert "class" not in once.lower()
|
||||
9
tests/unit_tests/rag/test_query_terms.py
Normal file
9
tests/unit_tests/rag/test_query_terms.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from app.modules.rag.retrieval.query_terms import extract_query_terms
|
||||
|
||||
|
||||
def test_extract_query_terms_from_code_question() -> None:
|
||||
terms = extract_query_terms("Объясни по коду как можно управлять COnfigmanager?")
|
||||
|
||||
assert "configmanager" in terms
|
||||
assert "config_manager" in terms
|
||||
assert "control" in terms
|
||||
52
tests/unit_tests/rag/test_rag_service_filtering.py
Normal file
52
tests/unit_tests/rag/test_rag_service_filtering.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
|
||||
from app.modules.rag.services.rag_service import RagService
|
||||
|
||||
|
||||
class _FakeEmbedder:
|
||||
def embed(self, texts: list[str]) -> list[list[float]]:
|
||||
return [[0.0, 0.1, 0.2] for _ in texts]
|
||||
|
||||
|
||||
class _FakeRepository:
|
||||
def __init__(self) -> None:
|
||||
self.replaced_docs = []
|
||||
|
||||
def get_session(self, rag_session_id: str) -> dict:
|
||||
return {"project_id": rag_session_id}
|
||||
|
||||
def get_cached_documents(self, repo_id: str, blob_sha: str) -> list:
|
||||
return []
|
||||
|
||||
def cache_documents(self, repo_id: str, path: str, blob_sha: str, docs: list) -> None:
|
||||
return None
|
||||
|
||||
def replace_documents(self, rag_session_id: str, docs: list) -> None:
|
||||
self.replaced_docs = docs
|
||||
|
||||
|
||||
def test_rag_service_progress_uses_only_indexable_files() -> None:
|
||||
repository = _FakeRepository()
|
||||
service = RagService(embedder=_FakeEmbedder(), repository=repository)
|
||||
files = [
|
||||
{"path": ".env", "content": "SECRET=1", "content_hash": "h1"},
|
||||
{"path": "src/.hidden/config.py", "content": "A=1", "content_hash": "h2"},
|
||||
{"path": "src/__pycache__/cache.py", "content": "A=2", "content_hash": "h3"},
|
||||
{"path": "src/main.py", "content": "def main():\n return 1\n", "content_hash": "h4"},
|
||||
]
|
||||
progress: list[tuple[int, int, str]] = []
|
||||
|
||||
def progress_cb(current: int, total: int, path: str) -> None:
|
||||
progress.append((current, total, path))
|
||||
|
||||
indexed, failed, cache_hits, cache_misses = asyncio.run(
|
||||
service.index_snapshot("project-1", files, progress_cb=progress_cb)
|
||||
)
|
||||
|
||||
assert indexed == 1
|
||||
assert failed == 0
|
||||
assert cache_hits == 0
|
||||
assert cache_misses == 1
|
||||
assert progress == [(1, 1, "src/main.py")]
|
||||
65
tests/unit_tests/rag/test_retrieval_statement_builder.py
Normal file
65
tests/unit_tests/rag/test_retrieval_statement_builder.py
Normal file
@@ -0,0 +1,65 @@
|
||||
from app.modules.rag.persistence.retrieval_statement_builder import RetrievalStatementBuilder
|
||||
from app.modules.rag.retrieval.test_filter import build_test_filters, is_test_path
|
||||
|
||||
|
||||
def test_retrieve_builder_adds_test_exclusion_filters() -> None:
|
||||
builder = RetrievalStatementBuilder()
|
||||
test_filters = build_test_filters()
|
||||
|
||||
sql, params = builder.build_retrieve(
|
||||
"rag-1",
|
||||
[0.1, 0.2],
|
||||
query_text="Explain user service",
|
||||
layers=["C0_SOURCE_CHUNKS"],
|
||||
exclude_path_prefixes=test_filters.exclude_path_prefixes,
|
||||
exclude_like_patterns=test_filters.exclude_like_patterns,
|
||||
)
|
||||
|
||||
assert "NOT (" in sql
|
||||
assert "vector_dims(embedding) = vector_dims(CAST(:emb AS vector))" in sql
|
||||
assert "path LIKE :exclude_prefix_0" in sql
|
||||
assert "lower(path) LIKE :exclude_like_0" in sql
|
||||
assert "ESCAPE E'\\\\'" in sql
|
||||
assert params["exclude_prefix_0"] == "tests/%"
|
||||
assert "%.test.%" in params.values()
|
||||
assert "%\\_test.%" in params.values()
|
||||
|
||||
|
||||
def test_retrieve_builder_adds_prefer_bonus_sorting() -> None:
|
||||
builder = RetrievalStatementBuilder()
|
||||
|
||||
sql, params = builder.build_retrieve(
|
||||
"rag-1",
|
||||
[0.1, 0.2],
|
||||
query_text="find context tests",
|
||||
layers=["C1_SYMBOL_CATALOG"],
|
||||
prefer_path_prefixes=["tests/"],
|
||||
prefer_like_patterns=["%/test\\_%.py"],
|
||||
)
|
||||
|
||||
assert "AS prefer_bonus" in sql
|
||||
assert "AS structural_rank" in sql
|
||||
assert "WHEN layer = 'C4_SEMANTIC_ROLES' THEN 2" in sql
|
||||
assert "ORDER BY prefer_bonus ASC, test_penalty ASC, layer_rank ASC" in sql
|
||||
assert params["prefer_prefix_0"] == "tests/%"
|
||||
assert params["prefer_like_0"] == "%/test\\_%.py"
|
||||
|
||||
|
||||
def test_lexical_builder_omits_test_filters_when_not_requested() -> None:
|
||||
builder = RetrievalStatementBuilder()
|
||||
|
||||
sql, params = builder.build_lexical_code(
|
||||
"rag-1",
|
||||
query_text="Explain user service",
|
||||
prefer_non_tests=False,
|
||||
)
|
||||
|
||||
assert sql is not None
|
||||
assert "exclude_prefix" not in sql
|
||||
assert "exclude_like" not in sql
|
||||
assert not any(key.startswith("exclude_") for key in params)
|
||||
|
||||
|
||||
def test_test_filter_does_not_treat_contest_file_as_test() -> None:
|
||||
assert is_test_path("app/contest.py") is False
|
||||
assert is_test_path("tests/test_users.py") is True
|
||||
52
tests/unit_tests/rag/test_retriever_v2_no_fallback.py
Normal file
52
tests/unit_tests/rag/test_retriever_v2_no_fallback.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from app.modules.rag.explain import CodeExplainRetrieverV2, LayeredRetrievalGateway
|
||||
|
||||
|
||||
class _ExplodingEmbedder:
|
||||
def embed(self, texts: list[str]) -> list[list[float]]:
|
||||
raise RuntimeError("embedding unavailable")
|
||||
|
||||
|
||||
class _RepositoryWithoutFallback:
|
||||
def retrieve(self, *args, **kwargs):
|
||||
raise RuntimeError("vector retrieval unavailable")
|
||||
|
||||
def retrieve_lexical_code(
|
||||
self,
|
||||
rag_session_id: str,
|
||||
query_text: str,
|
||||
*,
|
||||
limit: int = 5,
|
||||
path_prefixes: list[str] | None = None,
|
||||
exclude_path_prefixes: list[str] | None = None,
|
||||
exclude_like_patterns: list[str] | None = None,
|
||||
prefer_non_tests: bool = False,
|
||||
):
|
||||
return []
|
||||
|
||||
|
||||
class _FakeGraphRepository:
|
||||
def get_symbols_by_ids(self, rag_session_id: str, symbol_ids: list[str]):
|
||||
return []
|
||||
|
||||
def get_chunks_by_symbol_ids(self, rag_session_id: str, symbol_ids: list[str], prefer_chunk_type: str = "symbol_block"):
|
||||
return []
|
||||
|
||||
def get_out_edges(self, rag_session_id: str, src_symbol_ids: list[str], edge_types: list[str], limit_per_src: int):
|
||||
return []
|
||||
|
||||
def resolve_symbol_by_ref(self, rag_session_id: str, dst_ref: str, package_hint: str | None = None):
|
||||
return None
|
||||
|
||||
|
||||
def test_retriever_v2_returns_pack_without_fallback_method() -> None:
|
||||
retriever = CodeExplainRetrieverV2(
|
||||
gateway=LayeredRetrievalGateway(_RepositoryWithoutFallback(), _ExplodingEmbedder()),
|
||||
graph_repository=_FakeGraphRepository(),
|
||||
)
|
||||
|
||||
pack = retriever.build_pack("rag-1", "Explain get_user")
|
||||
|
||||
assert pack.code_excerpts == []
|
||||
assert any(item.startswith("layer:C3_ENTRYPOINTS retrieval_failed") for item in pack.missing)
|
||||
assert any(item.startswith("layer:C1_SYMBOL_CATALOG retrieval_failed") for item in pack.missing)
|
||||
assert "layer:C0 empty" in pack.missing
|
||||
105
tests/unit_tests/rag/test_retriever_v2_pack.py
Normal file
105
tests/unit_tests/rag/test_retriever_v2_pack.py
Normal file
@@ -0,0 +1,105 @@
|
||||
from app.modules.rag.explain.models import CodeLocation, LayeredRetrievalItem
|
||||
from app.modules.rag.explain.retriever_v2 import CodeExplainRetrieverV2
|
||||
|
||||
|
||||
class _FakeGateway:
|
||||
def retrieve_layer(
|
||||
self,
|
||||
rag_session_id: str,
|
||||
query: str,
|
||||
layer: str,
|
||||
*,
|
||||
limit: int,
|
||||
path_prefixes: list[str] | None = None,
|
||||
exclude_tests: bool = True,
|
||||
prefer_non_tests: bool = False,
|
||||
include_spans: bool = False,
|
||||
):
|
||||
if layer == "C3_ENTRYPOINTS":
|
||||
return __import__("types").SimpleNamespace(
|
||||
items=[
|
||||
LayeredRetrievalItem(
|
||||
source="app/api/users.py",
|
||||
content="GET /users/{id}",
|
||||
layer=layer,
|
||||
title="GET /users/{id}",
|
||||
metadata={"entry_type": "http", "handler_symbol_id": "handler-1"},
|
||||
location=CodeLocation(path="app/api/users.py", start_line=10, end_line=10),
|
||||
)
|
||||
],
|
||||
missing=[],
|
||||
)
|
||||
if layer == "C1_SYMBOL_CATALOG":
|
||||
return __import__("types").SimpleNamespace(
|
||||
items=[
|
||||
LayeredRetrievalItem(
|
||||
source="app/api/users.py",
|
||||
content="def get_user_handler",
|
||||
layer=layer,
|
||||
title="get_user_handler",
|
||||
metadata={"symbol_id": "handler-1"},
|
||||
location=CodeLocation(path="app/api/users.py", start_line=10, end_line=18),
|
||||
)
|
||||
],
|
||||
missing=[],
|
||||
)
|
||||
raise AssertionError(layer)
|
||||
|
||||
def retrieve_lexical_code(
|
||||
self,
|
||||
rag_session_id: str,
|
||||
query: str,
|
||||
*,
|
||||
limit: int,
|
||||
path_prefixes: list[str] | None = None,
|
||||
exclude_tests: bool = True,
|
||||
include_spans: bool = False,
|
||||
):
|
||||
return __import__("types").SimpleNamespace(items=[], missing=[])
|
||||
|
||||
|
||||
class _FakeGraphRepository:
|
||||
def get_symbols_by_ids(self, rag_session_id: str, symbol_ids: list[str]):
|
||||
return [
|
||||
LayeredRetrievalItem(
|
||||
source="app/api/users.py",
|
||||
content="def get_user_handler",
|
||||
layer="C1_SYMBOL_CATALOG",
|
||||
title="get_user_handler",
|
||||
metadata={"symbol_id": "handler-1"},
|
||||
location=CodeLocation(path="app/api/users.py", start_line=10, end_line=18),
|
||||
)
|
||||
]
|
||||
|
||||
def get_out_edges(self, rag_session_id: str, src_symbol_ids: list[str], edge_types: list[str], limit_per_src: int):
|
||||
return []
|
||||
|
||||
def resolve_symbol_by_ref(self, rag_session_id: str, dst_ref: str, package_hint: str | None = None):
|
||||
return None
|
||||
|
||||
def get_chunks_by_symbol_ids(self, rag_session_id: str, symbol_ids: list[str], prefer_chunk_type: str = "symbol_block"):
|
||||
return [
|
||||
LayeredRetrievalItem(
|
||||
source="app/api/users.py",
|
||||
content="async def get_user_handler(user_id: str):\n return await service.get_user(user_id)",
|
||||
layer="C0_SOURCE_CHUNKS",
|
||||
title="get_user_handler",
|
||||
metadata={"symbol_id": "handler-1"},
|
||||
location=CodeLocation(path="app/api/users.py", start_line=10, end_line=18),
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def test_retriever_v2_builds_pack_with_trace_and_excerpts() -> None:
|
||||
retriever = CodeExplainRetrieverV2(
|
||||
gateway=_FakeGateway(),
|
||||
graph_repository=_FakeGraphRepository(),
|
||||
)
|
||||
|
||||
pack = retriever.build_pack("rag-1", "Explain endpoint get_user")
|
||||
|
||||
assert len(pack.selected_entrypoints) == 1
|
||||
assert len(pack.seed_symbols) == 1
|
||||
assert len(pack.trace_paths) == 1
|
||||
assert len(pack.code_excerpts) == 1
|
||||
assert pack.code_excerpts[0].path == "app/api/users.py"
|
||||
142
tests/unit_tests/rag/test_retriever_v2_production_first.py
Normal file
142
tests/unit_tests/rag/test_retriever_v2_production_first.py
Normal file
@@ -0,0 +1,142 @@
|
||||
from types import SimpleNamespace
|
||||
|
||||
from app.modules.rag.explain.models import CodeLocation, LayeredRetrievalItem
|
||||
from app.modules.rag.explain.retriever_v2 import CodeExplainRetrieverV2
|
||||
|
||||
|
||||
class _ProductionFirstGateway:
|
||||
def __init__(self) -> None:
|
||||
self.lexical_calls: list[bool] = []
|
||||
|
||||
def retrieve_layer(
|
||||
self,
|
||||
rag_session_id: str,
|
||||
query: str,
|
||||
layer: str,
|
||||
*,
|
||||
limit: int,
|
||||
path_prefixes: list[str] | None = None,
|
||||
exclude_tests: bool = True,
|
||||
prefer_non_tests: bool = False,
|
||||
include_spans: bool = False,
|
||||
):
|
||||
if layer == "C3_ENTRYPOINTS":
|
||||
return SimpleNamespace(items=[], missing=[])
|
||||
if layer == "C1_SYMBOL_CATALOG":
|
||||
return SimpleNamespace(items=[], missing=[])
|
||||
raise AssertionError(layer)
|
||||
|
||||
def retrieve_lexical_code(
|
||||
self,
|
||||
rag_session_id: str,
|
||||
query: str,
|
||||
*,
|
||||
limit: int,
|
||||
path_prefixes: list[str] | None = None,
|
||||
exclude_tests: bool = True,
|
||||
include_spans: bool = False,
|
||||
):
|
||||
self.lexical_calls.append(exclude_tests)
|
||||
if exclude_tests:
|
||||
return SimpleNamespace(
|
||||
items=[
|
||||
LayeredRetrievalItem(
|
||||
source="app/users/service.py",
|
||||
content="def get_user():\n return repo.get_user()",
|
||||
layer="C0_SOURCE_CHUNKS",
|
||||
title="get_user",
|
||||
metadata={"symbol_id": "user-service", "is_test": False},
|
||||
location=CodeLocation(path="app/users/service.py", start_line=10, end_line=11),
|
||||
),
|
||||
LayeredRetrievalItem(
|
||||
source="app/users/repository.py",
|
||||
content="def get_user_repo():\n return {}",
|
||||
layer="C0_SOURCE_CHUNKS",
|
||||
title="get_user_repo",
|
||||
metadata={"symbol_id": "user-repo", "is_test": False},
|
||||
location=CodeLocation(path="app/users/repository.py", start_line=20, end_line=21),
|
||||
),
|
||||
],
|
||||
missing=[],
|
||||
)
|
||||
return SimpleNamespace(
|
||||
items=[
|
||||
LayeredRetrievalItem(
|
||||
source="tests/test_users.py",
|
||||
content="def test_get_user():\n assert service.get_user()",
|
||||
layer="C0_SOURCE_CHUNKS",
|
||||
title="test_get_user",
|
||||
metadata={"symbol_id": "test-user", "is_test": True},
|
||||
location=CodeLocation(path="tests/test_users.py", start_line=5, end_line=6),
|
||||
)
|
||||
],
|
||||
missing=[],
|
||||
)
|
||||
|
||||
|
||||
class _TestsOnlyGateway(_ProductionFirstGateway):
|
||||
def retrieve_lexical_code(
|
||||
self,
|
||||
rag_session_id: str,
|
||||
query: str,
|
||||
*,
|
||||
limit: int,
|
||||
path_prefixes: list[str] | None = None,
|
||||
exclude_tests: bool = True,
|
||||
include_spans: bool = False,
|
||||
):
|
||||
self.lexical_calls.append(exclude_tests)
|
||||
if exclude_tests:
|
||||
return SimpleNamespace(items=[], missing=[])
|
||||
return SimpleNamespace(
|
||||
items=[
|
||||
LayeredRetrievalItem(
|
||||
source="tests/test_users.py",
|
||||
content="def test_get_user():\n assert service.get_user()",
|
||||
layer="C0_SOURCE_CHUNKS",
|
||||
title="test_get_user",
|
||||
metadata={"symbol_id": "test-user", "is_test": True},
|
||||
location=CodeLocation(path="tests/test_users.py", start_line=5, end_line=6),
|
||||
)
|
||||
],
|
||||
missing=[],
|
||||
)
|
||||
|
||||
|
||||
class _FakeGraphRepository:
|
||||
def get_symbols_by_ids(self, rag_session_id: str, symbol_ids: list[str]):
|
||||
return []
|
||||
|
||||
def get_chunks_by_symbol_ids(self, rag_session_id: str, symbol_ids: list[str], prefer_chunk_type: str = "symbol_block"):
|
||||
return []
|
||||
|
||||
def get_out_edges(self, rag_session_id: str, src_symbol_ids: list[str], edge_types: list[str], limit_per_src: int):
|
||||
return []
|
||||
|
||||
def resolve_symbol_by_ref(self, rag_session_id: str, dst_ref: str, package_hint: str | None = None):
|
||||
return None
|
||||
|
||||
|
||||
def test_retriever_prefers_prod_chunks_and_skips_test_fallback_when_enough_evidence() -> None:
|
||||
gateway = _ProductionFirstGateway()
|
||||
retriever = CodeExplainRetrieverV2(gateway=gateway, graph_repository=_FakeGraphRepository())
|
||||
|
||||
pack = retriever.build_pack("rag-1", "Explain get_user")
|
||||
|
||||
assert gateway.lexical_calls == [True]
|
||||
assert [excerpt.path for excerpt in pack.code_excerpts] == [
|
||||
"app/users/service.py",
|
||||
"app/users/repository.py",
|
||||
]
|
||||
assert all(not excerpt.focus.startswith("test:") for excerpt in pack.code_excerpts)
|
||||
|
||||
|
||||
def test_retriever_uses_test_fallback_when_production_evidence_is_missing() -> None:
|
||||
gateway = _TestsOnlyGateway()
|
||||
retriever = CodeExplainRetrieverV2(gateway=gateway, graph_repository=_FakeGraphRepository())
|
||||
|
||||
pack = retriever.build_pack("rag-1", "Explain get_user")
|
||||
|
||||
assert gateway.lexical_calls == [True, False]
|
||||
assert [excerpt.path for excerpt in pack.code_excerpts] == ["tests/test_users.py"]
|
||||
assert pack.code_excerpts[0].focus == "test:lexical"
|
||||
83
tests/unit_tests/rag/test_trace_builder.py
Normal file
83
tests/unit_tests/rag/test_trace_builder.py
Normal file
@@ -0,0 +1,83 @@
|
||||
from app.modules.rag.explain.models import CodeLocation, LayeredRetrievalItem
|
||||
from app.modules.rag.explain.trace_builder import TraceBuilder
|
||||
|
||||
|
||||
class _FakeGraphRepository:
|
||||
def get_out_edges(self, rag_session_id: str, src_symbol_ids: list[str], edge_types: list[str], limit_per_src: int):
|
||||
assert rag_session_id == "rag-1"
|
||||
assert edge_types == ["calls", "imports", "inherits", "instantiates", "reads_attr", "writes_attr"]
|
||||
if src_symbol_ids == ["handler-1"]:
|
||||
return [
|
||||
LayeredRetrievalItem(
|
||||
source="app/api/users.py",
|
||||
content="handler calls get_user",
|
||||
layer="C2_DEPENDENCY_GRAPH",
|
||||
title="handler:calls",
|
||||
metadata={
|
||||
"src_symbol_id": "handler-1",
|
||||
"dst_symbol_id": None,
|
||||
"dst_ref": "UserService.get_user",
|
||||
"resolution": "partial",
|
||||
"edge_type": "calls",
|
||||
},
|
||||
location=CodeLocation(path="app/api/users.py", start_line=12, end_line=12),
|
||||
)
|
||||
]
|
||||
return []
|
||||
|
||||
def resolve_symbol_by_ref(self, rag_session_id: str, dst_ref: str, package_hint: str | None = None):
|
||||
assert rag_session_id == "rag-1"
|
||||
assert dst_ref == "UserService.get_user"
|
||||
assert package_hint == "app.api"
|
||||
return LayeredRetrievalItem(
|
||||
source="app/services/users.py",
|
||||
content="method UserService.get_user",
|
||||
layer="C1_SYMBOL_CATALOG",
|
||||
title="UserService.get_user",
|
||||
metadata={
|
||||
"symbol_id": "service-1",
|
||||
"package_or_module": "app.api.users",
|
||||
},
|
||||
location=CodeLocation(path="app/services/users.py", start_line=4, end_line=10),
|
||||
)
|
||||
|
||||
def get_symbols_by_ids(self, rag_session_id: str, symbol_ids: list[str]):
|
||||
assert rag_session_id == "rag-1"
|
||||
if symbol_ids == ["service-1"]:
|
||||
return [
|
||||
LayeredRetrievalItem(
|
||||
source="app/services/users.py",
|
||||
content="method UserService.get_user",
|
||||
layer="C1_SYMBOL_CATALOG",
|
||||
title="UserService.get_user",
|
||||
metadata={
|
||||
"symbol_id": "service-1",
|
||||
"package_or_module": "app.api.users",
|
||||
},
|
||||
location=CodeLocation(path="app/services/users.py", start_line=4, end_line=10),
|
||||
)
|
||||
]
|
||||
return []
|
||||
|
||||
|
||||
def test_trace_builder_resolves_partial_edges_across_files() -> None:
|
||||
builder = TraceBuilder(_FakeGraphRepository())
|
||||
seeds = [
|
||||
LayeredRetrievalItem(
|
||||
source="app/api/users.py",
|
||||
content="function handler",
|
||||
layer="C1_SYMBOL_CATALOG",
|
||||
title="get_user",
|
||||
metadata={
|
||||
"symbol_id": "handler-1",
|
||||
"package_or_module": "app.api.users",
|
||||
},
|
||||
location=CodeLocation(path="app/api/users.py", start_line=10, end_line=18),
|
||||
)
|
||||
]
|
||||
|
||||
paths = builder.build_paths("rag-1", seeds, max_depth=3)
|
||||
|
||||
assert len(paths) >= 1
|
||||
assert paths[0].symbol_ids == ["handler-1", "service-1"]
|
||||
assert "resolved:UserService.get_user" in paths[0].notes
|
||||
Reference in New Issue
Block a user