Скелет проекта

This commit is contained in:
2026-01-30 22:21:12 +03:00
commit 84ded7d7a9
30 changed files with 752 additions and 0 deletions

43
README.md Normal file
View File

@@ -0,0 +1,43 @@
# RAG Agent (Postgres)
Custom RAG agent that indexes text files from a git repository into Postgres
and answers queries using retrieval + LLM generation. Commits are tied to
**stories**; indexing and retrieval can be scoped by story.
## Quick start
1. Configure environment variables:
- `RAG_REPO_PATH` — path to git repo with text files
- `RAG_DB_DSN` — Postgres DSN (e.g. `postgresql://user:pass@localhost:5432/rag`)
- `RAG_EMBEDDINGS_DIM` — embedding vector dimension (e.g. `1536`)
2. Create DB schema:
- `python scripts/create_db.py` (or `psql "$RAG_DB_DSN" -f scripts/schema.sql`)
3. Index files for a story (e.g. branch name as story slug):
- `rag-agent index --story my-branch --changed --base-ref HEAD~1 --head-ref HEAD`
4. Ask a question (optionally scoped to a story):
- `rag-agent ask "What is covered?"`
- `rag-agent ask "What is covered?" --story my-branch`
## Git hook (index on commit)
Install the post-commit hook so changed files are indexed after each commit:
```bash
cp scripts/post-commit .git/hooks/post-commit && chmod +x .git/hooks/post-commit
```
Story for the commit is taken from (in order): env `RAG_STORY`, file `.rag-story` in repo root (one line = slug), or current branch name.
## DB structure
- **stories** — story slug (e.g. branch name); documents and chunks are tied to a story.
- **documents** — path + version per story; unique `(story_id, path)`.
- **chunks** — text chunks with embeddings (pgvector); updated when documents are re-indexed.
Scripts: `scripts/create_db.py` (Python, uses `ensure_schema` and `RAG_*` env), `scripts/schema.sql` (raw SQL).
## Notes
- The default embedding/LLM clients are stubs. Replace them in
`src/rag_agent/index/embeddings.py` and `src/rag_agent/agent/pipeline.py`.
- This project requires Postgres with the `pgvector` extension.

17
pyproject.toml Normal file
View File

@@ -0,0 +1,17 @@
[project]
name = "rag-agent"
version = "0.1.0"
description = "Custom RAG agent with Postgres-backed vector index"
readme = "README.md"
requires-python = ">=3.10"
dependencies = [
"psycopg[binary]>=3.1.18",
"pgvector>=0.2.5",
"pydantic>=2.7.0",
]
[project.scripts]
rag-agent = "rag_agent.cli:main"
[tool.ruff]
line-length = 100

29
scripts/create_db.py Normal file
View File

@@ -0,0 +1,29 @@
#!/usr/bin/env python3
"""
Create RAG vector DB schema in Postgres (extension + stories, documents, chunks).
Requires RAG_DB_DSN and RAG_EMBEDDINGS_DIM (optional, default 1536).
Run from repo root with package installed: pip install -e . && python scripts/create_db.py
"""
from __future__ import annotations
import sys
from pathlib import Path
# Allow importing rag_agent when run as scripts/create_db.py
repo_root = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(repo_root / "src"))
from rag_agent.config import load_config
from rag_agent.index.postgres import connect, ensure_schema
def main() -> None:
config = load_config()
conn = connect(config.db_dsn)
ensure_schema(conn, config.embeddings_dim)
conn.close()
print("Schema created successfully.")
if __name__ == "__main__":
main()

35
scripts/post-commit Normal file
View File

