217 lines
8.2 KiB
Python
217 lines
8.2 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
|
|
from sqlalchemy import text
|
|
|
|
from app.modules.rag.explain.models import CodeLocation, LayeredRetrievalItem
|
|
from app.modules.shared.db import get_engine
|
|
|
|
|
|
class CodeGraphRepository:
|
|
def get_out_edges(
|
|
self,
|
|
rag_session_id: str,
|
|
src_symbol_ids: list[str],
|
|
edge_types: list[str],
|
|
limit_per_src: int,
|
|
) -> list[LayeredRetrievalItem]:
|
|
if not src_symbol_ids:
|
|
return []
|
|
sql = """
|
|
SELECT path, content, layer, title, metadata_json, span_start, span_end
|
|
FROM rag_chunks
|
|
WHERE rag_session_id = :sid
|
|
AND layer = 'C2_DEPENDENCY_GRAPH'
|
|
AND CAST(metadata_json AS jsonb)->>'src_symbol_id' = ANY(:src_ids)
|
|
AND CAST(metadata_json AS jsonb)->>'edge_type' = ANY(:edge_types)
|
|
ORDER BY path, span_start
|
|
"""
|
|
with get_engine().connect() as conn:
|
|
rows = conn.execute(
|
|
text(sql),
|
|
{"sid": rag_session_id, "src_ids": src_symbol_ids, "edge_types": edge_types},
|
|
).mappings().fetchall()
|
|
grouped: dict[str, int] = {}
|
|
items: list[LayeredRetrievalItem] = []
|
|
for row in rows:
|
|
metadata = self._loads(row.get("metadata_json"))
|
|
src_symbol_id = str(metadata.get("src_symbol_id") or "")
|
|
grouped[src_symbol_id] = grouped.get(src_symbol_id, 0) + 1
|
|
if grouped[src_symbol_id] > limit_per_src:
|
|
continue
|
|
items.append(self._to_item(row, metadata))
|
|
return items
|
|
|
|
def get_in_edges(
|
|
self,
|
|
rag_session_id: str,
|
|
dst_symbol_ids: list[str],
|
|
edge_types: list[str],
|
|
limit_per_dst: int,
|
|
) -> list[LayeredRetrievalItem]:
|
|
if not dst_symbol_ids:
|
|
return []
|
|
sql = """
|
|
SELECT path, content, layer, title, metadata_json, span_start, span_end
|
|
FROM rag_chunks
|
|
WHERE rag_session_id = :sid
|
|
AND layer = 'C2_DEPENDENCY_GRAPH'
|
|
AND CAST(metadata_json AS jsonb)->>'dst_symbol_id' = ANY(:dst_ids)
|
|
AND CAST(metadata_json AS jsonb)->>'edge_type' = ANY(:edge_types)
|
|
ORDER BY path, span_start
|
|
"""
|
|
with get_engine().connect() as conn:
|
|
rows = conn.execute(
|
|
text(sql),
|
|
{"sid": rag_session_id, "dst_ids": dst_symbol_ids, "edge_types": edge_types},
|
|
).mappings().fetchall()
|
|
grouped: dict[str, int] = {}
|
|
items: list[LayeredRetrievalItem] = []
|
|
for row in rows:
|
|
metadata = self._loads(row.get("metadata_json"))
|
|
dst_symbol_id = str(metadata.get("dst_symbol_id") or "")
|
|
grouped[dst_symbol_id] = grouped.get(dst_symbol_id, 0) + 1
|
|
if grouped[dst_symbol_id] > limit_per_dst:
|
|
continue
|
|
items.append(self._to_item(row, metadata))
|
|
return items
|
|
|
|
def resolve_symbol_by_ref(
|
|
self,
|
|
rag_session_id: str,
|
|
dst_ref: str,
|
|
package_hint: str | None = None,
|
|
) -> LayeredRetrievalItem | None:
|
|
ref = (dst_ref or "").strip()
|
|
if not ref:
|
|
return None
|
|
with get_engine().connect() as conn:
|
|
rows = conn.execute(
|
|
text(
|
|
"""
|
|
SELECT path, content, layer, title, metadata_json, span_start, span_end, qname
|
|
FROM rag_chunks
|
|
WHERE rag_session_id = :sid
|
|
AND layer = 'C1_SYMBOL_CATALOG'
|
|
AND (qname = :ref OR title = :ref OR qname LIKE :tail)
|
|
ORDER BY path
|
|
LIMIT 12
|
|
"""
|
|
),
|
|
{"sid": rag_session_id, "ref": ref, "tail": f"%{ref}"},
|
|
).mappings().fetchall()
|
|
best: LayeredRetrievalItem | None = None
|
|
best_score = -1
|
|
for row in rows:
|
|
metadata = self._loads(row.get("metadata_json"))
|
|
package = str(metadata.get("package_or_module") or "")
|
|
score = 0
|
|
if str(row.get("qname") or "") == ref:
|
|
score += 3
|
|
if str(row.get("title") or "") == ref:
|
|
score += 2
|
|
if package_hint and package.startswith(package_hint):
|
|
score += 3
|
|
if package_hint and package_hint in str(row.get("path") or ""):
|
|
score += 1
|
|
if score > best_score:
|
|
best = self._to_item(row, metadata)
|
|
best_score = score
|
|
return best
|
|
|
|
def get_symbols_by_ids(self, rag_session_id: str, symbol_ids: list[str]) -> list[LayeredRetrievalItem]:
|
|
if not symbol_ids:
|
|
return []
|
|
with get_engine().connect() as conn:
|
|
rows = conn.execute(
|
|
text(
|
|
"""
|
|
SELECT path, content, layer, title, metadata_json, span_start, span_end
|
|
FROM rag_chunks
|
|
WHERE rag_session_id = :sid
|
|
AND layer = 'C1_SYMBOL_CATALOG'
|
|
AND symbol_id = ANY(:symbol_ids)
|
|
ORDER BY path, span_start
|
|
"""
|
|
),
|
|
{"sid": rag_session_id, "symbol_ids": symbol_ids},
|
|
).mappings().fetchall()
|
|
return [self._to_item(row, self._loads(row.get("metadata_json"))) for row in rows]
|
|
|
|
def get_chunks_by_symbol_ids(
|
|
self,
|
|
rag_session_id: str,
|
|
symbol_ids: list[str],
|
|
prefer_chunk_type: str = "symbol_block",
|
|
) -> list[LayeredRetrievalItem]:
|
|
symbols = self.get_symbols_by_ids(rag_session_id, symbol_ids)
|
|
chunks: list[LayeredRetrievalItem] = []
|
|
for symbol in symbols:
|
|
location = symbol.location
|
|
if location is None:
|
|
continue
|
|
chunk = self._chunk_for_symbol(rag_session_id, symbol, prefer_chunk_type=prefer_chunk_type)
|
|
if chunk is not None:
|
|
chunks.append(chunk)
|
|
return chunks
|
|
|
|
def _chunk_for_symbol(
|
|
self,
|
|
rag_session_id: str,
|
|
symbol: LayeredRetrievalItem,
|
|
*,
|
|
prefer_chunk_type: str,
|
|
) -> LayeredRetrievalItem | None:
|
|
location = symbol.location
|
|
if location is None:
|
|
return None
|
|
with get_engine().connect() as conn:
|
|
rows = conn.execute(
|
|
text(
|
|
"""
|
|
SELECT path, content, layer, title, metadata_json, span_start, span_end
|
|
FROM rag_chunks
|
|
WHERE rag_session_id = :sid
|
|
AND layer = 'C0_SOURCE_CHUNKS'
|
|
AND path = :path
|
|
AND COALESCE(span_start, 0) <= :end_line
|
|
AND COALESCE(span_end, 999999) >= :start_line
|
|
ORDER BY
|
|
CASE WHEN CAST(metadata_json AS jsonb)->>'chunk_type' = :prefer_chunk_type THEN 0 ELSE 1 END,
|
|
ABS(COALESCE(span_start, 0) - :start_line)
|
|
LIMIT 1
|
|
"""
|
|
),
|
|
{
|
|
"sid": rag_session_id,
|
|
"path": location.path,
|
|
"start_line": location.start_line or 0,
|
|
"end_line": location.end_line or 999999,
|
|
"prefer_chunk_type": prefer_chunk_type,
|
|
},
|
|
).mappings().fetchall()
|
|
if not rows:
|
|
return None
|
|
row = rows[0]
|
|
return self._to_item(row, self._loads(row.get("metadata_json")))
|
|
|
|
def _to_item(self, row, metadata: dict) -> LayeredRetrievalItem:
|
|
return LayeredRetrievalItem(
|
|
source=str(row.get("path") or ""),
|
|
content=str(row.get("content") or ""),
|
|
layer=str(row.get("layer") or ""),
|
|
title=str(row.get("title") or ""),
|
|
metadata=metadata,
|
|
location=CodeLocation(
|
|
path=str(row.get("path") or ""),
|
|
start_line=row.get("span_start"),
|
|
end_line=row.get("span_end"),
|
|
),
|
|
)
|
|
|
|
def _loads(self, value) -> dict:
|
|
if not value:
|
|
return {}
|
|
return json.loads(str(value))
|