diff --git a/effectful/ops/semantics.py b/effectful/ops/semantics.py index f7678fd2..8059b538 100644 --- a/effectful/ops/semantics.py +++ b/effectful/ops/semantics.py @@ -211,19 +211,31 @@ def evaluate[T]( @evaluate.register(bytes) def _evaluate_object[T](expr: T, **kwargs) -> T: if dataclasses.is_dataclass(expr) and not isinstance(expr, type): - return typing.cast( - T, - dataclasses.replace( - expr, - **{ - field.name: evaluate(getattr(expr, field.name)) - for field in dataclasses.fields(expr) - }, - ), - ) + return _evaluate_dataclass(expr, **kwargs) return expr +def _get_dataclass_constr_op(typ): + if hasattr(typ, "constr_op"): + return typ.constr_op + + @Operation.define + def constr_op(*args, **kwargs) -> typ: + return typ(*args, **kwargs) + + typ.constr_op = constr_op + return constr_op + + +def _evaluate_dataclass[T](expr: T, **kwargs) -> T: + dataclass_op = _get_dataclass_constr_op(type(expr)) + subst = { + field.name: evaluate(getattr(expr, field.name)) + for field in dataclasses.fields(expr) # type: ignore[arg-type] + } + return typing.cast(T, dataclass_op(**subst)) + + @evaluate.register(Term) def _evaluate_term(expr: Term, **kwargs): args = tuple(evaluate(arg) for arg in expr.args) diff --git a/tests/test_ops_semantics.py b/tests/test_ops_semantics.py index 78179d33..7eb65be5 100644 --- a/tests/test_ops_semantics.py +++ b/tests/test_ops_semantics.py @@ -879,6 +879,27 @@ def __init__(self, x: int): assert fvsof(A(v())) == {v} +def test_defdata_dataclass_init_effects() -> None: + @Operation.define + def f(x: int): + raise NotHandled + + @dataclasses.dataclass + class A: + x: int + + def __init__(self, x: int): + self.x = f(x) + + @Operation.define + def g(a: A): + raise NotHandled + + v = Operation.define(int) + t = g(A(v())) + assert isinstance(t.args[0].x, Term) + + def test_instanceop_super() -> None: class A: @Operation.define