diff --git a/marko/helpers.py b/marko/helpers.py index ff3041a..9e6ec82 100644 --- a/marko/helpers.py +++ b/marko/helpers.py @@ -8,19 +8,20 @@ import re from functools import partial from importlib import import_module -from typing import TYPE_CHECKING, overload +from typing import TYPE_CHECKING, Callable, Generic, TypeVar, overload from marko.renderer import Renderer if TYPE_CHECKING: - from typing import Any, Callable, Container, Iterable, TypeVar + from typing import Any, Container, Iterable - from .element import Element + from marko.element import Element - RendererFunc = Callable[[Any, Element], Any] - TRenderer = TypeVar("TRenderer", bound=RendererFunc) - D = TypeVar("D", bound="_RendererDispatcher") +T = TypeVar("T") +U = TypeVar("U") +ElementT = TypeVar("ElementT", bound="Element") +RendererFunc = Callable[[T, ElementT], U] def camel_to_snake_case(name: str) -> str: """Takes a camelCased string and converts to snake_case.""" @@ -109,7 +110,7 @@ def partition_by_spaces(text: str, spaces: str = " \t") -> tuple[str, str, str]: class MarkoExtension: parser_mixins: list[type] = dataclasses.field(default_factory=list) renderer_mixins: list[type] = dataclasses.field(default_factory=list) - elements: list[type[Element]] = dataclasses.field(default_factory=list) + elements: list[type["Element"]] = dataclasses.field(default_factory=list) def load_extension(name: str, **kwargs: Any) -> MarkoExtension: @@ -136,11 +137,13 @@ def load_extension(name: str, **kwargs: Any) -> MarkoExtension: ) from None -class _RendererDispatcher: +class _RendererDispatcher(Generic[T, ElementT, U]): name: str def __init__( - self, types: type[Renderer] | tuple[type[Renderer], ...], func: RendererFunc + self, + types: type[Renderer] | tuple[type[Renderer], ...], + func: RendererFunc[T, ElementT, U], ) -> None: from marko.ast_renderer import ASTRenderer, XMLRenderer @@ -148,9 +151,10 @@ def __init__( self._mapping.setdefault((ASTRenderer, XMLRenderer), self.render_ast) def dispatch( - self: D, types: type[Renderer] | tuple[type[Renderer], ...] - ) -> Callable[[RendererFunc], D]: - def decorator(func: RendererFunc) -> D: + self: _RendererDispatcher[T, ElementT, U], + types: type[Renderer] | tuple[type[Renderer], ...], + ) -> Callable[[RendererFunc[T, ElementT, U]], _RendererDispatcher[T, ElementT, U]]: + def decorator(func: RendererFunc[T, ElementT, U]) -> _RendererDispatcher[T, ElementT, U]: self._mapping[types] = func return self @@ -160,10 +164,10 @@ def __set_name__(self, owner: type, name: str) -> None: self.name = name @staticmethod - def render_ast(self, element: Element) -> Any: + def render_ast(self, element: "Element") -> Any: return self.render_children(element) - def super_render(self, r: Any, element: Element) -> Any: + def super_render(self, r: Any, element: "Element") -> Any: """Call on the next class in the MRO which has the same method.""" klasses = (c for c in type(r).mro() if self.name in c.__dict__) try: @@ -175,12 +179,23 @@ def super_render(self, r: Any, element: Element) -> Any: return getattr(parent, self.name)(r, element) @overload - def __get__(self: D, obj: None, owner: type) -> D: ... + def __get__( + self: _RendererDispatcher[T, ElementT, U], + obj: None, + owner: type, + ) -> _RendererDispatcher[T, ElementT, U]: ... @overload - def __get__(self: D, obj: Renderer, owner: type) -> RendererFunc: ... - - def __get__(self: D, obj: Renderer | None, owner: type) -> RendererFunc | D: + def __get__( + self: _RendererDispatcher[T, ElementT, U], + obj: Renderer, + owner: type, + ) -> RendererFunc[T, ElementT, U]: ... + + def __get__( + self: _RendererDispatcher[T, ElementT, U], + obj: Renderer | None, owner: type, + ) -> RendererFunc[T, ElementT, U] | _RendererDispatcher[T, ElementT, U]: if obj is None: return self for types, func in self._mapping.items(): @@ -191,8 +206,8 @@ def __get__(self: D, obj: Renderer | None, owner: type) -> RendererFunc | D: def render_dispatch( types: type[Renderer] | tuple[type[Renderer], ...], -) -> Callable[[RendererFunc], _RendererDispatcher]: - def decorator(func: RendererFunc) -> _RendererDispatcher: +) -> Callable[[RendererFunc[T, ElementT, U]], _RendererDispatcher[T, ElementT, U]]: + def decorator(func: RendererFunc[T, ElementT, U]) -> _RendererDispatcher[T, ElementT, U]: return _RendererDispatcher(types, func) return decorator