83 lines
3.1 KiB
Python
83 lines
3.1 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
|
|
from app.core.agent.intent_router import ConversationState, IntentRouterV2
|
|
from app.core.agent.intent_router.docs_mvp.llm_classifier import DocsMvpLlmClassifier
|
|
from app.core.agent.intent_router.intent.classifier import IntentClassifierV2
|
|
from tests.unit_tests.rag.intent_router_testkit import repo_context
|
|
|
|
|
|
class FakeLlm:
|
|
def __init__(self, response: str, *, fail: bool = False) -> None:
|
|
self.response = response
|
|
self.fail = fail
|
|
self.calls: list[tuple[str, str]] = []
|
|
|
|
def generate(self, prompt_name: str, user_input: str, *, log_context: str | None = None) -> str:
|
|
self.calls.append((prompt_name, user_input))
|
|
if self.fail:
|
|
raise RuntimeError("llm unavailable")
|
|
return self.response
|
|
|
|
|
|
def _router(llm: FakeLlm) -> IntentRouterV2:
|
|
return IntentRouterV2(
|
|
classifier=IntentClassifierV2(),
|
|
docs_llm_classifier=DocsMvpLlmClassifier(llm),
|
|
enable_llm_disambiguation=True,
|
|
)
|
|
|
|
|
|
def test_docs_technical_query_keeps_deterministic_routing_without_llm_call() -> None:
|
|
llm = FakeLlm("{}")
|
|
result = _router(llm).route("Объясни endpoint /health", ConversationState(), repo_context())
|
|
|
|
assert result.docs_routing is not None
|
|
assert result.docs_routing.sub_intent == "API_METHOD_EXPLAIN"
|
|
assert result.docs_routing.routing_mode == "deterministic"
|
|
assert result.llm_router_used is False
|
|
assert llm.calls == []
|
|
|
|
|
|
def test_ambiguous_query_can_be_resolved_by_llm() -> None:
|
|
llm = FakeLlm(
|
|
json.dumps(
|
|
{
|
|
"intent": "DOCS_DISCOVERY",
|
|
"sub_intent": "FIND_DOCUMENTS_BY_DOMAIN",
|
|
"confidence": 0.83,
|
|
"anchors": {"entity_name": "health", "doc_query": "документация по health"},
|
|
"scope": {"level": "domain"},
|
|
"reason_short": "health here is a docs topic",
|
|
},
|
|
ensure_ascii=False,
|
|
)
|
|
)
|
|
result = _router(llm).route("документация по health", ConversationState(), repo_context())
|
|
|
|
assert result.docs_routing is not None
|
|
assert result.docs_routing.routing_mode == "llm_assisted"
|
|
assert result.docs_routing.sub_intent == "FIND_DOCUMENTS_BY_DOMAIN"
|
|
assert result.retrieval_plan is not None
|
|
assert result.retrieval_plan.plan_id == "docs_find_documents_by_domain_v1"
|
|
assert result.llm_router_used is True
|
|
|
|
|
|
def test_ambiguous_query_falls_back_to_general_docs_when_llm_fails() -> None:
|
|
llm = FakeLlm("{}", fail=True)
|
|
result = _router(llm).route("health документация", ConversationState(), repo_context())
|
|
|
|
assert result.docs_routing is not None
|
|
assert result.docs_routing.routing_mode == "llm_fallback"
|
|
assert result.docs_routing.sub_intent == "GENERAL_DOCS_QA"
|
|
assert result.retrieval_plan is not None
|
|
assert result.retrieval_plan.plan_id == "docs_general_docs_qa_v1"
|
|
|
|
|
|
def test_llm_classifier_rejects_unknown_labels() -> None:
|
|
llm = FakeLlm(json.dumps({"intent": "DOCS_QA", "sub_intent": "MADE_UP"}))
|
|
classifier = DocsMvpLlmClassifier(llm)
|
|
|
|
assert classifier.classify({"query": "test"}) is None
|