73 lines
3.3 KiB
Python
73 lines
3.3 KiB
Python
from __future__ import annotations
|
|
|
|
from app.modules.rag.intent_router_v2.classifier import IntentClassifierV2
|
|
from app.modules.rag.intent_router_v2.conversation_policy import ConversationPolicy
|
|
from app.modules.rag.intent_router_v2.evidence_policy_factory import EvidencePolicyFactory
|
|
from app.modules.rag.intent_router_v2.graph_id_resolver import GraphIdResolver
|
|
from app.modules.rag.intent_router_v2.logger import IntentRouterLogger
|
|
from app.modules.rag.intent_router_v2.models import ConversationState, IntentRouterResult, RepoContext
|
|
from app.modules.rag.intent_router_v2.query_plan_builder import QueryPlanBuilder
|
|
from app.modules.rag.intent_router_v2.retrieval_spec_factory import RetrievalSpecFactory
|
|
|
|
|
|
class IntentRouterV2:
|
|
def __init__(
|
|
self,
|
|
classifier: IntentClassifierV2 | None = None,
|
|
conversation_policy: ConversationPolicy | None = None,
|
|
query_plan_builder: QueryPlanBuilder | None = None,
|
|
retrieval_factory: RetrievalSpecFactory | None = None,
|
|
evidence_factory: EvidencePolicyFactory | None = None,
|
|
graph_resolver: GraphIdResolver | None = None,
|
|
logger: IntentRouterLogger | None = None,
|
|
) -> None:
|
|
self._classifier = classifier or IntentClassifierV2()
|
|
self._conversation_policy = conversation_policy or ConversationPolicy()
|
|
self._query_plan_builder = query_plan_builder or QueryPlanBuilder()
|
|
self._retrieval_factory = retrieval_factory or RetrievalSpecFactory()
|
|
self._evidence_factory = evidence_factory or EvidencePolicyFactory()
|
|
self._graph_resolver = graph_resolver or GraphIdResolver()
|
|
self._logger = logger or IntentRouterLogger()
|
|
|
|
def route(
|
|
self,
|
|
user_query: str,
|
|
conversation_state: ConversationState | None = None,
|
|
repo_context: RepoContext | None = None,
|
|
) -> IntentRouterResult:
|
|
state = conversation_state or ConversationState()
|
|
context = repo_context or RepoContext()
|
|
self._logger.log_request(user_query, state, context)
|
|
decision = self._classifier.classify(user_query, state)
|
|
intent, conversation_mode = self._conversation_policy.resolve(decision, user_query, state)
|
|
query_plan = self._query_plan_builder.build(
|
|
user_query,
|
|
state,
|
|
continue_mode=conversation_mode == "CONTINUE",
|
|
conversation_mode=conversation_mode,
|
|
intent=intent,
|
|
)
|
|
result = IntentRouterResult(
|
|
intent=intent,
|
|
graph_id=self._graph_resolver.resolve(intent),
|
|
conversation_mode=conversation_mode,
|
|
query_plan=query_plan,
|
|
retrieval_spec=self._retrieval_factory.build(
|
|
intent,
|
|
query_plan.anchors,
|
|
context,
|
|
raw_query=query_plan.raw,
|
|
conversation_state=state,
|
|
conversation_mode=conversation_mode,
|
|
sub_intent=query_plan.sub_intent,
|
|
),
|
|
evidence_policy=self._evidence_factory.build(
|
|
intent,
|
|
sub_intent=query_plan.sub_intent,
|
|
negations=query_plan.negations,
|
|
has_user_anchor=any(anchor.source == "user_text" for anchor in query_plan.anchors),
|
|
),
|
|
)
|
|
self._logger.log_result(result)
|
|
return result
|