from __future__ import annotations import json from sqlalchemy import text from app.modules.rag.explain.models import CodeLocation, LayeredRetrievalItem from app.modules.shared.db import get_engine class CodeGraphRepository: def get_out_edges( self, rag_session_id: str, src_symbol_ids: list[str], edge_types: list[str], limit_per_src: int, ) -> list[LayeredRetrievalItem]: if not src_symbol_ids: return [] sql = """ SELECT path, content, layer, title, metadata_json, span_start, span_end FROM rag_chunks WHERE rag_session_id = :sid AND layer = 'C2_DEPENDENCY_GRAPH' AND CAST(metadata_json AS jsonb)->>'src_symbol_id' = ANY(:src_ids) AND CAST(metadata_json AS jsonb)->>'edge_type' = ANY(:edge_types) ORDER BY path, span_start """ with get_engine().connect() as conn: rows = conn.execute( text(sql), {"sid": rag_session_id, "src_ids": src_symbol_ids, "edge_types": edge_types}, ).mappings().fetchall() grouped: dict[str, int] = {} items: list[LayeredRetrievalItem] = [] for row in rows: metadata = self._loads(row.get("metadata_json")) src_symbol_id = str(metadata.get("src_symbol_id") or "") grouped[src_symbol_id] = grouped.get(src_symbol_id, 0) + 1 if grouped[src_symbol_id] > limit_per_src: continue items.append(self._to_item(row, metadata)) return items def get_in_edges( self, rag_session_id: str, dst_symbol_ids: list[str], edge_types: list[str], limit_per_dst: int, ) -> list[LayeredRetrievalItem]: if not dst_symbol_ids: return [] sql = """ SELECT path, content, layer, title, metadata_json, span_start, span_end FROM rag_chunks WHERE rag_session_id = :sid AND layer = 'C2_DEPENDENCY_GRAPH' AND CAST(metadata_json AS jsonb)->>'dst_symbol_id' = ANY(:dst_ids) AND CAST(metadata_json AS jsonb)->>'edge_type' = ANY(:edge_types) ORDER BY path, span_start """ with get_engine().connect() as conn: rows = conn.execute( text(sql), {"sid": rag_session_id, "dst_ids": dst_symbol_ids, "edge_types": edge_types}, ).mappings().fetchall() grouped: dict[str, int] = {} items: list[LayeredRetrievalItem] = [] for row in rows: metadata = self._loads(row.get("metadata_json")) dst_symbol_id = str(metadata.get("dst_symbol_id") or "") grouped[dst_symbol_id] = grouped.get(dst_symbol_id, 0) + 1 if grouped[dst_symbol_id] > limit_per_dst: continue items.append(self._to_item(row, metadata)) return items def resolve_symbol_by_ref( self, rag_session_id: str, dst_ref: str, package_hint: str | None = None, ) -> LayeredRetrievalItem | None: ref = (dst_ref or "").strip() if not ref: return None with get_engine().connect() as conn: rows = conn.execute( text( """ SELECT path, content, layer, title, metadata_json, span_start, span_end, qname FROM rag_chunks WHERE rag_session_id = :sid AND layer = 'C1_SYMBOL_CATALOG' AND (qname = :ref OR title = :ref OR qname LIKE :tail) ORDER BY path LIMIT 12 """ ), {"sid": rag_session_id, "ref": ref, "tail": f"%{ref}"}, ).mappings().fetchall() best: LayeredRetrievalItem | None = None best_score = -1 for row in rows: metadata = self._loads(row.get("metadata_json")) package = str(metadata.get("package_or_module") or "") score = 0 if str(row.get("qname") or "") == ref: score += 3 if str(row.get("title") or "") == ref: score += 2 if package_hint and package.startswith(package_hint): score += 3 if package_hint and package_hint in str(row.get("path") or ""): score += 1 if score > best_score: best = self._to_item(row, metadata) best_score = score return best def get_symbols_by_ids(self, rag_session_id: str, symbol_ids: list[str]) -> list[LayeredRetrievalItem]: if not symbol_ids: return [] with get_engine().connect() as conn: rows = conn.execute( text( """ SELECT path, content, layer, title, metadata_json, span_start, span_end FROM rag_chunks WHERE rag_session_id = :sid AND layer = 'C1_SYMBOL_CATALOG' AND symbol_id = ANY(:symbol_ids) ORDER BY path, span_start """ ), {"sid": rag_session_id, "symbol_ids": symbol_ids}, ).mappings().fetchall() return [self._to_item(row, self._loads(row.get("metadata_json"))) for row in rows] def get_chunks_by_symbol_ids( self, rag_session_id: str, symbol_ids: list[str], prefer_chunk_type: str = "symbol_block", ) -> list[LayeredRetrievalItem]: symbols = self.get_symbols_by_ids(rag_session_id, symbol_ids) chunks: list[LayeredRetrievalItem] = [] for symbol in symbols: location = symbol.location if location is None: continue chunk = self._chunk_for_symbol(rag_session_id, symbol, prefer_chunk_type=prefer_chunk_type) if chunk is not None: chunks.append(chunk) return chunks def _chunk_for_symbol( self, rag_session_id: str, symbol: LayeredRetrievalItem, *, prefer_chunk_type: str, ) -> LayeredRetrievalItem | None: location = symbol.location if location is None: return None with get_engine().connect() as conn: rows = conn.execute( text( """ SELECT path, content, layer, title, metadata_json, span_start, span_end FROM rag_chunks WHERE rag_session_id = :sid AND layer = 'C0_SOURCE_CHUNKS' AND path = :path AND COALESCE(span_start, 0) <= :end_line AND COALESCE(span_end, 999999) >= :start_line ORDER BY CASE WHEN CAST(metadata_json AS jsonb)->>'chunk_type' = :prefer_chunk_type THEN 0 ELSE 1 END, ABS(COALESCE(span_start, 0) - :start_line) LIMIT 1 """ ), { "sid": rag_session_id, "path": location.path, "start_line": location.start_line or 0, "end_line": location.end_line or 999999, "prefer_chunk_type": prefer_chunk_type, }, ).mappings().fetchall() if not rows: return None row = rows[0] return self._to_item(row, self._loads(row.get("metadata_json"))) def _to_item(self, row, metadata: dict) -> LayeredRetrievalItem: return LayeredRetrievalItem( source=str(row.get("path") or ""), content=str(row.get("content") or ""), layer=str(row.get("layer") or ""), title=str(row.get("title") or ""), metadata=metadata, location=CodeLocation( path=str(row.get("path") or ""), start_line=row.get("span_start"), end_line=row.get("span_end"), ), ) def _loads(self, value) -> dict: if not value: return {} return json.loads(str(value))