Skip to content
14 changes: 13 additions & 1 deletion effectful/ops/semantics.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import collections.abc
import contextlib
import dataclasses
import functools
import types
import typing
Expand Down Expand Up @@ -276,8 +277,19 @@ def evaluate[T](expr: Expr[T], *, intp: Interpretation | None = None) -> Expr[T]
return type(expr)(evaluate(item) for item in expr) # type: ignore
elif isinstance(expr, collections.abc.ValuesView):
return [evaluate(item) for item in expr] # type: ignore
elif dataclasses.is_dataclass(expr) and not isinstance(expr, type):
return typing.cast(
T,
dataclasses.replace(
expr,
**{
field.name: evaluate(getattr(expr, field.name))
for field in dataclasses.fields(expr)
},
),
)
else:
return expr
return typing.cast(T, expr)


def typeof[T](term: Expr[T]) -> type[T]:
Expand Down
13 changes: 13 additions & 0 deletions effectful/ops/syntax.py
Original file line number Diff line number Diff line change
Expand Up @@ -1140,6 +1140,19 @@ def syntactic_eq[T](x: Expr[T], other: Expr[T]) -> bool:
return len(x) == len(other) and all(
syntactic_eq(a, b) for a, b in zip(x, other)
)
elif (
dataclasses.is_dataclass(x)
and not isinstance(x, type)
and dataclasses.is_dataclass(other)
and not isinstance(other, type)
):
return type(x) == type(other) and syntactic_eq(
{field.name: getattr(x, field.name) for field in dataclasses.fields(x)},
{
field.name: getattr(other, field.name)
for field in dataclasses.fields(other)
},
)
else:
return x == other

Expand Down
37 changes: 37 additions & 0 deletions tests/test_ops_syntax.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import dataclasses
import functools
import inspect
from collections.abc import Callable, Iterable, Iterator, Mapping
Expand Down Expand Up @@ -541,3 +542,39 @@ def test_defstream_1():
# assert isinstance(tm_iter_next, numbers.Number) # TODO
# assert issubclass(typeof(tm_iter_next), numbers.Number)
assert tm_iter_next.op is next_


def test_eval_dataclass():
@dataclasses.dataclass
class Point:
x: int
y: int

@dataclasses.dataclass
class Line:
start: Point
end: Point

@dataclasses.dataclass
class Lines:
origin: Point
lines: list[Line]

x, y = defop(int, name="x"), defop(int, name="y")
p1 = Point(x(), y())
p2 = Point(x() + 1, y() + 1)
line = Line(p1, p2)
lines = Lines(p1, [line])

assert {x, y} <= fvsof(lines)

assert p1 == lines.origin

with handler({x: lambda: 3, y: lambda: 4}):
evaluated_lines = evaluate(lines)

assert isinstance(evaluated_lines, Lines)
assert evaluated_lines == Lines(
origin=Point(3, 4),
lines=[Line(Point(3, 4), Point(4, 5))],
)