Files
agent/app/modules/rag/persistence/cache_repository.py
2026-03-01 14:21:33 +03:00

190 lines
8.2 KiB
Python

from __future__ import annotations
import json
from sqlalchemy import text
from app.modules.rag.contracts import EvidenceLink, RagDocument, RagSource, RagSpan
from app.modules.rag.contracts.enums import RagLayer
class RagCacheRepository:
def get_cached_documents(self, repo_id: str, blob_sha: str) -> list[RagDocument]:
with self._engine().connect() as conn:
rows = conn.execute(
text(
"""
SELECT layer, lang, path, title, content, metadata_json, links_json, span_start, span_end,
repo_id, commit_sha, embedding::text AS embedding_txt
FROM rag_chunk_cache
WHERE repo_id = :repo_id AND blob_sha = :blob_sha
ORDER BY chunk_index ASC
"""
),
{"repo_id": repo_id, "blob_sha": blob_sha},
).mappings().fetchall()
docs: list[RagDocument] = []
for row in rows:
metadata = self._loads(row.get("metadata_json"))
docs.append(
RagDocument(
layer=str(row["layer"]),
lang=row.get("lang"),
source=RagSource(
repo_id=str(row["repo_id"]),
commit_sha=row.get("commit_sha"),
path=str(row["path"]),
),
title=str(row["title"] or row["path"]),
text=str(row["content"] or ""),
metadata=metadata,
links=[EvidenceLink(**item) for item in self._loads(row.get("links_json"), default=[])],
span=RagSpan(row.get("span_start"), row.get("span_end")),
embedding=self._parse_vector(str(row["embedding_txt"] or "")),
)
)
return docs
def cache_documents(self, repo_id: str, path: str, blob_sha: str, docs: list[RagDocument]) -> None:
if not docs:
return
with self._engine().connect() as conn:
first = docs[0].to_record()
first_meta = first["metadata"]
conn.execute(
text(
"""
INSERT INTO rag_blob_cache (
repo_id, blob_sha, path, artifact_type, section, doc_id, doc_version, owner,
system_component, last_modified, staleness_score, layer, lang, metadata_json
)
VALUES (
:repo_id, :blob_sha, :path, :artifact_type, :section, :doc_id, :doc_version, :owner,
:system_component, :last_modified, :staleness_score, :layer, :lang, :metadata_json
)
ON CONFLICT (repo_id, blob_sha, path) DO UPDATE SET
artifact_type = EXCLUDED.artifact_type,
section = EXCLUDED.section,
doc_id = EXCLUDED.doc_id,
doc_version = EXCLUDED.doc_version,
owner = EXCLUDED.owner,
system_component = EXCLUDED.system_component,
last_modified = EXCLUDED.last_modified,
staleness_score = EXCLUDED.staleness_score,
layer = EXCLUDED.layer,
lang = EXCLUDED.lang,
metadata_json = EXCLUDED.metadata_json,
updated_at = CURRENT_TIMESTAMP
"""
),
{
"repo_id": repo_id,
"blob_sha": blob_sha,
"path": path,
"artifact_type": first_meta.get("artifact_type"),
"section": first_meta.get("section") or first_meta.get("section_title"),
"doc_id": first_meta.get("doc_id"),
"doc_version": first_meta.get("doc_version"),
"owner": first_meta.get("owner"),
"system_component": first_meta.get("system_component"),
"last_modified": first_meta.get("last_modified"),
"staleness_score": first_meta.get("staleness_score"),
"layer": first["layer"],
"lang": first["lang"],
"metadata_json": json.dumps(first_meta, ensure_ascii=True),
},
)
conn.execute(
text("DELETE FROM rag_chunk_cache WHERE repo_id = :repo_id AND blob_sha = :blob_sha"),
{"repo_id": repo_id, "blob_sha": blob_sha},
)
for idx, doc in enumerate(docs):
row = doc.to_record()
metadata = row["metadata"]
emb = row["embedding"] or []
emb_str = "[" + ",".join(str(x) for x in emb) + "]" if emb else None
conn.execute(
text(
"""
INSERT INTO rag_chunk_cache (
repo_id, blob_sha, chunk_index, content, embedding, section, layer, lang, path, title,
metadata_json, links_json, span_start, span_end, commit_sha
)
VALUES (
:repo_id, :blob_sha, :chunk_index, :content, CAST(:embedding AS vector), :section, :layer,
:lang, :path, :title, :metadata_json, :links_json, :span_start, :span_end, :commit_sha
)
"""
),
{
"repo_id": repo_id,
"blob_sha": blob_sha,
"chunk_index": idx,
"content": row["text"],
"embedding": emb_str,
"section": metadata.get("section") or metadata.get("section_title"),
"layer": row["layer"],
"lang": row["lang"],
"path": row["path"],
"title": row["title"],
"metadata_json": json.dumps(metadata, ensure_ascii=True),
"links_json": json.dumps(row["links"], ensure_ascii=True),
"span_start": row["span_start"],
"span_end": row["span_end"],
"commit_sha": row["commit_sha"],
},
)
conn.commit()
def record_repo_cache(
self,
*,
project_id: str,
commit_sha: str | None,
changed_files: list[str],
summary: str,
) -> None:
docs: list[RagDocument] = []
for idx, path in enumerate(changed_files):
docs.append(
RagDocument(
layer=RagLayer.CODE_SOURCE_CHUNKS,
lang="python" if path.endswith(".py") else None,
source=RagSource(project_id, commit_sha, path),
title=path,
text=f"repo_webhook:{path}:{summary[:300]}",
metadata={"chunk_index": idx, "artifact_type": "CODE", "section": "repo_webhook"},
)
)
for doc in docs:
blob_sha = self._blob_sha(commit_sha, doc.source.path)
doc.metadata["blob_sha"] = blob_sha
self.cache_documents(project_id, doc.source.path, blob_sha, [doc])
def _blob_sha(self, commit_sha: str | None, path: str) -> str:
from hashlib import sha256
return sha256(f"{commit_sha or 'no-commit'}:{path}".encode("utf-8")).hexdigest()
def _engine(self):
from app.modules.shared.db import get_engine
return get_engine()
def _loads(self, value, default=None):
if default is None:
default = {}
if not value:
return default
return json.loads(str(value))
def _parse_vector(self, value: str) -> list[float]:
text_value = value.strip()
if not text_value:
return []
if text_value.startswith("[") and text_value.endswith("]"):
text_value = text_value[1:-1]
if not text_value:
return []
return [float(part.strip()) for part in text_value.split(",") if part.strip()]