Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 22 additions & 10 deletions effectful/ops/semantics.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,19 +211,31 @@ def evaluate[T](
@evaluate.register(bytes)
def _evaluate_object[T](expr: T, **kwargs) -> T:
if dataclasses.is_dataclass(expr) and not isinstance(expr, type):
return typing.cast(
T,
dataclasses.replace(
expr,
**{
field.name: evaluate(getattr(expr, field.name))
for field in dataclasses.fields(expr)
},
),
)
return _evaluate_dataclass(expr, **kwargs)
return expr


def _get_dataclass_constr_op(typ):
if hasattr(typ, "constr_op"):
return typ.constr_op

@Operation.define
def constr_op(*args, **kwargs) -> typ:
return typ(*args, **kwargs)

typ.constr_op = constr_op
return constr_op


def _evaluate_dataclass[T](expr: T, **kwargs) -> T:
dataclass_op = _get_dataclass_constr_op(type(expr))
subst = {
field.name: evaluate(getattr(expr, field.name))
for field in dataclasses.fields(expr) # type: ignore[arg-type]
}
return typing.cast(T, dataclass_op(**subst))


@evaluate.register(Term)
def _evaluate_term(expr: Term, **kwargs):
args = tuple(evaluate(arg) for arg in expr.args)
Expand Down
21 changes: 21 additions & 0 deletions tests/test_ops_semantics.py
Original file line number Diff line number Diff line change
Expand Up @@ -879,6 +879,27 @@ def __init__(self, x: int):
assert fvsof(A(v())) == {v}


def test_defdata_dataclass_init_effects() -> None:
@Operation.define
def f(x: int):
raise NotHandled

@dataclasses.dataclass
class A:
x: int

def __init__(self, x: int):
self.x = f(x)

@Operation.define
def g(a: A):
raise NotHandled

v = Operation.define(int)
t = g(A(v()))
assert isinstance(t.args[0].x, Term)


def test_instanceop_super() -> None:
class A:
@Operation.define
Expand Down
Loading