Фиксация изменений
This commit is contained in:
102
app/modules/rag/explain/trace_builder.py
Normal file
102
app/modules/rag/explain/trace_builder.py
Normal file
@@ -0,0 +1,102 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from app.modules.rag.explain.models import LayeredRetrievalItem, TracePath
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.modules.rag.explain.graph_repository import CodeGraphRepository
|
||||
|
||||
|
||||
class TraceBuilder:
|
||||
def __init__(self, graph_repository: CodeGraphRepository) -> None:
|
||||
self._graph = graph_repository
|
||||
|
||||
def build_paths(
|
||||
self,
|
||||
rag_session_id: str,
|
||||
seed_symbols: list[LayeredRetrievalItem],
|
||||
*,
|
||||
max_depth: int,
|
||||
max_paths: int = 3,
|
||||
edge_types: list[str] | None = None,
|
||||
) -> list[TracePath]:
|
||||
edges_filter = edge_types or ["calls", "imports", "inherits"]
|
||||
symbol_map = self._symbol_map(seed_symbols)
|
||||
paths: list[TracePath] = []
|
||||
for seed in seed_symbols:
|
||||
seed_id = str(seed.metadata.get("symbol_id") or "")
|
||||
if not seed_id:
|
||||
continue
|
||||
queue: list[tuple[list[str], float, list[str]]] = [([seed_id], 0.0, [])]
|
||||
while queue and len(paths) < max_paths * 3:
|
||||
current_path, score, notes = queue.pop(0)
|
||||
src_symbol_id = current_path[-1]
|
||||
out_edges = self._graph.get_out_edges(rag_session_id, [src_symbol_id], edges_filter, limit_per_src=4)
|
||||
if not out_edges or len(current_path) >= max_depth:
|
||||
paths.append(TracePath(symbol_ids=current_path, score=score, notes=notes))
|
||||
continue
|
||||
for edge in out_edges:
|
||||
metadata = edge.metadata
|
||||
dst_symbol_id = str(metadata.get("dst_symbol_id") or "")
|
||||
next_notes = list(notes)
|
||||
next_score = score + self._edge_score(edge, symbol_map.get(src_symbol_id))
|
||||
if not dst_symbol_id:
|
||||
dst_ref = str(metadata.get("dst_ref") or "")
|
||||
package_hint = self._package_hint(symbol_map.get(src_symbol_id))
|
||||
resolved = self._graph.resolve_symbol_by_ref(rag_session_id, dst_ref, package_hint=package_hint)
|
||||
if resolved is not None:
|
||||
dst_symbol_id = str(resolved.metadata.get("symbol_id") or "")
|
||||
symbol_map[dst_symbol_id] = resolved
|
||||
next_score += 2.0
|
||||
next_notes.append(f"resolved:{dst_ref}")
|
||||
if not dst_symbol_id or dst_symbol_id in current_path:
|
||||
paths.append(TracePath(symbol_ids=current_path, score=next_score, notes=next_notes))
|
||||
continue
|
||||
if dst_symbol_id not in symbol_map:
|
||||
symbols = self._graph.get_symbols_by_ids(rag_session_id, [dst_symbol_id])
|
||||
if symbols:
|
||||
symbol_map[dst_symbol_id] = symbols[0]
|
||||
queue.append((current_path + [dst_symbol_id], next_score, next_notes))
|
||||
unique = self._unique_paths(paths)
|
||||
unique.sort(key=lambda item: item.score, reverse=True)
|
||||
return unique[:max_paths] or [TracePath(symbol_ids=[seed.metadata.get("symbol_id", "")], score=0.0) for seed in seed_symbols[:1]]
|
||||
|
||||
def _edge_score(self, edge: LayeredRetrievalItem, source_symbol: LayeredRetrievalItem | None) -> float:
|
||||
metadata = edge.metadata
|
||||
score = 1.0
|
||||
if str(metadata.get("resolution") or "") == "resolved":
|
||||
score += 2.0
|
||||
source_path = source_symbol.source if source_symbol is not None else ""
|
||||
if source_path and edge.source == source_path:
|
||||
score += 1.0
|
||||
if "tests/" in edge.source or "/tests/" in edge.source:
|
||||
score -= 3.0
|
||||
return score
|
||||
|
||||
def _package_hint(self, symbol: LayeredRetrievalItem | None) -> str | None:
|
||||
if symbol is None:
|
||||
return None
|
||||
package = str(symbol.metadata.get("package_or_module") or "")
|
||||
if not package:
|
||||
return None
|
||||
return ".".join(package.split(".")[:-1]) or package
|
||||
|
||||
def _symbol_map(self, items: list[LayeredRetrievalItem]) -> dict[str, LayeredRetrievalItem]:
|
||||
result: dict[str, LayeredRetrievalItem] = {}
|
||||
for item in items:
|
||||
symbol_id = str(item.metadata.get("symbol_id") or "")
|
||||
if symbol_id:
|
||||
result[symbol_id] = item
|
||||
return result
|
||||
|
||||
def _unique_paths(self, items: list[TracePath]) -> list[TracePath]:
|
||||
result: list[TracePath] = []
|
||||
seen: set[tuple[str, ...]] = set()
|
||||
for item in items:
|
||||
key = tuple(symbol_id for symbol_id in item.symbol_ids if symbol_id)
|
||||
if not key or key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
result.append(item)
|
||||
return result
|
||||
Reference in New Issue
Block a user