131 lines
4.9 KiB
Python
131 lines
4.9 KiB
Python
from __future__ import annotations
|
|
|
|
import ast
|
|
from dataclasses import dataclass, field
|
|
from hashlib import sha256
|
|
|
|
|
|
@dataclass(slots=True)
|
|
class PySymbol:
|
|
symbol_id: str
|
|
qname: str
|
|
kind: str
|
|
path: str
|
|
start_line: int
|
|
end_line: int
|
|
signature: str
|
|
decorators: list[str] = field(default_factory=list)
|
|
docstring: str | None = None
|
|
parent_symbol_id: str | None = None
|
|
lang_payload: dict = field(default_factory=dict)
|
|
|
|
|
|
class SymbolExtractor:
|
|
def extract(self, path: str, text: str, ast_tree: ast.AST | None) -> list[PySymbol]:
|
|
if ast_tree is None:
|
|
return []
|
|
collector = _SymbolVisitor(path)
|
|
collector.visit(ast_tree)
|
|
return collector.symbols
|
|
|
|
|
|
class _SymbolVisitor(ast.NodeVisitor):
|
|
def __init__(self, path: str) -> None:
|
|
self._path = path
|
|
self._stack: list[tuple[str, str]] = []
|
|
self.symbols: list[PySymbol] = []
|
|
|
|
def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
|
|
if self._stack:
|
|
return
|
|
module = node.module or ""
|
|
for item in node.names:
|
|
local_name = item.asname or item.name
|
|
imported_name = f"{module}.{item.name}".strip(".")
|
|
self.symbols.append(
|
|
PySymbol(
|
|
symbol_id=sha256(f"{self._path}|{local_name}|import_alias".encode("utf-8")).hexdigest(),
|
|
qname=local_name,
|
|
kind="const",
|
|
path=self._path,
|
|
start_line=int(getattr(node, "lineno", 1)),
|
|
end_line=int(getattr(node, "end_lineno", getattr(node, "lineno", 1))),
|
|
signature=f"{local_name} = {imported_name}",
|
|
lang_payload={"imported_from": imported_name, "import_alias": True},
|
|
)
|
|
)
|
|
self.generic_visit(node)
|
|
|
|
def visit_Import(self, node: ast.Import) -> None:
|
|
if self._stack:
|
|
return
|
|
for item in node.names:
|
|
local_name = item.asname or item.name
|
|
self.symbols.append(
|
|
PySymbol(
|
|
symbol_id=sha256(f"{self._path}|{local_name}|import".encode("utf-8")).hexdigest(),
|
|
qname=local_name,
|
|
kind="const",
|
|
path=self._path,
|
|
start_line=int(getattr(node, "lineno", 1)),
|
|
end_line=int(getattr(node, "end_lineno", getattr(node, "lineno", 1))),
|
|
signature=f"import {item.name}",
|
|
lang_payload={"imported_from": item.name, "import_alias": bool(item.asname)},
|
|
)
|
|
)
|
|
self.generic_visit(node)
|
|
|
|
def visit_ClassDef(self, node: ast.ClassDef) -> None:
|
|
self._add_symbol(node, "class", {"bases": [self._expr_name(base) for base in node.bases]})
|
|
self.generic_visit(node)
|
|
self._stack.pop()
|
|
|
|
def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
|
|
self._add_function(node, is_async=False)
|
|
|
|
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None:
|
|
self._add_function(node, is_async=True)
|
|
|
|
def _add_function(self, node, *, is_async: bool) -> None:
|
|
kind = "method" if self._stack and self._stack[-1][0] == "class" else "function"
|
|
self._add_symbol(node, kind, {"async": is_async})
|
|
self.generic_visit(node)
|
|
self._stack.pop()
|
|
|
|
def _add_symbol(self, node, kind: str, lang_payload: dict) -> None:
|
|
names = [name for _, name in self._stack] + [node.name]
|
|
qname = ".".join(names)
|
|
symbol_id = sha256(f"{self._path}|{qname}|{kind}".encode("utf-8")).hexdigest()
|
|
signature = self._signature(node)
|
|
symbol = PySymbol(
|
|
symbol_id=symbol_id,
|
|
qname=qname,
|
|
kind=kind,
|
|
path=self._path,
|
|
start_line=int(getattr(node, "lineno", 1)),
|
|
end_line=int(getattr(node, "end_lineno", getattr(node, "lineno", 1))),
|
|
signature=signature,
|
|
decorators=[self._expr_name(item) for item in getattr(node, "decorator_list", [])],
|
|
docstring=ast.get_docstring(node),
|
|
parent_symbol_id=self._stack[-1][1] if self._stack else None,
|
|
lang_payload=lang_payload,
|
|
)
|
|
self.symbols.append(symbol)
|
|
self._stack.append((kind, qname))
|
|
|
|
def _signature(self, node) -> str:
|
|
if isinstance(node, ast.ClassDef):
|
|
bases = ", ".join(self._expr_name(base) for base in node.bases)
|
|
return f"{node.name}({bases})" if bases else node.name
|
|
args = [arg.arg for arg in getattr(node.args, "args", [])]
|
|
return f"{node.name}({', '.join(args)})"
|
|
|
|
def _expr_name(self, node) -> str:
|
|
if isinstance(node, ast.Name):
|
|
return node.id
|
|
if isinstance(node, ast.Attribute):
|
|
return f"{self._expr_name(node.value)}.{node.attr}"
|
|
if isinstance(node, ast.Call):
|
|
return self._expr_name(node.func)
|
|
return ast.dump(node, include_attributes=False)
|