diff --git a/effectful/ops/semantics.py b/effectful/ops/semantics.py index 16154bb7..3afd5d1c 100644 --- a/effectful/ops/semantics.py +++ b/effectful/ops/semantics.py @@ -1,5 +1,6 @@ import collections.abc import contextlib +import dataclasses import functools import types import typing @@ -276,8 +277,19 @@ def evaluate[T](expr: Expr[T], *, intp: Interpretation | None = None) -> Expr[T] return type(expr)(evaluate(item) for item in expr) # type: ignore elif isinstance(expr, collections.abc.ValuesView): return [evaluate(item) for item in expr] # type: ignore + elif dataclasses.is_dataclass(expr) and not isinstance(expr, type): + return typing.cast( + T, + dataclasses.replace( + expr, + **{ + field.name: evaluate(getattr(expr, field.name)) + for field in dataclasses.fields(expr) + }, + ), + ) else: - return expr + return typing.cast(T, expr) def typeof[T](term: Expr[T]) -> type[T]: diff --git a/effectful/ops/syntax.py b/effectful/ops/syntax.py index e97f1816..56074ede 100644 --- a/effectful/ops/syntax.py +++ b/effectful/ops/syntax.py @@ -1140,6 +1140,19 @@ def syntactic_eq[T](x: Expr[T], other: Expr[T]) -> bool: return len(x) == len(other) and all( syntactic_eq(a, b) for a, b in zip(x, other) ) + elif ( + dataclasses.is_dataclass(x) + and not isinstance(x, type) + and dataclasses.is_dataclass(other) + and not isinstance(other, type) + ): + return type(x) == type(other) and syntactic_eq( + {field.name: getattr(x, field.name) for field in dataclasses.fields(x)}, + { + field.name: getattr(other, field.name) + for field in dataclasses.fields(other) + }, + ) else: return x == other diff --git a/tests/test_ops_syntax.py b/tests/test_ops_syntax.py index 2d6fb215..82cae1ba 100644 --- a/tests/test_ops_syntax.py +++ b/tests/test_ops_syntax.py @@ -1,3 +1,4 @@ +import dataclasses import functools import inspect from collections.abc import Callable, Iterable, Iterator, Mapping @@ -541,3 +542,39 @@ def test_defstream_1(): # assert isinstance(tm_iter_next, numbers.Number) # TODO # assert issubclass(typeof(tm_iter_next), numbers.Number) assert tm_iter_next.op is next_ + + +def test_eval_dataclass(): + @dataclasses.dataclass + class Point: + x: int + y: int + + @dataclasses.dataclass + class Line: + start: Point + end: Point + + @dataclasses.dataclass + class Lines: + origin: Point + lines: list[Line] + + x, y = defop(int, name="x"), defop(int, name="y") + p1 = Point(x(), y()) + p2 = Point(x() + 1, y() + 1) + line = Line(p1, p2) + lines = Lines(p1, [line]) + + assert {x, y} <= fvsof(lines) + + assert p1 == lines.origin + + with handler({x: lambda: 3, y: lambda: 4}): + evaluated_lines = evaluate(lines) + + assert isinstance(evaluated_lines, Lines) + assert evaluated_lines == Lines( + origin=Point(3, 4), + lines=[Line(Point(3, 4), Point(4, 5))], + )