115 lines
3.9 KiB
Python
115 lines
3.9 KiB
Python
from __future__ import annotations
|
|
|
|
import ast
|
|
from dataclasses import dataclass, field
|
|
from hashlib import sha256
|
|
|
|
|
|
@dataclass(slots=True)
|
|
class PyEdge:
|
|
edge_id: str
|
|
edge_type: str
|
|
src_symbol_id: str
|
|
src_qname: str
|
|
dst_symbol_id: str | None
|
|
dst_ref: str | None
|
|
path: str
|
|
start_line: int
|
|
end_line: int
|
|
resolution: str = "partial"
|
|
metadata: dict = field(default_factory=dict)
|
|
|
|
|
|
class EdgeExtractor:
|
|
def extract(self, path: str, ast_tree: ast.AST | None, symbols: list) -> list[PyEdge]:
|
|
if ast_tree is None:
|
|
return []
|
|
qname_map = {symbol.qname: symbol.symbol_id for symbol in symbols}
|
|
visitor = _EdgeVisitor(path, qname_map)
|
|
visitor.visit(ast_tree)
|
|
return visitor.edges
|
|
|
|
|
|
class _EdgeVisitor(ast.NodeVisitor):
|
|
def __init__(self, path: str, qname_map: dict[str, str]) -> None:
|
|
self._path = path
|
|
self._qname_map = qname_map
|
|
self._scope: list[str] = []
|
|
self.edges: list[PyEdge] = []
|
|
|
|
def visit_ClassDef(self, node: ast.ClassDef) -> None:
|
|
current = self._enter(node.name)
|
|
for base in node.bases:
|
|
self._add_edge("inherits", current, self._name(base), base)
|
|
self.generic_visit(node)
|
|
self._scope.pop()
|
|
|
|
def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
|
|
self._visit_function(node)
|
|
|
|
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None:
|
|
self._visit_function(node)
|
|
|
|
def visit_Import(self, node: ast.Import) -> None:
|
|
current = self._current_qname()
|
|
if not current:
|
|
return
|
|
for item in node.names:
|
|
self._add_edge("imports", current, item.name, node)
|
|
|
|
def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
|
|
current = self._current_qname()
|
|
if not current:
|
|
return
|
|
module = node.module or ""
|
|
for item in node.names:
|
|
self._add_edge("imports", current, f"{module}.{item.name}".strip("."), node)
|
|
|
|
def _visit_function(self, node) -> None:
|
|
current = self._enter(node.name)
|
|
for inner in ast.walk(node):
|
|
if isinstance(inner, ast.Call):
|
|
self._add_edge("calls", current, self._name(inner.func), inner, {"callsite_kind": "function_call"})
|
|
self.generic_visit(node)
|
|
self._scope.pop()
|
|
|
|
def _enter(self, name: str) -> str:
|
|
self._scope.append(name)
|
|
return self._current_qname() or name
|
|
|
|
def _current_qname(self) -> str | None:
|
|
if not self._scope:
|
|
return None
|
|
return ".".join(self._scope)
|
|
|
|
def _add_edge(self, edge_type: str, src_qname: str, dst_ref: str, node, extra: dict | None = None) -> None:
|
|
if not dst_ref:
|
|
return
|
|
src_symbol_id = self._qname_map.get(src_qname, sha256(src_qname.encode("utf-8")).hexdigest())
|
|
dst_symbol_id = self._qname_map.get(dst_ref)
|
|
edge_id = sha256(f"{self._path}|{src_qname}|{edge_type}|{dst_ref}|{getattr(node, 'lineno', 1)}".encode("utf-8")).hexdigest()
|
|
self.edges.append(
|
|
PyEdge(
|
|
edge_id=edge_id,
|
|
edge_type=edge_type,
|
|
src_symbol_id=src_symbol_id,
|
|
src_qname=src_qname,
|
|
dst_symbol_id=dst_symbol_id,
|
|
dst_ref=dst_ref,
|
|
path=self._path,
|
|
start_line=int(getattr(node, "lineno", 1)),
|
|
end_line=int(getattr(node, "end_lineno", getattr(node, "lineno", 1))),
|
|
resolution="resolved" if dst_symbol_id else "partial",
|
|
metadata=extra or {},
|
|
)
|
|
)
|
|
|
|
def _name(self, node) -> str:
|
|
if isinstance(node, ast.Name):
|
|
return node.id
|
|
if isinstance(node, ast.Attribute):
|
|
return f"{self._name(node.value)}.{node.attr}"
|
|
if isinstance(node, ast.Call):
|
|
return self._name(node.func)
|
|
return ""
|