Files
agent/tests/unit_tests/rag/test_intent_router_invariants.py

122 lines
4.6 KiB
Python

import pytest
from tests.unit_tests.rag.intent_router_testkit import run_sequence
pytestmark = pytest.mark.intent_router
def _docs_result(query: str):
result = run_sequence([query])[0]
assert result.docs_routing is not None
assert result.retrieval_plan is not None
return result
@pytest.mark.parametrize(
("query", "endpoint"),
[
("как работает метод health", "/health"),
("объясни /health", "/health"),
("что делает endpoint /send", "/send"),
],
)
def test_docs_mvp_api_method_explain_cases(query: str, endpoint: str) -> None:
result = _docs_result(query)
assert result.docs_routing.sub_intent == "API_METHOD_EXPLAIN"
assert result.docs_routing.intent == "DOCS_QA"
assert result.docs_routing.scope.level == "method"
assert result.docs_routing.anchors.endpoint_path == endpoint
assert result.retrieval_plan.plan_id == "docs_api_method_explain_v1"
assert result.retrieval_plan.filters["endpoint_path"] == endpoint
@pytest.mark.parametrize(
("query", "scope_level", "domain_name"),
[
("какие есть методы в проекте", "project", None),
("покажи все api", "project", None),
("какие методы в notifications", "domain", "notifications"),
],
)
def test_docs_mvp_list_api_methods_cases(query: str, scope_level: str, domain_name: str | None) -> None:
result = _docs_result(query)
assert result.docs_routing.sub_intent == "LIST_API_METHODS"
assert result.docs_routing.intent == "DOCS_DISCOVERY"
assert result.docs_routing.scope.level == scope_level
assert result.retrieval_plan.plan_id == "docs_list_api_methods_v1"
assert result.retrieval_plan.primary_doc_types == ["api_method"]
if domain_name:
assert result.retrieval_plan.filters["domain_name"] == domain_name
@pytest.mark.parametrize(
("query", "domain_name", "subdomain_name", "entity_name"),
[
("какие документы есть по notifications", "notifications", None, None),
("найди документацию по telegram_delivery", "telegram_delivery", None, None),
("какие документы связаны с health", None, None, "health"),
],
)
def test_docs_mvp_find_documents_cases(
query: str,
domain_name: str | None,
subdomain_name: str | None,
entity_name: str | None,
) -> None:
result = _docs_result(query)
assert result.docs_routing.sub_intent == "FIND_DOCUMENTS_BY_DOMAIN"
assert result.docs_routing.intent == "DOCS_DISCOVERY"
assert result.retrieval_plan.plan_id == "docs_find_documents_by_domain_v1"
if domain_name:
assert result.retrieval_plan.filters["domain_name"] == domain_name
if subdomain_name:
assert result.retrieval_plan.filters["subdomain_name"] == subdomain_name
if entity_name:
assert result.retrieval_plan.filters["entity_name"] == entity_name
@pytest.mark.parametrize(
("query", "scope_level", "endpoint"),
[
("сгенерируй openapi по /health", "method", "/health"),
("собери swagger по notifications", "domain", None),
("сделай спецификацию api по всему проекту", "project", None),
],
)
def test_docs_mvp_generate_openapi_cases(query: str, scope_level: str, endpoint: str | None) -> None:
result = _docs_result(query)
assert result.docs_routing.sub_intent == "GENERATE_OPENAPI"
assert result.docs_routing.intent == "DOCS_GENERATION"
assert result.docs_routing.scope.level == scope_level
assert result.retrieval_plan.plan_id == "docs_generate_openapi_v1"
if endpoint:
assert result.retrieval_plan.filters["endpoint_path"] == endpoint
@pytest.mark.parametrize(
"query",
[
"что делает это приложение",
"как устроен сервис",
"как связаны worker и api",
],
)
def test_docs_mvp_general_docs_qa_cases(query: str) -> None:
result = _docs_result(query)
assert result.docs_routing.sub_intent == "GENERAL_DOCS_QA"
assert result.docs_routing.intent == "DOCS_FALLBACK"
assert result.retrieval_plan.plan_id == "docs_general_docs_qa_v1"
def test_docs_mvp_retrieval_filters_are_merged_into_legacy_spec() -> None:
result = _docs_result("какие методы в notifications")
assert getattr(result.retrieval_spec.filters, "doc_type", None) == "api_method"
assert getattr(result.retrieval_spec.filters, "domain_name", None) == "notifications"
assert getattr(result.retrieval_spec.filters, "scope_level", None) == "domain"