Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 35 additions & 20 deletions marko/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,20 @@
import re
from functools import partial
from importlib import import_module
from typing import TYPE_CHECKING, overload
from typing import TYPE_CHECKING, Callable, Generic, TypeVar, overload

from marko.renderer import Renderer

if TYPE_CHECKING:
from typing import Any, Callable, Container, Iterable, TypeVar
from typing import Any, Container, Iterable

from .element import Element
from marko.element import Element

RendererFunc = Callable[[Any, Element], Any]
TRenderer = TypeVar("TRenderer", bound=RendererFunc)
D = TypeVar("D", bound="_RendererDispatcher")

T = TypeVar("T")
U = TypeVar("U")
ElementT = TypeVar("ElementT", bound="Element")
RendererFunc = Callable[[T, ElementT], U]

def camel_to_snake_case(name: str) -> str:
"""Takes a camelCased string and converts to snake_case."""
Expand Down Expand Up @@ -109,7 +110,7 @@ def partition_by_spaces(text: str, spaces: str = " \t") -> tuple[str, str, str]:
class MarkoExtension:
parser_mixins: list[type] = dataclasses.field(default_factory=list)
renderer_mixins: list[type] = dataclasses.field(default_factory=list)
elements: list[type[Element]] = dataclasses.field(default_factory=list)
elements: list[type["Element"]] = dataclasses.field(default_factory=list)


def load_extension(name: str, **kwargs: Any) -> MarkoExtension:
Expand All @@ -136,21 +137,24 @@ def load_extension(name: str, **kwargs: Any) -> MarkoExtension:
) from None


class _RendererDispatcher:
class _RendererDispatcher(Generic[T, ElementT, U]):
name: str

def __init__(
self, types: type[Renderer] | tuple[type[Renderer], ...], func: RendererFunc
self,
types: type[Renderer] | tuple[type[Renderer], ...],
func: RendererFunc[T, ElementT, U],
) -> None:
from marko.ast_renderer import ASTRenderer, XMLRenderer

self._mapping = {types: func}
self._mapping.setdefault((ASTRenderer, XMLRenderer), self.render_ast)

def dispatch(
self: D, types: type[Renderer] | tuple[type[Renderer], ...]
) -> Callable[[RendererFunc], D]:
def decorator(func: RendererFunc) -> D:
self: _RendererDispatcher[T, ElementT, U],
types: type[Renderer] | tuple[type[Renderer], ...],
) -> Callable[[RendererFunc[T, ElementT, U]], _RendererDispatcher[T, ElementT, U]]:
def decorator(func: RendererFunc[T, ElementT, U]) -> _RendererDispatcher[T, ElementT, U]:
self._mapping[types] = func
return self

Expand All @@ -160,10 +164,10 @@ def __set_name__(self, owner: type, name: str) -> None:
self.name = name

@staticmethod
def render_ast(self, element: Element) -> Any:
def render_ast(self, element: "Element") -> Any:
return self.render_children(element)

def super_render(self, r: Any, element: Element) -> Any:
def super_render(self, r: Any, element: "Element") -> Any:
"""Call on the next class in the MRO which has the same method."""
klasses = (c for c in type(r).mro() if self.name in c.__dict__)
try:
Expand All @@ -175,12 +179,23 @@ def super_render(self, r: Any, element: Element) -> Any:
return getattr(parent, self.name)(r, element)

@overload
def __get__(self: D, obj: None, owner: type) -> D: ...
def __get__(
self: _RendererDispatcher[T, ElementT, U],
obj: None,
owner: type,
) -> _RendererDispatcher[T, ElementT, U]: ...

@overload
def __get__(self: D, obj: Renderer, owner: type) -> RendererFunc: ...

def __get__(self: D, obj: Renderer | None, owner: type) -> RendererFunc | D:
def __get__(
self: _RendererDispatcher[T, ElementT, U],
obj: Renderer,
owner: type,
) -> RendererFunc[T, ElementT, U]: ...

def __get__(
self: _RendererDispatcher[T, ElementT, U],
obj: Renderer | None, owner: type,
) -> RendererFunc[T, ElementT, U] | _RendererDispatcher[T, ElementT, U]:
if obj is None:
return self
for types, func in self._mapping.items():
Expand All @@ -191,8 +206,8 @@ def __get__(self: D, obj: Renderer | None, owner: type) -> RendererFunc | D:

def render_dispatch(
types: type[Renderer] | tuple[type[Renderer], ...],
) -> Callable[[RendererFunc], _RendererDispatcher]:
def decorator(func: RendererFunc) -> _RendererDispatcher:
) -> Callable[[RendererFunc[T, ElementT, U]], _RendererDispatcher[T, ElementT, U]]:
def decorator(func: RendererFunc[T, ElementT, U]) -> _RendererDispatcher[T, ElementT, U]:
return _RendererDispatcher(types, func)

return decorator
Loading