109 lines
3.8 KiB
Python
109 lines
3.8 KiB
Python
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import hashlib
|
|
import logging
|
|
from pathlib import Path
|
|
from typing import TYPE_CHECKING
|
|
from uuid import uuid4
|
|
|
|
if TYPE_CHECKING:
|
|
from app.modules.rag.persistence.repository import RagRepository
|
|
|
|
|
|
class LocalRepoFileCollector:
|
|
_SKIP_DIRS = {".git", ".venv", "venv", "__pycache__", ".pytest_cache", "node_modules"}
|
|
_TEXT_EXTENSIONS = {
|
|
".py", ".md", ".txt", ".rst", ".json", ".yaml", ".yml", ".toml", ".ini", ".cfg", ".env",
|
|
".js", ".ts", ".tsx", ".jsx", ".sql", ".sh",
|
|
}
|
|
|
|
def __init__(self, root: Path, max_bytes: int = 300_000) -> None:
|
|
self._root = root
|
|
self._max_bytes = max_bytes
|
|
|
|
def collect(self) -> list[dict]:
|
|
files: list[dict] = []
|
|
for path in sorted(self._root.rglob("*")):
|
|
if not path.is_file() or self._should_skip(path):
|
|
continue
|
|
item = self._read_file(path)
|
|
if item:
|
|
files.append(item)
|
|
return files
|
|
|
|
def _should_skip(self, path: Path) -> bool:
|
|
rel_parts = path.relative_to(self._root).parts
|
|
if any(part in self._SKIP_DIRS for part in rel_parts):
|
|
return True
|
|
if any(part.startswith(".") for part in rel_parts):
|
|
return True
|
|
if path.suffix.lower() not in self._TEXT_EXTENSIONS:
|
|
return True
|
|
return path.stat().st_size > self._max_bytes
|
|
|
|
def _read_file(self, path: Path) -> dict | None:
|
|
raw = path.read_bytes()
|
|
if b"\x00" in raw:
|
|
return None
|
|
content = raw.decode("utf-8", errors="ignore")
|
|
return {
|
|
"path": path.relative_to(self._root).as_posix(),
|
|
"content": content,
|
|
"content_hash": hashlib.sha256(content.encode("utf-8")).hexdigest(),
|
|
}
|
|
|
|
|
|
class DeterministicEmbedder:
|
|
def __init__(self, dim: int = 64) -> None:
|
|
self._dim = dim
|
|
|
|
def embed(self, texts: list[str]) -> list[list[float]]:
|
|
return [self._embed_one(text) for text in texts]
|
|
|
|
def _embed_one(self, text: str) -> list[float]:
|
|
digest = hashlib.sha256(text.encode("utf-8")).digest()
|
|
values: list[float] = []
|
|
while len(values) < self._dim:
|
|
for byte in digest:
|
|
values.append((byte / 127.5) - 1.0)
|
|
if len(values) == self._dim:
|
|
break
|
|
digest = hashlib.sha256(digest).digest()
|
|
return values
|
|
|
|
|
|
class RagSessionIndexer:
|
|
def __init__(self, repository: "RagRepository") -> None:
|
|
from app.modules.rag.services.rag_service import RagService
|
|
|
|
self._repository = repository
|
|
self._rag = RagService(embedder=DeterministicEmbedder(), repository=repository)
|
|
|
|
def index_repo(self, repo_path: Path, project_id: str | None = None) -> str:
|
|
self._repository.ensure_tables()
|
|
rag_session_id = str(uuid4())
|
|
resolved_project_id = project_id or repo_path.name
|
|
self._repository.upsert_session(rag_session_id, resolved_project_id)
|
|
files = LocalRepoFileCollector(repo_path).collect()
|
|
if not files:
|
|
raise ValueError(f"No indexable text files found under: {repo_path}")
|
|
logger = logging.getLogger("app.modules.rag.services.rag_service")
|
|
previous_level = logger.level
|
|
logger.setLevel(logging.ERROR)
|
|
try:
|
|
asyncio.run(
|
|
self._rag.index_snapshot(
|
|
rag_session_id=rag_session_id,
|
|
files=files,
|
|
progress_cb=self._print_progress,
|
|
)
|
|
)
|
|
finally:
|
|
logger.setLevel(previous_level)
|
|
print(f"rag_session_id={rag_session_id}")
|
|
return rag_session_id
|
|
|
|
def _print_progress(self, current_file_index: int, total_files: int, current_file_name: str) -> None:
|
|
print(f"[{current_file_index}/{total_files}] {current_file_name}")
|