123 lines
5.9 KiB
Python
123 lines
5.9 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
|
|
from sqlalchemy import text
|
|
|
|
from app.modules.rag.contracts import RagDocument
|
|
|
|
|
|
class RagDocumentRepository:
|
|
def replace_documents(self, conn, rag_session_id: str, docs: list[RagDocument]) -> None:
|
|
conn.execute(text("DELETE FROM rag_chunks WHERE rag_session_id = :sid"), {"sid": rag_session_id})
|
|
conn.execute(text("DELETE FROM rag_session_chunk_map WHERE rag_session_id = :sid"), {"sid": rag_session_id})
|
|
self.insert_documents(conn, rag_session_id, docs)
|
|
|
|
def apply_document_changes(
|
|
self,
|
|
conn,
|
|
rag_session_id: str,
|
|
delete_paths: list[str],
|
|
docs: list[RagDocument],
|
|
) -> None:
|
|
if delete_paths:
|
|
conn.execute(
|
|
text("DELETE FROM rag_chunks WHERE rag_session_id = :sid AND path = ANY(:paths)"),
|
|
{"sid": rag_session_id, "paths": delete_paths},
|
|
)
|
|
conn.execute(
|
|
text("DELETE FROM rag_session_chunk_map WHERE rag_session_id = :sid AND path = ANY(:paths)"),
|
|
{"sid": rag_session_id, "paths": delete_paths},
|
|
)
|
|
if not docs:
|
|
return
|
|
paths = sorted({doc.source.path for doc in docs})
|
|
conn.execute(
|
|
text("DELETE FROM rag_chunks WHERE rag_session_id = :sid AND path = ANY(:paths)"),
|
|
{"sid": rag_session_id, "paths": paths},
|
|
)
|
|
conn.execute(
|
|
text("DELETE FROM rag_session_chunk_map WHERE rag_session_id = :sid AND path = ANY(:paths)"),
|
|
{"sid": rag_session_id, "paths": paths},
|
|
)
|
|
self.insert_documents(conn, rag_session_id, docs)
|
|
|
|
def insert_documents(self, conn, rag_session_id: str, docs: list[RagDocument]) -> None:
|
|
for doc in docs:
|
|
row = doc.to_record()
|
|
metadata = row["metadata"]
|
|
links = row["links"]
|
|
emb = row["embedding"] or []
|
|
emb_str = "[" + ",".join(str(x) for x in emb) + "]" if emb else None
|
|
conn.execute(
|
|
text(
|
|
"""
|
|
INSERT INTO rag_chunks (
|
|
rag_session_id, path, chunk_index, content, embedding, artifact_type, section, doc_id,
|
|
doc_version, owner, system_component, last_modified, staleness_score, created_at, updated_at,
|
|
rag_doc_id, layer, lang, repo_id, commit_sha, title, metadata_json, links_json, span_start,
|
|
span_end, symbol_id, qname, kind, framework, entrypoint_type, module_id, section_path, doc_kind
|
|
)
|
|
VALUES (
|
|
:sid, :path, :chunk_index, :content, CAST(:emb AS vector), :artifact_type, :section, :doc_id,
|
|
:doc_version, :owner, :system_component, :last_modified, :staleness_score, CURRENT_TIMESTAMP,
|
|
CURRENT_TIMESTAMP, :rag_doc_id, :layer, :lang, :repo_id, :commit_sha, :title, :metadata_json,
|
|
:links_json, :span_start, :span_end, :symbol_id, :qname, :kind, :framework, :entrypoint_type,
|
|
:module_id, :section_path, :doc_kind
|
|
)
|
|
"""
|
|
),
|
|
{
|
|
"sid": rag_session_id,
|
|
"path": row["path"],
|
|
"chunk_index": int(metadata.get("chunk_index", 0)),
|
|
"content": row["text"],
|
|
"emb": emb_str,
|
|
"artifact_type": metadata.get("artifact_type"),
|
|
"section": metadata.get("section") or metadata.get("section_title"),
|
|
"doc_id": metadata.get("doc_id"),
|
|
"doc_version": metadata.get("doc_version"),
|
|
"owner": metadata.get("owner"),
|
|
"system_component": metadata.get("system_component"),
|
|
"last_modified": metadata.get("last_modified"),
|
|
"staleness_score": metadata.get("staleness_score"),
|
|
"rag_doc_id": row["doc_id"],
|
|
"layer": row["layer"],
|
|
"lang": row["lang"],
|
|
"repo_id": row["repo_id"],
|
|
"commit_sha": row["commit_sha"],
|
|
"title": row["title"],
|
|
"metadata_json": json.dumps(metadata, ensure_ascii=True),
|
|
"links_json": json.dumps(links, ensure_ascii=True),
|
|
"span_start": row["span_start"],
|
|
"span_end": row["span_end"],
|
|
"symbol_id": metadata.get("symbol_id"),
|
|
"qname": metadata.get("qname"),
|
|
"kind": metadata.get("kind") or metadata.get("type"),
|
|
"framework": metadata.get("framework"),
|
|
"entrypoint_type": metadata.get("entry_type") or metadata.get("entrypoint_type"),
|
|
"module_id": metadata.get("module_id") or metadata.get("policy_id"),
|
|
"section_path": metadata.get("section_path"),
|
|
"doc_kind": metadata.get("doc_kind"),
|
|
},
|
|
)
|
|
repo_id = str(row["repo_id"] or "").strip()
|
|
blob_sha = str(metadata.get("blob_sha") or "").strip()
|
|
if repo_id and blob_sha:
|
|
conn.execute(
|
|
text(
|
|
"""
|
|
INSERT INTO rag_session_chunk_map (
|
|
rag_session_id, repo_id, blob_sha, chunk_index, path
|
|
) VALUES (:sid, :repo_id, :blob_sha, :chunk_index, :path)
|
|
"""
|
|
),
|
|
{
|
|
"sid": rag_session_id,
|
|
"repo_id": repo_id,
|
|
"blob_sha": blob_sha,
|
|
"chunk_index": int(metadata.get("chunk_index", 0)),
|
|
"path": row["path"],
|
|
},
|
|
)
|