Files
agent/tests/pipeline_setup/utils/rag_indexer/indexer.py
2026-03-12 16:55:23 +03:00

109 lines
3.8 KiB
Python

from __future__ import annotations
import asyncio
import hashlib
import logging
from pathlib import Path
from typing import TYPE_CHECKING
from uuid import uuid4
if TYPE_CHECKING:
from app.modules.rag.persistence.repository import RagRepository
class LocalRepoFileCollector:
_SKIP_DIRS = {".git", ".venv", "venv", "__pycache__", ".pytest_cache", "node_modules"}
_TEXT_EXTENSIONS = {
".py", ".md", ".txt", ".rst", ".json", ".yaml", ".yml", ".toml", ".ini", ".cfg", ".env",
".js", ".ts", ".tsx", ".jsx", ".sql", ".sh",
}
def __init__(self, root: Path, max_bytes: int = 300_000) -> None:
self._root = root
self._max_bytes = max_bytes
def collect(self) -> list[dict]:
files: list[dict] = []
for path in sorted(self._root.rglob("*")):
if not path.is_file() or self._should_skip(path):
continue
item = self._read_file(path)
if item:
files.append(item)
return files
def _should_skip(self, path: Path) -> bool:
rel_parts = path.relative_to(self._root).parts
if any(part in self._SKIP_DIRS for part in rel_parts):
return True
if any(part.startswith(".") for part in rel_parts):
return True
if path.suffix.lower() not in self._TEXT_EXTENSIONS:
return True
return path.stat().st_size > self._max_bytes
def _read_file(self, path: Path) -> dict | None:
raw = path.read_bytes()
if b"\x00" in raw:
return None
content = raw.decode("utf-8", errors="ignore")
return {
"path": path.relative_to(self._root).as_posix(),
"content": content,
"content_hash": hashlib.sha256(content.encode("utf-8")).hexdigest(),
}
class DeterministicEmbedder:
def __init__(self, dim: int = 64) -> None:
self._dim = dim
def embed(self, texts: list[str]) -> list[list[float]]:
return [self._embed_one(text) for text in texts]
def _embed_one(self, text: str) -> list[float]:
digest = hashlib.sha256(text.encode("utf-8")).digest()
values: list[float] = []
while len(values) < self._dim:
for byte in digest:
values.append((byte / 127.5) - 1.0)
if len(values) == self._dim:
break
digest = hashlib.sha256(digest).digest()
return values
class RagSessionIndexer:
def __init__(self, repository: "RagRepository") -> None:
from app.modules.rag.services.rag_service import RagService
self._repository = repository
self._rag = RagService(embedder=DeterministicEmbedder(), repository=repository)
def index_repo(self, repo_path: Path, project_id: str | None = None) -> str:
self._repository.ensure_tables()
rag_session_id = str(uuid4())
resolved_project_id = project_id or repo_path.name
self._repository.upsert_session(rag_session_id, resolved_project_id)
files = LocalRepoFileCollector(repo_path).collect()
if not files:
raise ValueError(f"No indexable text files found under: {repo_path}")
logger = logging.getLogger("app.modules.rag.services.rag_service")
previous_level = logger.level
logger.setLevel(logging.ERROR)
try:
asyncio.run(
self._rag.index_snapshot(
rag_session_id=rag_session_id,
files=files,
progress_cb=self._print_progress,
)
)
finally:
logger.setLevel(previous_level)
print(f"rag_session_id={rag_session_id}")
return rag_session_id
def _print_progress(self, current_file_index: int, total_files: int, current_file_name: str) -> None:
print(f"[{current_file_index}/{total_files}] {current_file_name}")