From 2fb706227d74bda9b80da7fa1828a42df3ec2135 Mon Sep 17 00:00:00 2001 From: Eli Date: Mon, 16 Feb 2026 18:39:07 -0500 Subject: [PATCH 1/2] add a couple tests --- effectful/internals/runtime.py | 16 ++++--- effectful/ops/semantics.py | 17 +++++--- tests/test_ops_semantics.py | 77 ++++++++++++++++++++++++++++++++++ 3 files changed, 99 insertions(+), 11 deletions(-) diff --git a/effectful/internals/runtime.py b/effectful/internals/runtime.py index 4ad7b534f..2ca0bf3a3 100644 --- a/effectful/internals/runtime.py +++ b/effectful/internals/runtime.py @@ -1,7 +1,8 @@ import contextlib import dataclasses import functools -from collections.abc import Callable, Mapping +import typing +from collections.abc import Callable, Mapping, MutableMapping from threading import local from effectful.ops.syntax import defop @@ -11,11 +12,12 @@ @dataclasses.dataclass class Runtime[S, T](local): interpretation: "Interpretation[S, T]" + cache: MutableMapping[int, typing.Any] | None @functools.lru_cache(maxsize=1) def get_runtime() -> Runtime: - return Runtime(interpretation={}) + return Runtime(interpretation={}, cache=None) def get_interpretation(): @@ -25,12 +27,16 @@ def get_interpretation(): @contextlib.contextmanager def interpreter(intp: "Interpretation"): r = get_runtime() - old_intp = r.interpretation + old_intp, old_cache = r.interpretation, r.cache try: - old_intp, r.interpretation = r.interpretation, dict(intp) + old_intp, r.interpretation = r.interpretation, intp + old_cache, r.cache = ( + r.cache, + old_cache if old_intp is intp and old_cache is not None else {}, + ) yield intp finally: - r.interpretation = old_intp + r.interpretation, r.cache = old_intp, old_cache @defop diff --git a/effectful/ops/semantics.py b/effectful/ops/semantics.py index f7475dc71..cd5394448 100644 --- a/effectful/ops/semantics.py +++ b/effectful/ops/semantics.py @@ -198,12 +198,17 @@ def evaluate[T]( 6 """ - from effectful.internals.runtime import interpreter - - if intp is not None: - return interpreter(intp)(evaluate)(expr) - - return __dispatch(type(expr))(expr) + from effectful.internals.runtime import get_runtime, interpreter + + with interpreter(intp if intp is not None else get_runtime().interpretation): + cache = get_runtime().cache + assert cache is not None, "Cache should be initialized by interpreter" + try: + return cache[id(expr)] + except KeyError: + result = __dispatch(type(expr))(expr) + cache[id(expr)] = result + return result @evaluate.register(object) diff --git a/tests/test_ops_semantics.py b/tests/test_ops_semantics.py index 79b806633..9ffa29f31 100644 --- a/tests/test_ops_semantics.py +++ b/tests/test_ops_semantics.py @@ -863,3 +863,80 @@ def get_mixed() -> Literal[1, "a"]: with pytest.raises(TypeError, match="Union types are not supported"): typeof(get_mixed()) + + +def test_evaluate_dag_no_exponential_blowup(): + """A DAG of nested tuples sharing the same Term is O(n), not O(2^n).""" + call_count = 0 + + @defop + def counted() -> int: + raise NotHandled + + def counted_handler(): + nonlocal call_count + call_count += 1 + return 42 + + # Build a DAG of nested tuples: each level shares the same child object. + # As a tree this would have 2^depth leaves; as a DAG it's depth+1 objects. + depth = 20 + node = counted() + for _ in range(depth): + node = (node, node) + + call_count = 0 + with handler({counted: counted_handler}): + result = evaluate(node) + + # The handler should only be called once (the shared Term) + assert call_count == 1 + # The result should be nested tuples of 42 + leaf = result + for _ in range(depth): + assert isinstance(leaf, tuple) and len(leaf) == 2 + assert leaf[0] is leaf[1] # memoization returns same object + leaf = leaf[0] + assert leaf == 42 + + +def test_evaluate_dag_cache_isolation(): + """Different interpretations produce different results for the same expr.""" + x = defop(int, name="x") + shared = x() + expr = (shared, shared) + + assert evaluate(expr, intp={x: lambda: 1}) == (1, 1) + assert evaluate(expr, intp={x: lambda: 99}) == (99, 99) + + +def test_evaluate_dag_nested_different_intp(): + """evaluate(expr, intp=...) inside a handler gets its own cache.""" + x = defop(int, name="x") + y = defop(int, name="y") + + shared = x() + inner_expr = (shared, shared) + + result = evaluate( + y(), intp={y: lambda: evaluate(inner_expr, intp={x: lambda: 7})} + ) + assert result == (7, 7) + + +def test_evaluate_dag_matches_tree(): + """DAG evaluation produces the same result as evaluating an equivalent tree.""" + x = defop(int, name="x") + + @defop + def mul(a: int, b: int) -> int: + raise NotHandled + + shared = x() + dag = (mul(shared, shared), mul(shared, shared)) + + # Equivalent tree with distinct Term objects + tree = (mul(x(), x()), mul(x(), x())) + + intp = {x: lambda: 3, mul: lambda a, b: a * b} + assert evaluate(dag, intp=intp) == evaluate(tree, intp=intp) == (9, 9) From a94da493a24360d63b54e7cfe532b6847b881de0 Mon Sep 17 00:00:00 2001 From: Eli Date: Thu, 19 Feb 2026 16:45:25 -0500 Subject: [PATCH 2/2] had to make ugly changes to ugly internals to make this work --- effectful/internals/product_n.py | 48 ++++++++++++++++++------------ effectful/internals/unification.py | 21 +++++++++++++ effectful/ops/semantics.py | 14 +++++---- effectful/ops/syntax.py | 14 ++++++--- tests/test_ops_semantics.py | 6 ++-- 5 files changed, 71 insertions(+), 32 deletions(-) diff --git a/effectful/internals/product_n.py b/effectful/internals/product_n.py index 4b8bd2a81..ca7c2f0b3 100644 --- a/effectful/internals/product_n.py +++ b/effectful/internals/product_n.py @@ -58,49 +58,59 @@ def _unpack(x, prompt): return x -def map_structure(func, expr): +def map_structure(func, expr, _cache=None): + if _cache is None: + _cache = {} + + key = id(expr) + if key in _cache: + ref, result = _cache[key] + if ref is expr: + return result + + def recurse(x): + return map_structure(func, x, _cache) + if isinstance(expr, collections.abc.Mapping): if isinstance(expr, collections.defaultdict): - return type(expr)( - expr.default_factory, map_structure(func, tuple(expr.items())) - ) + result = type(expr)(expr.default_factory, recurse(tuple(expr.items()))) elif isinstance(expr, types.MappingProxyType): - return type(expr)(dict(map_structure(func, tuple(expr.items())))) + result = type(expr)(dict(recurse(tuple(expr.items())))) else: - return type(expr)(map_structure(func, tuple(expr.items()))) + result = type(expr)(recurse(tuple(expr.items()))) elif isinstance(expr, collections.abc.Sequence): if isinstance(expr, str | bytes): - return expr + result = expr elif ( isinstance(expr, tuple) and hasattr(expr, "_fields") and all(hasattr(expr, field) for field in getattr(expr, "_fields")) ): # namedtuple - return type(expr)( - **{ - field: map_structure(func, getattr(expr, field)) - for field in expr._fields - } + result = type(expr)( + **{field: recurse(getattr(expr, field)) for field in expr._fields} ) else: - return type(expr)(map_structure(func, item) for item in expr) + result = type(expr)(recurse(item) for item in expr) elif isinstance(expr, collections.abc.Set): if isinstance(expr, collections.abc.ItemsView | collections.abc.KeysView): - return {map_structure(func, item) for item in expr} + result = {recurse(item) for item in expr} else: - return type(expr)(map_structure(func, item) for item in expr) + result = type(expr)(recurse(item) for item in expr) elif isinstance(expr, collections.abc.ValuesView): - return [map_structure(func, item) for item in expr] + result = [recurse(item) for item in expr] elif dataclasses.is_dataclass(expr) and not isinstance(expr, type): - return dataclasses.replace( + result = dataclasses.replace( expr, **{ - field.name: map_structure(func, getattr(expr, field.name)) + field.name: recurse(getattr(expr, field.name)) for field in dataclasses.fields(expr) }, ) else: - return func(expr) + result = func(expr) + + _cache[key] = (expr, result) + return result def productN(intps: Mapping[Operation, Interpretation]) -> Interpretation: diff --git a/effectful/internals/unification.py b/effectful/internals/unification.py index 71d6583f2..94345a9ee 100644 --- a/effectful/internals/unification.py +++ b/effectful/internals/unification.py @@ -64,6 +64,7 @@ import inspect import numbers import operator +import threading import types import typing from dataclasses import dataclass @@ -815,6 +816,26 @@ def _(value: str | bytes | range | None): return Box(type(value)) +_nested_type_dispatch = nested_type +_nested_type_state = threading.local() +_NESTED_TYPE_MAX_DEPTH = 5 + + +def nested_type(value) -> Box[TypeExpression]: # type: ignore[no-redef] + depth = getattr(_nested_type_state, "depth", 0) + if depth >= _NESTED_TYPE_MAX_DEPTH: + return Box(type(value)) + _nested_type_state.depth = depth + 1 + try: + return _nested_type_dispatch(value) + finally: + _nested_type_state.depth = depth + + +nested_type.register = _nested_type_dispatch.register # type: ignore[attr-defined] +nested_type.dispatch = _nested_type_dispatch.dispatch # type: ignore[attr-defined] + + def freetypevars(typ) -> collections.abc.Set[TypeVariable]: """ Return a set of free type variables in the given type expression. diff --git a/effectful/ops/semantics.py b/effectful/ops/semantics.py index cd5394448..918bd14c5 100644 --- a/effectful/ops/semantics.py +++ b/effectful/ops/semantics.py @@ -203,12 +203,14 @@ def evaluate[T]( with interpreter(intp if intp is not None else get_runtime().interpretation): cache = get_runtime().cache assert cache is not None, "Cache should be initialized by interpreter" - try: - return cache[id(expr)] - except KeyError: - result = __dispatch(type(expr))(expr) - cache[id(expr)] = result - return result + key = id(expr) + if key in cache: + ref, result = cache[key] + if ref is expr: + return result + result = __dispatch(type(expr))(expr) + cache[key] = (expr, result) + return result @evaluate.register(object) diff --git a/effectful/ops/syntax.py b/effectful/ops/syntax.py index 764016752..6d25b67c4 100644 --- a/effectful/ops/syntax.py +++ b/effectful/ops/syntax.py @@ -361,16 +361,22 @@ def analyze(self, bound_sig: inspect.BoundArguments) -> frozenset[Operation]: elif param_ordinal: # Only process if there's a Scoped annotation # We can't use flatten here because we want to be able # to see dict keys - def extract_operations(obj): + def extract_operations(obj, _seen=None): + if _seen is None: + _seen = set() + obj_id = id(obj) + if obj_id in _seen: + return + _seen.add(obj_id) if isinstance(obj, Operation): param_bound_vars.add(obj) elif isinstance(obj, dict): for k, v in obj.items(): - extract_operations(k) - extract_operations(v) + extract_operations(k, _seen) + extract_operations(v, _seen) elif isinstance(obj, list | set | tuple): for v in obj: - extract_operations(v) + extract_operations(v, _seen) extract_operations(param_value) diff --git a/tests/test_ops_semantics.py b/tests/test_ops_semantics.py index 9ffa29f31..3ce7ae029 100644 --- a/tests/test_ops_semantics.py +++ b/tests/test_ops_semantics.py @@ -889,6 +889,8 @@ def counted_handler(): with handler({counted: counted_handler}): result = evaluate(node) + deffn(node, counted)(0) + # The handler should only be called once (the shared Term) assert call_count == 1 # The result should be nested tuples of 42 @@ -918,9 +920,7 @@ def test_evaluate_dag_nested_different_intp(): shared = x() inner_expr = (shared, shared) - result = evaluate( - y(), intp={y: lambda: evaluate(inner_expr, intp={x: lambda: 7})} - ) + result = evaluate(y(), intp={y: lambda: evaluate(inner_expr, intp={x: lambda: 7})}) assert result == (7, 7)