from __future__ import annotations from dataclasses import dataclass from sqlalchemy import text from app.modules.shared.db import get_engine @dataclass class RagJobRow: index_job_id: str rag_session_id: str status: str indexed_files: int failed_files: int error_code: str | None error_desc: str | None error_module: str | None class RagRepository: def ensure_tables(self) -> None: engine = get_engine() with engine.connect() as conn: conn.execute(text("CREATE EXTENSION IF NOT EXISTS vector")) conn.execute( text( """ CREATE TABLE IF NOT EXISTS rag_sessions ( rag_session_id VARCHAR(64) PRIMARY KEY, project_id VARCHAR(512) NOT NULL, created_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP ) """ ) ) conn.execute( text( """ CREATE TABLE IF NOT EXISTS rag_index_jobs ( index_job_id VARCHAR(64) PRIMARY KEY, rag_session_id VARCHAR(64) NOT NULL, status VARCHAR(16) NOT NULL, indexed_files INTEGER NOT NULL DEFAULT 0, failed_files INTEGER NOT NULL DEFAULT 0, error_code VARCHAR(128) NULL, error_desc TEXT NULL, error_module VARCHAR(64) NULL, created_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP ) """ ) ) conn.execute( text( """ CREATE TABLE IF NOT EXISTS rag_chunks ( id BIGSERIAL PRIMARY KEY, rag_session_id VARCHAR(64) NOT NULL, path TEXT NOT NULL, chunk_index INTEGER NOT NULL, content TEXT NOT NULL, embedding vector NULL, created_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP ) """ ) ) conn.execute( text( """ ALTER TABLE rag_chunks ADD COLUMN IF NOT EXISTS created_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP """ ) ) conn.execute( text( """ ALTER TABLE rag_chunks ADD COLUMN IF NOT EXISTS updated_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP """ ) ) conn.execute(text("CREATE INDEX IF NOT EXISTS idx_rag_chunks_session ON rag_chunks (rag_session_id)")) conn.commit() def upsert_session(self, rag_session_id: str, project_id: str) -> None: with get_engine().connect() as conn: conn.execute( text( """ INSERT INTO rag_sessions (rag_session_id, project_id) VALUES (:sid, :pid) ON CONFLICT (rag_session_id) DO UPDATE SET project_id = EXCLUDED.project_id """ ), {"sid": rag_session_id, "pid": project_id}, ) conn.commit() def session_exists(self, rag_session_id: str) -> bool: with get_engine().connect() as conn: row = conn.execute( text("SELECT 1 FROM rag_sessions WHERE rag_session_id = :sid"), {"sid": rag_session_id}, ).fetchone() return bool(row) def get_session(self, rag_session_id: str) -> dict | None: with get_engine().connect() as conn: row = conn.execute( text("SELECT rag_session_id, project_id FROM rag_sessions WHERE rag_session_id = :sid"), {"sid": rag_session_id}, ).mappings().fetchone() return dict(row) if row else None def create_job(self, index_job_id: str, rag_session_id: str, status: str) -> None: with get_engine().connect() as conn: conn.execute( text( """ INSERT INTO rag_index_jobs (index_job_id, rag_session_id, status) VALUES (:jid, :sid, :status) """ ), {"jid": index_job_id, "sid": rag_session_id, "status": status}, ) conn.commit() def update_job( self, index_job_id: str, *, status: str, indexed_files: int, failed_files: int, error_code: str | None = None, error_desc: str | None = None, error_module: str | None = None, ) -> None: with get_engine().connect() as conn: conn.execute( text( """ UPDATE rag_index_jobs SET status = :status, indexed_files = :indexed, failed_files = :failed, error_code = :ecode, error_desc = :edesc, error_module = :emodule, updated_at = CURRENT_TIMESTAMP WHERE index_job_id = :jid """ ), { "jid": index_job_id, "status": status, "indexed": indexed_files, "failed": failed_files, "ecode": error_code, "edesc": error_desc, "emodule": error_module, }, ) conn.commit() def get_job(self, index_job_id: str) -> RagJobRow | None: with get_engine().connect() as conn: row = conn.execute( text( """ SELECT index_job_id, rag_session_id, status, indexed_files, failed_files, error_code, error_desc, error_module FROM rag_index_jobs WHERE index_job_id = :jid """ ), {"jid": index_job_id}, ).mappings().fetchone() if not row: return None return RagJobRow(**dict(row)) def replace_chunks(self, rag_session_id: str, items: list[dict]) -> None: with get_engine().connect() as conn: conn.execute(text("DELETE FROM rag_chunks WHERE rag_session_id = :sid"), {"sid": rag_session_id}) self._insert_chunks(conn, rag_session_id, items) conn.commit() def apply_changes(self, rag_session_id: str, delete_paths: list[str], upserts: list[dict]) -> None: with get_engine().connect() as conn: if delete_paths: conn.execute( text("DELETE FROM rag_chunks WHERE rag_session_id = :sid AND path = ANY(:paths)"), {"sid": rag_session_id, "paths": delete_paths}, ) if upserts: paths = sorted({str(x["path"]) for x in upserts}) conn.execute( text("DELETE FROM rag_chunks WHERE rag_session_id = :sid AND path = ANY(:paths)"), {"sid": rag_session_id, "paths": paths}, ) self._insert_chunks(conn, rag_session_id, upserts) conn.commit() def retrieve(self, rag_session_id: str, query_embedding: list[float], limit: int = 5) -> list[dict]: emb = "[" + ",".join(str(x) for x in query_embedding) + "]" with get_engine().connect() as conn: rows = conn.execute( text( """ SELECT path, content FROM rag_chunks WHERE rag_session_id = :sid ORDER BY embedding <=> CAST(:emb AS vector) LIMIT :lim """ ), {"sid": rag_session_id, "emb": emb, "lim": limit}, ).mappings().fetchall() return [dict(x) for x in rows] def fallback_chunks(self, rag_session_id: str, limit: int = 5) -> list[dict]: with get_engine().connect() as conn: rows = conn.execute( text( """ SELECT path, content FROM rag_chunks WHERE rag_session_id = :sid ORDER BY id DESC LIMIT :lim """ ), {"sid": rag_session_id, "lim": limit}, ).mappings().fetchall() return [dict(x) for x in rows] def _insert_chunks(self, conn, rag_session_id: str, items: list[dict]) -> None: for item in items: emb = item.get("embedding") or [] emb_str = "[" + ",".join(str(x) for x in emb) + "]" if emb else None conn.execute( text( """ INSERT INTO rag_chunks (rag_session_id, path, chunk_index, content, embedding, created_at, updated_at) VALUES (:sid, :path, :idx, :content, CAST(:emb AS vector), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP) """ ), { "sid": rag_session_id, "path": item["path"], "idx": int(item["chunk_index"]), "content": item["content"], "emb": emb_str, }, )