Скелет проекта
This commit is contained in:
1
src/rag_agent/index/__init__.py
Normal file
1
src/rag_agent/index/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
__all__ = []
|
||||
BIN
src/rag_agent/index/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
src/rag_agent/index/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
src/rag_agent/index/__pycache__/embeddings.cpython-312.pyc
Normal file
BIN
src/rag_agent/index/__pycache__/embeddings.cpython-312.pyc
Normal file
Binary file not shown.
BIN
src/rag_agent/index/__pycache__/postgres.cpython-312.pyc
Normal file
BIN
src/rag_agent/index/__pycache__/postgres.cpython-312.pyc
Normal file
Binary file not shown.
29
src/rag_agent/index/embeddings.py
Normal file
29
src/rag_agent/index/embeddings.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
from dataclasses import dataclass
|
||||
from typing import Iterable, Protocol
|
||||
|
||||
|
||||
class EmbeddingClient(Protocol):
|
||||
def embed_texts(self, texts: Iterable[str]) -> list[list[float]]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@dataclass
|
||||
class StubEmbeddingClient:
|
||||
dim: int
|
||||
|
||||
def embed_texts(self, texts: Iterable[str]) -> list[list[float]]:
|
||||
vectors: list[list[float]] = []
|
||||
for text in texts:
|
||||
digest = hashlib.sha256(text.encode("utf-8")).digest()
|
||||
values = [b / 255.0 for b in digest]
|
||||
if len(values) < self.dim:
|
||||
values = (values * ((self.dim // len(values)) + 1))[: self.dim]
|
||||
vectors.append(values[: self.dim])
|
||||
return vectors
|
||||
|
||||
|
||||
def get_embedding_client(dim: int) -> EmbeddingClient:
|
||||
return StubEmbeddingClient(dim=dim)
|
||||
194
src/rag_agent/index/postgres.py
Normal file
194
src/rag_agent/index/postgres.py
Normal file
@@ -0,0 +1,194 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
from typing import Iterable
|
||||
|
||||
import psycopg
|
||||
from pgvector.psycopg import register_vector
|
||||
|
||||
from rag_agent.ingest.chunker import TextChunk
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ChunkRecord:
|
||||
document_path: str
|
||||
document_version: str
|
||||
chunk: TextChunk
|
||||
embedding: list[float]
|
||||
|
||||
|
||||
def connect(dsn: str) -> psycopg.Connection:
|
||||
conn = psycopg.connect(dsn)
|
||||
register_vector(conn)
|
||||
return conn
|
||||
|
||||
|
||||
def ensure_schema(conn: psycopg.Connection, embeddings_dim: int) -> None:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute("CREATE EXTENSION IF NOT EXISTS vector;")
|
||||
cur.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS stories (
|
||||
id SERIAL PRIMARY KEY,
|
||||
slug TEXT UNIQUE NOT NULL,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT (NOW() AT TIME ZONE 'utc')
|
||||
);
|
||||
"""
|
||||
)
|
||||
cur.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS documents (
|
||||
id SERIAL PRIMARY KEY,
|
||||
story_id INTEGER NOT NULL REFERENCES stories(id) ON DELETE CASCADE,
|
||||
path TEXT NOT NULL,
|
||||
version TEXT NOT NULL,
|
||||
updated_at TIMESTAMPTZ NOT NULL,
|
||||
UNIQUE(story_id, path)
|
||||
);
|
||||
"""
|
||||
)
|
||||
cur.execute(
|
||||
f"""
|
||||
CREATE TABLE IF NOT EXISTS chunks (
|
||||
id SERIAL PRIMARY KEY,
|
||||
document_id INTEGER NOT NULL REFERENCES documents(id) ON DELETE CASCADE,
|
||||
chunk_index INTEGER NOT NULL,
|
||||
hash TEXT NOT NULL,
|
||||
content TEXT NOT NULL,
|
||||
embedding vector({embeddings_dim}) NOT NULL
|
||||
);
|
||||
"""
|
||||
)
|
||||
cur.execute(
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS idx_documents_story_id
|
||||
ON documents(story_id);
|
||||
"""
|
||||
)
|
||||
cur.execute(
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS idx_chunks_document_id
|
||||
ON chunks(document_id);
|
||||
"""
|
||||
)
|
||||
cur.execute(
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS idx_chunks_embedding
|
||||
ON chunks USING ivfflat (embedding vector_cosine_ops);
|
||||
"""
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
|
||||
def get_or_create_story(conn: psycopg.Connection, slug: str) -> int:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
INSERT INTO stories (slug)
|
||||
VALUES (%s)
|
||||
ON CONFLICT (slug) DO UPDATE SET slug = EXCLUDED.slug
|
||||
RETURNING id;
|
||||
""",
|
||||
(slug.strip(),),
|
||||
)
|
||||
story_id = cur.fetchone()[0]
|
||||
conn.commit()
|
||||
return story_id
|
||||
|
||||
|
||||
def get_story_id(conn: psycopg.Connection, slug: str) -> int | None:
|
||||
s = slug.strip()
|
||||
with conn.cursor() as cur:
|
||||
if s.isdigit():
|
||||
cur.execute("SELECT id FROM stories WHERE id = %s;", (int(s),))
|
||||
else:
|
||||
cur.execute("SELECT id FROM stories WHERE slug = %s;", (s,))
|
||||
row = cur.fetchone()
|
||||
return row[0] if row else None
|
||||
|
||||
|
||||
def upsert_document(
|
||||
conn: psycopg.Connection, story_id: int, path: str, version: str
|
||||
) -> int:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
INSERT INTO documents (story_id, path, version, updated_at)
|
||||
VALUES (%s, %s, %s, %s)
|
||||
ON CONFLICT (story_id, path) DO UPDATE
|
||||
SET version = EXCLUDED.version,
|
||||
updated_at = EXCLUDED.updated_at
|
||||
RETURNING id;
|
||||
""",
|
||||
(story_id, path, version, datetime.now(timezone.utc)),
|
||||
)
|
||||
document_id = cur.fetchone()[0]
|
||||
return document_id
|
||||
|
||||
|
||||
def replace_chunks(
|
||||
conn: psycopg.Connection,
|
||||
document_id: int,
|
||||
chunks: Iterable[TextChunk],
|
||||
embeddings: Iterable[list[float]],
|
||||
) -> None:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"DELETE FROM chunks WHERE document_id = %s;",
|
||||
(document_id,),
|
||||
)
|
||||
for chunk, embedding in zip(chunks, embeddings):
|
||||
cur.execute(
|
||||
"""
|
||||
INSERT INTO chunks (document_id, chunk_index, hash, content, embedding)
|
||||
VALUES (%s, %s, %s, %s, %s);
|
||||
""",
|
||||
(document_id, chunk.index, chunk.hash, chunk.text, embedding),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
|
||||
def delete_document(
|
||||
conn: psycopg.Connection, story_id: int, path: str
|
||||
) -> None:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"DELETE FROM documents WHERE story_id = %s AND path = %s;",
|
||||
(story_id, path),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
|
||||
def fetch_similar(
|
||||
conn: psycopg.Connection,
|
||||
query_embedding: list[float],
|
||||
top_k: int,
|
||||
story_id: int | None = None,
|
||||
) -> list[tuple[str, str, float]]:
|
||||
with conn.cursor() as cur:
|
||||
if story_id is not None:
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT d.path, c.content, c.embedding <=> %s AS distance
|
||||
FROM chunks c
|
||||
JOIN documents d ON d.id = c.document_id
|
||||
WHERE d.story_id = %s
|
||||
ORDER BY c.embedding <=> %s
|
||||
LIMIT %s;
|
||||
""",
|
||||
(query_embedding, story_id, query_embedding, top_k),
|
||||
)
|
||||
else:
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT d.path, c.content, c.embedding <=> %s AS distance
|
||||
FROM chunks c
|
||||
JOIN documents d ON d.id = c.document_id
|
||||
ORDER BY c.embedding <=> %s
|
||||
LIMIT %s;
|
||||
""",
|
||||
(query_embedding, query_embedding, top_k),
|
||||
)
|
||||
rows = cur.fetchall()
|
||||
return [(row[0], row[1], row[2]) for row in rows]
|
||||
Reference in New Issue
Block a user