Скелет проекта
This commit is contained in:
3
src/rag_agent/__init__.py
Normal file
3
src/rag_agent/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
__all__ = [
|
||||
"config",
|
||||
]
|
||||
BIN
src/rag_agent/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
src/rag_agent/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
src/rag_agent/__pycache__/cli.cpython-312.pyc
Normal file
BIN
src/rag_agent/__pycache__/cli.cpython-312.pyc
Normal file
Binary file not shown.
BIN
src/rag_agent/__pycache__/config.cpython-312.pyc
Normal file
BIN
src/rag_agent/__pycache__/config.cpython-312.pyc
Normal file
Binary file not shown.
1
src/rag_agent/agent/__init__.py
Normal file
1
src/rag_agent/agent/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
__all__ = []
|
||||
51
src/rag_agent/agent/pipeline.py
Normal file
51
src/rag_agent/agent/pipeline.py
Normal file
@@ -0,0 +1,51 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Protocol
|
||||
|
||||
import psycopg
|
||||
|
||||
from rag_agent.config import AppConfig
|
||||
from rag_agent.index.embeddings import EmbeddingClient
|
||||
from rag_agent.retrieval.search import search_similar
|
||||
|
||||
|
||||
class LLMClient(Protocol):
|
||||
def generate(self, prompt: str, model: str) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@dataclass
|
||||
class StubLLMClient:
|
||||
def generate(self, prompt: str, model: str) -> str:
|
||||
return (
|
||||
"LLM client is not configured. "
|
||||
"Replace StubLLMClient with a real implementation."
|
||||
)
|
||||
|
||||
|
||||
def build_prompt(question: str, contexts: list[str]) -> str:
|
||||
joined = "\n\n".join(contexts)
|
||||
return (
|
||||
"You are a RAG assistant. Use the context below to answer the question.\n\n"
|
||||
f"Context:\n{joined}\n\n"
|
||||
f"Question: {question}\nAnswer:"
|
||||
)
|
||||
|
||||
|
||||
def answer_query(
|
||||
conn: psycopg.Connection,
|
||||
config: AppConfig,
|
||||
embedding_client: EmbeddingClient,
|
||||
llm_client: LLMClient,
|
||||
question: str,
|
||||
top_k: int = 5,
|
||||
story_id: int | None = None,
|
||||
) -> str:
|
||||
query_embedding = embedding_client.embed_texts([question])[0]
|
||||
results = search_similar(
|
||||
conn, query_embedding, top_k=top_k, story_id=story_id
|
||||
)
|
||||
contexts = [f"Source: {item.path}\n{item.content}" for item in results]
|
||||
prompt = build_prompt(question, contexts)
|
||||
return llm_client.generate(prompt, model=config.llm_model)
|
||||
117
src/rag_agent/cli.py
Normal file
117
src/rag_agent/cli.py
Normal file
@@ -0,0 +1,117 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import hashlib
|
||||
from pathlib import Path
|
||||
|
||||
from rag_agent.config import load_config
|
||||
from rag_agent.ingest.chunker import chunk_text
|
||||
from rag_agent.ingest.file_loader import iter_text_files
|
||||
from rag_agent.ingest.git_watcher import filter_existing, filter_removed, get_changed_files
|
||||
from rag_agent.index.embeddings import get_embedding_client
|
||||
from rag_agent.index.postgres import (
|
||||
connect,
|
||||
delete_document,
|
||||
ensure_schema,
|
||||
get_or_create_story,
|
||||
get_story_id,
|
||||
replace_chunks,
|
||||
upsert_document,
|
||||
)
|
||||
from rag_agent.agent.pipeline import StubLLMClient, answer_query
|
||||
|
||||
|
||||
def _file_version(path: Path) -> str:
|
||||
stat = path.stat()
|
||||
payload = f"{path.as_posix()}:{stat.st_mtime_ns}:{stat.st_size}"
|
||||
return hashlib.sha256(payload.encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
def cmd_index(args: argparse.Namespace) -> None:
|
||||
config = load_config()
|
||||
conn = connect(config.db_dsn)
|
||||
ensure_schema(conn, config.embeddings_dim)
|
||||
story_id = get_or_create_story(conn, args.story)
|
||||
embedding_client = get_embedding_client(config.embeddings_dim)
|
||||
|
||||
if args.changed:
|
||||
changed_files = get_changed_files(config.repo_path, args.base_ref, args.head_ref)
|
||||
removed = filter_removed(changed_files)
|
||||
existing = filter_existing(changed_files)
|
||||
else:
|
||||
removed = []
|
||||
existing = [path for path in Path(config.repo_path).rglob("*") if path.is_file()]
|
||||
|
||||
for path in removed:
|
||||
delete_document(conn, story_id, str(path))
|
||||
|
||||
for path, text in iter_text_files(existing, config.allowed_extensions):
|
||||
chunks = chunk_text(text, config.chunk_size, config.chunk_overlap)
|
||||
if not chunks:
|
||||
continue
|
||||
embeddings = embedding_client.embed_texts([chunk.text for chunk in chunks])
|
||||
document_id = upsert_document(
|
||||
conn, story_id, str(path), _file_version(path)
|
||||
)
|
||||
replace_chunks(conn, document_id, chunks, embeddings)
|
||||
|
||||
|
||||
def cmd_ask(args: argparse.Namespace) -> None:
|
||||
config = load_config()
|
||||
conn = connect(config.db_dsn)
|
||||
ensure_schema(conn, config.embeddings_dim)
|
||||
story_id = None
|
||||
if args.story:
|
||||
story_id = get_story_id(conn, args.story)
|
||||
if story_id is None:
|
||||
raise SystemExit(f"Story not found: {args.story}")
|
||||
embedding_client = get_embedding_client(config.embeddings_dim)
|
||||
llm_client = StubLLMClient()
|
||||
answer = answer_query(
|
||||
conn,
|
||||
config,
|
||||
embedding_client,
|
||||
llm_client,
|
||||
args.question,
|
||||
top_k=args.top_k,
|
||||
story_id=story_id,
|
||||
)
|
||||
print(answer)
|
||||
|
||||
|
||||
def build_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser(prog="rag-agent")
|
||||
sub = parser.add_subparsers(dest="command", required=True)
|
||||
|
||||
index_parser = sub.add_parser("index", help="Index files into Postgres")
|
||||
index_parser.add_argument(
|
||||
"--story",
|
||||
required=True,
|
||||
help="Story slug (e.g. branch name or story id); documents are tied to this story",
|
||||
)
|
||||
index_parser.add_argument("--changed", action="store_true", help="Index only changed files")
|
||||
index_parser.add_argument("--base-ref", default="HEAD~1", help="Base git ref")
|
||||
index_parser.add_argument("--head-ref", default="HEAD", help="Head git ref")
|
||||
index_parser.set_defaults(func=cmd_index)
|
||||
|
||||
ask_parser = sub.add_parser("ask", help="Ask a question")
|
||||
ask_parser.add_argument("question", help="Question text")
|
||||
ask_parser.add_argument(
|
||||
"--story",
|
||||
default=None,
|
||||
help="Limit retrieval to this story slug (optional)",
|
||||
)
|
||||
ask_parser.add_argument("--top-k", type=int, default=5, help="Top K chunks to retrieve")
|
||||
ask_parser.set_defaults(func=cmd_ask)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = build_parser()
|
||||
args = parser.parse_args()
|
||||
args.func(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
57
src/rag_agent/config.py
Normal file
57
src/rag_agent/config.py
Normal file
@@ -0,0 +1,57 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Iterable, Sequence
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AppConfig:
|
||||
repo_path: str
|
||||
db_dsn: str
|
||||
chunk_size: int
|
||||
chunk_overlap: int
|
||||
embeddings_dim: int
|
||||
embeddings_model: str
|
||||
llm_model: str
|
||||
allowed_extensions: Sequence[str]
|
||||
|
||||
|
||||
def _env_int(name: str, default: int) -> int:
|
||||
value = os.getenv(name, "").strip()
|
||||
if not value:
|
||||
return default
|
||||
try:
|
||||
return int(value)
|
||||
except ValueError as exc:
|
||||
raise ValueError(f"Invalid integer for {name}: {value}") from exc
|
||||
|
||||
|
||||
def _env_list(name: str, default: Iterable[str]) -> list[str]:
|
||||
value = os.getenv(name, "").strip()
|
||||
if not value:
|
||||
return list(default)
|
||||
return [item.strip() for item in value.split(",") if item.strip()]
|
||||
|
||||
|
||||
def load_config() -> AppConfig:
|
||||
repo_path = os.getenv("RAG_REPO_PATH", "").strip()
|
||||
if not repo_path:
|
||||
raise ValueError("RAG_REPO_PATH is required")
|
||||
|
||||
db_dsn = os.getenv("RAG_DB_DSN", "").strip()
|
||||
if not db_dsn:
|
||||
raise ValueError("RAG_DB_DSN is required")
|
||||
|
||||
return AppConfig(
|
||||
repo_path=repo_path,
|
||||
db_dsn=db_dsn,
|
||||
chunk_size=_env_int("RAG_CHUNK_SIZE", 400),
|
||||
chunk_overlap=_env_int("RAG_CHUNK_OVERLAP", 50),
|
||||
embeddings_dim=_env_int("RAG_EMBEDDINGS_DIM", 1536),
|
||||
embeddings_model=os.getenv("RAG_EMBEDDINGS_MODEL", "stub-embeddings"),
|
||||
llm_model=os.getenv("RAG_LLM_MODEL", "stub-llm"),
|
||||
allowed_extensions=tuple(
|
||||
_env_list("RAG_ALLOWED_EXTENSIONS", [".md", ".txt", ".rst"])
|
||||
),
|
||||
)
|
||||
1
src/rag_agent/index/__init__.py
Normal file
1
src/rag_agent/index/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
__all__ = []
|
||||
BIN
src/rag_agent/index/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
src/rag_agent/index/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
src/rag_agent/index/__pycache__/embeddings.cpython-312.pyc
Normal file
BIN
src/rag_agent/index/__pycache__/embeddings.cpython-312.pyc
Normal file
Binary file not shown.
BIN
src/rag_agent/index/__pycache__/postgres.cpython-312.pyc
Normal file
BIN
src/rag_agent/index/__pycache__/postgres.cpython-312.pyc
Normal file
Binary file not shown.
29
src/rag_agent/index/embeddings.py
Normal file
29
src/rag_agent/index/embeddings.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
from dataclasses import dataclass
|
||||
from typing import Iterable, Protocol
|
||||
|
||||
|
||||
class EmbeddingClient(Protocol):
|
||||
def embed_texts(self, texts: Iterable[str]) -> list[list[float]]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@dataclass
|
||||
class StubEmbeddingClient:
|
||||
dim: int
|
||||
|
||||
def embed_texts(self, texts: Iterable[str]) -> list[list[float]]:
|
||||
vectors: list[list[float]] = []
|
||||
for text in texts:
|
||||
digest = hashlib.sha256(text.encode("utf-8")).digest()
|
||||
values = [b / 255.0 for b in digest]
|
||||
if len(values) < self.dim:
|
||||
values = (values * ((self.dim // len(values)) + 1))[: self.dim]
|
||||
vectors.append(values[: self.dim])
|
||||
return vectors
|
||||
|
||||
|
||||
def get_embedding_client(dim: int) -> EmbeddingClient:
|
||||
return StubEmbeddingClient(dim=dim)
|
||||
194
src/rag_agent/index/postgres.py
Normal file
194
src/rag_agent/index/postgres.py
Normal file
@@ -0,0 +1,194 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
from typing import Iterable
|
||||
|
||||
import psycopg
|
||||
from pgvector.psycopg import register_vector
|
||||
|
||||
from rag_agent.ingest.chunker import TextChunk
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ChunkRecord:
|
||||
document_path: str
|
||||
document_version: str
|
||||
chunk: TextChunk
|
||||
embedding: list[float]
|
||||
|
||||
|
||||
def connect(dsn: str) -> psycopg.Connection:
|
||||
conn = psycopg.connect(dsn)
|
||||
register_vector(conn)
|
||||
return conn
|
||||
|
||||
|
||||
def ensure_schema(conn: psycopg.Connection, embeddings_dim: int) -> None:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute("CREATE EXTENSION IF NOT EXISTS vector;")
|
||||
cur.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS stories (
|
||||
id SERIAL PRIMARY KEY,
|
||||
slug TEXT UNIQUE NOT NULL,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT (NOW() AT TIME ZONE 'utc')
|
||||
);
|
||||
"""
|
||||
)
|
||||
cur.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS documents (
|
||||
id SERIAL PRIMARY KEY,
|
||||
story_id INTEGER NOT NULL REFERENCES stories(id) ON DELETE CASCADE,
|
||||
path TEXT NOT NULL,
|
||||
version TEXT NOT NULL,
|
||||
updated_at TIMESTAMPTZ NOT NULL,
|
||||
UNIQUE(story_id, path)
|
||||
);
|
||||
"""
|
||||
)
|
||||
cur.execute(
|
||||
f"""
|
||||
CREATE TABLE IF NOT EXISTS chunks (
|
||||
id SERIAL PRIMARY KEY,
|
||||
document_id INTEGER NOT NULL REFERENCES documents(id) ON DELETE CASCADE,
|
||||
chunk_index INTEGER NOT NULL,
|
||||
hash TEXT NOT NULL,
|
||||
content TEXT NOT NULL,
|
||||
embedding vector({embeddings_dim}) NOT NULL
|
||||
);
|
||||
"""
|
||||
)
|
||||
cur.execute(
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS idx_documents_story_id
|
||||
ON documents(story_id);
|
||||
"""
|
||||
)
|
||||
cur.execute(
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS idx_chunks_document_id
|
||||
ON chunks(document_id);
|
||||
"""
|
||||
)
|
||||
cur.execute(
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS idx_chunks_embedding
|
||||
ON chunks USING ivfflat (embedding vector_cosine_ops);
|
||||
"""
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
|
||||
def get_or_create_story(conn: psycopg.Connection, slug: str) -> int:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
INSERT INTO stories (slug)
|
||||
VALUES (%s)
|
||||
ON CONFLICT (slug) DO UPDATE SET slug = EXCLUDED.slug
|
||||
RETURNING id;
|
||||
""",
|
||||
(slug.strip(),),
|
||||
)
|
||||
story_id = cur.fetchone()[0]
|
||||
conn.commit()
|
||||
return story_id
|
||||
|
||||
|
||||
def get_story_id(conn: psycopg.Connection, slug: str) -> int | None:
|
||||
s = slug.strip()
|
||||
with conn.cursor() as cur:
|
||||
if s.isdigit():
|
||||
cur.execute("SELECT id FROM stories WHERE id = %s;", (int(s),))
|
||||
else:
|
||||
cur.execute("SELECT id FROM stories WHERE slug = %s;", (s,))
|
||||
row = cur.fetchone()
|
||||
return row[0] if row else None
|
||||
|
||||
|
||||
def upsert_document(
|
||||
conn: psycopg.Connection, story_id: int, path: str, version: str
|
||||
) -> int:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
INSERT INTO documents (story_id, path, version, updated_at)
|
||||
VALUES (%s, %s, %s, %s)
|
||||
ON CONFLICT (story_id, path) DO UPDATE
|
||||
SET version = EXCLUDED.version,
|
||||
updated_at = EXCLUDED.updated_at
|
||||
RETURNING id;
|
||||
""",
|
||||
(story_id, path, version, datetime.now(timezone.utc)),
|
||||
)
|
||||
document_id = cur.fetchone()[0]
|
||||
return document_id
|
||||
|
||||
|
||||
def replace_chunks(
|
||||
conn: psycopg.Connection,
|
||||
document_id: int,
|
||||
chunks: Iterable[TextChunk],
|
||||
embeddings: Iterable[list[float]],
|
||||
) -> None:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"DELETE FROM chunks WHERE document_id = %s;",
|
||||
(document_id,),
|
||||
)
|
||||
for chunk, embedding in zip(chunks, embeddings):
|
||||
cur.execute(
|
||||
"""
|
||||
INSERT INTO chunks (document_id, chunk_index, hash, content, embedding)
|
||||
VALUES (%s, %s, %s, %s, %s);
|
||||
""",
|
||||
(document_id, chunk.index, chunk.hash, chunk.text, embedding),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
|
||||
def delete_document(
|
||||
conn: psycopg.Connection, story_id: int, path: str
|
||||
) -> None:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"DELETE FROM documents WHERE story_id = %s AND path = %s;",
|
||||
(story_id, path),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
|
||||
def fetch_similar(
|
||||
conn: psycopg.Connection,
|
||||
query_embedding: list[float],
|
||||
top_k: int,
|
||||
story_id: int | None = None,
|
||||
) -> list[tuple[str, str, float]]:
|
||||
with conn.cursor() as cur:
|
||||
if story_id is not None:
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT d.path, c.content, c.embedding <=> %s AS distance
|
||||
FROM chunks c
|
||||
JOIN documents d ON d.id = c.document_id
|
||||
WHERE d.story_id = %s
|
||||
ORDER BY c.embedding <=> %s
|
||||
LIMIT %s;
|
||||
""",
|
||||
(query_embedding, story_id, query_embedding, top_k),
|
||||
)
|
||||
else:
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT d.path, c.content, c.embedding <=> %s AS distance
|
||||
FROM chunks c
|
||||
JOIN documents d ON d.id = c.document_id
|
||||
ORDER BY c.embedding <=> %s
|
||||
LIMIT %s;
|
||||
""",
|
||||
(query_embedding, query_embedding, top_k),
|
||||
)
|
||||
rows = cur.fetchall()
|
||||
return [(row[0], row[1], row[2]) for row in rows]
|
||||
1
src/rag_agent/ingest/__init__.py
Normal file
1
src/rag_agent/ingest/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
__all__ = []
|
||||
BIN
src/rag_agent/ingest/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
src/rag_agent/ingest/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
src/rag_agent/ingest/__pycache__/chunker.cpython-312.pyc
Normal file
BIN
src/rag_agent/ingest/__pycache__/chunker.cpython-312.pyc
Normal file
Binary file not shown.
BIN
src/rag_agent/ingest/__pycache__/file_loader.cpython-312.pyc
Normal file
BIN
src/rag_agent/ingest/__pycache__/file_loader.cpython-312.pyc
Normal file
Binary file not shown.
BIN
src/rag_agent/ingest/__pycache__/git_watcher.cpython-312.pyc
Normal file
BIN
src/rag_agent/ingest/__pycache__/git_watcher.cpython-312.pyc
Normal file
Binary file not shown.
42
src/rag_agent/ingest/chunker.py
Normal file
42
src/rag_agent/ingest/chunker.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
from dataclasses import dataclass
|
||||
from typing import Iterable, Iterator
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TextChunk:
|
||||
index: int
|
||||
text: str
|
||||
hash: str
|
||||
|
||||
|
||||
def _hash_text(text: str) -> str:
|
||||
return hashlib.sha256(text.encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
def chunk_text(text: str, chunk_size: int, overlap: int) -> list[TextChunk]:
|
||||
tokens = text.split()
|
||||
if not tokens:
|
||||
return []
|
||||
|
||||
chunks: list[TextChunk] = []
|
||||
start = 0
|
||||
index = 0
|
||||
while start < len(tokens):
|
||||
end = min(start + chunk_size, len(tokens))
|
||||
chunk_text_value = " ".join(tokens[start:end])
|
||||
chunks.append(TextChunk(index=index, text=chunk_text_value, hash=_hash_text(chunk_text_value)))
|
||||
index += 1
|
||||
if end == len(tokens):
|
||||
break
|
||||
start = max(end - overlap, 0)
|
||||
return chunks
|
||||
|
||||
|
||||
def iter_chunks(
|
||||
texts: Iterable[str], chunk_size: int, overlap: int
|
||||
) -> Iterator[list[TextChunk]]:
|
||||
for text in texts:
|
||||
yield chunk_text(text, chunk_size, overlap)
|
||||
23
src/rag_agent/ingest/file_loader.py
Normal file
23
src/rag_agent/ingest/file_loader.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Iterable, Iterator
|
||||
|
||||
|
||||
def is_allowed(path: Path, allowed_extensions: Iterable[str]) -> bool:
|
||||
return path.suffix.lower() in {ext.lower() for ext in allowed_extensions}
|
||||
|
||||
|
||||
def read_text_file(path: Path) -> str:
|
||||
return path.read_text(encoding="utf-8", errors="ignore")
|
||||
|
||||
|
||||
def iter_text_files(
|
||||
paths: Iterable[Path], allowed_extensions: Iterable[str]
|
||||
) -> Iterator[tuple[Path, str]]:
|
||||
for path in paths:
|
||||
if not path.is_file():
|
||||
continue
|
||||
if not is_allowed(path, allowed_extensions):
|
||||
continue
|
||||
yield path, read_text_file(path)
|
||||
42
src/rag_agent/ingest/git_watcher.py
Normal file
42
src/rag_agent/ingest/git_watcher.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from typing import Iterable
|
||||
|
||||
|
||||
def get_changed_files(
|
||||
repo_path: str, base_ref: str, head_ref: str = "HEAD"
|
||||
) -> list[Path]:
|
||||
args = [
|
||||
"git",
|
||||
"-C",
|
||||
repo_path,
|
||||
"diff",
|
||||
"--name-only",
|
||||
base_ref,
|
||||
head_ref,
|
||||
]
|
||||
try:
|
||||
result = subprocess.run(
|
||||
args, check=True, capture_output=True, text=True
|
||||
)
|
||||
except subprocess.CalledProcessError as exc:
|
||||
raise RuntimeError(
|
||||
f"git diff failed: {exc.stderr.strip() or exc}"
|
||||
) from exc
|
||||
|
||||
files = []
|
||||
for line in result.stdout.splitlines():
|
||||
value = line.strip()
|
||||
if value:
|
||||
files.append(Path(repo_path) / value)
|
||||
return files
|
||||
|
||||
|
||||
def filter_existing(paths: Iterable[Path]) -> list[Path]:
|
||||
return [path for path in paths if path.exists()]
|
||||
|
||||
|
||||
def filter_removed(paths: Iterable[Path]) -> list[Path]:
|
||||
return [path for path in paths if not path.exists()]
|
||||
1
src/rag_agent/retrieval/__init__.py
Normal file
1
src/rag_agent/retrieval/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
__all__ = []
|
||||
24
src/rag_agent/retrieval/search.py
Normal file
24
src/rag_agent/retrieval/search.py
Normal file
@@ -0,0 +1,24 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import psycopg
|
||||
|
||||
from rag_agent.index.postgres import fetch_similar
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SearchResult:
|
||||
path: str
|
||||
content: str
|
||||
distance: float
|
||||
|
||||
|
||||
def search_similar(
|
||||
conn: psycopg.Connection,
|
||||
query_embedding: list[float],
|
||||
top_k: int = 5,
|
||||
story_id: int | None = None,
|
||||
) -> list[SearchResult]:
|
||||
rows = fetch_similar(conn, query_embedding, top_k, story_id=story_id)
|
||||
return [SearchResult(path=row[0], content=row[1], distance=row[2]) for row in rows]
|
||||
Reference in New Issue
Block a user