from __future__ import annotations from typing import TYPE_CHECKING from app.modules.rag.explain.models import LayeredRetrievalItem, TracePath if TYPE_CHECKING: from app.modules.rag.explain.graph_repository import CodeGraphRepository class TraceBuilder: def __init__(self, graph_repository: CodeGraphRepository) -> None: self._graph = graph_repository def build_paths( self, rag_session_id: str, seed_symbols: list[LayeredRetrievalItem], *, max_depth: int, max_paths: int = 3, edge_types: list[str] | None = None, ) -> list[TracePath]: edges_filter = edge_types or ["calls", "imports", "inherits"] symbol_map = self._symbol_map(seed_symbols) paths: list[TracePath] = [] for seed in seed_symbols: seed_id = str(seed.metadata.get("symbol_id") or "") if not seed_id: continue queue: list[tuple[list[str], float, list[str]]] = [([seed_id], 0.0, [])] while queue and len(paths) < max_paths * 3: current_path, score, notes = queue.pop(0) src_symbol_id = current_path[-1] out_edges = self._graph.get_out_edges(rag_session_id, [src_symbol_id], edges_filter, limit_per_src=4) if not out_edges or len(current_path) >= max_depth: paths.append(TracePath(symbol_ids=current_path, score=score, notes=notes)) continue for edge in out_edges: metadata = edge.metadata dst_symbol_id = str(metadata.get("dst_symbol_id") or "") next_notes = list(notes) next_score = score + self._edge_score(edge, symbol_map.get(src_symbol_id)) if not dst_symbol_id: dst_ref = str(metadata.get("dst_ref") or "") package_hint = self._package_hint(symbol_map.get(src_symbol_id)) resolved = self._graph.resolve_symbol_by_ref(rag_session_id, dst_ref, package_hint=package_hint) if resolved is not None: dst_symbol_id = str(resolved.metadata.get("symbol_id") or "") symbol_map[dst_symbol_id] = resolved next_score += 2.0 next_notes.append(f"resolved:{dst_ref}") if not dst_symbol_id or dst_symbol_id in current_path: paths.append(TracePath(symbol_ids=current_path, score=next_score, notes=next_notes)) continue if dst_symbol_id not in symbol_map: symbols = self._graph.get_symbols_by_ids(rag_session_id, [dst_symbol_id]) if symbols: symbol_map[dst_symbol_id] = symbols[0] queue.append((current_path + [dst_symbol_id], next_score, next_notes)) unique = self._unique_paths(paths) unique.sort(key=lambda item: item.score, reverse=True) return unique[:max_paths] or [TracePath(symbol_ids=[seed.metadata.get("symbol_id", "")], score=0.0) for seed in seed_symbols[:1]] def _edge_score(self, edge: LayeredRetrievalItem, source_symbol: LayeredRetrievalItem | None) -> float: metadata = edge.metadata score = 1.0 if str(metadata.get("resolution") or "") == "resolved": score += 2.0 source_path = source_symbol.source if source_symbol is not None else "" if source_path and edge.source == source_path: score += 1.0 if "tests/" in edge.source or "/tests/" in edge.source: score -= 3.0 return score def _package_hint(self, symbol: LayeredRetrievalItem | None) -> str | None: if symbol is None: return None package = str(symbol.metadata.get("package_or_module") or "") if not package: return None return ".".join(package.split(".")[:-1]) or package def _symbol_map(self, items: list[LayeredRetrievalItem]) -> dict[str, LayeredRetrievalItem]: result: dict[str, LayeredRetrievalItem] = {} for item in items: symbol_id = str(item.metadata.get("symbol_id") or "") if symbol_id: result[symbol_id] = item return result def _unique_paths(self, items: list[TracePath]) -> list[TracePath]: result: list[TracePath] = [] seen: set[tuple[str, ...]] = set() for item in items: key = tuple(symbol_id for symbol_id in item.symbol_ids if symbol_id) if not key or key in seen: continue seen.add(key) result.append(item) return result