Files
agent/app/modules/rag/explain/graph_repository.py

217 lines
8.2 KiB
Python

from __future__ import annotations
import json
from sqlalchemy import text
from app.modules.rag.explain.models import CodeLocation, LayeredRetrievalItem
from app.modules.shared.db import get_engine
class CodeGraphRepository:
def get_out_edges(
self,
rag_session_id: str,
src_symbol_ids: list[str],
edge_types: list[str],
limit_per_src: int,
) -> list[LayeredRetrievalItem]:
if not src_symbol_ids:
return []
sql = """
SELECT path, content, layer, title, metadata_json, span_start, span_end
FROM rag_chunks
WHERE rag_session_id = :sid
AND layer = 'C2_DEPENDENCY_GRAPH'
AND CAST(metadata_json AS jsonb)->>'src_symbol_id' = ANY(:src_ids)
AND CAST(metadata_json AS jsonb)->>'edge_type' = ANY(:edge_types)
ORDER BY path, span_start
"""
with get_engine().connect() as conn:
rows = conn.execute(
text(sql),
{"sid": rag_session_id, "src_ids": src_symbol_ids, "edge_types": edge_types},
).mappings().fetchall()
grouped: dict[str, int] = {}
items: list[LayeredRetrievalItem] = []
for row in rows:
metadata = self._loads(row.get("metadata_json"))
src_symbol_id = str(metadata.get("src_symbol_id") or "")
grouped[src_symbol_id] = grouped.get(src_symbol_id, 0) + 1
if grouped[src_symbol_id] > limit_per_src:
continue
items.append(self._to_item(row, metadata))
return items
def get_in_edges(
self,
rag_session_id: str,
dst_symbol_ids: list[str],
edge_types: list[str],
limit_per_dst: int,
) -> list[LayeredRetrievalItem]:
if not dst_symbol_ids:
return []
sql = """
SELECT path, content, layer, title, metadata_json, span_start, span_end
FROM rag_chunks
WHERE rag_session_id = :sid
AND layer = 'C2_DEPENDENCY_GRAPH'
AND CAST(metadata_json AS jsonb)->>'dst_symbol_id' = ANY(:dst_ids)
AND CAST(metadata_json AS jsonb)->>'edge_type' = ANY(:edge_types)
ORDER BY path, span_start
"""
with get_engine().connect() as conn:
rows = conn.execute(
text(sql),
{"sid": rag_session_id, "dst_ids": dst_symbol_ids, "edge_types": edge_types},
).mappings().fetchall()
grouped: dict[str, int] = {}
items: list[LayeredRetrievalItem] = []
for row in rows:
metadata = self._loads(row.get("metadata_json"))
dst_symbol_id = str(metadata.get("dst_symbol_id") or "")
grouped[dst_symbol_id] = grouped.get(dst_symbol_id, 0) + 1
if grouped[dst_symbol_id] > limit_per_dst:
continue
items.append(self._to_item(row, metadata))
return items
def resolve_symbol_by_ref(
self,
rag_session_id: str,
dst_ref: str,
package_hint: str | None = None,
) -> LayeredRetrievalItem | None:
ref = (dst_ref or "").strip()
if not ref:
return None
with get_engine().connect() as conn:
rows = conn.execute(
text(
"""
SELECT path, content, layer, title, metadata_json, span_start, span_end, qname
FROM rag_chunks
WHERE rag_session_id = :sid
AND layer = 'C1_SYMBOL_CATALOG'
AND (qname = :ref OR title = :ref OR qname LIKE :tail)
ORDER BY path
LIMIT 12
"""
),
{"sid": rag_session_id, "ref": ref, "tail": f"%{ref}"},
).mappings().fetchall()
best: LayeredRetrievalItem | None = None
best_score = -1
for row in rows:
metadata = self._loads(row.get("metadata_json"))
package = str(metadata.get("package_or_module") or "")
score = 0
if str(row.get("qname") or "") == ref:
score += 3
if str(row.get("title") or "") == ref:
score += 2
if package_hint and package.startswith(package_hint):
score += 3
if package_hint and package_hint in str(row.get("path") or ""):
score += 1
if score > best_score:
best = self._to_item(row, metadata)
best_score = score
return best
def get_symbols_by_ids(self, rag_session_id: str, symbol_ids: list[str]) -> list[LayeredRetrievalItem]:
if not symbol_ids:
return []
with get_engine().connect() as conn:
rows = conn.execute(
text(
"""
SELECT path, content, layer, title, metadata_json, span_start, span_end
FROM rag_chunks
WHERE rag_session_id = :sid
AND layer = 'C1_SYMBOL_CATALOG'
AND symbol_id = ANY(:symbol_ids)
ORDER BY path, span_start
"""
),
{"sid": rag_session_id, "symbol_ids": symbol_ids},
).mappings().fetchall()
return [self._to_item(row, self._loads(row.get("metadata_json"))) for row in rows]
def get_chunks_by_symbol_ids(
self,
rag_session_id: str,
symbol_ids: list[str],
prefer_chunk_type: str = "symbol_block",
) -> list[LayeredRetrievalItem]:
symbols = self.get_symbols_by_ids(rag_session_id, symbol_ids)
chunks: list[LayeredRetrievalItem] = []
for symbol in symbols:
location = symbol.location
if location is None:
continue
chunk = self._chunk_for_symbol(rag_session_id, symbol, prefer_chunk_type=prefer_chunk_type)
if chunk is not None:
chunks.append(chunk)
return chunks
def _chunk_for_symbol(
self,
rag_session_id: str,
symbol: LayeredRetrievalItem,
*,
prefer_chunk_type: str,
) -> LayeredRetrievalItem | None:
location = symbol.location
if location is None:
return None
with get_engine().connect() as conn:
rows = conn.execute(
text(
"""
SELECT path, content, layer, title, metadata_json, span_start, span_end
FROM rag_chunks
WHERE rag_session_id = :sid
AND layer = 'C0_SOURCE_CHUNKS'
AND path = :path
AND COALESCE(span_start, 0) <= :end_line
AND COALESCE(span_end, 999999) >= :start_line
ORDER BY
CASE WHEN CAST(metadata_json AS jsonb)->>'chunk_type' = :prefer_chunk_type THEN 0 ELSE 1 END,
ABS(COALESCE(span_start, 0) - :start_line)
LIMIT 1
"""
),
{
"sid": rag_session_id,
"path": location.path,
"start_line": location.start_line or 0,
"end_line": location.end_line or 999999,
"prefer_chunk_type": prefer_chunk_type,
},
).mappings().fetchall()
if not rows:
return None
row = rows[0]
return self._to_item(row, self._loads(row.get("metadata_json")))
def _to_item(self, row, metadata: dict) -> LayeredRetrievalItem:
return LayeredRetrievalItem(
source=str(row.get("path") or ""),
content=str(row.get("content") or ""),
layer=str(row.get("layer") or ""),
title=str(row.get("title") or ""),
metadata=metadata,
location=CodeLocation(
path=str(row.get("path") or ""),
start_line=row.get("span_start"),
end_line=row.get("span_end"),
),
)
def _loads(self, value) -> dict:
if not value:
return {}
return json.loads(str(value))