114 lines
4.8 KiB
Python
114 lines
4.8 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
import re
|
|
|
|
from app.modules.rag.intent_router_v2.models import ConversationState, IntentDecision
|
|
from app.modules.rag.intent_router_v2.protocols import TextGenerator
|
|
from app.modules.rag.intent_router_v2.test_signals import has_test_focus
|
|
|
|
_CODE_FILE_PATH_RE = re.compile(
|
|
r"\b(?:[\w.-]+/)*[\w.-]+\.(?:py|js|jsx|ts|tsx|java|kt|go|rb|php|c|cc|cpp|h|hpp|cs|swift|rs)(?!\w)\b",
|
|
re.IGNORECASE,
|
|
)
|
|
|
|
|
|
class IntentClassifierV2:
|
|
_GENERATE_DOCS_MARKERS = (
|
|
"сгенерируй документац",
|
|
"подготовь документац",
|
|
"создай документац",
|
|
"генерац",
|
|
"generate documentation",
|
|
"write documentation",
|
|
)
|
|
_DOCS_MARKERS = ("документац", "readme", "docs/", ".md", "spec", "runbook", "markdown")
|
|
_CODE_MARKERS = ("по коду", "код", "класс", "метод", "функц", "модул", "пакет", "файл", "block", "блок", "handler", "endpoint")
|
|
|
|
def __init__(self, llm: TextGenerator | None = None) -> None:
|
|
self._llm = llm
|
|
|
|
def classify(self, user_query: str, conversation_state: ConversationState) -> IntentDecision:
|
|
deterministic = self._deterministic(user_query)
|
|
if deterministic:
|
|
return deterministic
|
|
llm_decision = self._classify_with_llm(user_query, conversation_state)
|
|
if llm_decision:
|
|
return llm_decision
|
|
return IntentDecision(intent="PROJECT_MISC", confidence=0.55, reason="fallback_project_misc")
|
|
|
|
def _deterministic(self, user_query: str) -> IntentDecision | None:
|
|
text = " ".join((user_query or "").lower().split())
|
|
if any(marker in text for marker in self._GENERATE_DOCS_MARKERS):
|
|
return IntentDecision(intent="GENERATE_DOCS_FROM_CODE", confidence=0.97, reason="deterministic_generate_docs")
|
|
if self._looks_like_docs_question(text):
|
|
return IntentDecision(intent="DOCS_QA", confidence=0.9, reason="deterministic_docs")
|
|
if self._looks_like_code_question(user_query, text):
|
|
return IntentDecision(intent="CODE_QA", confidence=0.9, reason="deterministic_code")
|
|
return None
|
|
|
|
def _classify_with_llm(self, user_query: str, conversation_state: ConversationState) -> IntentDecision | None:
|
|
if self._llm is None:
|
|
return None
|
|
payload = json.dumps(
|
|
{
|
|
"message": user_query,
|
|
"active_intent": conversation_state.active_intent,
|
|
"last_query": conversation_state.last_query,
|
|
"allowed_intents": ["CODE_QA", "DOCS_QA", "GENERATE_DOCS_FROM_CODE", "PROJECT_MISC"],
|
|
},
|
|
ensure_ascii=False,
|
|
)
|
|
try:
|
|
raw = self._llm.generate("rag_intent_router_v2", payload, log_context="rag.intent_router_v2.classify").strip()
|
|
except Exception:
|
|
return None
|
|
parsed = self._parse(raw)
|
|
if parsed is None:
|
|
return None
|
|
return parsed
|
|
|
|
def _parse(self, raw: str) -> IntentDecision | None:
|
|
candidate = self._strip_code_fence(raw)
|
|
try:
|
|
payload = json.loads(candidate)
|
|
except json.JSONDecodeError:
|
|
return None
|
|
intent = str(payload.get("intent") or "").strip().upper()
|
|
if intent not in {"CODE_QA", "DOCS_QA", "GENERATE_DOCS_FROM_CODE", "PROJECT_MISC"}:
|
|
return None
|
|
return IntentDecision(
|
|
intent=intent,
|
|
confidence=float(payload.get("confidence") or 0.7),
|
|
reason=str(payload.get("reason") or "llm").strip() or "llm",
|
|
)
|
|
|
|
def _strip_code_fence(self, text: str) -> str:
|
|
if not text.startswith("```"):
|
|
return text
|
|
lines = text.splitlines()
|
|
if len(lines) < 3 or lines[-1].strip() != "```":
|
|
return text
|
|
return "\n".join(lines[1:-1]).strip()
|
|
|
|
def _looks_like_docs_question(self, text: str) -> bool:
|
|
if self._has_code_file_path(text):
|
|
return False
|
|
return any(marker in text for marker in self._DOCS_MARKERS)
|
|
|
|
def _looks_like_code_question(self, raw_text: str, lowered: str) -> bool:
|
|
if self._has_code_file_path(raw_text):
|
|
return True
|
|
if has_test_focus(lowered):
|
|
return True
|
|
if any(marker in lowered for marker in self._DOCS_MARKERS) and not any(marker in lowered for marker in self._CODE_MARKERS):
|
|
return False
|
|
if any(marker in lowered for marker in self._CODE_MARKERS):
|
|
return True
|
|
if re.search(r"\b[A-Z][A-Za-z0-9_]{2,}(?:\.[A-Za-z_][A-Za-z0-9_]*)*\b", raw_text or ""):
|
|
return True
|
|
return bool(re.search(r"\b[a-z_][A-Za-z0-9_]{2,}\(", raw_text or ""))
|
|
|
|
def _has_code_file_path(self, text: str) -> bool:
|
|
return bool(_CODE_FILE_PATH_RE.search(text or ""))
|