183 lines
5.9 KiB
Python
183 lines
5.9 KiB
Python
from __future__ import annotations
|
|
|
|
import re
|
|
from typing import Literal
|
|
|
|
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
|
|
|
|
|
IntentType = Literal["CODE_QA", "DOCS_QA", "GENERATE_DOCS_FROM_CODE", "PROJECT_MISC"]
|
|
ConversationMode = Literal["START", "CONTINUE", "SWITCH"]
|
|
AnchorType = Literal["FILE_PATH", "SYMBOL", "DOC_REF", "KEY_TERM"]
|
|
AnchorSource = Literal["user_text", "conversation_state", "heuristic"]
|
|
_INLINE_CODE_RE = re.compile(r"`([^`]*)`")
|
|
_CODE_SYMBOL_RE = re.compile(r"\b([A-Za-z_][A-Za-z0-9_]{2,})\b")
|
|
|
|
|
|
class AnchorSpan(BaseModel):
|
|
model_config = ConfigDict(extra="forbid")
|
|
|
|
start: int = 0
|
|
end: int = 0
|
|
|
|
|
|
class QueryAnchor(BaseModel):
|
|
model_config = ConfigDict(extra="forbid")
|
|
|
|
type: AnchorType
|
|
value: str
|
|
source: AnchorSource = "user_text"
|
|
subtype: str | None = None
|
|
span: AnchorSpan | None = None
|
|
confidence: float = 0.0
|
|
|
|
@field_validator("confidence")
|
|
@classmethod
|
|
def clamp_confidence(cls, value: float) -> float:
|
|
return max(0.0, min(1.0, float(value)))
|
|
|
|
|
|
class QueryPlan(BaseModel):
|
|
model_config = ConfigDict(extra="forbid")
|
|
|
|
raw: str
|
|
normalized: str
|
|
sub_intent: str = "EXPLAIN"
|
|
negations: list[str] = Field(default_factory=list)
|
|
expansions: list[str] = Field(default_factory=list)
|
|
keyword_hints: list[str] = Field(default_factory=list)
|
|
anchors: list[QueryAnchor] = Field(default_factory=list)
|
|
|
|
|
|
class LayerQuery(BaseModel):
|
|
model_config = ConfigDict(extra="forbid")
|
|
|
|
layer_id: str
|
|
top_k: int
|
|
|
|
|
|
class CodeRetrievalFilters(BaseModel):
|
|
model_config = ConfigDict(extra="forbid")
|
|
|
|
test_policy: str = "EXCLUDE"
|
|
path_scope: list[str] = Field(default_factory=list)
|
|
language: list[str] = Field(default_factory=list)
|
|
|
|
|
|
class DocsRetrievalFilters(BaseModel):
|
|
model_config = ConfigDict(extra="forbid")
|
|
|
|
path_scope: list[str] = Field(default_factory=list)
|
|
doc_kinds: list[str] = Field(default_factory=list)
|
|
doc_language: list[str] = Field(default_factory=list)
|
|
|
|
|
|
class HybridRetrievalFilters(BaseModel):
|
|
model_config = ConfigDict(extra="forbid")
|
|
|
|
test_policy: str = "EXCLUDE"
|
|
path_scope: list[str] = Field(default_factory=list)
|
|
language: list[str] = Field(default_factory=list)
|
|
doc_kinds: list[str] = Field(default_factory=list)
|
|
doc_language: list[str] = Field(default_factory=list)
|
|
|
|
|
|
class RetrievalSpec(BaseModel):
|
|
model_config = ConfigDict(extra="forbid")
|
|
|
|
domains: list[str] = Field(default_factory=list)
|
|
layer_queries: list[LayerQuery] = Field(default_factory=list)
|
|
filters: CodeRetrievalFilters | DocsRetrievalFilters | HybridRetrievalFilters = Field(default_factory=CodeRetrievalFilters)
|
|
rerank_profile: str = ""
|
|
|
|
|
|
class EvidencePolicy(BaseModel):
|
|
model_config = ConfigDict(extra="forbid")
|
|
|
|
require_def: bool = False
|
|
require_flow: bool = False
|
|
require_spec: bool = False
|
|
allow_answer_without_evidence: bool = False
|
|
|
|
|
|
class IntentRouterResult(BaseModel):
|
|
model_config = ConfigDict(extra="forbid")
|
|
|
|
schema_version: str = "1.1"
|
|
intent: IntentType
|
|
graph_id: str
|
|
conversation_mode: ConversationMode
|
|
query_plan: QueryPlan
|
|
retrieval_spec: RetrievalSpec
|
|
evidence_policy: EvidencePolicy
|
|
|
|
|
|
class ConversationState(BaseModel):
|
|
model_config = ConfigDict(extra="forbid")
|
|
|
|
active_intent: IntentType | None = None
|
|
active_domain: str | None = None
|
|
active_anchors: list[QueryAnchor] = Field(default_factory=list)
|
|
active_symbol: str | None = None
|
|
active_path_scope: list[str] = Field(default_factory=list)
|
|
active_code_span_symbols: list[str] = Field(default_factory=list)
|
|
last_query: str = ""
|
|
turn_index: int = 0
|
|
|
|
def advance(self, result: IntentRouterResult) -> "ConversationState":
|
|
user_anchors = [anchor for anchor in result.query_plan.anchors if anchor.source == "user_text"]
|
|
symbol_candidates = [anchor.value for anchor in user_anchors if anchor.type == "SYMBOL"]
|
|
has_user_file_anchor = any(anchor.type == "FILE_PATH" for anchor in user_anchors)
|
|
if symbol_candidates:
|
|
active_symbol = symbol_candidates[-1]
|
|
elif has_user_file_anchor:
|
|
active_symbol = None
|
|
else:
|
|
active_symbol = self.active_symbol
|
|
raw_code_symbols = _extract_code_symbols(result.query_plan.raw)
|
|
active_code_span_symbols = raw_code_symbols or list(self.active_code_span_symbols)
|
|
path_scope = list(getattr(result.retrieval_spec.filters, "path_scope", []) or [])
|
|
active_domains = list(result.retrieval_spec.domains or [])
|
|
active_domain = active_domains[0] if len(active_domains) == 1 else self.active_domain
|
|
return ConversationState(
|
|
active_intent=result.intent,
|
|
active_domain=active_domain,
|
|
active_anchors=list(user_anchors),
|
|
active_symbol=active_symbol,
|
|
active_path_scope=path_scope or list(self.active_path_scope),
|
|
active_code_span_symbols=active_code_span_symbols,
|
|
last_query=result.query_plan.raw,
|
|
turn_index=self.turn_index + 1,
|
|
)
|
|
|
|
|
|
class RepoContext(BaseModel):
|
|
model_config = ConfigDict(extra="forbid")
|
|
|
|
languages: list[str] = Field(default_factory=list)
|
|
available_domains: list[str] = Field(default_factory=lambda: ["CODE", "DOCS"])
|
|
available_layers: list[str] = Field(default_factory=list)
|
|
|
|
|
|
class IntentDecision(BaseModel):
|
|
model_config = ConfigDict(extra="forbid")
|
|
|
|
intent: IntentType
|
|
confidence: float = 0.0
|
|
reason: str = ""
|
|
|
|
@field_validator("confidence")
|
|
@classmethod
|
|
def clamp_confidence(cls, value: float) -> float:
|
|
return max(0.0, min(1.0, float(value)))
|
|
|
|
|
|
def _extract_code_symbols(raw: str) -> list[str]:
|
|
symbols: list[str] = []
|
|
for match in _INLINE_CODE_RE.finditer(raw or ""):
|
|
snippet = match.group(1)
|
|
for token in _CODE_SYMBOL_RE.findall(snippet):
|
|
if token not in symbols:
|
|
symbols.append(token)
|
|
return symbols[:8]
|