Files
agent/tests/unit_tests/rag/intent_router_testkit.py

47 lines
1.6 KiB
Python

from __future__ import annotations
import json
from app.core.rag.contracts.enums import RagLayer
from app.core.agent.intent_router import ConversationState, IntentRouterV2, RepoContext
def repo_context() -> RepoContext:
return RepoContext(
languages=["python"],
available_domains=["DOCS", "GENERAL"],
available_layers=[
RagLayer.DOCS_DOC_CHUNKS,
RagLayer.DOCS_DOCUMENT_CATALOG,
RagLayer.DOCS_FACT_INDEX,
RagLayer.DOCS_ENTITY_CATALOG,
RagLayer.DOCS_WORKFLOW_INDEX,
RagLayer.DOCS_RELATION_GRAPH,
RagLayer.DOCS_INTEGRATION_INDEX,
],
)
def run_sequence(queries: list[str], *, router: IntentRouterV2 | None = None, trace_label: str = "intent-router") -> list:
active_router = router or IntentRouterV2()
state = ConversationState()
results = []
for index, query in enumerate(queries, start=1):
result = active_router.route(query, state, repo_context())
print_trace(index, query, result, label=trace_label)
results.append(result)
state = state.advance(result)
return results
def run_single(query: str, *, router: IntentRouterV2 | None = None, trace_label: str = "intent-router"):
result = run_sequence([query], router=router, trace_label=trace_label)[0]
return result
def print_trace(index: int, query: str, result, *, label: str = "intent-router") -> None:
print(f"[{label}][turn {index}] input: {query}")
print()
print(f"[{label}][turn {index}] output: {json.dumps(result.model_dump(), ensure_ascii=False)}")
print("=" * 50)