Files
agent/app/modules/rag/intent_router_v2/models.py

183 lines
5.9 KiB
Python

from __future__ import annotations
import re
from typing import Literal
from pydantic import BaseModel, ConfigDict, Field, field_validator
IntentType = Literal["CODE_QA", "DOCS_QA", "GENERATE_DOCS_FROM_CODE", "PROJECT_MISC"]
ConversationMode = Literal["START", "CONTINUE", "SWITCH"]
AnchorType = Literal["FILE_PATH", "SYMBOL", "DOC_REF", "KEY_TERM"]
AnchorSource = Literal["user_text", "conversation_state", "heuristic"]
_INLINE_CODE_RE = re.compile(r"`([^`]*)`")
_CODE_SYMBOL_RE = re.compile(r"\b([A-Za-z_][A-Za-z0-9_]{2,})\b")
class AnchorSpan(BaseModel):
model_config = ConfigDict(extra="forbid")
start: int = 0
end: int = 0
class QueryAnchor(BaseModel):
model_config = ConfigDict(extra="forbid")
type: AnchorType
value: str
source: AnchorSource = "user_text"
subtype: str | None = None
span: AnchorSpan | None = None
confidence: float = 0.0
@field_validator("confidence")
@classmethod
def clamp_confidence(cls, value: float) -> float:
return max(0.0, min(1.0, float(value)))
class QueryPlan(BaseModel):
model_config = ConfigDict(extra="forbid")
raw: str
normalized: str
sub_intent: str = "EXPLAIN"
negations: list[str] = Field(default_factory=list)
expansions: list[str] = Field(default_factory=list)
keyword_hints: list[str] = Field(default_factory=list)
anchors: list[QueryAnchor] = Field(default_factory=list)
class LayerQuery(BaseModel):
model_config = ConfigDict(extra="forbid")
layer_id: str
top_k: int
class CodeRetrievalFilters(BaseModel):
model_config = ConfigDict(extra="forbid")
test_policy: str = "EXCLUDE"
path_scope: list[str] = Field(default_factory=list)
language: list[str] = Field(default_factory=list)
class DocsRetrievalFilters(BaseModel):
model_config = ConfigDict(extra="forbid")
path_scope: list[str] = Field(default_factory=list)
doc_kinds: list[str] = Field(default_factory=list)
doc_language: list[str] = Field(default_factory=list)
class HybridRetrievalFilters(BaseModel):
model_config = ConfigDict(extra="forbid")
test_policy: str = "EXCLUDE"
path_scope: list[str] = Field(default_factory=list)
language: list[str] = Field(default_factory=list)
doc_kinds: list[str] = Field(default_factory=list)
doc_language: list[str] = Field(default_factory=list)
class RetrievalSpec(BaseModel):
model_config = ConfigDict(extra="forbid")
domains: list[str] = Field(default_factory=list)
layer_queries: list[LayerQuery] = Field(default_factory=list)
filters: CodeRetrievalFilters | DocsRetrievalFilters | HybridRetrievalFilters = Field(default_factory=CodeRetrievalFilters)
rerank_profile: str = ""
class EvidencePolicy(BaseModel):
model_config = ConfigDict(extra="forbid")
require_def: bool = False
require_flow: bool = False
require_spec: bool = False
allow_answer_without_evidence: bool = False
class IntentRouterResult(BaseModel):
model_config = ConfigDict(extra="forbid")
schema_version: str = "1.1"
intent: IntentType
graph_id: str
conversation_mode: ConversationMode
query_plan: QueryPlan
retrieval_spec: RetrievalSpec
evidence_policy: EvidencePolicy
class ConversationState(BaseModel):
model_config = ConfigDict(extra="forbid")
active_intent: IntentType | None = None
active_domain: str | None = None
active_anchors: list[QueryAnchor] = Field(default_factory=list)
active_symbol: str | None = None
active_path_scope: list[str] = Field(default_factory=list)
active_code_span_symbols: list[str] = Field(default_factory=list)
last_query: str = ""
turn_index: int = 0
def advance(self, result: IntentRouterResult) -> "ConversationState":
user_anchors = [anchor for anchor in result.query_plan.anchors if anchor.source == "user_text"]
symbol_candidates = [anchor.value for anchor in user_anchors if anchor.type == "SYMBOL"]
has_user_file_anchor = any(anchor.type == "FILE_PATH" for anchor in user_anchors)
if symbol_candidates:
active_symbol = symbol_candidates[-1]
elif has_user_file_anchor:
active_symbol = None
else:
active_symbol = self.active_symbol
raw_code_symbols = _extract_code_symbols(result.query_plan.raw)
active_code_span_symbols = raw_code_symbols or list(self.active_code_span_symbols)
path_scope = list(getattr(result.retrieval_spec.filters, "path_scope", []) or [])
active_domains = list(result.retrieval_spec.domains or [])
active_domain = active_domains[0] if len(active_domains) == 1 else self.active_domain
return ConversationState(
active_intent=result.intent,
active_domain=active_domain,
active_anchors=list(user_anchors),
active_symbol=active_symbol,
active_path_scope=path_scope or list(self.active_path_scope),
active_code_span_symbols=active_code_span_symbols,
last_query=result.query_plan.raw,
turn_index=self.turn_index + 1,
)
class RepoContext(BaseModel):
model_config = ConfigDict(extra="forbid")
languages: list[str] = Field(default_factory=list)
available_domains: list[str] = Field(default_factory=lambda: ["CODE", "DOCS"])
available_layers: list[str] = Field(default_factory=list)
class IntentDecision(BaseModel):
model_config = ConfigDict(extra="forbid")
intent: IntentType
confidence: float = 0.0
reason: str = ""
@field_validator("confidence")
@classmethod
def clamp_confidence(cls, value: float) -> float:
return max(0.0, min(1.0, float(value)))
def _extract_code_symbols(raw: str) -> list[str]:
symbols: list[str] = []
for match in _INLINE_CODE_RE.finditer(raw or ""):
snippet = match.group(1)
for token in _CODE_SYMBOL_RE.findall(snippet):
if token not in symbols:
symbols.append(token)
return symbols[:8]