Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 58 additions & 13 deletions packages/reflex-base/src/reflex_base/context/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,39 +2,58 @@

from __future__ import annotations

import dataclasses
from contextvars import ContextVar, Token
from types import TracebackType
from typing import ClassVar

from typing_extensions import Self


@dataclasses.dataclass(frozen=True, slots=True, kw_only=True)
class BaseContext:
"""Base context class that acts as an async context manager to set the context var."""
"""Base context class that acts as a sync/async context manager for a per-subclass ContextVar.

Each subclass gets its own :class:`ContextVar` and a class-level mapping from
attached instances to their reset tokens, so any number of subclasses can be
entered concurrently without interfering with each other.

Instances use identity equality (and identity-based hashing) so that two
distinct contexts with the same field values are still considered different.
"""

__slots__ = ()

_context_var: ClassVar[ContextVar[Self]]
_attached_context_token: ClassVar[dict[Self, Token[Self]]]

__eq__ = object.__eq__
__hash__ = object.__hash__

@classmethod
def __init_subclass__(cls, **kwargs):
"""Initialize the context variable for the subclass."""
super(BaseContext, cls).__init_subclass__(**kwargs)
"""Initialize the context variable and token registry for the subclass.

Args:
**kwargs: Forwarded to ``super().__init_subclass__``.
"""
super().__init_subclass__(**kwargs)
cls._context_var = ContextVar(cls.__name__)
cls._attached_context_token = {}

@classmethod
def get(cls) -> Self:
"""Get the context from the context variable.
"""Get the active context from the context variable.

Returns:
The context instance.
The active context instance.

Raises:
LookupError: If no context has been set for this class.
"""
return cls._context_var.get()

@classmethod
def set(cls, context: Self) -> Token[Self]:
"""Set the context in the context variable.
"""Set the active context in the context variable.

Args:
context: The context instance to set.
Expand All @@ -54,23 +73,49 @@ def reset(cls, token: Token[Self]) -> None:
cls._context_var.reset(token)

def __enter__(self) -> Self:
"""Enter the context.
"""Attach this context to the current task.

Returns:
This context instance.

Raises:
RuntimeError: If this instance is already attached.
"""
if self._attached_context_token.get(self) is not None:
msg = "Context is already attached, cannot enter context manager."
raise RuntimeError(msg)
self._attached_context_token[self] = self._context_var.set(self)
return self

def __exit__(self, *exc_info):
"""Exit the context."""
if (token := self._attached_context_token.pop(self)) is not None:
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
"""Detach this context from the current task."""
del exc_type, exc_val, exc_tb
if (token := self._attached_context_token.pop(self, None)) is not None:
self._context_var.reset(token)

def ensure_context_attached(self):
async def __aenter__(self) -> Self:
"""Attach this context to the current task asynchronously.

Returns:
This context instance.
"""
return self.__enter__()

async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
"""Detach this context from the current task asynchronously."""
self.__exit__(exc_type, exc_val, exc_tb)

def ensure_context_attached(self) -> None:
"""Ensure that the context is attached to the current context variable.

Raises:
Expand Down
101 changes: 4 additions & 97 deletions packages/reflex-base/src/reflex_base/plugins/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,10 @@
import dataclasses
import inspect
from collections.abc import Callable, Sequence
from contextvars import ContextVar, Token
from types import TracebackType
from typing import TYPE_CHECKING, Any, ClassVar, Protocol, TypeAlias, TypeVar, cast

from typing_extensions import Self
from typing import TYPE_CHECKING, Any, Protocol, TypeAlias, TypeVar, cast

from reflex_base.components.component import BaseComponent, Component
from reflex_base.context.base import BaseContext
from reflex_base.utils.imports import ParsedImportDict, collapse_imports, merge_imports
from reflex_base.vars import VarData

Expand Down Expand Up @@ -581,97 +578,7 @@ def _apply_replacement(
return replacement, children


@dataclasses.dataclass(kw_only=True)
class BaseContext:
"""Context manager that exposes itself through a class-local context var."""

__context_var__: ClassVar[ContextVar[Self | None]]

_attached_context_token: Token[Self | None] | None = dataclasses.field(
default=None,
init=False,
repr=False,
)

@classmethod
def __init_subclass__(cls, **kwargs: Any) -> None:
"""Initialize a dedicated context variable for each subclass."""
super().__init_subclass__(**kwargs)
cls.__context_var__ = ContextVar(cls.__name__, default=None)

@classmethod
def get(cls) -> Self:
"""Return the active context instance for the current task.

Returns:
The active context instance for the current task.
"""
context = cls.__context_var__.get()
if context is None:
msg = f"No active {cls.__name__} is attached to the current context."
raise RuntimeError(msg)
return context

def __enter__(self) -> Self:
"""Attach this context to the current task.

Returns:
The attached context instance.
"""
if self._attached_context_token is not None:
msg = "Context is already attached and cannot be entered twice."
raise RuntimeError(msg)
self._attached_context_token = type(self).__context_var__.set(self)
return self

def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
"""Detach this context from the current task."""
del exc_type, exc_val, exc_tb
if self._attached_context_token is None:
return
try:
type(self).__context_var__.reset(self._attached_context_token)
finally:
self._attached_context_token = None

async def __aenter__(self) -> Self:
"""Attach this context to the current task asynchronously.

