diff --git a/tests/test_web_driver.py b/tests/test_web_driver.py index 688cfdd..691e718 100644 --- a/tests/test_web_driver.py +++ b/tests/test_web_driver.py @@ -21,8 +21,11 @@ def test_view_tokens_includes_token_names(self) -> None: def test_view_ast_returns_module_dump(self) -> None: rendered = driver.view_ast("x = 1\n", optimize=False) - self.assertIn("Module(", rendered) - self.assertIn("Assign(", rendered) + self.assertTrue(rendered["html"]) + self.assertIn(">Module(", rendered["text"]) + self.assertIn(">Assign(", rendered["text"]) + # Every row maps back to source line 1 via lineno propagation. + self.assertTrue(all(ln == 1 for ln in rendered["lines"] if ln is not None)) def test_view_pseudo_smoke(self) -> None: rendered = driver.view_pseudo("def f(x):\n return x\n\nprint(f(42))\n") diff --git a/web/driver.py b/web/driver.py index 183489c..3750c00 100644 --- a/web/driver.py +++ b/web/driver.py @@ -2,8 +2,10 @@ import ast import dis +import html import io import json +import re import sys import tokenize import traceback @@ -42,67 +44,138 @@ def view_tokens(code: str) -> dict[str, Any]: return _as_view(rows) -def _has_ast_children(node: ast.AST) -> bool: - if isinstance(node, ast.Name): - return False - SENTINEL = object() - for name in node._fields: - value = getattr(node, name, SENTINEL) - if isinstance(value, (list, ast.AST)): - return True - return False - - -def _ast_attr_repr(node: ast.AST, attr: str) -> str: - value = getattr(node, attr, ...) - if isinstance(value, (ast.Load, ast.Store, ast.Del)): - return value.__class__.__name__ - return repr(value) - - -def _dump_ast(tree: ast.AST) -> Iterator[tuple[str, int | None]]: - SENTINEL = object() - indent = " " - - def walk( - node: Any, level: int = 0, last_line: int = 0, prepend: str = "" - ) -> Iterator[tuple[str, int]]: - prefix = f"{indent * level}{prepend}" - if isinstance(node, ast.AST): - fields = node._fields - start = getattr(node, "lineno", last_line) or last_line - if not _has_ast_children(node): - args = ", ".join(f"{n}={_ast_attr_repr(node, n)}" for n in fields) - yield f"{prefix}{node.__class__.__name__}({args})", start - else: - yield f"{prefix}{node.__class__.__name__}()", start - for name in fields: - value = getattr(node, name, SENTINEL) - if value is SENTINEL: - continue - yield from walk(value, level + 1, start, f"{name}=") - elif isinstance(node, list): - if len(node) == 1 and not _has_ast_children(node[0]): - inner = list(walk(node[0], level, last_line, prepend + "[")) - if len(inner) == 1: - text, line = inner[0] - yield text + "]", line - return - yield from inner - else: - yield f"{prefix}[]", last_line - for value in node: - yield from walk(value, level + 1, last_line) - else: - yield f"{prefix}{node!r}", last_line +_ANSI_RE = re.compile(r"\x1b\[([0-9;]*)m") +_LINENO_RE = re.compile(r"\blineno=(\d+)") +_ATTR_ROW_RE = re.compile(r"^\s*(?:lineno|col_offset|end_lineno|end_col_offset)=\d+") +_ANSI_CLASS = { + "36": "ast-node", + "34": "ast-field", + "90": "ast-attribute", + "32": "ast-string", + "33": "ast-number", + "1;34": "ast-keyword", +} + + +def _ansi_to_html(s: str) -> str: + out: list[str] = [] + pos = 0 + open_span = False + for m in _ANSI_RE.finditer(s): + out.append(html.escape(s[pos : m.start()])) + if open_span: + out.append("") + open_span = False + code = m.group(1) + cls = _ANSI_CLASS.get(code) if code and code != "0" else None + if cls: + out.append(f'') + open_span = True + pos = m.end() + out.append(html.escape(s[pos:])) + if open_span: + out.append("") + return "".join(out) + + +def _attach_linenos(plain_lines: list[str]) -> list[int | None]: + n = len(plain_lines) + result: list[int | None] = [None] * n + indents = [len(line) - len(line.lstrip(" ")) for line in plain_lines] + + for i, line in enumerate(plain_lines): + m = _LINENO_RE.search(line) + if m: + result[i] = int(m.group(1)) + + for i in range(n - 1, -1, -1): + if result[i] is not None: + continue + my_indent = indents[i] + for j in range(i + 1, n): + if indents[j] <= my_indent: + break + if indents[j] == my_indent + 4 and result[j] is not None: + result[i] = result[j] + break + + for i in range(n): + if result[i] is not None: + continue + for j in range(i - 1, -1, -1): + if indents[j] < indents[i] and result[j] is not None: + result[i] = result[j] + break + + return result - for text, line in walk(tree): - yield text, (line if line and line > 0 else None) + +_END_COL_RE = re.compile(r"^\s*end_col_offset=\d+([)\]]*)(,?)\s*$") + + +def _strip_attribute_rows( + plain_lines: list[str], + html_lines: list[str], + lineno_map: list[int | None], +) -> tuple[list[str], list[str], list[int | None]]: + n = len(plain_lines) + keep = [True] * n + plain_lines = list(plain_lines) + html_lines = list(html_lines) + + i = 0 + while i < n: + if not _ATTR_ROW_RE.match(plain_lines[i]): + i += 1 + continue + start = i + struct = "" + trailing_comma = False + while i < n and _ATTR_ROW_RE.match(plain_lines[i]): + keep[i] = False + m = _END_COL_RE.match(plain_lines[i]) + if m: + struct += m.group(1) + trailing_comma = bool(m.group(2)) + i += 1 + tail = struct + ("," if trailing_comma else "") + prev = start - 1 + if prev >= 0 and tail: + plain_lines[prev] = _replace_trailing_comma(plain_lines[prev], tail) + html_lines[prev] = _replace_trailing_comma(html_lines[prev], tail) + + new_plain = [line for i, line in enumerate(plain_lines) if keep[i]] + new_html = [line for i, line in enumerate(html_lines) if keep[i]] + new_lineno = [ln for i, ln in enumerate(lineno_map) if keep[i]] + return new_plain, new_html, new_lineno + + +def _replace_trailing_comma(line: str, tail: str) -> str: + rstripped = line.rstrip() + if rstripped.endswith(","): + rstripped = rstripped[:-1] + return rstripped + tail def view_ast(code: str, *, optimize: bool = False) -> dict[str, Any]: tree = ast.parse(code, optimize=1) if optimize else ast.parse(code) - return _as_view(list(_dump_ast(tree))) + dump_kwargs: dict[str, Any] = dict( + indent=4, include_attributes=True, show_empty=True + ) + if sys.version_info >= (3, 15): + dump_kwargs["color"] = True + colored = ast.dump(tree, **dump_kwargs) + plain_lines = _ANSI_RE.sub("", colored).split("\n") + html_lines = [_ansi_to_html(line) for line in colored.split("\n")] + lineno_map = _attach_linenos(plain_lines) + _, html_lines, lineno_map = _strip_attribute_rows( + plain_lines, html_lines, lineno_map + ) + return { + "text": "\n".join(html_lines), + "lines": lineno_map, + "html": True, + } class _PseudoArgResolver(dis.ArgResolver): diff --git a/web/index.html b/web/index.html index a23b275..67e2380 100644 --- a/web/index.html +++ b/web/index.html @@ -132,6 +132,12 @@ .panel > .content .line.highlight { background: rgba(255, 215, 0, 0.35); } + .ast-node { color: #5fc1e0; } + .ast-field { color: #6cb6ff; } + .ast-attribute { color: #8a93a0; } + .ast-string { color: #b5e890; } + .ast-number { color: #f0c674; } + .ast-keyword { color: #b9a0ff; font-weight: bold; } .ace-codoscope-highlight { position: absolute; background: rgba(255, 215, 0, 0.35); @@ -233,13 +239,19 @@