ййй
This commit is contained in:
@@ -0,0 +1 @@
|
||||
"""Shared helpers for pipeline_setup_v3."""
|
||||
@@ -0,0 +1,40 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from app.modules.shared.env_loader import load_workspace_env
|
||||
|
||||
|
||||
def load_pipeline_setup_env(start_dir: str | Path | None = None) -> list[Path]:
|
||||
base = Path(start_dir or Path.cwd()).resolve()
|
||||
loaded = load_workspace_env(start_dir=base)
|
||||
env_path = _find_v3_root(base) / ".env"
|
||||
if env_path.is_file():
|
||||
_apply_env_file(env_path)
|
||||
loaded.append(env_path)
|
||||
return loaded
|
||||
|
||||
|
||||
def _find_v3_root(base: Path) -> Path:
|
||||
for directory in (base, *base.parents):
|
||||
if directory.name == "pipeline_setup_v3" and (directory / "__init__.py").is_file():
|
||||
return directory
|
||||
raise RuntimeError(f"Unable to locate tests/pipeline_setup_v3 root from: {base}")
|
||||
|
||||
|
||||
def _apply_env_file(path: Path) -> None:
|
||||
for raw_line in path.read_text(encoding="utf-8").splitlines():
|
||||
line = raw_line.strip()
|
||||
if not line or line.startswith("#") or "=" not in line:
|
||||
continue
|
||||
key, raw_value = line.split("=", 1)
|
||||
name = key.removeprefix("export ").strip()
|
||||
if name:
|
||||
os.environ[name] = _normalize_value(raw_value.strip())
|
||||
|
||||
|
||||
def _normalize_value(value: str) -> str:
|
||||
if len(value) >= 2 and value[0] == value[-1] and value[0] in {"'", '"'}:
|
||||
return value[1:-1]
|
||||
return value
|
||||
@@ -0,0 +1,110 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import uuid4
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.modules.rag.persistence.repository import RagRepository
|
||||
|
||||
|
||||
class LocalRepoFileCollector:
|
||||
_SKIP_DIRS = {".git", ".venv", "venv", "__pycache__", ".pytest_cache", "node_modules"}
|
||||
_TEXT_EXTENSIONS = {
|
||||
".py", ".md", ".txt", ".rst", ".json", ".yaml", ".yml", ".toml", ".ini", ".cfg", ".env",
|
||||
".js", ".ts", ".tsx", ".jsx", ".sql", ".sh",
|
||||
}
|
||||
|
||||
def __init__(self, root: Path, max_bytes: int = 300_000) -> None:
|
||||
self._root = root
|
||||
self._max_bytes = max_bytes
|
||||
|
||||
def collect(self) -> list[dict]:
|
||||
files: list[dict] = []
|
||||
for path in sorted(self._root.rglob("*")):
|
||||
if not path.is_file() or self._should_skip(path):
|
||||
continue
|
||||
item = self._read_file(path)
|
||||
if item:
|
||||
files.append(item)
|
||||
return files
|
||||
|
||||
def _should_skip(self, path: Path) -> bool:
|
||||
rel_parts = path.relative_to(self._root).parts
|
||||
if not rel_parts or rel_parts[0] != "docs":
|
||||
return True
|
||||
if any(part in self._SKIP_DIRS for part in rel_parts):
|
||||
return True
|
||||
if any(part.startswith(".") for part in rel_parts):
|
||||
return True
|
||||
if path.suffix.lower() not in self._TEXT_EXTENSIONS:
|
||||
return True
|
||||
return path.stat().st_size > self._max_bytes
|
||||
|
||||
def _read_file(self, path: Path) -> dict | None:
|
||||
raw = path.read_bytes()
|
||||
if b"\x00" in raw:
|
||||
return None
|
||||
content = raw.decode("utf-8", errors="ignore")
|
||||
return {
|
||||
"path": path.relative_to(self._root).as_posix(),
|
||||
"content": content,
|
||||
"content_hash": hashlib.sha256(content.encode("utf-8")).hexdigest(),
|
||||
}
|
||||
|
||||
|
||||
class DeterministicEmbedder:
|
||||
def __init__(self, dim: int = 64) -> None:
|
||||
self._dim = dim
|
||||
|
||||
def embed(self, texts: list[str]) -> list[list[float]]:
|
||||
return [self._embed_one(text) for text in texts]
|
||||
|
||||
def _embed_one(self, text: str) -> list[float]:
|
||||
digest = hashlib.sha256(text.encode("utf-8")).digest()
|
||||
values: list[float] = []
|
||||
while len(values) < self._dim:
|
||||
for byte in digest:
|
||||
values.append((byte / 127.5) - 1.0)
|
||||
if len(values) == self._dim:
|
||||
break
|
||||
digest = hashlib.sha256(digest).digest()
|
||||
return values
|
||||
|
||||
|
||||
class RagSessionIndexer:
|
||||
def __init__(self, repository: "RagRepository") -> None:
|
||||
from app.modules.rag.services.rag_service import RagService
|
||||
|
||||
self._repository = repository
|
||||
self._rag = RagService(embedder=DeterministicEmbedder(), repository=repository)
|
||||
|
||||
def index_repo(self, repo_path: Path, project_id: str | None = None) -> str:
|
||||
self._repository.ensure_tables()
|
||||
rag_session_id = str(uuid4())
|
||||
resolved_project_id = project_id or repo_path.name
|
||||
self._repository.upsert_session(rag_session_id, resolved_project_id)
|
||||
files = LocalRepoFileCollector(repo_path).collect()
|
||||
if not files:
|
||||
raise ValueError(f"No indexable text files found under: {repo_path}")
|
||||
logger = logging.getLogger("app.modules.rag.services.rag_service")
|
||||
previous_level = logger.level
|
||||
logger.setLevel(logging.ERROR)
|
||||
try:
|
||||
asyncio.run(
|
||||
self._rag.index_snapshot(
|
||||
rag_session_id=rag_session_id,
|
||||
files=files,
|
||||
progress_cb=self._print_progress,
|
||||
)
|
||||
)
|
||||
finally:
|
||||
logger.setLevel(previous_level)
|
||||
print(f"rag_session_id={rag_session_id}")
|
||||
return rag_session_id
|
||||
|
||||
def _print_progress(self, current_file_index: int, total_files: int, current_file_name: str) -> None:
|
||||
print(f"[{current_file_index}/{total_files}] {current_file_name}")
|
||||
Reference in New Issue
Block a user