Files
agent/app/modules/rag/indexing/code/edges/extractor.py
2026-03-01 14:21:33 +03:00

115 lines
3.9 KiB
Python

from __future__ import annotations
import ast
from dataclasses import dataclass, field
from hashlib import sha256
@dataclass(slots=True)
class PyEdge:
edge_id: str
edge_type: str
src_symbol_id: str
src_qname: str
dst_symbol_id: str | None
dst_ref: str | None
path: str
start_line: int
end_line: int
resolution: str = "partial"
metadata: dict = field(default_factory=dict)
class EdgeExtractor:
def extract(self, path: str, ast_tree: ast.AST | None, symbols: list) -> list[PyEdge]:
if ast_tree is None:
return []
qname_map = {symbol.qname: symbol.symbol_id for symbol in symbols}
visitor = _EdgeVisitor(path, qname_map)
visitor.visit(ast_tree)
return visitor.edges
class _EdgeVisitor(ast.NodeVisitor):
def __init__(self, path: str, qname_map: dict[str, str]) -> None:
self._path = path
self._qname_map = qname_map
self._scope: list[str] = []
self.edges: list[PyEdge] = []
def visit_ClassDef(self, node: ast.ClassDef) -> None:
current = self._enter(node.name)
for base in node.bases:
self._add_edge("inherits", current, self._name(base), base)
self.generic_visit(node)
self._scope.pop()
def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
self._visit_function(node)
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None:
self._visit_function(node)
def visit_Import(self, node: ast.Import) -> None:
current = self._current_qname()
if not current:
return
for item in node.names:
self._add_edge("imports", current, item.name, node)
def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
current = self._current_qname()
if not current:
return
module = node.module or ""
for item in node.names:
self._add_edge("imports", current, f"{module}.{item.name}".strip("."), node)
def _visit_function(self, node) -> None:
current = self._enter(node.name)
for inner in ast.walk(node):
if isinstance(inner, ast.Call):
self._add_edge("calls", current, self._name(inner.func), inner, {"callsite_kind": "function_call"})
self.generic_visit(node)
self._scope.pop()
def _enter(self, name: str) -> str:
self._scope.append(name)
return self._current_qname() or name
def _current_qname(self) -> str | None:
if not self._scope:
return None
return ".".join(self._scope)
def _add_edge(self, edge_type: str, src_qname: str, dst_ref: str, node, extra: dict | None = None) -> None:
if not dst_ref:
return
src_symbol_id = self._qname_map.get(src_qname, sha256(src_qname.encode("utf-8")).hexdigest())
dst_symbol_id = self._qname_map.get(dst_ref)
edge_id = sha256(f"{self._path}|{src_qname}|{edge_type}|{dst_ref}|{getattr(node, 'lineno', 1)}".encode("utf-8")).hexdigest()
self.edges.append(
PyEdge(
edge_id=edge_id,
edge_type=edge_type,
src_symbol_id=src_symbol_id,
src_qname=src_qname,
dst_symbol_id=dst_symbol_id,
dst_ref=dst_ref,
path=self._path,
start_line=int(getattr(node, "lineno", 1)),
end_line=int(getattr(node, "end_lineno", getattr(node, "lineno", 1))),
resolution="resolved" if dst_symbol_id else "partial",
metadata=extra or {},
)
)
def _name(self, node) -> str:
if isinstance(node, ast.Name):
return node.id
if isinstance(node, ast.Attribute):
return f"{self._name(node.value)}.{node.attr}"
if isinstance(node, ast.Call):
return self._name(node.func)
return ""