Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 29 additions & 19 deletions effectful/internals/product_n.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,49 +58,59 @@ def _unpack(x, prompt):
return x


def map_structure(func, expr):
def map_structure(func, expr, _cache=None):
if _cache is None:
_cache = {}

key = id(expr)
if key in _cache:
ref, result = _cache[key]
if ref is expr:
return result

def recurse(x):
return map_structure(func, x, _cache)

if isinstance(expr, collections.abc.Mapping):
if isinstance(expr, collections.defaultdict):
return type(expr)(
expr.default_factory, map_structure(func, tuple(expr.items()))
)
result = type(expr)(expr.default_factory, recurse(tuple(expr.items())))
elif isinstance(expr, types.MappingProxyType):
return type(expr)(dict(map_structure(func, tuple(expr.items()))))
result = type(expr)(dict(recurse(tuple(expr.items()))))
else:
return type(expr)(map_structure(func, tuple(expr.items())))
result = type(expr)(recurse(tuple(expr.items())))
elif isinstance(expr, collections.abc.Sequence):
if isinstance(expr, str | bytes):
return expr
result = expr
elif (
isinstance(expr, tuple)
and hasattr(expr, "_fields")
and all(hasattr(expr, field) for field in getattr(expr, "_fields"))
): # namedtuple
return type(expr)(
**{
field: map_structure(func, getattr(expr, field))
for field in expr._fields
}
result = type(expr)(
**{field: recurse(getattr(expr, field)) for field in expr._fields}
)
else:
return type(expr)(map_structure(func, item) for item in expr)
result = type(expr)(recurse(item) for item in expr)
elif isinstance(expr, collections.abc.Set):
if isinstance(expr, collections.abc.ItemsView | collections.abc.KeysView):
return {map_structure(func, item) for item in expr}
result = {recurse(item) for item in expr}
else:
return type(expr)(map_structure(func, item) for item in expr)
result = type(expr)(recurse(item) for item in expr)
elif isinstance(expr, collections.abc.ValuesView):
return [map_structure(func, item) for item in expr]
result = [recurse(item) for item in expr]
elif dataclasses.is_dataclass(expr) and not isinstance(expr, type):
return dataclasses.replace(
result = dataclasses.replace(
expr,
**{
field.name: map_structure(func, getattr(expr, field.name))
field.name: recurse(getattr(expr, field.name))
for field in dataclasses.fields(expr)
},
)
else:
return func(expr)
result = func(expr)

_cache[key] = (expr, result)
return result


def productN(intps: Mapping[Operation, Interpretation]) -> Interpretation:
Expand Down
16 changes: 11 additions & 5 deletions effectful/internals/runtime.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import contextlib
import dataclasses
import functools
from collections.abc import Callable, Mapping
import typing
from collections.abc import Callable, Mapping, MutableMapping
from threading import local

from effectful.ops.syntax import defop
Expand All @@ -11,11 +12,12 @@
@dataclasses.dataclass
class Runtime[S, T](local):
interpretation: "Interpretation[S, T]"
cache: MutableMapping[int, typing.Any] | None


@functools.lru_cache(maxsize=1)
def get_runtime() -> Runtime:
return Runtime(interpretation={})
return Runtime(interpretation={}, cache=None)


def get_interpretation():
Expand All @@ -25,12 +27,16 @@ def get_interpretation():
@contextlib.contextmanager
def interpreter(intp: "Interpretation"):
r = get_runtime()
old_intp = r.interpretation
old_intp, old_cache = r.interpretation, r.cache
try:
old_intp, r.interpretation = r.interpretation, dict(intp)
old_intp, r.interpretation = r.interpretation, intp
old_cache, r.cache = (
r.cache,
old_cache if old_intp is intp and old_cache is not None else {},
)
yield intp
finally:
r.interpretation = old_intp
r.interpretation, r.cache = old_intp, old_cache


@defop
Expand Down
21 changes: 21 additions & 0 deletions effectful/internals/unification.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
import inspect
import numbers
import operator
import threading
import types
import typing
from dataclasses import dataclass
Expand Down Expand Up @@ -815,6 +816,26 @@ def _(value: str | bytes | range | None):
return Box(type(value))


_nested_type_dispatch = nested_type
_nested_type_state = threading.local()
_NESTED_TYPE_MAX_DEPTH = 5


def nested_type(value) -> Box[TypeExpression]: # type: ignore[no-redef]
depth = getattr(_nested_type_state, "depth", 0)
if depth >= _NESTED_TYPE_MAX_DEPTH:
return Box(type(value))
_nested_type_state.depth = depth + 1
try:
return _nested_type_dispatch(value)
finally:
_nested_type_state.depth = depth


nested_type.register = _nested_type_dispatch.register # type: ignore[attr-defined]
nested_type.dispatch = _nested_type_dispatch.dispatch # type: ignore[attr-defined]


def freetypevars(typ) -> collections.abc.Set[TypeVariable]:
"""
Return a set of free type variables in the given type expression.
Expand Down
19 changes: 13 additions & 6 deletions effectful/ops/semantics.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,12 +198,19 @@ def evaluate[T](
6

"""
from effectful.internals.runtime import interpreter

if intp is not None:
return interpreter(intp)(evaluate)(expr)

return __dispatch(type(expr))(expr)
from effectful.internals.runtime import get_runtime, interpreter

with interpreter(intp if intp is not None else get_runtime().interpretation):
cache = get_runtime().cache
assert cache is not None, "Cache should be initialized by interpreter"
key = id(expr)
if key in cache:
ref, result = cache[key]
if ref is expr:
return result
result = __dispatch(type(expr))(expr)
cache[key] = (expr, result)
return result


@evaluate.register(object)
Expand Down
14 changes: 10 additions & 4 deletions effectful/ops/syntax.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,16 +361,22 @@ def analyze(self, bound_sig: inspect.BoundArguments) -> frozenset[Operation]:
elif param_ordinal: # Only process if there's a Scoped annotation
# We can't use flatten here because we want to be able
# to see dict keys
def extract_operations(obj):
def extract_operations(obj, _seen=None):
if _seen is None:
_seen = set()
obj_id = id(obj)
if obj_id in _seen:
return
_seen.add(obj_id)
if isinstance(obj, Operation):
param_bound_vars.add(obj)
elif isinstance(obj, dict):
for k, v in obj.items():
extract_operations(k)
extract_operations(v)
extract_operations(k, _seen)
extract_operations(v, _seen)
elif isinstance(obj, list | set | tuple):
for v in obj:
extract_operations(v)
extract_operations(v, _seen)

extract_operations(param_value)

Expand Down
77 changes: 77 additions & 0 deletions tests/test_ops_semantics.py
Original file line number Diff line number Diff line change
Expand Up @@ -863,3 +863,80 @@ def get_mixed() -> Literal[1, "a"]:

with pytest.raises(TypeError, match="Union types are not supported"):
typeof(get_mixed())


def test_evaluate_dag_no_exponential_blowup():
"""A DAG of nested tuples sharing the same Term is O(n), not O(2^n)."""
call_count = 0

@defop
def counted() -> int:
raise NotHandled

def counted_handler():
nonlocal call_count
call_count += 1
return 42

# Build a DAG of nested tuples: each level shares the same child object.
# As a tree this would have 2^depth leaves; as a DAG it's depth+1 objects.
depth = 20
node = counted()
for _ in range(depth):
node = (node, node)

call_count = 0
with handler({counted: counted_handler}):
result = evaluate(node)

deffn(node, counted)(0)

# The handler should only be called once (the shared Term)
assert call_count == 1
# The result should be nested tuples of 42
leaf = result
for _ in range(depth):
assert isinstance(leaf, tuple) and len(leaf) == 2
assert leaf[0] is leaf[1] # memoization returns same object
leaf = leaf[0]
assert leaf == 42


def test_evaluate_dag_cache_isolation():
"""Different interpretations produce different results for the same expr."""
x = defop(int, name="x")
shared = x()
expr = (shared, shared)

assert evaluate(expr, intp={x: lambda: 1}) == (1, 1)
assert evaluate(expr, intp={x: lambda: 99}) == (99, 99)


def test_evaluate_dag_nested_different_intp():
"""evaluate(expr, intp=...) inside a handler gets its own cache."""
x = defop(int, name="x")
y = defop(int, name="y")

shared = x()
inner_expr = (shared, shared)

result = evaluate(y(), intp={y: lambda: evaluate(inner_expr, intp={x: lambda: 7})})
assert result == (7, 7)


def test_evaluate_dag_matches_tree():
"""DAG evaluation produces the same result as evaluating an equivalent tree."""
x = defop(int, name="x")

@defop
def mul(a: int, b: int) -> int:
raise NotHandled

shared = x()
dag = (mul(shared, shared), mul(shared, shared))

# Equivalent tree with distinct Term objects
tree = (mul(x(), x()), mul(x(), x()))

intp = {x: lambda: 3, mul: lambda a, b: a * b}
assert evaluate(dag, intp=intp) == evaluate(tree, intp=intp) == (9, 9)
Loading