from __future__ import annotations import asyncio import hashlib import logging from pathlib import Path from typing import TYPE_CHECKING from uuid import uuid4 if TYPE_CHECKING: from app.modules.rag.persistence.repository import RagRepository class LocalRepoFileCollector: _SKIP_DIRS = {".git", ".venv", "venv", "__pycache__", ".pytest_cache", "node_modules"} _TEXT_EXTENSIONS = { ".py", ".md", ".txt", ".rst", ".json", ".yaml", ".yml", ".toml", ".ini", ".cfg", ".env", ".js", ".ts", ".tsx", ".jsx", ".sql", ".sh", } def __init__(self, root: Path, max_bytes: int = 300_000) -> None: self._root = root self._max_bytes = max_bytes def collect(self) -> list[dict]: files: list[dict] = [] for path in sorted(self._root.rglob("*")): if not path.is_file() or self._should_skip(path): continue item = self._read_file(path) if item: files.append(item) return files def _should_skip(self, path: Path) -> bool: rel_parts = path.relative_to(self._root).parts if any(part in self._SKIP_DIRS for part in rel_parts): return True if any(part.startswith(".") for part in rel_parts): return True if path.suffix.lower() not in self._TEXT_EXTENSIONS: return True return path.stat().st_size > self._max_bytes def _read_file(self, path: Path) -> dict | None: raw = path.read_bytes() if b"\x00" in raw: return None content = raw.decode("utf-8", errors="ignore") return { "path": path.relative_to(self._root).as_posix(), "content": content, "content_hash": hashlib.sha256(content.encode("utf-8")).hexdigest(), } class DeterministicEmbedder: def __init__(self, dim: int = 64) -> None: self._dim = dim def embed(self, texts: list[str]) -> list[list[float]]: return [self._embed_one(text) for text in texts] def _embed_one(self, text: str) -> list[float]: digest = hashlib.sha256(text.encode("utf-8")).digest() values: list[float] = [] while len(values) < self._dim: for byte in digest: values.append((byte / 127.5) - 1.0) if len(values) == self._dim: break digest = hashlib.sha256(digest).digest() return values class RagSessionIndexer: def __init__(self, repository: "RagRepository") -> None: from app.modules.rag.services.rag_service import RagService self._repository = repository self._rag = RagService(embedder=DeterministicEmbedder(), repository=repository) def index_repo(self, repo_path: Path, project_id: str | None = None) -> str: self._repository.ensure_tables() rag_session_id = str(uuid4()) resolved_project_id = project_id or repo_path.name self._repository.upsert_session(rag_session_id, resolved_project_id) files = LocalRepoFileCollector(repo_path).collect() if not files: raise ValueError(f"No indexable text files found under: {repo_path}") logger = logging.getLogger("app.modules.rag.services.rag_service") previous_level = logger.level logger.setLevel(logging.ERROR) try: asyncio.run( self._rag.index_snapshot( rag_session_id=rag_session_id, files=files, progress_cb=self._print_progress, ) ) finally: logger.setLevel(previous_level) print(f"rag_session_id={rag_session_id}") return rag_session_id def _print_progress(self, current_file_index: int, total_files: int, current_file_name: str) -> None: print(f"[{current_file_index}/{total_files}] {current_file_name}")