103 lines
4.7 KiB
Python
103 lines
4.7 KiB
Python
from __future__ import annotations
|
|
|
|
from typing import TYPE_CHECKING
|
|
|
|
from app.modules.rag.explain.models import LayeredRetrievalItem, TracePath
|
|
|
|
if TYPE_CHECKING:
|
|
from app.modules.rag.explain.graph_repository import CodeGraphRepository
|
|
|
|
|
|
class TraceBuilder:
|
|
def __init__(self, graph_repository: CodeGraphRepository) -> None:
|
|
self._graph = graph_repository
|
|
|
|
def build_paths(
|
|
self,
|
|
rag_session_id: str,
|
|
seed_symbols: list[LayeredRetrievalItem],
|
|
*,
|
|
max_depth: int,
|
|
max_paths: int = 3,
|
|
edge_types: list[str] | None = None,
|
|
) -> list[TracePath]:
|
|
edges_filter = edge_types or ["calls", "imports", "inherits"]
|
|
symbol_map = self._symbol_map(seed_symbols)
|
|
paths: list[TracePath] = []
|
|
for seed in seed_symbols:
|
|
seed_id = str(seed.metadata.get("symbol_id") or "")
|
|
if not seed_id:
|
|
continue
|
|
queue: list[tuple[list[str], float, list[str]]] = [([seed_id], 0.0, [])]
|
|
while queue and len(paths) < max_paths * 3:
|
|
current_path, score, notes = queue.pop(0)
|
|
src_symbol_id = current_path[-1]
|
|
out_edges = self._graph.get_out_edges(rag_session_id, [src_symbol_id], edges_filter, limit_per_src=4)
|
|
if not out_edges or len(current_path) >= max_depth:
|
|
paths.append(TracePath(symbol_ids=current_path, score=score, notes=notes))
|
|
continue
|
|
for edge in out_edges:
|
|
metadata = edge.metadata
|
|
dst_symbol_id = str(metadata.get("dst_symbol_id") or "")
|
|
next_notes = list(notes)
|
|
next_score = score + self._edge_score(edge, symbol_map.get(src_symbol_id))
|
|
if not dst_symbol_id:
|
|
dst_ref = str(metadata.get("dst_ref") or "")
|
|
package_hint = self._package_hint(symbol_map.get(src_symbol_id))
|
|
resolved = self._graph.resolve_symbol_by_ref(rag_session_id, dst_ref, package_hint=package_hint)
|
|
if resolved is not None:
|
|
dst_symbol_id = str(resolved.metadata.get("symbol_id") or "")
|
|
symbol_map[dst_symbol_id] = resolved
|
|
next_score += 2.0
|
|
next_notes.append(f"resolved:{dst_ref}")
|
|
if not dst_symbol_id or dst_symbol_id in current_path:
|
|
paths.append(TracePath(symbol_ids=current_path, score=next_score, notes=next_notes))
|
|
continue
|
|
if dst_symbol_id not in symbol_map:
|
|
symbols = self._graph.get_symbols_by_ids(rag_session_id, [dst_symbol_id])
|
|
if symbols:
|
|
symbol_map[dst_symbol_id] = symbols[0]
|
|
queue.append((current_path + [dst_symbol_id], next_score, next_notes))
|
|
unique = self._unique_paths(paths)
|
|
unique.sort(key=lambda item: item.score, reverse=True)
|
|
return unique[:max_paths] or [TracePath(symbol_ids=[seed.metadata.get("symbol_id", "")], score=0.0) for seed in seed_symbols[:1]]
|
|
|
|
def _edge_score(self, edge: LayeredRetrievalItem, source_symbol: LayeredRetrievalItem | None) -> float:
|
|
metadata = edge.metadata
|
|
score = 1.0
|
|
if str(metadata.get("resolution") or "") == "resolved":
|
|
score += 2.0
|
|
source_path = source_symbol.source if source_symbol is not None else ""
|
|
if source_path and edge.source == source_path:
|
|
score += 1.0
|
|
if "tests/" in edge.source or "/tests/" in edge.source:
|
|
score -= 3.0
|
|
return score
|
|
|
|
def _package_hint(self, symbol: LayeredRetrievalItem | None) -> str | None:
|
|
if symbol is None:
|
|
return None
|
|
package = str(symbol.metadata.get("package_or_module") or "")
|
|
if not package:
|
|
return None
|
|
return ".".join(package.split(".")[:-1]) or package
|
|
|
|
def _symbol_map(self, items: list[LayeredRetrievalItem]) -> dict[str, LayeredRetrievalItem]:
|
|
result: dict[str, LayeredRetrievalItem] = {}
|
|
for item in items:
|
|
symbol_id = str(item.metadata.get("symbol_id") or "")
|
|
if symbol_id:
|
|
result[symbol_id] = item
|
|
return result
|
|
|
|
def _unique_paths(self, items: list[TracePath]) -> list[TracePath]:
|
|
result: list[TracePath] = []
|
|
seen: set[tuple[str, ...]] = set()
|
|
for item in items:
|
|
key = tuple(symbol_id for symbol_id in item.symbol_ids if symbol_id)
|
|
if not key or key in seen:
|
|
continue
|
|
seen.add(key)
|
|
result.append(item)
|
|
return result
|