@@ -0,0 +1,35 @@
#!/usr/bin/env sh
# Git post-commit hook: index changed files into RAG vector DB for the current story.
# Install: cp scripts/post-commit .git/hooks/post-commit && chmod +x .git/hooks/post-commit
# Requires: RAG_REPO_PATH, RAG_DB_DSN, RAG_EMBEDDINGS_DIM; story from RAG_STORY or .rag-story or current branch name.
set -e
cd "$(git rev-parse --show-toplevel)"
# Resolve story: env RAG_STORY > file .rag-story > current branch name
if [ -n "${RAG_STORY}" ]; then
STORY="${RAG_STORY}"
elif [ -f .rag-story ]; then
STORY=$(head -n1 .rag-story | tr -d '\n\r')
else
STORY=$(git branch --show-current)
fi
if [ -z "$STORY" ]; then
echo "post-commit: RAG_STORY or .rag-story or branch name required for indexing."
exit 0
fi
# Run index (changed files only: previous commit -> HEAD)
if command -v rag-agent >/dev/null 2>&1; then
rag-agent index --changed --base-ref HEAD~1 --head-ref HEAD --story "$STORY"
elif [ -n "${VIRTUAL_ENV}" ]; then
rag-agent index --changed --base-ref HEAD~1 --head-ref HEAD --story "$STORY"
else
# Try repo venv or python -m
if [ -f "venv/bin/rag-agent" ]; then
venv/bin/rag-agent index --changed --base-ref HEAD~1 --head-ref HEAD --story "$STORY"
else
PYTHONPATH=src python -m rag_agent.cli index --changed --base-ref HEAD~1 --head-ref HEAD --story "$STORY" 2>/dev/null || true
fi
fi

33
scripts/schema.sql Normal file
View File

