from __future__ import annotations from dataclasses import dataclass from datetime import datetime, timezone from typing import Iterable, Sequence import psycopg from pgvector.psycopg import register_vector from rag_agent.ingest.chunker import TextChunk CHANGE_ADDED = "added" CHANGE_MODIFIED = "modified" CHANGE_UNCHANGED = "unchanged" @dataclass(frozen=True) class ChunkRecord: document_path: str document_version: str chunk: TextChunk embedding: list[float] @dataclass(frozen=True) class ChangedChunkRecord: """Chunk that was added or modified in a story (for test-case generation).""" path: str content: str change_type: str start_line: int | None end_line: int | None previous_content: str | None def connect(dsn: str) -> psycopg.Connection: conn = psycopg.connect(dsn) register_vector(conn) return conn def ensure_schema(conn: psycopg.Connection, embeddings_dim: int) -> None: with conn.cursor() as cur: cur.execute("CREATE EXTENSION IF NOT EXISTS vector;") cur.execute( """ CREATE TABLE IF NOT EXISTS stories ( id SERIAL PRIMARY KEY, slug TEXT UNIQUE NOT NULL, created_at TIMESTAMPTZ NOT NULL DEFAULT (NOW() AT TIME ZONE 'utc'), indexed_base_ref TEXT, indexed_head_ref TEXT, indexed_at TIMESTAMPTZ ); """ ) for col_def in ( "ADD COLUMN IF NOT EXISTS indexed_base_ref TEXT", "ADD COLUMN IF NOT EXISTS indexed_head_ref TEXT", "ADD COLUMN IF NOT EXISTS indexed_at TIMESTAMPTZ", ): try: cur.execute(f"ALTER TABLE stories {col_def};") except psycopg.ProgrammingError: conn.rollback() pass cur.execute( """ CREATE TABLE IF NOT EXISTS documents ( id SERIAL PRIMARY KEY, story_id INTEGER NOT NULL REFERENCES stories(id) ON DELETE CASCADE, path TEXT NOT NULL, version TEXT NOT NULL, updated_at TIMESTAMPTZ NOT NULL, UNIQUE(story_id, path) ); """ ) cur.execute( f""" CREATE TABLE IF NOT EXISTS chunks ( id SERIAL PRIMARY KEY, document_id INTEGER NOT NULL REFERENCES documents(id) ON DELETE CASCADE, chunk_index INTEGER NOT NULL, hash TEXT NOT NULL, content TEXT NOT NULL, embedding vector({embeddings_dim}) NOT NULL, start_line INTEGER, end_line INTEGER, change_type TEXT NOT NULL DEFAULT 'added' CHECK (change_type IN ('added', 'modified', 'unchanged')), previous_content TEXT ); """ ) # Migrations: add columns if table already existed without them (Postgres 11+) for col_def in ( "ADD COLUMN IF NOT EXISTS start_line INTEGER", "ADD COLUMN IF NOT EXISTS end_line INTEGER", "ADD COLUMN IF NOT EXISTS previous_content TEXT", "ADD COLUMN IF NOT EXISTS change_type TEXT DEFAULT 'added'", ): try: cur.execute(f"ALTER TABLE chunks {col_def};") except psycopg.ProgrammingError: conn.rollback() pass try: cur.execute( "ALTER TABLE chunks ALTER COLUMN change_type SET NOT NULL;" ) except psycopg.ProgrammingError: conn.rollback() pass try: cur.execute( """ ALTER TABLE chunks ADD CONSTRAINT chunks_change_type_check CHECK (change_type IN ('added', 'modified', 'unchanged')); """ ) except psycopg.ProgrammingError: conn.rollback() pass # constraint may already exist cur.execute( """ CREATE INDEX IF NOT EXISTS idx_documents_story_id ON documents(story_id); """ ) cur.execute( """ CREATE INDEX IF NOT EXISTS idx_chunks_document_id ON chunks(document_id); """ ) cur.execute( """ CREATE INDEX IF NOT EXISTS idx_chunks_embedding ON chunks USING ivfflat (embedding vector_cosine_ops); """ ) cur.execute( """ CREATE INDEX IF NOT EXISTS idx_chunks_change_type ON chunks(change_type); """ ) conn.commit() def get_or_create_story(conn: psycopg.Connection, slug: str) -> int: with conn.cursor() as cur: cur.execute( """ INSERT INTO stories (slug) VALUES (%s) ON CONFLICT (slug) DO UPDATE SET slug = EXCLUDED.slug RETURNING id; """, (slug.strip(),), ) story_id = cur.fetchone()[0] conn.commit() return story_id def update_story_indexed_range( conn: psycopg.Connection, story_id: int, base_ref: str, head_ref: str, ) -> None: """Record that this story was indexed as all changes from base_ref to head_ref (all commits in story).""" with conn.cursor() as cur: cur.execute( """ UPDATE stories SET indexed_base_ref = %s, indexed_head_ref = %s, indexed_at = (NOW() AT TIME ZONE 'utc') WHERE id = %s; """, (base_ref.strip(), head_ref.strip(), story_id), ) conn.commit() def get_story_indexed_range( conn: psycopg.Connection, story_id: int ) -> tuple[str | None, str | None, datetime | None]: """Return (indexed_base_ref, indexed_head_ref, indexed_at) for the story, or (None, None, None).""" with conn.cursor() as cur: cur.execute( """ SELECT indexed_base_ref, indexed_head_ref, indexed_at FROM stories WHERE id = %s; """, (story_id,), ) row = cur.fetchone() if row is None: return (None, None, None) return (row[0], row[1], row[2]) def get_story_id(conn: psycopg.Connection, slug: str) -> int | None: s = slug.strip() with conn.cursor() as cur: if s.isdigit(): cur.execute("SELECT id FROM stories WHERE id = %s;", (int(s),)) else: cur.execute("SELECT id FROM stories WHERE slug = %s;", (s,)) row = cur.fetchone() return row[0] if row else None def upsert_document( conn: psycopg.Connection, story_id: int, path: str, version: str ) -> int: with conn.cursor() as cur: cur.execute( """ INSERT INTO documents (story_id, path, version, updated_at) VALUES (%s, %s, %s, %s) ON CONFLICT (story_id, path) DO UPDATE SET version = EXCLUDED.version, updated_at = EXCLUDED.updated_at RETURNING id; """, (story_id, path, version, datetime.now(timezone.utc)), ) document_id = cur.fetchone()[0] return document_id def _change_type_and_previous( chunk: TextChunk, base_by_range: dict[tuple[int, int], TextChunk], ) -> tuple[str, str | None]: """Determine change_type and previous_content for a chunk given base chunks keyed by (start_line, end_line).""" if chunk.start_line is None or chunk.end_line is None: return (CHANGE_ADDED, None) key = (chunk.start_line, chunk.end_line) base = base_by_range.get(key) if base is None: return (CHANGE_ADDED, None) if base.hash == chunk.hash: return (CHANGE_UNCHANGED, None) return (CHANGE_MODIFIED, base.text) def replace_chunks( conn: psycopg.Connection, document_id: int, chunks: Iterable[TextChunk], embeddings: Iterable[list[float]], base_chunks: Sequence[TextChunk] | None = None, ) -> None: base_by_range: dict[tuple[int, int], TextChunk] = {} if base_chunks: for c in base_chunks: if c.start_line is not None and c.end_line is not None: base_by_range[(c.start_line, c.end_line)] = c with conn.cursor() as cur: cur.execute( "DELETE FROM chunks WHERE document_id = %s;", (document_id,), ) for chunk, embedding in zip(chunks, embeddings): change_type, previous_content = _change_type_and_previous( chunk, base_by_range ) cur.execute( """ INSERT INTO chunks ( document_id, chunk_index, hash, content, embedding, start_line, end_line, change_type, previous_content ) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s); """, ( document_id, chunk.index, chunk.hash, chunk.text, embedding, chunk.start_line, chunk.end_line, change_type, previous_content, ), ) conn.commit() def fetch_changed_chunks( conn: psycopg.Connection, story_id: int ) -> list[ChangedChunkRecord]: """Return chunks that were added or modified in this story (for test-case generation).""" with conn.cursor() as cur: cur.execute( """ SELECT d.path, c.content, c.change_type, c.start_line, c.end_line, c.previous_content FROM chunks c JOIN documents d ON d.id = c.document_id WHERE d.story_id = %s AND c.change_type IN ('added', 'modified') ORDER BY d.path, c.start_line NULLS FIRST, c.chunk_index; """, (story_id,), ) rows = cur.fetchall() return [ ChangedChunkRecord( path=row[0], content=row[1], change_type=row[2], start_line=row[3], end_line=row[4], previous_content=row[5], ) for row in rows ] def delete_document( conn: psycopg.Connection, story_id: int, path: str ) -> None: with conn.cursor() as cur: cur.execute( "DELETE FROM documents WHERE story_id = %s AND path = %s;", (story_id, path), ) conn.commit() def fetch_similar( conn: psycopg.Connection, query_embedding: list[float], top_k: int, story_id: int | None = None, ) -> list[tuple[str, str, float]]: with conn.cursor() as cur: if story_id is not None: cur.execute( """ SELECT d.path, c.content, c.embedding <=> %s AS distance FROM chunks c JOIN documents d ON d.id = c.document_id WHERE d.story_id = %s ORDER BY c.embedding <=> %s LIMIT %s; """, (query_embedding, story_id, query_embedding, top_k), ) else: cur.execute( """ SELECT d.path, c.content, c.embedding <=> %s AS distance FROM chunks c JOIN documents d ON d.id = c.document_id ORDER BY c.embedding <=> %s LIMIT %s; """, (query_embedding, query_embedding, top_k), ) rows = cur.fetchall() return [(row[0], row[1], row[2]) for row in rows]