Скелет проекта
This commit is contained in:
43
README.md
Normal file
43
README.md
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
# RAG Agent (Postgres)
|
||||||
|
|
||||||
|
Custom RAG agent that indexes text files from a git repository into Postgres
|
||||||
|
and answers queries using retrieval + LLM generation. Commits are tied to
|
||||||
|
**stories**; indexing and retrieval can be scoped by story.
|
||||||
|
|
||||||
|
## Quick start
|
||||||
|
|
||||||
|
1. Configure environment variables:
|
||||||
|
- `RAG_REPO_PATH` — path to git repo with text files
|
||||||
|
- `RAG_DB_DSN` — Postgres DSN (e.g. `postgresql://user:pass@localhost:5432/rag`)
|
||||||
|
- `RAG_EMBEDDINGS_DIM` — embedding vector dimension (e.g. `1536`)
|
||||||
|
2. Create DB schema:
|
||||||
|
- `python scripts/create_db.py` (or `psql "$RAG_DB_DSN" -f scripts/schema.sql`)
|
||||||
|
3. Index files for a story (e.g. branch name as story slug):
|
||||||
|
- `rag-agent index --story my-branch --changed --base-ref HEAD~1 --head-ref HEAD`
|
||||||
|
4. Ask a question (optionally scoped to a story):
|
||||||
|
- `rag-agent ask "What is covered?"`
|
||||||
|
- `rag-agent ask "What is covered?" --story my-branch`
|
||||||
|
|
||||||
|
## Git hook (index on commit)
|
||||||
|
|
||||||
|
Install the post-commit hook so changed files are indexed after each commit:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cp scripts/post-commit .git/hooks/post-commit && chmod +x .git/hooks/post-commit
|
||||||
|
```
|
||||||
|
|
||||||
|
Story for the commit is taken from (in order): env `RAG_STORY`, file `.rag-story` in repo root (one line = slug), or current branch name.
|
||||||
|
|
||||||
|
## DB structure
|
||||||
|
|
||||||
|
- **stories** — story slug (e.g. branch name); documents and chunks are tied to a story.
|
||||||
|
- **documents** — path + version per story; unique `(story_id, path)`.
|
||||||
|
- **chunks** — text chunks with embeddings (pgvector); updated when documents are re-indexed.
|
||||||
|
|
||||||
|
Scripts: `scripts/create_db.py` (Python, uses `ensure_schema` and `RAG_*` env), `scripts/schema.sql` (raw SQL).
|
||||||
|
|
||||||
|
## Notes
|
||||||
|
|
||||||
|
- The default embedding/LLM clients are stubs. Replace them in
|
||||||
|
`src/rag_agent/index/embeddings.py` and `src/rag_agent/agent/pipeline.py`.
|
||||||
|
- This project requires Postgres with the `pgvector` extension.
|
||||||
17
pyproject.toml
Normal file
17
pyproject.toml
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
[project]
|
||||||
|
name = "rag-agent"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = "Custom RAG agent with Postgres-backed vector index"
|
||||||
|
readme = "README.md"
|
||||||
|
requires-python = ">=3.10"
|
||||||
|
dependencies = [
|
||||||
|
"psycopg[binary]>=3.1.18",
|
||||||
|
"pgvector>=0.2.5",
|
||||||
|
"pydantic>=2.7.0",
|
||||||
|
]
|
||||||
|
|
||||||
|
[project.scripts]
|
||||||
|
rag-agent = "rag_agent.cli:main"
|
||||||
|
|
||||||
|
[tool.ruff]
|
||||||
|
line-length = 100
|
||||||
29
scripts/create_db.py
Normal file
29
scripts/create_db.py
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Create RAG vector DB schema in Postgres (extension + stories, documents, chunks).
|
||||||
|
Requires RAG_DB_DSN and RAG_EMBEDDINGS_DIM (optional, default 1536).
|
||||||
|
Run from repo root with package installed: pip install -e . && python scripts/create_db.py
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Allow importing rag_agent when run as scripts/create_db.py
|
||||||
|
repo_root = Path(__file__).resolve().parent.parent
|
||||||
|
sys.path.insert(0, str(repo_root / "src"))
|
||||||
|
|
||||||
|
from rag_agent.config import load_config
|
||||||
|
from rag_agent.index.postgres import connect, ensure_schema
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
config = load_config()
|
||||||
|
conn = connect(config.db_dsn)
|
||||||
|
ensure_schema(conn, config.embeddings_dim)
|
||||||
|
conn.close()
|
||||||
|
print("Schema created successfully.")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
35
scripts/post-commit
Normal file
35
scripts/post-commit
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
#!/usr/bin/env sh
|
||||||
|
# Git post-commit hook: index changed files into RAG vector DB for the current story.
|
||||||
|
# Install: cp scripts/post-commit .git/hooks/post-commit && chmod +x .git/hooks/post-commit
|
||||||
|
# Requires: RAG_REPO_PATH, RAG_DB_DSN, RAG_EMBEDDINGS_DIM; story from RAG_STORY or .rag-story or current branch name.
|
||||||
|
|
||||||
|
set -e
|
||||||
|
cd "$(git rev-parse --show-toplevel)"
|
||||||
|
|
||||||
|
# Resolve story: env RAG_STORY > file .rag-story > current branch name
|
||||||
|
if [ -n "${RAG_STORY}" ]; then
|
||||||
|
STORY="${RAG_STORY}"
|
||||||
|
elif [ -f .rag-story ]; then
|
||||||
|
STORY=$(head -n1 .rag-story | tr -d '\n\r')
|
||||||
|
else
|
||||||
|
STORY=$(git branch --show-current)
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ -z "$STORY" ]; then
|
||||||
|
echo "post-commit: RAG_STORY or .rag-story or branch name required for indexing."
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Run index (changed files only: previous commit -> HEAD)
|
||||||
|
if command -v rag-agent >/dev/null 2>&1; then
|
||||||
|
rag-agent index --changed --base-ref HEAD~1 --head-ref HEAD --story "$STORY"
|
||||||
|
elif [ -n "${VIRTUAL_ENV}" ]; then
|
||||||
|
rag-agent index --changed --base-ref HEAD~1 --head-ref HEAD --story "$STORY"
|
||||||
|
else
|
||||||
|
# Try repo venv or python -m
|
||||||
|
if [ -f "venv/bin/rag-agent" ]; then
|
||||||
|
venv/bin/rag-agent index --changed --base-ref HEAD~1 --head-ref HEAD --story "$STORY"
|
||||||
|
else
|
||||||
|
PYTHONPATH=src python -m rag_agent.cli index --changed --base-ref HEAD~1 --head-ref HEAD --story "$STORY" 2>/dev/null || true
|
||||||
|
fi
|
||||||
|
fi
|
||||||
33
scripts/schema.sql
Normal file
33
scripts/schema.sql
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
-- RAG vector DB schema for Postgres (pgvector).
|
||||||
|
-- Run once against an empty DB. If RAG_EMBEDDINGS_DIM is not 1536, change vector(1536) below.
|
||||||
|
-- Usage: psql "$RAG_DB_DSN" -f scripts/schema.sql
|
||||||
|
|
||||||
|
CREATE EXTENSION IF NOT EXISTS vector;
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS stories (
|
||||||
|
id SERIAL PRIMARY KEY,
|
||||||
|
slug TEXT UNIQUE NOT NULL,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT (NOW() AT TIME ZONE 'utc')
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS documents (
|
||||||
|
id SERIAL PRIMARY KEY,
|
||||||
|
story_id INTEGER NOT NULL REFERENCES stories(id) ON DELETE CASCADE,
|
||||||
|
path TEXT NOT NULL,
|
||||||
|
version TEXT NOT NULL,
|
||||||
|
updated_at TIMESTAMPTZ NOT NULL,
|
||||||
|
UNIQUE(story_id, path)
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS chunks (
|
||||||
|
id SERIAL PRIMARY KEY,
|
||||||
|
document_id INTEGER NOT NULL REFERENCES documents(id) ON DELETE CASCADE,
|
||||||
|
chunk_index INTEGER NOT NULL,
|
||||||
|
hash TEXT NOT NULL,
|
||||||
|
content TEXT NOT NULL,
|
||||||
|
embedding vector(1536) NOT NULL
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_documents_story_id ON documents(story_id);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_chunks_document_id ON chunks(document_id);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_chunks_embedding ON chunks USING ivfflat (embedding vector_cosine_ops);
|
||||||
3
src/rag_agent/__init__.py
Normal file
3
src/rag_agent/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
__all__ = [
|
||||||
|
"config",
|
||||||
|
]
|
||||||
BIN
src/rag_agent/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
src/rag_agent/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
src/rag_agent/__pycache__/cli.cpython-312.pyc
Normal file
BIN
src/rag_agent/__pycache__/cli.cpython-312.pyc
Normal file
Binary file not shown.
BIN
src/rag_agent/__pycache__/config.cpython-312.pyc
Normal file
BIN
src/rag_agent/__pycache__/config.cpython-312.pyc
Normal file
Binary file not shown.
1
src/rag_agent/agent/__init__.py
Normal file
1
src/rag_agent/agent/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
__all__ = []
|
||||||
51
src/rag_agent/agent/pipeline.py
Normal file
51
src/rag_agent/agent/pipeline.py
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Protocol
|
||||||
|
|
||||||
|
import psycopg
|
||||||
|
|
||||||
|
from rag_agent.config import AppConfig
|
||||||
|
from rag_agent.index.embeddings import EmbeddingClient
|
||||||
|
from rag_agent.retrieval.search import search_similar
|
||||||
|
|
||||||
|
|
||||||
|
class LLMClient(Protocol):
|
||||||
|
def generate(self, prompt: str, model: str) -> str:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class StubLLMClient:
|
||||||
|
def generate(self, prompt: str, model: str) -> str:
|
||||||
|
return (
|
||||||
|
"LLM client is not configured. "
|
||||||
|
"Replace StubLLMClient with a real implementation."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def build_prompt(question: str, contexts: list[str]) -> str:
|
||||||
|
joined = "\n\n".join(contexts)
|
||||||
|
return (
|
||||||
|
"You are a RAG assistant. Use the context below to answer the question.\n\n"
|
||||||
|
f"Context:\n{joined}\n\n"
|
||||||
|
f"Question: {question}\nAnswer:"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def answer_query(
|
||||||
|
conn: psycopg.Connection,
|
||||||
|
config: AppConfig,
|
||||||
|
embedding_client: EmbeddingClient,
|
||||||
|
llm_client: LLMClient,
|
||||||
|
question: str,
|
||||||
|
top_k: int = 5,
|
||||||
|
story_id: int | None = None,
|
||||||
|
) -> str:
|
||||||
|
query_embedding = embedding_client.embed_texts([question])[0]
|
||||||
|
results = search_similar(
|
||||||
|
conn, query_embedding, top_k=top_k, story_id=story_id
|
||||||
|
)
|
||||||
|
contexts = [f"Source: {item.path}\n{item.content}" for item in results]
|
||||||
|
prompt = build_prompt(question, contexts)
|
||||||
|
return llm_client.generate(prompt, model=config.llm_model)
|
||||||
117
src/rag_agent/cli.py
Normal file
117
src/rag_agent/cli.py
Normal file
@@ -0,0 +1,117 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import hashlib
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from rag_agent.config import load_config
|
||||||
|
from rag_agent.ingest.chunker import chunk_text
|
||||||
|
from rag_agent.ingest.file_loader import iter_text_files
|
||||||
|
from rag_agent.ingest.git_watcher import filter_existing, filter_removed, get_changed_files
|
||||||
|
from rag_agent.index.embeddings import get_embedding_client
|
||||||
|
from rag_agent.index.postgres import (
|
||||||
|
connect,
|
||||||
|
delete_document,
|
||||||
|
ensure_schema,
|
||||||
|
get_or_create_story,
|
||||||
|
get_story_id,
|
||||||
|
replace_chunks,
|
||||||
|
upsert_document,
|
||||||
|
)
|
||||||
|
from rag_agent.agent.pipeline import StubLLMClient, answer_query
|
||||||
|
|
||||||
|
|
||||||
|
def _file_version(path: Path) -> str:
|
||||||
|
stat = path.stat()
|
||||||
|
payload = f"{path.as_posix()}:{stat.st_mtime_ns}:{stat.st_size}"
|
||||||
|
return hashlib.sha256(payload.encode("utf-8")).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
def cmd_index(args: argparse.Namespace) -> None:
|
||||||
|
config = load_config()
|
||||||
|
conn = connect(config.db_dsn)
|
||||||
|
ensure_schema(conn, config.embeddings_dim)
|
||||||
|
story_id = get_or_create_story(conn, args.story)
|
||||||
|
embedding_client = get_embedding_client(config.embeddings_dim)
|
||||||
|
|
||||||
|
if args.changed:
|
||||||
|
changed_files = get_changed_files(config.repo_path, args.base_ref, args.head_ref)
|
||||||
|
removed = filter_removed(changed_files)
|
||||||
|
existing = filter_existing(changed_files)
|
||||||
|
else:
|
||||||
|
removed = []
|
||||||
|
existing = [path for path in Path(config.repo_path).rglob("*") if path.is_file()]
|
||||||
|
|
||||||
|
for path in removed:
|
||||||
|
delete_document(conn, story_id, str(path))
|
||||||
|
|
||||||
|
for path, text in iter_text_files(existing, config.allowed_extensions):
|
||||||
|
chunks = chunk_text(text, config.chunk_size, config.chunk_overlap)
|
||||||
|
if not chunks:
|
||||||
|
continue
|
||||||
|
embeddings = embedding_client.embed_texts([chunk.text for chunk in chunks])
|
||||||
|
document_id = upsert_document(
|
||||||
|
conn, story_id, str(path), _file_version(path)
|
||||||
|
)
|
||||||
|
replace_chunks(conn, document_id, chunks, embeddings)
|
||||||
|
|
||||||
|
|
||||||
|
def cmd_ask(args: argparse.Namespace) -> None:
|
||||||
|
config = load_config()
|
||||||
|
conn = connect(config.db_dsn)
|
||||||
|
ensure_schema(conn, config.embeddings_dim)
|
||||||
|
story_id = None
|
||||||
|
if args.story:
|
||||||
|
story_id = get_story_id(conn, args.story)
|
||||||
|
if story_id is None:
|
||||||
|
raise SystemExit(f"Story not found: {args.story}")
|
||||||
|
embedding_client = get_embedding_client(config.embeddings_dim)
|
||||||
|
llm_client = StubLLMClient()
|
||||||
|
answer = answer_query(
|
||||||
|
conn,
|
||||||
|
config,
|
||||||
|
embedding_client,
|
||||||
|
llm_client,
|
||||||
|
args.question,
|
||||||
|
top_k=args.top_k,
|
||||||
|
story_id=story_id,
|
||||||
|
)
|
||||||
|
print(answer)
|
||||||
|
|
||||||
|
|
||||||
|
def build_parser() -> argparse.ArgumentParser:
|
||||||
|
parser = argparse.ArgumentParser(prog="rag-agent")
|
||||||
|
sub = parser.add_subparsers(dest="command", required=True)
|
||||||
|
|
||||||
|
index_parser = sub.add_parser("index", help="Index files into Postgres")
|
||||||
|
index_parser.add_argument(
|
||||||
|
"--story",
|
||||||
|
required=True,
|
||||||
|
help="Story slug (e.g. branch name or story id); documents are tied to this story",
|
||||||
|
)
|
||||||
|
index_parser.add_argument("--changed", action="store_true", help="Index only changed files")
|
||||||
|
index_parser.add_argument("--base-ref", default="HEAD~1", help="Base git ref")
|
||||||
|
index_parser.add_argument("--head-ref", default="HEAD", help="Head git ref")
|
||||||
|
index_parser.set_defaults(func=cmd_index)
|
||||||
|
|
||||||
|
ask_parser = sub.add_parser("ask", help="Ask a question")
|
||||||
|
ask_parser.add_argument("question", help="Question text")
|
||||||
|
ask_parser.add_argument(
|
||||||
|
"--story",
|
||||||
|
default=None,
|
||||||
|
help="Limit retrieval to this story slug (optional)",
|
||||||
|
)
|
||||||
|
ask_parser.add_argument("--top-k", type=int, default=5, help="Top K chunks to retrieve")
|
||||||
|
ask_parser.set_defaults(func=cmd_ask)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
parser = build_parser()
|
||||||
|
args = parser.parse_args()
|
||||||
|
args.func(args)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
57
src/rag_agent/config.py
Normal file
57
src/rag_agent/config.py
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Iterable, Sequence
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class AppConfig:
|
||||||
|
repo_path: str
|
||||||
|
db_dsn: str
|
||||||
|
chunk_size: int
|
||||||
|
chunk_overlap: int
|
||||||
|
embeddings_dim: int
|
||||||
|
embeddings_model: str
|
||||||
|
llm_model: str
|
||||||
|
allowed_extensions: Sequence[str]
|
||||||
|
|
||||||
|
|
||||||
|
def _env_int(name: str, default: int) -> int:
|
||||||
|
value = os.getenv(name, "").strip()
|
||||||
|
if not value:
|
||||||
|
return default
|
||||||
|
try:
|
||||||
|
return int(value)
|
||||||
|
except ValueError as exc:
|
||||||
|
raise ValueError(f"Invalid integer for {name}: {value}") from exc
|
||||||
|
|
||||||
|
|
||||||
|
def _env_list(name: str, default: Iterable[str]) -> list[str]:
|
||||||
|
value = os.getenv(name, "").strip()
|
||||||
|
if not value:
|
||||||
|
return list(default)
|
||||||
|
return [item.strip() for item in value.split(",") if item.strip()]
|
||||||
|
|
||||||
|
|
||||||
|
def load_config() -> AppConfig:
|
||||||
|
repo_path = os.getenv("RAG_REPO_PATH", "").strip()
|
||||||
|
if not repo_path:
|
||||||
|
raise ValueError("RAG_REPO_PATH is required")
|
||||||
|
|
||||||
|
db_dsn = os.getenv("RAG_DB_DSN", "").strip()
|
||||||
|
if not db_dsn:
|
||||||
|
raise ValueError("RAG_DB_DSN is required")
|
||||||
|
|
||||||
|
return AppConfig(
|
||||||
|
repo_path=repo_path,
|
||||||
|
db_dsn=db_dsn,
|
||||||
|
chunk_size=_env_int("RAG_CHUNK_SIZE", 400),
|
||||||
|
chunk_overlap=_env_int("RAG_CHUNK_OVERLAP", 50),
|
||||||
|
embeddings_dim=_env_int("RAG_EMBEDDINGS_DIM", 1536),
|
||||||
|
embeddings_model=os.getenv("RAG_EMBEDDINGS_MODEL", "stub-embeddings"),
|
||||||
|
llm_model=os.getenv("RAG_LLM_MODEL", "stub-llm"),
|
||||||
|
allowed_extensions=tuple(
|
||||||
|
_env_list("RAG_ALLOWED_EXTENSIONS", [".md", ".txt", ".rst"])
|
||||||
|
),
|
||||||
|
)
|
||||||
1
src/rag_agent/index/__init__.py
Normal file
1
src/rag_agent/index/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
__all__ = []
|
||||||
BIN
src/rag_agent/index/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
src/rag_agent/index/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
src/rag_agent/index/__pycache__/embeddings.cpython-312.pyc
Normal file
BIN
src/rag_agent/index/__pycache__/embeddings.cpython-312.pyc
Normal file
Binary file not shown.
BIN
src/rag_agent/index/__pycache__/postgres.cpython-312.pyc
Normal file
BIN
src/rag_agent/index/__pycache__/postgres.cpython-312.pyc
Normal file
Binary file not shown.
29
src/rag_agent/index/embeddings.py
Normal file
29
src/rag_agent/index/embeddings.py
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Iterable, Protocol
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingClient(Protocol):
|
||||||
|
def embed_texts(self, texts: Iterable[str]) -> list[list[float]]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class StubEmbeddingClient:
|
||||||
|
dim: int
|
||||||
|
|
||||||
|
def embed_texts(self, texts: Iterable[str]) -> list[list[float]]:
|
||||||
|
vectors: list[list[float]] = []
|
||||||
|
for text in texts:
|
||||||
|
digest = hashlib.sha256(text.encode("utf-8")).digest()
|
||||||
|
values = [b / 255.0 for b in digest]
|
||||||
|
if len(values) < self.dim:
|
||||||
|
values = (values * ((self.dim // len(values)) + 1))[: self.dim]
|
||||||
|
vectors.append(values[: self.dim])
|
||||||
|
return vectors
|
||||||
|
|
||||||
|
|
||||||
|
def get_embedding_client(dim: int) -> EmbeddingClient:
|
||||||
|
return StubEmbeddingClient(dim=dim)
|
||||||
194
src/rag_agent/index/postgres.py
Normal file
194
src/rag_agent/index/postgres.py
Normal file
@@ -0,0 +1,194 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Iterable
|
||||||
|
|
||||||
|
import psycopg
|
||||||
|
from pgvector.psycopg import register_vector
|
||||||
|
|
||||||
|
from rag_agent.ingest.chunker import TextChunk
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class ChunkRecord:
|
||||||
|
document_path: str
|
||||||
|
document_version: str
|
||||||
|
chunk: TextChunk
|
||||||
|
embedding: list[float]
|
||||||
|
|
||||||
|
|
||||||
|
def connect(dsn: str) -> psycopg.Connection:
|
||||||
|
conn = psycopg.connect(dsn)
|
||||||
|
register_vector(conn)
|
||||||
|
return conn
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_schema(conn: psycopg.Connection, embeddings_dim: int) -> None:
|
||||||
|
with conn.cursor() as cur:
|
||||||
|
cur.execute("CREATE EXTENSION IF NOT EXISTS vector;")
|
||||||
|
cur.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE IF NOT EXISTS stories (
|
||||||
|
id SERIAL PRIMARY KEY,
|
||||||
|
slug TEXT UNIQUE NOT NULL,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT (NOW() AT TIME ZONE 'utc')
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
cur.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE IF NOT EXISTS documents (
|
||||||
|
id SERIAL PRIMARY KEY,
|
||||||
|
story_id INTEGER NOT NULL REFERENCES stories(id) ON DELETE CASCADE,
|
||||||
|
path TEXT NOT NULL,
|
||||||
|
version TEXT NOT NULL,
|
||||||
|
updated_at TIMESTAMPTZ NOT NULL,
|
||||||
|
UNIQUE(story_id, path)
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
cur.execute(
|
||||||
|
f"""
|
||||||
|
CREATE TABLE IF NOT EXISTS chunks (
|
||||||
|
id SERIAL PRIMARY KEY,
|
||||||
|
document_id INTEGER NOT NULL REFERENCES documents(id) ON DELETE CASCADE,
|
||||||
|
chunk_index INTEGER NOT NULL,
|
||||||
|
hash TEXT NOT NULL,
|
||||||
|
content TEXT NOT NULL,
|
||||||
|
embedding vector({embeddings_dim}) NOT NULL
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
cur.execute(
|
||||||
|
"""
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_documents_story_id
|
||||||
|
ON documents(story_id);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
cur.execute(
|
||||||
|
"""
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_chunks_document_id
|
||||||
|
ON chunks(document_id);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
cur.execute(
|
||||||
|
"""
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_chunks_embedding
|
||||||
|
ON chunks USING ivfflat (embedding vector_cosine_ops);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
|
||||||
|
def get_or_create_story(conn: psycopg.Connection, slug: str) -> int:
|
||||||
|
with conn.cursor() as cur:
|
||||||
|
cur.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO stories (slug)
|
||||||
|
VALUES (%s)
|
||||||
|
ON CONFLICT (slug) DO UPDATE SET slug = EXCLUDED.slug
|
||||||
|
RETURNING id;
|
||||||
|
""",
|
||||||
|
(slug.strip(),),
|
||||||
|
)
|
||||||
|
story_id = cur.fetchone()[0]
|
||||||
|
conn.commit()
|
||||||
|
return story_id
|
||||||
|
|
||||||
|
|
||||||
|
def get_story_id(conn: psycopg.Connection, slug: str) -> int | None:
|
||||||
|
s = slug.strip()
|
||||||
|
with conn.cursor() as cur:
|
||||||
|
if s.isdigit():
|
||||||
|
cur.execute("SELECT id FROM stories WHERE id = %s;", (int(s),))
|
||||||
|
else:
|
||||||
|
cur.execute("SELECT id FROM stories WHERE slug = %s;", (s,))
|
||||||
|
row = cur.fetchone()
|
||||||
|
return row[0] if row else None
|
||||||
|
|
||||||
|
|
||||||
|
def upsert_document(
|
||||||
|
conn: psycopg.Connection, story_id: int, path: str, version: str
|
||||||
|
) -> int:
|
||||||
|
with conn.cursor() as cur:
|
||||||
|
cur.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO documents (story_id, path, version, updated_at)
|
||||||
|
VALUES (%s, %s, %s, %s)
|
||||||
|
ON CONFLICT (story_id, path) DO UPDATE
|
||||||
|
SET version = EXCLUDED.version,
|
||||||
|
updated_at = EXCLUDED.updated_at
|
||||||
|
RETURNING id;
|
||||||
|
""",
|
||||||
|
(story_id, path, version, datetime.now(timezone.utc)),
|
||||||
|
)
|
||||||
|
document_id = cur.fetchone()[0]
|
||||||
|
return document_id
|
||||||
|
|
||||||
|
|
||||||
|
def replace_chunks(
|
||||||
|
conn: psycopg.Connection,
|
||||||
|
document_id: int,
|
||||||
|
chunks: Iterable[TextChunk],
|
||||||
|
embeddings: Iterable[list[float]],
|
||||||
|
) -> None:
|
||||||
|
with conn.cursor() as cur:
|
||||||
|
cur.execute(
|
||||||
|
"DELETE FROM chunks WHERE document_id = %s;",
|
||||||
|
(document_id,),
|
||||||
|
)
|
||||||
|
for chunk, embedding in zip(chunks, embeddings):
|
||||||
|
cur.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO chunks (document_id, chunk_index, hash, content, embedding)
|
||||||
|
VALUES (%s, %s, %s, %s, %s);
|
||||||
|
""",
|
||||||
|
(document_id, chunk.index, chunk.hash, chunk.text, embedding),
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
|
||||||
|
def delete_document(
|
||||||
|
conn: psycopg.Connection, story_id: int, path: str
|
||||||
|
) -> None:
|
||||||
|
with conn.cursor() as cur:
|
||||||
|
cur.execute(
|
||||||
|
"DELETE FROM documents WHERE story_id = %s AND path = %s;",
|
||||||
|
(story_id, path),
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_similar(
|
||||||
|
conn: psycopg.Connection,
|
||||||
|
query_embedding: list[float],
|
||||||
|
top_k: int,
|
||||||
|
story_id: int | None = None,
|
||||||
|
) -> list[tuple[str, str, float]]:
|
||||||
|
with conn.cursor() as cur:
|
||||||
|
if story_id is not None:
|
||||||
|
cur.execute(
|
||||||
|
"""
|
||||||
|
SELECT d.path, c.content, c.embedding <=> %s AS distance
|
||||||
|
FROM chunks c
|
||||||
|
JOIN documents d ON d.id = c.document_id
|
||||||
|
WHERE d.story_id = %s
|
||||||
|
ORDER BY c.embedding <=> %s
|
||||||
|
LIMIT %s;
|
||||||
|
""",
|
||||||
|
(query_embedding, story_id, query_embedding, top_k),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
cur.execute(
|
||||||
|
"""
|
||||||
|
SELECT d.path, c.content, c.embedding <=> %s AS distance
|
||||||
|
FROM chunks c
|
||||||
|
JOIN documents d ON d.id = c.document_id
|
||||||
|
ORDER BY c.embedding <=> %s
|
||||||
|
LIMIT %s;
|
||||||
|
""",
|
||||||
|
(query_embedding, query_embedding, top_k),
|
||||||
|
)
|
||||||
|
rows = cur.fetchall()
|
||||||
|
return [(row[0], row[1], row[2]) for row in rows]
|
||||||
1
src/rag_agent/ingest/__init__.py
Normal file
1
src/rag_agent/ingest/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
__all__ = []
|
||||||
BIN
src/rag_agent/ingest/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
src/rag_agent/ingest/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
src/rag_agent/ingest/__pycache__/chunker.cpython-312.pyc
Normal file
BIN
src/rag_agent/ingest/__pycache__/chunker.cpython-312.pyc
Normal file
Binary file not shown.
BIN
src/rag_agent/ingest/__pycache__/file_loader.cpython-312.pyc
Normal file
BIN
src/rag_agent/ingest/__pycache__/file_loader.cpython-312.pyc
Normal file
Binary file not shown.
BIN
src/rag_agent/ingest/__pycache__/git_watcher.cpython-312.pyc
Normal file
BIN
src/rag_agent/ingest/__pycache__/git_watcher.cpython-312.pyc
Normal file
Binary file not shown.
42
src/rag_agent/ingest/chunker.py
Normal file
42
src/rag_agent/ingest/chunker.py
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Iterable, Iterator
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class TextChunk:
|
||||||
|
index: int
|
||||||
|
text: str
|
||||||
|
hash: str
|
||||||
|
|
||||||
|
|
||||||
|
def _hash_text(text: str) -> str:
|
||||||
|
return hashlib.sha256(text.encode("utf-8")).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
def chunk_text(text: str, chunk_size: int, overlap: int) -> list[TextChunk]:
|
||||||
|
tokens = text.split()
|
||||||
|
if not tokens:
|
||||||
|
return []
|
||||||
|
|
||||||
|
chunks: list[TextChunk] = []
|
||||||
|
start = 0
|
||||||
|
index = 0
|
||||||
|
while start < len(tokens):
|
||||||
|
end = min(start + chunk_size, len(tokens))
|
||||||
|
chunk_text_value = " ".join(tokens[start:end])
|
||||||
|
chunks.append(TextChunk(index=index, text=chunk_text_value, hash=_hash_text(chunk_text_value)))
|
||||||
|
index += 1
|
||||||
|
if end == len(tokens):
|
||||||
|
break
|
||||||
|
start = max(end - overlap, 0)
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
|
||||||
|
def iter_chunks(
|
||||||
|
texts: Iterable[str], chunk_size: int, overlap: int
|
||||||
|
) -> Iterator[list[TextChunk]]:
|
||||||
|
for text in texts:
|
||||||
|
yield chunk_text(text, chunk_size, overlap)
|
||||||
23
src/rag_agent/ingest/file_loader.py
Normal file
23
src/rag_agent/ingest/file_loader.py
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Iterable, Iterator
|
||||||
|
|
||||||
|
|
||||||
|
def is_allowed(path: Path, allowed_extensions: Iterable[str]) -> bool:
|
||||||
|
return path.suffix.lower() in {ext.lower() for ext in allowed_extensions}
|
||||||
|
|
||||||
|
|
||||||
|
def read_text_file(path: Path) -> str:
|
||||||
|
return path.read_text(encoding="utf-8", errors="ignore")
|
||||||
|
|
||||||
|
|
||||||
|
def iter_text_files(
|
||||||
|
paths: Iterable[Path], allowed_extensions: Iterable[str]
|
||||||
|
) -> Iterator[tuple[Path, str]]:
|
||||||
|
for path in paths:
|
||||||
|
if not path.is_file():
|
||||||
|
continue
|
||||||
|
if not is_allowed(path, allowed_extensions):
|
||||||
|
continue
|
||||||
|
yield path, read_text_file(path)
|
||||||
42
src/rag_agent/ingest/git_watcher.py
Normal file
42
src/rag_agent/ingest/git_watcher.py
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import subprocess
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Iterable
|
||||||
|
|
||||||
|
|
||||||
|
def get_changed_files(
|
||||||
|
repo_path: str, base_ref: str, head_ref: str = "HEAD"
|
||||||
|
) -> list[Path]:
|
||||||
|
args = [
|
||||||
|
"git",
|
||||||
|
"-C",
|
||||||
|
repo_path,
|
||||||
|
"diff",
|
||||||
|
"--name-only",
|
||||||
|
base_ref,
|
||||||
|
head_ref,
|
||||||
|
]
|
||||||
|
try:
|
||||||
|
result = subprocess.run(
|
||||||
|
args, check=True, capture_output=True, text=True
|
||||||
|
)
|
||||||
|
except subprocess.CalledProcessError as exc:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"git diff failed: {exc.stderr.strip() or exc}"
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
files = []
|
||||||
|
for line in result.stdout.splitlines():
|
||||||
|
value = line.strip()
|
||||||
|
if value:
|
||||||
|
files.append(Path(repo_path) / value)
|
||||||
|
return files
|
||||||
|
|
||||||
|
|
||||||
|
def filter_existing(paths: Iterable[Path]) -> list[Path]:
|
||||||
|
return [path for path in paths if path.exists()]
|
||||||
|
|
||||||
|
|
||||||
|
def filter_removed(paths: Iterable[Path]) -> list[Path]:
|
||||||
|
return [path for path in paths if not path.exists()]
|
||||||
1
src/rag_agent/retrieval/__init__.py
Normal file
1
src/rag_agent/retrieval/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
__all__ = []
|
||||||
24
src/rag_agent/retrieval/search.py
Normal file
24
src/rag_agent/retrieval/search.py
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
import psycopg
|
||||||
|
|
||||||
|
from rag_agent.index.postgres import fetch_similar
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class SearchResult:
|
||||||
|
path: str
|
||||||
|
content: str
|
||||||
|
distance: float
|
||||||
|
|
||||||
|
|
||||||
|
def search_similar(
|
||||||
|
conn: psycopg.Connection,
|
||||||
|
query_embedding: list[float],
|
||||||
|
top_k: int = 5,
|
||||||
|
story_id: int | None = None,
|
||||||
|
) -> list[SearchResult]:
|
||||||
|
rows = fetch_similar(conn, query_embedding, top_k, story_id=story_id)
|
||||||
|
return [SearchResult(path=row[0], content=row[1], distance=row[2]) for row in rows]
|
||||||
9
tests/test_chunker.py
Normal file
9
tests/test_chunker.py
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
from rag_agent.ingest.chunker import chunk_text
|
||||||
|
|
||||||
|
|
||||||
|
def test_chunk_text_basic():
|
||||||
|
text = "one two three four five six seven eight"
|
||||||
|
chunks = chunk_text(text, chunk_size=3, overlap=1)
|
||||||
|
assert len(chunks) == 3
|
||||||
|
assert chunks[0].text == "one two three"
|
||||||
|
assert chunks[1].text.startswith("three four")
|
||||||
Reference in New Issue
Block a user