Фикс состояния
This commit is contained in:
1
tests/pipeline_setup/utils/__init__.py
Normal file
1
tests/pipeline_setup/utils/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Shared low-level utilities for pipeline_setup, including the reusable RAG indexer."""
|
||||
11
tests/pipeline_setup/utils/rag_indexer/__init__.py
Normal file
11
tests/pipeline_setup/utils/rag_indexer/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from tests.pipeline_setup.utils.rag_indexer.indexer import (
|
||||
DeterministicEmbedder,
|
||||
LocalRepoFileCollector,
|
||||
RagSessionIndexer,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"DeterministicEmbedder",
|
||||
"LocalRepoFileCollector",
|
||||
"RagSessionIndexer",
|
||||
]
|
||||
108
tests/pipeline_setup/utils/rag_indexer/indexer.py
Normal file
108
tests/pipeline_setup/utils/rag_indexer/indexer.py
Normal file
@@ -0,0 +1,108 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import uuid4
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.modules.rag.persistence.repository import RagRepository
|
||||
|
||||
|
||||
class LocalRepoFileCollector:
|
||||
_SKIP_DIRS = {".git", ".venv", "venv", "__pycache__", ".pytest_cache", "node_modules"}
|
||||
_TEXT_EXTENSIONS = {
|
||||
".py", ".md", ".txt", ".rst", ".json", ".yaml", ".yml", ".toml", ".ini", ".cfg", ".env",
|
||||
".js", ".ts", ".tsx", ".jsx", ".sql", ".sh",
|
||||
}
|
||||
|
||||
def __init__(self, root: Path, max_bytes: int = 300_000) -> None:
|
||||
self._root = root
|
||||
self._max_bytes = max_bytes
|
||||
|
||||
def collect(self) -> list[dict]:
|
||||
files: list[dict] = []
|
||||
for path in sorted(self._root.rglob("*")):
|
||||
if not path.is_file() or self._should_skip(path):
|
||||
continue
|
||||
item = self._read_file(path)
|
||||
if item:
|
||||
files.append(item)
|
||||
return files
|
||||
|
||||
def _should_skip(self, path: Path) -> bool:
|
||||
rel_parts = path.relative_to(self._root).parts
|
||||
if any(part in self._SKIP_DIRS for part in rel_parts):
|
||||
return True
|
||||
if any(part.startswith(".") for part in rel_parts):
|
||||
return True
|
||||
if path.suffix.lower() not in self._TEXT_EXTENSIONS:
|
||||
return True
|
||||
return path.stat().st_size > self._max_bytes
|
||||
|
||||
def _read_file(self, path: Path) -> dict | None:
|
||||
raw = path.read_bytes()
|
||||
if b"\x00" in raw:
|
||||
return None
|
||||
content = raw.decode("utf-8", errors="ignore")
|
||||
return {
|
||||
"path": path.relative_to(self._root).as_posix(),
|
||||
"content": content,
|
||||
"content_hash": hashlib.sha256(content.encode("utf-8")).hexdigest(),
|
||||
}
|
||||
|
||||
|
||||
class DeterministicEmbedder:
|
||||
def __init__(self, dim: int = 64) -> None:
|
||||
self._dim = dim
|
||||
|
||||
def embed(self, texts: list[str]) -> list[list[float]]:
|
||||
return [self._embed_one(text) for text in texts]
|
||||
|
||||
def _embed_one(self, text: str) -> list[float]:
|
||||
digest = hashlib.sha256(text.encode("utf-8")).digest()
|
||||
values: list[float] = []
|
||||
while len(values) < self._dim:
|
||||
for byte in digest:
|
||||
values.append((byte / 127.5) - 1.0)
|
||||
if len(values) == self._dim:
|
||||
break
|
||||
digest = hashlib.sha256(digest).digest()
|
||||
return values
|
||||
|
||||
|
||||
class RagSessionIndexer:
|
||||
def __init__(self, repository: "RagRepository") -> None:
|
||||
from app.modules.rag.services.rag_service import RagService
|
||||
|
||||
self._repository = repository
|
||||
self._rag = RagService(embedder=DeterministicEmbedder(), repository=repository)
|
||||
|
||||
def index_repo(self, repo_path: Path, project_id: str | None = None) -> str:
|
||||
self._repository.ensure_tables()
|
||||
rag_session_id = str(uuid4())
|
||||
resolved_project_id = project_id or repo_path.name
|
||||
self._repository.upsert_session(rag_session_id, resolved_project_id)
|
||||
files = LocalRepoFileCollector(repo_path).collect()
|
||||
if not files:
|
||||
raise ValueError(f"No indexable text files found under: {repo_path}")
|
||||
logger = logging.getLogger("app.modules.rag.services.rag_service")
|
||||
previous_level = logger.level
|
||||
logger.setLevel(logging.ERROR)
|
||||
try:
|
||||
asyncio.run(
|
||||
self._rag.index_snapshot(
|
||||
rag_session_id=rag_session_id,
|
||||
files=files,
|
||||
progress_cb=self._print_progress,
|
||||
)
|
||||
)
|
||||
finally:
|
||||
logger.setLevel(previous_level)
|
||||
print(f"rag_session_id={rag_session_id}")
|
||||
return rag_session_id
|
||||
|
||||
def _print_progress(self, current_file_index: int, total_files: int, current_file_name: str) -> None:
|
||||
print(f"[{current_file_index}/{total_files}] {current_file_name}")
|
||||
Reference in New Issue
Block a user