372 lines
12 KiB
Python
372 lines
12 KiB
Python
from __future__ import annotations
|
|
|
|
from dataclasses import dataclass
|
|
from datetime import datetime, timezone
|
|
from typing import Iterable, Sequence
|
|
|
|
import psycopg
|
|
from pgvector.psycopg import register_vector
|
|
|
|
from rag_agent.ingest.chunker import TextChunk
|
|
|
|
CHANGE_ADDED = "added"
|
|
CHANGE_MODIFIED = "modified"
|
|
CHANGE_UNCHANGED = "unchanged"
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class ChunkRecord:
|
|
document_path: str
|
|
document_version: str
|
|
chunk: TextChunk
|
|
embedding: list[float]
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class ChangedChunkRecord:
|
|
"""Chunk that was added or modified in a story (for test-case generation)."""
|
|
|
|
path: str
|
|
content: str
|
|
change_type: str
|
|
start_line: int | None
|
|
end_line: int | None
|
|
previous_content: str | None
|
|
|
|
|
|
def connect(dsn: str) -> psycopg.Connection:
|
|
conn = psycopg.connect(dsn)
|
|
register_vector(conn)
|
|
return conn
|
|
|
|
|
|
def ensure_schema(conn: psycopg.Connection, embeddings_dim: int) -> None:
|
|
with conn.cursor() as cur:
|
|
cur.execute("CREATE EXTENSION IF NOT EXISTS vector;")
|
|
cur.execute(
|
|
"""
|
|
CREATE TABLE IF NOT EXISTS stories (
|
|
id SERIAL PRIMARY KEY,
|
|
slug TEXT UNIQUE NOT NULL,
|
|
created_at TIMESTAMPTZ NOT NULL DEFAULT (NOW() AT TIME ZONE 'utc'),
|
|
indexed_base_ref TEXT,
|
|
indexed_head_ref TEXT,
|
|
indexed_at TIMESTAMPTZ
|
|
);
|
|
"""
|
|
)
|
|
for col_def in (
|
|
"ADD COLUMN IF NOT EXISTS indexed_base_ref TEXT",
|
|
"ADD COLUMN IF NOT EXISTS indexed_head_ref TEXT",
|
|
"ADD COLUMN IF NOT EXISTS indexed_at TIMESTAMPTZ",
|
|
):
|
|
try:
|
|
cur.execute(f"ALTER TABLE stories {col_def};")
|
|
except psycopg.ProgrammingError:
|
|
conn.rollback()
|
|
pass
|
|
cur.execute(
|
|
"""
|
|
CREATE TABLE IF NOT EXISTS documents (
|
|
id SERIAL PRIMARY KEY,
|
|
story_id INTEGER NOT NULL REFERENCES stories(id) ON DELETE CASCADE,
|
|
path TEXT NOT NULL,
|
|
version TEXT NOT NULL,
|
|
updated_at TIMESTAMPTZ NOT NULL,
|
|
UNIQUE(story_id, path)
|
|
);
|
|
"""
|
|
)
|
|
cur.execute(
|
|
f"""
|
|
CREATE TABLE IF NOT EXISTS chunks (
|
|
id SERIAL PRIMARY KEY,
|
|
document_id INTEGER NOT NULL REFERENCES documents(id) ON DELETE CASCADE,
|
|
chunk_index INTEGER NOT NULL,
|
|
hash TEXT NOT NULL,
|
|
content TEXT NOT NULL,
|
|
embedding vector({embeddings_dim}) NOT NULL,
|
|
start_line INTEGER,
|
|
end_line INTEGER,
|
|
change_type TEXT NOT NULL DEFAULT 'added'
|
|
CHECK (change_type IN ('added', 'modified', 'unchanged')),
|
|
previous_content TEXT
|
|
);
|
|
"""
|
|
)
|
|
# Migrations: add columns if table already existed without them (Postgres 11+)
|
|
for col_def in (
|
|
"ADD COLUMN IF NOT EXISTS start_line INTEGER",
|
|
"ADD COLUMN IF NOT EXISTS end_line INTEGER",
|
|
"ADD COLUMN IF NOT EXISTS previous_content TEXT",
|
|
"ADD COLUMN IF NOT EXISTS change_type TEXT DEFAULT 'added'",
|
|
):
|
|
try:
|
|
cur.execute(f"ALTER TABLE chunks {col_def};")
|
|
except psycopg.ProgrammingError:
|
|
conn.rollback()
|
|
pass
|
|
try:
|
|
cur.execute(
|
|
"ALTER TABLE chunks ALTER COLUMN change_type SET NOT NULL;"
|
|
)
|
|
except psycopg.ProgrammingError:
|
|
conn.rollback()
|
|
pass
|
|
try:
|
|
cur.execute(
|
|
"""
|
|
ALTER TABLE chunks ADD CONSTRAINT chunks_change_type_check
|
|
CHECK (change_type IN ('added', 'modified', 'unchanged'));
|
|
"""
|
|
)
|
|
except psycopg.ProgrammingError:
|
|
conn.rollback()
|
|
pass # constraint may already exist
|
|
cur.execute(
|
|
"""
|
|
CREATE INDEX IF NOT EXISTS idx_documents_story_id
|
|
ON documents(story_id);
|
|
"""
|
|
)
|
|
cur.execute(
|
|
"""
|
|
CREATE INDEX IF NOT EXISTS idx_chunks_document_id
|
|
ON chunks(document_id);
|
|
"""
|
|
)
|
|
cur.execute(
|
|
"""
|
|
CREATE INDEX IF NOT EXISTS idx_chunks_embedding
|
|
ON chunks USING ivfflat (embedding vector_cosine_ops);
|
|
"""
|
|
)
|
|
cur.execute(
|
|
"""
|
|
CREATE INDEX IF NOT EXISTS idx_chunks_change_type
|
|
ON chunks(change_type);
|
|
"""
|
|
)
|
|
conn.commit()
|
|
|
|
|
|
def get_or_create_story(conn: psycopg.Connection, slug: str) -> int:
|
|
with conn.cursor() as cur:
|
|
cur.execute(
|
|
"""
|
|
INSERT INTO stories (slug)
|
|
VALUES (%s)
|
|
ON CONFLICT (slug) DO UPDATE SET slug = EXCLUDED.slug
|
|
RETURNING id;
|
|
""",
|
|
(slug.strip(),),
|
|
)
|
|
story_id = cur.fetchone()[0]
|
|
conn.commit()
|
|
return story_id
|
|
|
|
|
|
def update_story_indexed_range(
|
|
conn: psycopg.Connection,
|
|
story_id: int,
|
|
base_ref: str,
|
|
head_ref: str,
|
|
) -> None:
|
|
"""Record that this story was indexed as all changes from base_ref to head_ref (all commits in story)."""
|
|
with conn.cursor() as cur:
|
|
cur.execute(
|
|
"""
|
|
UPDATE stories
|
|
SET indexed_base_ref = %s, indexed_head_ref = %s,
|
|
indexed_at = (NOW() AT TIME ZONE 'utc')
|
|
WHERE id = %s;
|
|
""",
|
|
(base_ref.strip(), head_ref.strip(), story_id),
|
|
)
|
|
conn.commit()
|
|
|
|
|
|
def get_story_indexed_range(
|
|
conn: psycopg.Connection, story_id: int
|
|
) -> tuple[str | None, str | None, datetime | None]:
|
|
"""Return (indexed_base_ref, indexed_head_ref, indexed_at) for the story, or (None, None, None)."""
|
|
with conn.cursor() as cur:
|
|
cur.execute(
|
|
"""
|
|
SELECT indexed_base_ref, indexed_head_ref, indexed_at
|
|
FROM stories WHERE id = %s;
|
|
""",
|
|
(story_id,),
|
|
)
|
|
row = cur.fetchone()
|
|
if row is None:
|
|
return (None, None, None)
|
|
return (row[0], row[1], row[2])
|
|
|
|
|
|
def get_story_id(conn: psycopg.Connection, slug: str) -> int | None:
|
|
s = slug.strip()
|
|
with conn.cursor() as cur:
|
|
if s.isdigit():
|
|
cur.execute("SELECT id FROM stories WHERE id = %s;", (int(s),))
|
|
else:
|
|
cur.execute("SELECT id FROM stories WHERE slug = %s;", (s,))
|
|
row = cur.fetchone()
|
|
return row[0] if row else None
|
|
|
|
|
|
def upsert_document(
|
|
conn: psycopg.Connection, story_id: int, path: str, version: str
|
|
) -> int:
|
|
with conn.cursor() as cur:
|
|
cur.execute(
|
|
"""
|
|
INSERT INTO documents (story_id, path, version, updated_at)
|
|
VALUES (%s, %s, %s, %s)
|
|
ON CONFLICT (story_id, path) DO UPDATE
|
|
SET version = EXCLUDED.version,
|
|
updated_at = EXCLUDED.updated_at
|
|
RETURNING id;
|
|
""",
|
|
(story_id, path, version, datetime.now(timezone.utc)),
|
|
)
|
|
document_id = cur.fetchone()[0]
|
|
return document_id
|
|
|
|
|
|
def _change_type_and_previous(
|
|
chunk: TextChunk,
|
|
base_by_range: dict[tuple[int, int], TextChunk],
|
|
) -> tuple[str, str | None]:
|
|
"""Determine change_type and previous_content for a chunk given base chunks keyed by (start_line, end_line)."""
|
|
if chunk.start_line is None or chunk.end_line is None:
|
|
return (CHANGE_ADDED, None)
|
|
key = (chunk.start_line, chunk.end_line)
|
|
base = base_by_range.get(key)
|
|
if base is None:
|
|
return (CHANGE_ADDED, None)
|
|
if base.hash == chunk.hash:
|
|
return (CHANGE_UNCHANGED, None)
|
|
return (CHANGE_MODIFIED, base.text)
|
|
|
|
|
|
def replace_chunks(
|
|
conn: psycopg.Connection,
|
|
document_id: int,
|
|
chunks: Iterable[TextChunk],
|
|
embeddings: Iterable[list[float]],
|
|
base_chunks: Sequence[TextChunk] | None = None,
|
|
) -> None:
|
|
base_by_range: dict[tuple[int, int], TextChunk] = {}
|
|
if base_chunks:
|
|
for c in base_chunks:
|
|
if c.start_line is not None and c.end_line is not None:
|
|
base_by_range[(c.start_line, c.end_line)] = c
|
|
|
|
with conn.cursor() as cur:
|
|
cur.execute(
|
|
"DELETE FROM chunks WHERE document_id = %s;",
|
|
(document_id,),
|
|
)
|
|
for chunk, embedding in zip(chunks, embeddings):
|
|
change_type, previous_content = _change_type_and_previous(
|
|
chunk, base_by_range
|
|
)
|
|
cur.execute(
|
|
"""
|
|
INSERT INTO chunks (
|
|
document_id, chunk_index, hash, content, embedding,
|
|
start_line, end_line, change_type, previous_content
|
|
)
|
|
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s);
|
|
""",
|
|
(
|
|
document_id,
|
|
chunk.index,
|
|
chunk.hash,
|
|
chunk.text,
|
|
embedding,
|
|
chunk.start_line,
|
|
chunk.end_line,
|
|
change_type,
|
|
previous_content,
|
|
),
|
|
)
|
|
conn.commit()
|
|
|
|
|
|
def fetch_changed_chunks(
|
|
conn: psycopg.Connection, story_id: int
|
|
) -> list[ChangedChunkRecord]:
|
|
"""Return chunks that were added or modified in this story (for test-case generation)."""
|
|
with conn.cursor() as cur:
|
|
cur.execute(
|
|
"""
|
|
SELECT d.path, c.content, c.change_type, c.start_line, c.end_line,
|
|
c.previous_content
|
|
FROM chunks c
|
|
JOIN documents d ON d.id = c.document_id
|
|
WHERE d.story_id = %s
|
|
AND c.change_type IN ('added', 'modified')
|
|
ORDER BY d.path, c.start_line NULLS FIRST, c.chunk_index;
|
|
""",
|
|
(story_id,),
|
|
)
|
|
rows = cur.fetchall()
|
|
return [
|
|
ChangedChunkRecord(
|
|
path=row[0],
|
|
content=row[1],
|
|
change_type=row[2],
|
|
start_line=row[3],
|
|
end_line=row[4],
|
|
previous_content=row[5],
|
|
)
|
|
for row in rows
|
|
]
|
|
|
|
|
|
def delete_document(
|
|
conn: psycopg.Connection, story_id: int, path: str
|
|
) -> None:
|
|
with conn.cursor() as cur:
|
|
cur.execute(
|
|
"DELETE FROM documents WHERE story_id = %s AND path = %s;",
|
|
(story_id, path),
|
|
)
|
|
conn.commit()
|
|
|
|
|
|
def fetch_similar(
|
|
conn: psycopg.Connection,
|
|
query_embedding: list[float],
|
|
top_k: int,
|
|
story_id: int | None = None,
|
|
) -> list[tuple[str, str, float]]:
|
|
with conn.cursor() as cur:
|
|
if story_id is not None:
|
|
cur.execute(
|
|
"""
|
|
SELECT d.path, c.content, c.embedding <=> %s AS distance
|
|
FROM chunks c
|
|
JOIN documents d ON d.id = c.document_id
|
|
WHERE d.story_id = %s
|
|
ORDER BY c.embedding <=> %s
|
|
LIMIT %s;
|
|
""",
|
|
(query_embedding, story_id, query_embedding, top_k),
|
|
)
|
|
else:
|
|
cur.execute(
|
|
"""
|
|
SELECT d.path, c.content, c.embedding <=> %s AS distance
|
|
FROM chunks c
|
|
JOIN documents d ON d.id = c.document_id
|
|
ORDER BY c.embedding <=> %s
|
|
LIMIT %s;
|
|
""",
|
|
(query_embedding, query_embedding, top_k),
|
|
)
|
|
rows = cur.fetchall()
|
|
return [(row[0], row[1], row[2]) for row in rows]
|