Returns:
The attached context instance.
"""
return self.__enter__()

async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
"""Detach this context from the current task asynchronously."""
self.__exit__(exc_type, exc_val, exc_tb)

def ensure_context_attached(self) -> None:
"""Ensure this instance is the active context for the current task."""
try:
current = type(self).get()
except RuntimeError as err:
msg = (
f"{type(self).__name__} must be entered with 'with' or 'async with' "
"before calling this method."
)
raise RuntimeError(msg) from err
if current is not self:
msg = f"{type(self).__name__} is not attached to the current task context."
raise RuntimeError(msg)


@dataclasses.dataclass(slots=True, kw_only=True)
@dataclasses.dataclass(slots=True, kw_only=True, eq=False)
class PageContext(BaseContext):
"""Mutable compilation state for a single page."""

Expand Down Expand Up @@ -749,7 +656,7 @@ def custom_code_dict(self) -> dict[str, None]:
return dict(self.module_code)


@dataclasses.dataclass(slots=True, kw_only=True)
@dataclasses.dataclass(slots=True, kw_only=True, eq=False)
class CompileContext(BaseContext):
"""Mutable compilation state for an entire compile run."""

Expand Down
20 changes: 8 additions & 12 deletions tests/units/compiler/test_plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,33 +620,31 @@ def test_context_lifecycle_and_cleanup() -> None:
root_component=Fragment.create(),
)

with pytest.raises(RuntimeError, match="No active CompileContext"):
with pytest.raises(LookupError):
CompileContext.get()
with pytest.raises(
RuntimeError, match="must be entered with 'with' or 'async with'"
):
with pytest.raises(RuntimeError, match="must be entered"):
compile_ctx.ensure_context_attached()

with compile_ctx:
assert CompileContext.get() is compile_ctx
with pytest.raises(RuntimeError, match="No active PageContext"):
with pytest.raises(LookupError):
PageContext.get()
with page_ctx:
assert CompileContext.get() is compile_ctx
assert PageContext.get() is page_ctx
page_ctx.ensure_context_attached()
with pytest.raises(RuntimeError, match="No active PageContext"):
with pytest.raises(LookupError):
PageContext.get()
assert CompileContext.get() is compile_ctx

with pytest.raises(RuntimeError, match="No active CompileContext"):
with pytest.raises(LookupError):
CompileContext.get()

with pytest.raises(ValueError, match="boom"), compile_ctx:
msg = "boom"
raise ValueError(msg)

with pytest.raises(RuntimeError, match="No active CompileContext"):
with pytest.raises(LookupError):
CompileContext.get()


Expand Down Expand Up @@ -707,7 +705,7 @@ class DynamicContext(BaseContext):
class AnotherDynamicContext(BaseContext):
pass

assert DynamicContext.__context_var__ is not AnotherDynamicContext.__context_var__
assert DynamicContext._context_var is not AnotherDynamicContext._context_var


def test_apply_style_plugin_matches_legacy_style_behavior() -> None:
Expand Down Expand Up @@ -1031,9 +1029,7 @@ def test_compile_context_requires_attached_context() -> None:
hooks=CompilerHooks(),
)

with pytest.raises(
RuntimeError, match="must be entered with 'with' or 'async with'"
):
with pytest.raises(RuntimeError, match="must be entered"):
compile_ctx.compile()


Expand Down
34 changes: 32 additions & 2 deletions tests/units/reflex_base/context/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from reflex_base.context.base import BaseContext


@dataclasses.dataclass(frozen=True, kw_only=True, slots=True)
@dataclasses.dataclass(frozen=True, kw_only=True, slots=True, eq=False)
class _TestContext(BaseContext):
"""Minimal BaseContext subclass for unit testing."""

Expand Down Expand Up @@ -83,7 +83,7 @@ def test_ensure_context_attached():
def test_subclasses_have_independent_context_vars():
"""Two BaseContext subclasses do not share their ContextVar."""

@dataclasses.dataclass(frozen=True, kw_only=True, slots=True)
@dataclasses.dataclass(frozen=True, kw_only=True, slots=True, eq=False)
class _OtherContext(BaseContext):
value: int = 0

Expand All @@ -92,3 +92,33 @@ class _OtherContext(BaseContext):
with ctx_a, ctx_b:
assert _TestContext.get().label == "a"
assert _OtherContext.get().value == 42


def test_identity_equality_for_subclasses_with_eq_false():
"""Two BaseContext subclass instances with the same fields are not equal."""
ctx_a = _TestContext(label="same")
ctx_b = _TestContext(label="same")
assert ctx_a is not ctx_b
assert ctx_a != ctx_b
assert hash(ctx_a) != hash(ctx_b)


def test_identity_equality_isolates_entered_state():
"""Two equal-by-field instances can be entered independently."""
ctx_a = _TestContext(label="same")
ctx_b = _TestContext(label="same")
with ctx_a:
# Entering ctx_b must not see ctx_a's attachment as its own.
with ctx_b:
assert _TestContext.get() is ctx_b
assert _TestContext.get() is ctx_a


async def test_async_context_manager():
"""Async __aenter__/__aexit__ attaches and detaches the context."""
ctx = _TestContext(label="async")
async with ctx as entered:
assert entered is ctx
assert _TestContext.get() is ctx
with pytest.raises(LookupError):
_TestContext.get()
Loading