Фиксация изменений
This commit is contained in:
Binary file not shown.
Binary file not shown.
@@ -1,43 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from app.modules.rag.contracts import RagLayer, RetrievalMode
|
||||
|
||||
|
||||
class RagQueryRouter:
|
||||
_CODE_HINTS = (
|
||||
"как работает код",
|
||||
"explain code",
|
||||
"explain the code",
|
||||
"по коду",
|
||||
"из кода",
|
||||
"построй документацию по коду",
|
||||
"документацию по коду",
|
||||
"where is implemented",
|
||||
"где реализовано",
|
||||
"endpoint",
|
||||
"handler",
|
||||
"symbol",
|
||||
"function",
|
||||
"class",
|
||||
"method",
|
||||
)
|
||||
|
||||
_DOCS_LAYERS = [
|
||||
RagLayer.DOCS_MODULE_CATALOG,
|
||||
RagLayer.DOCS_FACT_INDEX,
|
||||
RagLayer.DOCS_SECTION_INDEX,
|
||||
RagLayer.DOCS_POLICY_INDEX,
|
||||
]
|
||||
_CODE_LAYERS = [
|
||||
RagLayer.CODE_ENTRYPOINTS,
|
||||
RagLayer.CODE_SYMBOL_CATALOG,
|
||||
RagLayer.CODE_DEPENDENCY_GRAPH,
|
||||
RagLayer.CODE_SOURCE_CHUNKS,
|
||||
]
|
||||
|
||||
def resolve_mode(self, query: str) -> str:
|
||||
lowered = query.lower()
|
||||
return RetrievalMode.CODE if any(hint in lowered for hint in self._CODE_HINTS) else RetrievalMode.DOCS
|
||||
|
||||
def layers_for_mode(self, mode: str) -> list[str]:
|
||||
return list(self._CODE_LAYERS if mode == RetrievalMode.CODE else self._DOCS_LAYERS)
|
||||
97
app/modules/rag/retrieval/test_filter.py
Normal file
97
app/modules/rag/retrieval/test_filter.py
Normal file
@@ -0,0 +1,97 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from fnmatch import fnmatch
|
||||
from typing import Iterable
|
||||
|
||||
DEFAULT_TEST_PATH_PATTERNS = (
|
||||
"tests/",
|
||||
"test/",
|
||||
"__tests__/",
|
||||
"mocks/",
|
||||
"fixtures/",
|
||||
"stubs/",
|
||||
"conftest.py",
|
||||
"*_test.*",
|
||||
"*.test.*",
|
||||
"*.spec.*",
|
||||
)
|
||||
|
||||
_TRUE_VALUES = {"1", "true", "yes", "on"}
|
||||
_SAFE_PATTERN_RE = re.compile(r"^[A-Za-z0-9_./*?-]+$")
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class RetrievalPathFilter:
|
||||
exclude_path_prefixes: list[str]
|
||||
exclude_like_patterns: list[str]
|
||||
|
||||
|
||||
def exclude_tests_default() -> bool:
|
||||
return os.getenv("RAG_EXCLUDE_TESTS_DEFAULT", "true").strip().lower() in _TRUE_VALUES
|
||||
|
||||
|
||||
def debug_disable_test_filter() -> bool:
|
||||
return os.getenv("RAG_DEBUG_DISABLE_TEST_FILTER", "false").strip().lower() in _TRUE_VALUES
|
||||
|
||||
|
||||
def configured_test_patterns() -> list[str]:
|
||||
raw = os.getenv("RAG_TEST_PATH_PATTERNS", "")
|
||||
if not raw.strip():
|
||||
return list(DEFAULT_TEST_PATH_PATTERNS)
|
||||
return [item.strip() for item in raw.split(",") if item.strip()]
|
||||
|
||||
|
||||
def build_test_filters(patterns: Iterable[str] | None = None) -> RetrievalPathFilter:
|
||||
prefixes: list[str] = []
|
||||
like_patterns: list[str] = []
|
||||
for pattern in _validated_patterns(patterns or configured_test_patterns()):
|
||||
if pattern.endswith("/"):
|
||||
_append(prefixes, pattern)
|
||||
_append(like_patterns, f"%/{pattern}%")
|
||||
continue
|
||||
sql_like = _glob_to_sql_like(pattern)
|
||||
_append(like_patterns, sql_like)
|
||||
if "/" not in pattern:
|
||||
_append(like_patterns, f"%/{sql_like}")
|
||||
return RetrievalPathFilter(exclude_path_prefixes=prefixes, exclude_like_patterns=like_patterns)
|
||||
|
||||
|
||||
def is_test_path(path: str, patterns: Iterable[str] | None = None) -> bool:
|
||||
normalized = (path or "").strip().lower()
|
||||
if not normalized:
|
||||
return False
|
||||
for pattern in _validated_patterns(patterns or configured_test_patterns()):
|
||||
if pattern.endswith("/"):
|
||||
token = pattern.rstrip("/")
|
||||
if normalized.startswith(pattern) or f"/{token}/" in normalized:
|
||||
return True
|
||||
continue
|
||||
if fnmatch(normalized, pattern) or fnmatch(normalized, f"*/{pattern}"):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _validated_patterns(patterns: Iterable[str]) -> list[str]:
|
||||
result: list[str] = []
|
||||
for raw_pattern in patterns:
|
||||
pattern = (raw_pattern or "").strip().lower()
|
||||
if not pattern:
|
||||
continue
|
||||
if not _SAFE_PATTERN_RE.fullmatch(pattern):
|
||||
continue
|
||||
if pattern not in result:
|
||||
result.append(pattern)
|
||||
return result
|
||||
|
||||
|
||||
def _glob_to_sql_like(pattern: str) -> str:
|
||||
escaped = pattern.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_")
|
||||
return escaped.replace("*", "%").replace("?", "_")
|
||||
|
||||
|
||||
def _append(values: list[str], item: str) -> None:
|
||||
if item and item not in values:
|
||||
values.append(item)
|
||||
Reference in New Issue
Block a user