Files
agent/app/modules/rag/persistence/document_repository.py
2026-03-01 14:21:33 +03:00

123 lines
5.9 KiB
Python

from __future__ import annotations
import json
from sqlalchemy import text
from app.modules.rag.contracts import RagDocument
class RagDocumentRepository:
def replace_documents(self, conn, rag_session_id: str, docs: list[RagDocument]) -> None:
conn.execute(text("DELETE FROM rag_chunks WHERE rag_session_id = :sid"), {"sid": rag_session_id})
conn.execute(text("DELETE FROM rag_session_chunk_map WHERE rag_session_id = :sid"), {"sid": rag_session_id})
self.insert_documents(conn, rag_session_id, docs)
def apply_document_changes(
self,
conn,
rag_session_id: str,
delete_paths: list[str],
docs: list[RagDocument],
) -> None:
if delete_paths:
conn.execute(
text("DELETE FROM rag_chunks WHERE rag_session_id = :sid AND path = ANY(:paths)"),
{"sid": rag_session_id, "paths": delete_paths},
)
conn.execute(
text("DELETE FROM rag_session_chunk_map WHERE rag_session_id = :sid AND path = ANY(:paths)"),
{"sid": rag_session_id, "paths": delete_paths},
)
if not docs:
return
paths = sorted({doc.source.path for doc in docs})
conn.execute(
text("DELETE FROM rag_chunks WHERE rag_session_id = :sid AND path = ANY(:paths)"),
{"sid": rag_session_id, "paths": paths},
)
conn.execute(
text("DELETE FROM rag_session_chunk_map WHERE rag_session_id = :sid AND path = ANY(:paths)"),
{"sid": rag_session_id, "paths": paths},
)
self.insert_documents(conn, rag_session_id, docs)
def insert_documents(self, conn, rag_session_id: str, docs: list[RagDocument]) -> None:
for doc in docs:
row = doc.to_record()
metadata = row["metadata"]
links = row["links"]
emb = row["embedding"] or []
emb_str = "[" + ",".join(str(x) for x in emb) + "]" if emb else None
conn.execute(
text(
"""
INSERT INTO rag_chunks (
rag_session_id, path, chunk_index, content, embedding, artifact_type, section, doc_id,
doc_version, owner, system_component, last_modified, staleness_score, created_at, updated_at,
rag_doc_id, layer, lang, repo_id, commit_sha, title, metadata_json, links_json, span_start,
span_end, symbol_id, qname, kind, framework, entrypoint_type, module_id, section_path, doc_kind
)
VALUES (
:sid, :path, :chunk_index, :content, CAST(:emb AS vector), :artifact_type, :section, :doc_id,
:doc_version, :owner, :system_component, :last_modified, :staleness_score, CURRENT_TIMESTAMP,
CURRENT_TIMESTAMP, :rag_doc_id, :layer, :lang, :repo_id, :commit_sha, :title, :metadata_json,
:links_json, :span_start, :span_end, :symbol_id, :qname, :kind, :framework, :entrypoint_type,
:module_id, :section_path, :doc_kind
)
"""
),
{
"sid": rag_session_id,
"path": row["path"],
"chunk_index": int(metadata.get("chunk_index", 0)),
"content": row["text"],
"emb": emb_str,
"artifact_type": metadata.get("artifact_type"),
"section": metadata.get("section") or metadata.get("section_title"),
"doc_id": metadata.get("doc_id"),
"doc_version": metadata.get("doc_version"),
"owner": metadata.get("owner"),
"system_component": metadata.get("system_component"),
"last_modified": metadata.get("last_modified"),
"staleness_score": metadata.get("staleness_score"),
"rag_doc_id": row["doc_id"],
"layer": row["layer"],
"lang": row["lang"],
"repo_id": row["repo_id"],
"commit_sha": row["commit_sha"],
"title": row["title"],
"metadata_json": json.dumps(metadata, ensure_ascii=True),
"links_json": json.dumps(links, ensure_ascii=True),
"span_start": row["span_start"],
"span_end": row["span_end"],
"symbol_id": metadata.get("symbol_id"),
"qname": metadata.get("qname"),
"kind": metadata.get("kind") or metadata.get("type"),
"framework": metadata.get("framework"),
"entrypoint_type": metadata.get("entry_type") or metadata.get("entrypoint_type"),
"module_id": metadata.get("module_id") or metadata.get("policy_id"),
"section_path": metadata.get("section_path"),
"doc_kind": metadata.get("doc_kind"),
},
)
repo_id = str(row["repo_id"] or "").strip()
blob_sha = str(metadata.get("blob_sha") or "").strip()
if repo_id and blob_sha:
conn.execute(
text(
"""
INSERT INTO rag_session_chunk_map (
rag_session_id, repo_id, blob_sha, chunk_index, path
) VALUES (:sid, :repo_id, :blob_sha, :chunk_index, :path)
"""
),
{
"sid": rag_session_id,
"repo_id": repo_id,
"blob_sha": blob_sha,
"chunk_index": int(metadata.get("chunk_index", 0)),
"path": row["path"],
},
)