190 lines
8.2 KiB
Python
190 lines
8.2 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
|
|
from sqlalchemy import text
|
|
|
|
from app.modules.rag.contracts import EvidenceLink, RagDocument, RagSource, RagSpan
|
|
from app.modules.rag.contracts.enums import RagLayer
|
|
|
|
|
|
class RagCacheRepository:
|
|
def get_cached_documents(self, repo_id: str, blob_sha: str) -> list[RagDocument]:
|
|
with self._engine().connect() as conn:
|
|
rows = conn.execute(
|
|
text(
|
|
"""
|
|
SELECT layer, lang, path, title, content, metadata_json, links_json, span_start, span_end,
|
|
repo_id, commit_sha, embedding::text AS embedding_txt
|
|
FROM rag_chunk_cache
|
|
WHERE repo_id = :repo_id AND blob_sha = :blob_sha
|
|
ORDER BY chunk_index ASC
|
|
"""
|
|
),
|
|
{"repo_id": repo_id, "blob_sha": blob_sha},
|
|
).mappings().fetchall()
|
|
docs: list[RagDocument] = []
|
|
for row in rows:
|
|
metadata = self._loads(row.get("metadata_json"))
|
|
docs.append(
|
|
RagDocument(
|
|
layer=str(row["layer"]),
|
|
lang=row.get("lang"),
|
|
source=RagSource(
|
|
repo_id=str(row["repo_id"]),
|
|
commit_sha=row.get("commit_sha"),
|
|
path=str(row["path"]),
|
|
),
|
|
title=str(row["title"] or row["path"]),
|
|
text=str(row["content"] or ""),
|
|
metadata=metadata,
|
|
links=[EvidenceLink(**item) for item in self._loads(row.get("links_json"), default=[])],
|
|
span=RagSpan(row.get("span_start"), row.get("span_end")),
|
|
embedding=self._parse_vector(str(row["embedding_txt"] or "")),
|
|
)
|
|
)
|
|
return docs
|
|
|
|
def cache_documents(self, repo_id: str, path: str, blob_sha: str, docs: list[RagDocument]) -> None:
|
|
if not docs:
|
|
return
|
|
with self._engine().connect() as conn:
|
|
first = docs[0].to_record()
|
|
first_meta = first["metadata"]
|
|
conn.execute(
|
|
text(
|
|
"""
|
|
INSERT INTO rag_blob_cache (
|
|
repo_id, blob_sha, path, artifact_type, section, doc_id, doc_version, owner,
|
|
system_component, last_modified, staleness_score, layer, lang, metadata_json
|
|
)
|
|
VALUES (
|
|
:repo_id, :blob_sha, :path, :artifact_type, :section, :doc_id, :doc_version, :owner,
|
|
:system_component, :last_modified, :staleness_score, :layer, :lang, :metadata_json
|
|
)
|
|
ON CONFLICT (repo_id, blob_sha, path) DO UPDATE SET
|
|
artifact_type = EXCLUDED.artifact_type,
|
|
section = EXCLUDED.section,
|
|
doc_id = EXCLUDED.doc_id,
|
|
doc_version = EXCLUDED.doc_version,
|
|
owner = EXCLUDED.owner,
|
|
system_component = EXCLUDED.system_component,
|
|
last_modified = EXCLUDED.last_modified,
|
|
staleness_score = EXCLUDED.staleness_score,
|
|
layer = EXCLUDED.layer,
|
|
lang = EXCLUDED.lang,
|
|
metadata_json = EXCLUDED.metadata_json,
|
|
updated_at = CURRENT_TIMESTAMP
|
|
"""
|
|
),
|
|
{
|
|
"repo_id": repo_id,
|
|
"blob_sha": blob_sha,
|
|
"path": path,
|
|
"artifact_type": first_meta.get("artifact_type"),
|
|
"section": first_meta.get("section") or first_meta.get("section_title"),
|
|
"doc_id": first_meta.get("doc_id"),
|
|
"doc_version": first_meta.get("doc_version"),
|
|
"owner": first_meta.get("owner"),
|
|
"system_component": first_meta.get("system_component"),
|
|
"last_modified": first_meta.get("last_modified"),
|
|
"staleness_score": first_meta.get("staleness_score"),
|
|
"layer": first["layer"],
|
|
"lang": first["lang"],
|
|
"metadata_json": json.dumps(first_meta, ensure_ascii=True),
|
|
},
|
|
)
|
|
conn.execute(
|
|
text("DELETE FROM rag_chunk_cache WHERE repo_id = :repo_id AND blob_sha = :blob_sha"),
|
|
{"repo_id": repo_id, "blob_sha": blob_sha},
|
|
)
|
|
for idx, doc in enumerate(docs):
|
|
row = doc.to_record()
|
|
metadata = row["metadata"]
|
|
emb = row["embedding"] or []
|
|
emb_str = "[" + ",".join(str(x) for x in emb) + "]" if emb else None
|
|
conn.execute(
|
|
text(
|
|
"""
|
|
INSERT INTO rag_chunk_cache (
|
|
repo_id, blob_sha, chunk_index, content, embedding, section, layer, lang, path, title,
|
|
metadata_json, links_json, span_start, span_end, commit_sha
|
|
)
|
|
VALUES (
|
|
:repo_id, :blob_sha, :chunk_index, :content, CAST(:embedding AS vector), :section, :layer,
|
|
:lang, :path, :title, :metadata_json, :links_json, :span_start, :span_end, :commit_sha
|
|
)
|
|
"""
|
|
),
|
|
{
|
|
"repo_id": repo_id,
|
|
"blob_sha": blob_sha,
|
|
"chunk_index": idx,
|
|
"content": row["text"],
|
|
"embedding": emb_str,
|
|
"section": metadata.get("section") or metadata.get("section_title"),
|
|
"layer": row["layer"],
|
|
"lang": row["lang"],
|
|
"path": row["path"],
|
|
"title": row["title"],
|
|
"metadata_json": json.dumps(metadata, ensure_ascii=True),
|
|
"links_json": json.dumps(row["links"], ensure_ascii=True),
|
|
"span_start": row["span_start"],
|
|
"span_end": row["span_end"],
|
|
"commit_sha": row["commit_sha"],
|
|
},
|
|
)
|
|
conn.commit()
|
|
|
|
def record_repo_cache(
|
|
self,
|
|
*,
|
|
project_id: str,
|
|
commit_sha: str | None,
|
|
changed_files: list[str],
|
|
summary: str,
|
|
) -> None:
|
|
docs: list[RagDocument] = []
|
|
for idx, path in enumerate(changed_files):
|
|
docs.append(
|
|
RagDocument(
|
|
layer=RagLayer.CODE_SOURCE_CHUNKS,
|
|
lang="python" if path.endswith(".py") else None,
|
|
source=RagSource(project_id, commit_sha, path),
|
|
title=path,
|
|
text=f"repo_webhook:{path}:{summary[:300]}",
|
|
metadata={"chunk_index": idx, "artifact_type": "CODE", "section": "repo_webhook"},
|
|
)
|
|
)
|
|
for doc in docs:
|
|
blob_sha = self._blob_sha(commit_sha, doc.source.path)
|
|
doc.metadata["blob_sha"] = blob_sha
|
|
self.cache_documents(project_id, doc.source.path, blob_sha, [doc])
|
|
|
|
def _blob_sha(self, commit_sha: str | None, path: str) -> str:
|
|
from hashlib import sha256
|
|
|
|
return sha256(f"{commit_sha or 'no-commit'}:{path}".encode("utf-8")).hexdigest()
|
|
|
|
def _engine(self):
|
|
from app.modules.shared.db import get_engine
|
|
|
|
return get_engine()
|
|
|
|
def _loads(self, value, default=None):
|
|
if default is None:
|
|
default = {}
|
|
if not value:
|
|
return default
|
|
return json.loads(str(value))
|
|
|
|
def _parse_vector(self, value: str) -> list[float]:
|
|
text_value = value.strip()
|
|
if not text_value:
|
|
return []
|
|
if text_value.startswith("[") and text_value.endswith("]"):
|
|
text_value = text_value[1:-1]
|
|
if not text_value:
|
|
return []
|
|
return [float(part.strip()) for part in text_value.split(",") if part.strip()]
|