@@ -0,0 +1,33 @@
-- RAG vector DB schema for Postgres (pgvector).
-- Run once against an empty DB. If RAG_EMBEDDINGS_DIM is not 1536, change vector(1536) below.
-- Usage: psql "$RAG_DB_DSN" -f scripts/schema.sql
CREATE EXTENSION IF NOT EXISTS vector;
CREATE TABLE IF NOT EXISTS stories (
id SERIAL PRIMARY KEY,
slug TEXT UNIQUE NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT (NOW() AT TIME ZONE 'utc')
);
CREATE TABLE IF NOT EXISTS documents (
id SERIAL PRIMARY KEY,
story_id INTEGER NOT NULL REFERENCES stories(id) ON DELETE CASCADE,
path TEXT NOT NULL,
version TEXT NOT NULL,
updated_at TIMESTAMPTZ NOT NULL,
UNIQUE(story_id, path)
);
CREATE TABLE IF NOT EXISTS chunks (
id SERIAL PRIMARY KEY,
document_id INTEGER NOT NULL REFERENCES documents(id) ON DELETE CASCADE,
chunk_index INTEGER NOT NULL,
hash TEXT NOT NULL,
content TEXT NOT NULL,
embedding vector(1536) NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_documents_story_id ON documents(story_id);
CREATE INDEX IF NOT EXISTS idx_chunks_document_id ON chunks(document_id);
CREATE INDEX IF NOT EXISTS idx_chunks_embedding ON chunks USING ivfflat (embedding vector_cosine_ops);

View File

@@ -0,0 +1,3 @@
__all__ = [
"config",
]

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@@ -0,0 +1 @@
__all__ = []

View File

@@ -0,0 +1,51 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Protocol
import psycopg
from rag_agent.config import AppConfig
from rag_agent.index.embeddings import EmbeddingClient
from rag_agent.retrieval.search import search_similar
class LLMClient(Protocol):
def generate(self, prompt: str, model: str) -> str:
raise NotImplementedError
@dataclass
class StubLLMClient:
def generate(self, prompt: str, model: str) -> str:
return (
"LLM client is not configured. "
"Replace StubLLMClient with a real implementation."
)
def build_prompt(question: str, contexts: list[str]) -> str:
joined = "\n\n".join(contexts)
return (
"You are a RAG assistant. Use the context below to answer the question.\n\n"
f"Context:\n{joined}\n\n"
f"Question: {question}\nAnswer:"
)
def answer_query(
conn: psycopg.Connection,
config: AppConfig,
embedding_client: EmbeddingClient,
llm_client: LLMClient,
question: str,
top_k: int = 5,
story_id: int | None = None,
) -> str:
query_embedding = embedding_client.embed_texts([question])[0]
results = search_similar(
conn, query_embedding, top_k=top_k, story_id=story_id
)
contexts = [f"Source: {item.path}\n{item.content}" for item in results]
prompt = build_prompt(question, contexts)
return llm_client.generate(prompt, model=config.llm_model)

117
src/rag_agent/cli.py Normal file
View File

@@ -0,0 +1,117 @@
from __future__ import annotations
import argparse
import hashlib
from pathlib import Path
from rag_agent.config import load_config
from rag_agent.ingest.chunker import chunk_text
from rag_agent.ingest.file_loader import iter_text_files
from rag_agent.ingest.git_watcher import filter_existing, filter_removed, get_changed_files
from rag_agent.index.embeddings import get_embedding_client
from rag_agent.index.postgres import (
connect,
delete_document,
ensure_schema,
get_or_create_story,
get_story_id,
replace_chunks,
upsert_document,
)
from rag_agent.agent.pipeline import StubLLMClient, answer_query
def _file_version(path: Path) -> str:
stat = path.stat()
payload = f"{path.as_posix()}:{stat.st_mtime_ns}:{stat.st_size}"
return hashlib.sha256(payload.encode("utf-8")).hexdigest()
def cmd_index(args: argparse.Namespace) -> None:
config = load_config()
conn = connect(config.db_dsn)
ensure_schema(conn, config.embeddings_dim)
story_id = get_or_create_story(conn, args.story)
embedding_client = get_embedding_client(config.embeddings_dim)
if args.changed:
changed_files = get_changed_files(config.repo_path, args.base_ref, args.head_ref)
removed = filter_removed(changed_files)
existing = filter_existing(changed_files)
else:
removed = []
existing = [path for path in Path(config.repo_path).rglob("*") if path.is_file()]
for path in removed:
delete_document(conn, story_id, str(path))
for path, text in iter_text_files(existing, config.allowed_extensions):
chunks = chunk_text(text, config.chunk_size, config.chunk_overlap)
if not chunks:
continue
embeddings = embedding_client.embed_texts([chunk.text for chunk in chunks])
document_id = upsert_document(
conn, story_id, str(path), _file_version(path)
)
replace_chunks(conn, document_id, chunks, embeddings)
def cmd_ask(args: argparse.Namespace) -> None:
config = load_config()
conn = connect(config.db_dsn)
ensure_schema(conn, config.embeddings_dim)
story_id = None
if args.story:
story_id = get_story_id(conn, args.story)
if story_id is None:
raise SystemExit(f"Story not found: {args.story}")
embedding_client = get_embedding_client(config.embeddings_dim)
llm_client = StubLLMClient()
answer = answer_query(
conn,
config,
embedding_client,
llm_client,
args.question,
top_k=args.top_k,
story_id=story_id,
)
print(answer)
def build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(prog="rag-agent")
sub = parser.add_subparsers(dest="command", required=True)
index_parser = sub.add_parser("index", help="Index files into Postgres")
index_parser.add_argument(
"--story",
required=True,
help="Story slug (e.g. branch name or story id); documents are tied to this story",
)
index_parser.add_argument("--changed", action="store_true", help="Index only changed files")
index_parser.add_argument("--base-ref", default="HEAD~1", help="Base git ref")
index_parser.add_argument("--head-ref", default="HEAD", help="Head git ref")
index_parser.set_defaults(func=cmd_index)
ask_parser = sub.add_parser("ask", help="Ask a question")
ask_parser.add_argument("question", help="Question text")
ask_parser.add_argument(
"--story",
default=None,
help="Limit retrieval to this story slug (optional)",
)
ask_parser.add_argument("--top-k", type=int, default=5, help="Top K chunks to retrieve")
ask_parser.set_defaults(func=cmd_ask)
return parser
def main() -> None:
parser = build_parser()
args = parser.parse_args()
args.func(args)
if __name__ == "__main__":
main()

57
src/rag_agent/config.py Normal file
View File

@@ -0,0 +1,57 @@
from __future__ import annotations
import os
from dataclasses import dataclass
from typing import Iterable, Sequence
@dataclass(frozen=True)
class AppConfig:
repo_path: str
db_dsn: str
chunk_size: int
chunk_overlap: int
embeddings_dim: int
embeddings_model: str
llm_model: str
allowed_extensions: Sequence[str]
def _env_int(name: str, default: int) -> int:
value = os.getenv(name, "").strip()
if not value:
return default
try:
return int(value)
except ValueError as exc:
raise ValueError(f"Invalid integer for {name}: {value}") from exc
def _env_list(name: str, default: Iterable[str]) -> list[str]:
value = os.getenv(name, "").strip()
if not value:
return list(default)
return [item.strip() for item in value.split(",") if item.strip()]
def load_config() -> AppConfig:
repo_path = os.getenv("RAG_REPO_PATH", "").strip()
if not repo_path:
raise ValueError("RAG_REPO_PATH is required")
db_dsn = os.getenv("RAG_DB_DSN", "").strip()
if not db_dsn:
raise ValueError("RAG_DB_DSN is required")
return AppConfig(
repo_path=repo_path,
db_dsn=db_dsn,
chunk_size=_env_int("RAG_CHUNK_SIZE", 400),
chunk_overlap=_env_int("RAG_CHUNK_OVERLAP", 50),
embeddings_dim=_env_int("RAG_EMBEDDINGS_DIM", 1536),
embeddings_model=os.getenv("RAG_EMBEDDINGS_MODEL", "stub-embeddings"),
llm_model=os.getenv("RAG_LLM_MODEL", "stub-llm"),
allowed_extensions=tuple(
_env_list("RAG_ALLOWED_EXTENSIONS", [".md", ".txt", ".rst"])
),
)

View File

@@ -0,0 +1 @@
__all__ = []

View File

@@ -0,0 +1,29 @@
from __future__ import annotations
import hashlib
from dataclasses import dataclass
from typing import Iterable, Protocol
class EmbeddingClient(Protocol):
def embed_texts(self, texts: Iterable[str]) -> list[list[float]]:
raise NotImplementedError
@dataclass
class StubEmbeddingClient:
dim: int
def embed_texts(self, texts: Iterable[str]) -> list[list[float]]:
vectors: list[list[float]] = []
for text in texts:
digest = hashlib.sha256(text.encode("utf-8")).digest()
values = [b / 255.0 for b in digest]
if len(values) < self.dim:
values = (values * ((self.dim // len(values)) + 1))[: self.dim]
vectors.append(values[: self.dim])
return vectors
def get_embedding_client(dim: int) -> EmbeddingClient:
return StubEmbeddingClient(dim=dim)

View File

@@ -0,0 +1,194 @@
from __future__ import annotations
from dataclasses import dataclass
from datetime import datetime, timezone
from typing import Iterable
import psycopg
from pgvector.psycopg import register_vector
from rag_agent.ingest.chunker import TextChunk
@dataclass(frozen=True)
class ChunkRecord:
document_path: str
document_version: str
chunk: TextChunk
embedding: list[float]
def connect(dsn: str) -> psycopg.Connection:
conn = psycopg.connect(dsn)
register_vector(conn)
return conn
def ensure_schema(conn: psycopg.Connection, embeddings_dim: int) -> None:
with conn.cursor() as cur:
cur.execute("CREATE EXTENSION IF NOT EXISTS vector;")
cur.execute(
"""
CREATE TABLE IF NOT EXISTS stories (
id SERIAL PRIMARY KEY,
slug TEXT UNIQUE NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT (NOW() AT TIME ZONE 'utc')
);
"""
)
cur.execute(
"""
CREATE TABLE IF NOT EXISTS documents (
id SERIAL PRIMARY KEY,
story_id INTEGER NOT NULL REFERENCES stories(id) ON DELETE CASCADE,
path TEXT NOT NULL,
version TEXT NOT NULL,
updated_at TIMESTAMPTZ NOT NULL,
UNIQUE(story_id, path)
);
"""
)
cur.execute(
f"""
CREATE TABLE IF NOT EXISTS chunks (
id SERIAL PRIMARY KEY,
document_id INTEGER NOT NULL REFERENCES documents(id) ON DELETE CASCADE,
chunk_index INTEGER NOT NULL,
hash TEXT NOT NULL,
content TEXT NOT NULL,
embedding vector({embeddings_dim}) NOT NULL
);
"""
)
cur.execute(
"""
CREATE INDEX IF NOT EXISTS idx_documents_story_id
ON documents(story_id);
"""
)
cur.execute(
"""
CREATE INDEX IF NOT EXISTS idx_chunks_document_id
ON chunks(document_id);
"""
)
cur.execute(
"""
CREATE INDEX IF NOT EXISTS idx_chunks_embedding
ON chunks USING ivfflat (embedding vector_cosine_ops);
"""
)
conn.commit()
def get_or_create_story(conn: psycopg.Connection, slug: str) -> int:
with conn.cursor() as cur:
cur.execute(
"""
INSERT INTO stories (slug)
VALUES (%s)
ON CONFLICT (slug) DO UPDATE SET slug = EXCLUDED.slug
RETURNING id;
""",
(slug.strip(),),
)
story_id = cur.fetchone()[0]
conn.commit()
return story_id
def get_story_id(conn: psycopg.Connection, slug: str) -> int | None:
s = slug.strip()
with conn.cursor() as cur:
if s.isdigit():
cur.execute("SELECT id FROM stories WHERE id = %s;", (int(s),))
else:
cur.execute("SELECT id FROM stories WHERE slug = %s;", (s,))
row = cur.fetchone()
return row[0] if row else None
def upsert_document(
conn: psycopg.Connection, story_id: int, path: str, version: str
) -> int:
with conn.cursor() as cur:
cur.execute(
"""
INSERT INTO documents (story_id, path, version, updated_at)
VALUES (%s, %s, %s, %s)
ON CONFLICT (story_id, path) DO UPDATE
SET version = EXCLUDED.version,
updated_at = EXCLUDED.updated_at
RETURNING id;
""",
(story_id, path, version, datetime.now(timezone.utc)),
)
document_id = cur.fetchone()[0]
return document_id
def replace_chunks(
conn: psycopg.Connection,
document_id: int,
chunks: Iterable[TextChunk],
embeddings: Iterable[list[float]],
) -> None:
with conn.cursor() as cur:
cur.execute(
"DELETE FROM chunks WHERE document_id = %s;",
(document_id,),
)
for chunk, embedding in zip(chunks, embeddings):
cur.execute(
"""
INSERT INTO chunks (document_id, chunk_index, hash, content, embedding)
VALUES (%s, %s, %s, %s, %s);
""",
(document_id, chunk.index, chunk.hash, chunk.text, embedding),
)
conn.commit()
def delete_document(
conn: psycopg.Connection, story_id: int, path: str
) -> None:
with conn.cursor() as cur:
cur.execute(
"DELETE FROM documents WHERE story_id = %s AND path = %s;",
(story_id, path),
)
conn.commit()
def fetch_similar(
conn: psycopg.Connection,
query_embedding: list[float],
top_k: int,
story_id: int | None = None,
) -> list[tuple[str, str, float]]:
with conn.cursor() as cur:
if story_id is not None:
cur.execute(
"""
SELECT d.path, c.content, c.embedding <=> %s AS distance
FROM chunks c
JOIN documents d ON d.id = c.document_id
WHERE d.story_id = %s
ORDER BY c.embedding <=> %s
LIMIT %s;
""",
(query_embedding, story_id, query_embedding, top_k),
)
else:
cur.execute(
"""
SELECT d.path, c.content, c.embedding <=> %s AS distance
FROM chunks c
JOIN documents d ON d.id = c.document_id
ORDER BY c.embedding <=> %s
LIMIT %s;
""",
(query_embedding, query_embedding, top_k),
)
rows = cur.fetchall()
return [(row[0], row[1], row[2]) for row in rows]

View File

@@ -0,0 +1 @@
__all__ = []

View File

@@ -0,0 +1,42 @@
from __future__ import annotations
import hashlib
from dataclasses import dataclass
from typing import Iterable, Iterator
@dataclass(frozen=True)
class TextChunk:
index: int
text: str
hash: str
def _hash_text(text: str) -> str:
return hashlib.sha256(text.encode("utf-8")).hexdigest()
def chunk_text(text: str, chunk_size: int, overlap: int) -> list[TextChunk]:
tokens = text.split()
if not tokens:
return []
chunks: list[TextChunk] = []
start = 0
index = 0
while start < len(tokens):
end = min(start + chunk_size, len(tokens))
chunk_text_value = " ".join(tokens[start:end])
chunks.append(TextChunk(index=index, text=chunk_text_value, hash=_hash_text(chunk_text_value)))
index += 1
if end == len(tokens):
break
start = max(end - overlap, 0)
return chunks
def iter_chunks(
texts: Iterable[str], chunk_size: int, overlap: int
) -> Iterator[list[TextChunk]]:
for text in texts:
yield chunk_text(text, chunk_size, overlap)

View File

@@ -0,0 +1,23 @@
from __future__ import annotations
from pathlib import Path
from typing import Iterable, Iterator
def is_allowed(path: Path, allowed_extensions: Iterable[str]) -> bool:
return path.suffix.lower() in {ext.lower() for ext in allowed_extensions}
def read_text_file(path: Path) -> str:
return path.read_text(encoding="utf-8", errors="ignore")
def iter_text_files(
paths: Iterable[Path], allowed_extensions: Iterable[str]
) -> Iterator[tuple[Path, str]]:
for path in paths:
if not path.is_file():
continue
if not is_allowed(path, allowed_extensions):
continue
yield path, read_text_file(path)

View File

@@ -0,0 +1,42 @@
from __future__ import annotations
import subprocess
from pathlib import Path
from typing import Iterable
def get_changed_files(
repo_path: str, base_ref: str, head_ref: str = "HEAD"
) -> list[Path]:
args = [
"git",
"-C",
repo_path,
"diff",
"--name-only",
base_ref,
head_ref,
]
try:
result = subprocess.run(
args, check=True, capture_output=True, text=True
)
except subprocess.CalledProcessError as exc:
raise RuntimeError(
f"git diff failed: {exc.stderr.strip() or exc}"
) from exc
files = []
for line in result.stdout.splitlines():
value = line.strip()
if value:
files.append(Path(repo_path) / value)
return files
def filter_existing(paths: Iterable[Path]) -> list[Path]:
return [path for path in paths if path.exists()]
def filter_removed(paths: Iterable[Path]) -> list[Path]:
return [path for path in paths if not path.exists()]

View File

@@ -0,0 +1 @@
__all__ = []

View File

@@ -0,0 +1,24 @@
from __future__ import annotations
from dataclasses import dataclass
import psycopg
from rag_agent.index.postgres import fetch_similar
@dataclass(frozen=True)
class SearchResult:
path: str
content: str
distance: float
def search_similar(
conn: psycopg.Connection,
query_embedding: list[float],
top_k: int = 5,
story_id: int | None = None,
) -> list[SearchResult]:
rows = fetch_similar(conn, query_embedding, top_k, story_id=story_id)
return [SearchResult(path=row[0], content=row[1], distance=row[2]) for row in rows]

9
tests/test_chunker.py Normal file
View File

@@ -0,0 +1,9 @@
from rag_agent.ingest.chunker import chunk_text
def test_chunk_text_basic():
text = "one two three four five six seven eight"
chunks = chunk_text(text, chunk_size=3, overlap=1)
assert len(chunks) == 3
assert chunks[0].text == "one two three"
assert chunks[1].text.startswith("three four")