diff --git a/README.rst b/README.rst index e5e648b4..8bd8fdfe 100644 --- a/README.rst +++ b/README.rst @@ -50,9 +50,10 @@ Here's an example demonstrating how ``effectful`` can be used to implement a sim import functools from effectful.ops.types import Term - from effectful.ops.syntax import defop + from effectful.ops.syntax import defdata, defop from effectful.ops.semantics import handler, evaluate, coproduct, fwd - from effectful.handlers.numbers import add + + add = defdata.dispatch(int).__add__ def beta_add(x: int, y: int) -> int: match x, y: diff --git a/docs/source/effectful.rst b/docs/source/effectful.rst index 5ef25be8..818b3ce0 100644 --- a/docs/source/effectful.rst +++ b/docs/source/effectful.rst @@ -44,13 +44,6 @@ Handlers :members: :undoc-members: -Numbers -^^^^^^^ - -.. automodule:: effectful.handlers.numbers - :members: - :undoc-members: - Pyro ^^^^ diff --git a/docs/source/lambda_.py b/docs/source/lambda_.py index b270ebd3..822f32d9 100644 --- a/docs/source/lambda_.py +++ b/docs/source/lambda_.py @@ -1,11 +1,12 @@ import functools from typing import Annotated, Callable -from effectful.handlers.numbers import add from effectful.ops.semantics import coproduct, evaluate, fvsof, fwd, handler -from effectful.ops.syntax import Scoped, defop, syntactic_eq +from effectful.ops.syntax import Scoped, defdata, defop, syntactic_eq from effectful.ops.types import Expr, Interpretation, NotHandled, Operation, Term +add = defdata.dispatch(int).__add__ + @defop def App[S, T](f: Callable[[S], T], arg: S) -> T: diff --git a/docs/source/readme_example.py b/docs/source/readme_example.py index 4bf9e7d8..aa9e04fb 100644 --- a/docs/source/readme_example.py +++ b/docs/source/readme_example.py @@ -1,10 +1,11 @@ import functools -from effectful.handlers.numbers import add from effectful.ops.semantics import coproduct, evaluate, fwd, handler -from effectful.ops.syntax import defop +from effectful.ops.syntax import defdata, defop from effectful.ops.types import Term +add = defdata.dispatch(int).__add__ + def beta_add(x: int, y: int) -> int: match x, y: diff --git a/docs/source/semi_ring.py b/docs/source/semi_ring.py index a8233a11..fd799c1b 100644 --- a/docs/source/semi_ring.py +++ b/docs/source/semi_ring.py @@ -3,7 +3,6 @@ import types from typing import Annotated, Tuple, Union, cast, overload -import effectful.handlers.numbers # noqa: F401 from effectful.ops.semantics import coproduct, evaluate, fwd, handler from effectful.ops.syntax import Scoped, defop from effectful.ops.types import Interpretation, NotHandled, Operation, Term diff --git a/effectful/handlers/jax/_handlers.py b/effectful/handlers/jax/_handlers.py index 54cc8147..5bfefa47 100644 --- a/effectful/handlers/jax/_handlers.py +++ b/effectful/handlers/jax/_handlers.py @@ -12,7 +12,6 @@ import tree -import effectful.handlers.numbers # noqa: F401 from effectful.ops.semantics import fvsof, typeof from effectful.ops.syntax import ( Scoped, diff --git a/effectful/handlers/numbers.py b/effectful/handlers/numbers.py deleted file mode 100644 index d83150fb..00000000 --- a/effectful/handlers/numbers.py +++ /dev/null @@ -1,246 +0,0 @@ -""" -This module provides a term representation for numbers and operations on them. -""" - -import numbers -import operator -from typing import Any - -from effectful.ops.syntax import defdata, defop -from effectful.ops.types import Expr, NotHandled, Operation, Term - - -@defdata.register(numbers.Number) -@numbers.Number.register -class _NumberTerm(Term[numbers.Number]): - def __init__( - self, op: Operation[..., numbers.Number], *args: Expr, **kwargs: Expr - ) -> None: - self._op = op - self._args = args - self._kwargs = kwargs - - @property - def op(self) -> Operation[..., numbers.Number]: - return self._op - - @property - def args(self) -> tuple: - return self._args - - @property - def kwargs(self) -> dict: - return self._kwargs - - def __hash__(self): - return hash((self.op, tuple(self.args), tuple(self.kwargs.items()))) - - -# Complex specific methods -def _wrap_cmp(op): - def _wrapped_op[T_Number: numbers.Number](x: T_Number, y: T_Number) -> bool: - if not any(isinstance(a, Term) for a in (x, y)): - return op(x, y) - else: - raise NotHandled - - _wrapped_op.__name__ = op.__name__ - return _wrapped_op - - -def _wrap_binop(op): - def _wrapped_op[T_Number: numbers.Number](x: T_Number, y: T_Number) -> T_Number: - if not any(isinstance(a, Term) for a in (x, y)): - return op(x, y) - else: - raise NotHandled - - _wrapped_op.__name__ = op.__name__ - return _wrapped_op - - -def _wrap_unop(op): - def _wrapped_op[T_Number: numbers.Number](x: T_Number) -> T_Number: - if not isinstance(x, Term): - return op(x) - else: - raise NotHandled - - _wrapped_op.__name__ = op.__name__ - return _wrapped_op - - -add = defop(_wrap_binop(operator.add)) -neg = defop(_wrap_unop(operator.neg)) -pos = defop(_wrap_unop(operator.pos)) -sub = defop(_wrap_binop(operator.sub)) -mul = defop(_wrap_binop(operator.mul)) -truediv = defop(_wrap_binop(operator.truediv)) -pow = defop(_wrap_binop(operator.pow)) -abs = defop(_wrap_unop(operator.abs)) -eq = defop(_wrap_cmp(operator.eq)) - - -@defdata.register(numbers.Complex) -@numbers.Complex.register -class _ComplexTerm(_NumberTerm, Term[numbers.Complex]): - def __bool__(self) -> bool: - raise ValueError("Cannot convert term to bool") - - def __add__(self, other: Any) -> numbers.Real: - return add(self, other) - - def __radd__(self, other: Any) -> numbers.Real: - return add(other, self) - - def __neg__(self): - return neg(self) - - def __pos__(self): - return pos(self) - - def __sub__(self, other: Any) -> numbers.Real: - return sub(self, other) - - def __rsub__(self, other: Any) -> numbers.Real: - return sub(other, self) - - def __mul__(self, other: Any) -> numbers.Real: - return mul(self, other) - - def __rmul__(self, other: Any) -> numbers.Real: - return mul(other, self) - - def __truediv__(self, other: Any) -> numbers.Real: - return truediv(self, other) - - def __rtruediv__(self, other: Any) -> numbers.Real: - return truediv(other, self) - - def __pow__(self, other: Any) -> numbers.Real: - return pow(self, other) - - def __rpow__(self, other: Any) -> numbers.Real: - return pow(other, self) - - def __abs__(self) -> numbers.Real: - return abs(self) - - def __eq__(self, other: Any) -> bool: - return eq(self, other) - - -# Real specific methods -floordiv = defop(_wrap_binop(operator.floordiv)) -mod = defop(_wrap_binop(operator.mod)) -lt = defop(_wrap_cmp(operator.lt)) -le = defop(_wrap_cmp(operator.le)) -gt = defop(_wrap_cmp(operator.gt)) -ge = defop(_wrap_cmp(operator.ge)) - - -@defdata.register(numbers.Real) -@numbers.Real.register -class _RealTerm(_ComplexTerm, Term[numbers.Real]): - # Real specific methods - def __float__(self) -> float: - raise ValueError("Cannot convert term to float") - - def __trunc__(self) -> numbers.Integral: - raise NotImplementedError - - def __floor__(self) -> numbers.Integral: - raise NotImplementedError - - def __ceil__(self) -> numbers.Integral: - raise NotImplementedError - - def __round__(self, ndigits=None) -> numbers.Integral: - raise NotImplementedError - - def __floordiv__(self, other): - return floordiv(self, other) - - def __rfloordiv__(self, other): - return floordiv(other, self) - - def __mod__(self, other): - return mod(self, other) - - def __rmod__(self, other): - return mod(other, self) - - def __lt__(self, other): - return lt(self, other) - - def __le__(self, other): - return le(self, other) - - -@defdata.register(numbers.Rational) -@numbers.Rational.register -class _RationalTerm(_RealTerm, Term[numbers.Rational]): - @property - def numerator(self): - raise NotImplementedError - - @property - def denominator(self): - raise NotImplementedError - - -# Integral specific methods -index = defop(_wrap_unop(operator.index)) -lshift = defop(_wrap_binop(operator.lshift)) -rshift = defop(_wrap_binop(operator.rshift)) -and_ = defop(_wrap_binop(operator.and_)) -xor = defop(_wrap_binop(operator.xor)) -or_ = defop(_wrap_binop(operator.or_)) -invert = defop(_wrap_unop(operator.invert)) - - -@defdata.register(numbers.Integral) -@numbers.Integral.register -class _IntegralTerm(_RationalTerm, Term[numbers.Integral]): - # Integral specific methods - def __int__(self) -> int: - raise ValueError("Cannot convert term to int") - - def __index__(self) -> numbers.Integral: - return index(self) - - def __pow__(self, exponent: Any, modulus=None) -> numbers.Integral: - return pow(self, exponent) - - def __lshift__(self, other): - return lshift(self, other) - - def __rlshift__(self, other): - return lshift(other, self) - - def __rshift__(self, other): - return rshift(self, other) - - def __rrshift__(self, other): - return rshift(other, self) - - def __and__(self, other): - return and_(self, other) - - def __rand__(self, other): - return and_(other, self) - - def __xor__(self, other): - return xor(self, other) - - def __rxor__(self, other): - return xor(other, self) - - def __or__(self, other): - return or_(self, other) - - def __ror__(self, other): - return or_(other, self) - - def __invert__(self): - return invert(self) diff --git a/effectful/handlers/torch.py b/effectful/handlers/torch.py index 5fe57953..71a7f02c 100644 --- a/effectful/handlers/torch.py +++ b/effectful/handlers/torch.py @@ -11,7 +11,6 @@ import tree -import effectful.handlers.numbers # noqa: F401 from effectful.internals.runtime import interpreter from effectful.internals.tensor_utils import _desugar_tensor_index from effectful.ops.semantics import apply, evaluate, fvsof, handler, typeof diff --git a/effectful/ops/syntax.py b/effectful/ops/syntax.py index 76a2d4cb..a4fe9014 100644 --- a/effectful/ops/syntax.py +++ b/effectful/ops/syntax.py @@ -2,6 +2,8 @@ import dataclasses import functools import inspect +import numbers +import operator import random import types import typing @@ -50,7 +52,6 @@ class Scoped(Annotation): >>> from typing import Annotated >>> from effectful.ops.syntax import Scoped, defop >>> from effectful.ops.semantics import fvsof - >>> from effectful.handlers.numbers import add >>> x, y = defop(int, name='x'), defop(int, name='y') * For example, we can define a higher-order operation :func:`Lambda` @@ -69,11 +70,11 @@ class Scoped(Annotation): passed to :func:`Lambda` may appear free in ``body``, but not in the resulting function. In other words, it is bound by :func:`Lambda`: - >>> assert x not in fvsof(Lambda(x, add(x(), 1))) + >>> assert x not in fvsof(Lambda(x, x() + 1)) However, variables in ``body`` other than ``var`` still appear free in the result: - >>> assert y in fvsof(Lambda(x, add(x(), y()))) + >>> assert y in fvsof(Lambda(x, x() + y())) * :class:`Scoped` can also be used with variadic arguments and keyword arguments. For example, we can define a generalized :func:`LambdaN` that takes a variable @@ -89,7 +90,7 @@ class Scoped(Annotation): This is equivalent to the built-in :class:`Operation` :func:`deffn`: - >>> assert not {x, y} & fvsof(LambdaN(add(x(), y()), x, y)) + >>> assert not {x, y} & fvsof(LambdaN(x() + y(), x, y)) * :class:`Scoped` and :func:`defop` can also express more complex scoping semantics. For example, we can define a :func:`Let` operation that binds a variable in @@ -105,15 +106,15 @@ class Scoped(Annotation): Here the variable ``var`` is bound by :func:`Let` in `body` but not in ``val`` : - >>> assert x not in fvsof(Let(x, add(y(), 1), add(x(), y()))) + >>> assert x not in fvsof(Let(x, y() + 1, x() + y())) - >>> fvs = fvsof(Let(x, add(y(), x()), add(x(), y()))) + >>> fvs = fvsof(Let(x, y() + x(), x() + y())) >>> assert x in fvs and y in fvs This is reflected in the free variables of subterms of the result: - >>> assert x in fvsof(Let(x, add(x(), y()), add(x(), y())).args[1]) - >>> assert x not in fvsof(Let(x, add(y(), 1), add(x(), y())).args[2]) + >>> assert x in fvsof(Let(x, x() + y(), x() + y()).args[1]) + >>> assert x not in fvsof(Let(x, y() + 1, x() + y()).args[2]) """ ordinal: collections.abc.Set @@ -438,7 +439,6 @@ def defop[**P, T]( Passing :func:`defop` a type is a handy way to create a free variable. - >>> import effectful.handlers.numbers >>> from effectful.ops.semantics import evaluate >>> x = defop(int, name='x') >>> y = x() + 1 @@ -446,7 +446,7 @@ def defop[**P, T]( ``y`` is free in ``x``, so it is not fully evaluated: >>> print(str(y)) - add(x(), 1) + __add__(x(), 1) We bind ``x`` by installing a handler for it: @@ -471,7 +471,6 @@ def defop[**P, T]( operation object. In this example, ``scale`` returns a term with a free variable ``x``: - >>> import effectful.handlers.numbers >>> x = defop(float, name='x') >>> def scale(a: float) -> float: ... return x() * a @@ -482,7 +481,7 @@ def defop[**P, T]( >>> fresh_x = defop(float, name='x') >>> with handler({fresh_x: lambda: 2.0}): ... print(str(evaluate(term))) - mul(x(), 3.0) + __mul__(x(), 3.0) Only the original operation object will work: @@ -807,14 +806,13 @@ def deffn[T, A, B]( Here :func:`deffn` is used to define a term that represents the function ``lambda x, y=1: 2 * x + y``: - >>> import effectful.handlers.numbers >>> import random >>> random.seed(0) >>> x, y = defop(int, name='x'), defop(int, name='y') >>> term = deffn(2 * x() + y(), x, y=y) - >>> print(str(term)) - deffn(add(mul(2, x()), y()), x, y=y) + >>> print(str(term)) # doctest: +ELLIPSIS + deffn(...) >>> term(3, y=4) 10 @@ -1050,7 +1048,7 @@ def trace[**P, T](value: Callable[P, T]) -> Callable[P, T]: >>> term = trace(incr) >>> print(str(term)) - deffn(add(int(), 1), int) + deffn(__add__(int(), 1), int) >>> term(2) 3 @@ -1285,3 +1283,348 @@ def implements[**P, V](op: Operation[P, V]): """ return _ImplementedOperation(op) + + +@defdata.register(numbers.Number) +@functools.total_ordering +class _NumberTerm[T: numbers.Number](_BaseTerm[T], numbers.Number): + def __hash__(self): + return id(self) + + def __complex__(self) -> complex: + raise ValueError("Cannot convert term to complex number") + + def __float__(self) -> float: + raise ValueError("Cannot convert term to float") + + def __int__(self) -> int: + raise ValueError("Cannot convert term to int") + + def __bool__(self) -> bool: + raise ValueError("Cannot convert term to bool") + + @defop # type: ignore[prop-decorator] + @property + def real(self) -> float: + if not isinstance(self, Term): + return self.real + else: + raise NotHandled + + @defop # type: ignore[prop-decorator] + @property + def imag(self) -> float: + if not isinstance(self, Term): + return self.imag + else: + raise NotHandled + + @defop + def conjugate(self) -> complex: + if not isinstance(self, Term): + return self.conjugate() + else: + raise NotHandled + + @defop # type: ignore[prop-decorator] + @property + def numerator(self) -> int: + if not isinstance(self, Term): + return self.numerator + else: + raise NotHandled + + @defop # type: ignore[prop-decorator] + @property + def denominator(self) -> int: + if not isinstance(self, Term): + return self.denominator + else: + raise NotHandled + + @defop + def __abs__(self) -> float: + """Return the absolute value of the term.""" + if not isinstance(self, Term): + return self.__abs__() + else: + raise NotHandled + + @defop + def __neg__(self: T) -> T: + if not isinstance(self, Term): + return self.__neg__() # type: ignore + else: + raise NotHandled + + @defop + def __pos__(self: T) -> T: + if not isinstance(self, Term): + return self.__pos__() # type: ignore + else: + raise NotHandled + + @defop + def __trunc__(self) -> int: + if not isinstance(self, Term): + return self.__trunc__() + else: + raise NotHandled + + @defop + def __floor__(self) -> int: + if not isinstance(self, Term): + return self.__floor__() + else: + raise NotHandled + + @defop + def __ceil__(self) -> int: + if not isinstance(self, Term): + return self.__ceil__() + else: + raise NotHandled + + @defop + def __round__(self, ndigits: int | None = None) -> numbers.Real: + if not isinstance(self, Term) and not isinstance(ndigits, Term): + return self.__round__(ndigits) + else: + raise NotHandled + + @defop + def __invert__(self) -> int: + if not isinstance(self, Term): + return self.__invert__() + else: + raise NotHandled + + @defop + def __index__(self) -> int: + if not isinstance(self, Term): + return self.__index__() + else: + raise NotHandled + + @defop + def __eq__(self, other) -> bool: # type: ignore[override] + if not isinstance(self, Term) and not isinstance(other, Term): + return self.__eq__(other) + else: + return syntactic_eq(self, other) + + @defop + def __lt__(self, other) -> bool: + if not isinstance(self, Term) and not isinstance(other, Term): + return self.__lt__(other) + else: + raise NotHandled + + @defop + def __add__(self, other: T) -> T: + if not isinstance(self, Term) and not isinstance(other, Term): + return operator.__add__(self, other) + else: + raise NotHandled + + def __radd__(self, other): + if isinstance(other, Term) and isinstance(other, type(self)): + return other.__add__(self) + elif not isinstance(other, Term): + return type(self).__add__(other, self) + else: + return NotImplemented + + @defop + def __sub__(self, other: T) -> T: + if not isinstance(self, Term) and not isinstance(other, Term): + return operator.__sub__(self, other) + else: + raise NotHandled + + def __rsub__(self, other): + if isinstance(other, Term) and isinstance(other, type(self)): + return other.__sub__(self) + elif not isinstance(other, Term): + return type(self).__sub__(other, self) + else: + return NotImplemented + + @defop + def __mul__(self, other: T) -> T: + if not isinstance(self, Term) and not isinstance(other, Term): + return operator.__mul__(self, other) + else: + raise NotHandled + + def __rmul__(self, other): + if isinstance(other, Term) and isinstance(other, type(self)): + return other.__mul__(self) + elif not isinstance(other, Term): + return type(self).__mul__(other, self) + else: + return NotImplemented + + @defop + def __truediv__(self, other: T) -> T: + if not isinstance(self, Term) and not isinstance(other, Term): + return operator.__truediv__(self, other) + else: + raise NotHandled + + def __rtruediv__(self, other): + if isinstance(other, Term) and isinstance(other, type(self)): + return other.__truediv__(self) + elif not isinstance(other, Term): + return type(self).__truediv__(other, self) + else: + return NotImplemented + + @defop + def __floordiv__(self, other: T) -> T: + if not isinstance(self, Term) and not isinstance(other, Term): + return operator.__floordiv__(self, other) + else: + raise NotHandled + + def __rfloordiv__(self, other): + if isinstance(other, Term) and isinstance(other, type(self)): + return other.__floordiv__(self) + elif not isinstance(other, Term): + return type(self).__floordiv__(other, self) + else: + return NotImplemented + + @defop + def __mod__(self, other: T) -> T: + if not isinstance(self, Term) and not isinstance(other, Term): + return operator.__mod__(self, other) + else: + raise NotHandled + + def __rmod__(self, other): + if isinstance(other, Term) and isinstance(other, type(self)): + return other.__mod__(self) + elif not isinstance(other, Term): + return type(self).__mod__(other, self) + else: + return NotImplemented + + @defop + def __pow__(self, other: T) -> T: + if not isinstance(self, Term) and not isinstance(other, Term): + return operator.__pow__(self, other) + else: + raise NotHandled + + def __rpow__(self, other): + if isinstance(other, Term) and isinstance(other, type(self)): + return other.__pow__(self) + elif not isinstance(other, Term): + return type(self).__pow__(other, self) + else: + return NotImplemented + + @defop + def __lshift__(self, other: T) -> T: + if not isinstance(self, Term) and not isinstance(other, Term): + return operator.__lshift__(self, other) + else: + raise NotHandled + + def __rlshift__(self, other): + if isinstance(other, Term) and isinstance(other, type(self)): + return other.__lshift__(self) + elif not isinstance(other, Term): + return type(self).__lshift__(other, self) + else: + return NotImplemented + + @defop + def __rshift__(self, other: T) -> T: + if not isinstance(self, Term) and not isinstance(other, Term): + return operator.__rshift__(self, other) + else: + raise NotHandled + + def __rrshift__(self, other): + if isinstance(other, Term) and isinstance(other, type(self)): + return other.__rshift__(self) + elif not isinstance(other, Term): + return type(self).__rshift__(other, self) + else: + return NotImplemented + + @defop + def __and__(self, other: T) -> T: + if not isinstance(self, Term) and not isinstance(other, Term): + return operator.__and__(self, other) + else: + raise NotHandled + + def __rand__(self, other): + if isinstance(other, Term) and isinstance(other, type(self)): + return other.__and__(self) + elif not isinstance(other, Term): + return type(self).__and__(other, self) + else: + return NotImplemented + + @defop + def __xor__(self, other: T) -> T: + if not isinstance(self, Term) and not isinstance(other, Term): + return operator.__xor__(self, other) + else: + raise NotHandled + + def __rxor__(self, other): + if isinstance(other, Term) and isinstance(other, type(self)): + return other.__xor__(self) + elif not isinstance(other, Term): + return type(self).__xor__(other, self) + else: + return NotImplemented + + @defop + def __or__(self, other: T) -> T: + if not isinstance(self, Term) and not isinstance(other, Term): + return operator.__or__(self, other) + else: + raise NotHandled + + def __ror__(self, other): + if isinstance(other, Term) and isinstance(other, type(self)): + return other.__or__(self) + elif not isinstance(other, Term): + return type(self).__or__(other, self) + else: + return NotImplemented + + +@defdata.register(numbers.Complex) +@numbers.Complex.register +class _ComplexTerm[T: numbers.Complex](_NumberTerm[T]): + pass + + +@defdata.register(numbers.Real) +@numbers.Real.register +class _RealTerm[T: numbers.Real](_ComplexTerm[T]): + pass + + +@defdata.register(numbers.Rational) +@numbers.Rational.register +class _RationalTerm[T: numbers.Rational](_RealTerm[T]): + pass + + +@defdata.register(numbers.Integral) +@numbers.Integral.register +class _IntegralTerm[T: numbers.Integral](_RationalTerm[T]): + pass + + +@defdata.register(bool) +class _BoolTerm[T: bool](_IntegralTerm[T]): # type: ignore + pass diff --git a/tests/test_handlers_numbers.py b/tests/test_handlers_numbers.py deleted file mode 100644 index 31d117dd..00000000 --- a/tests/test_handlers_numbers.py +++ /dev/null @@ -1,273 +0,0 @@ -import collections -import collections.abc -import logging -import os -import typing - -import pytest - -from docs.source.lambda_ import App, Lam, Let, eager_mixed -from effectful.ops.semantics import evaluate, fvsof, handler, typeof -from effectful.ops.syntax import defop, syntactic_eq, trace -from effectful.ops.types import Term - -logger = logging.getLogger(__name__) - -T = typing.TypeVar("T") - - -def test_lambda_calculus_1(): - x, y = defop(int), defop(int) - - with handler(eager_mixed): - e1 = x() + 1 - f1 = Lam(x, e1) - - assert syntactic_eq(App(f1, 1), 2) - assert syntactic_eq(Lam(y, f1), f1) - assert syntactic_eq(Lam(x, f1.args[1]), f1.args[1]) - - assert fvsof(e1) == fvsof(x() + 1) - assert fvsof(Lam(x, e1).args[1]) != fvsof(Lam(x, e1).args[1]) - - assert typeof(e1) is int - assert typeof(f1) is collections.abc.Callable - - -def test_lambda_calculus_2(): - x, y = defop(int), defop(int) - - with handler(eager_mixed): - f2 = Lam(x, Lam(y, (x() + y()))) - assert syntactic_eq(App(App(f2, 1), 2), 3) - assert syntactic_eq(Lam(y, f2), f2) - - -def test_lambda_calculus_3(): - x, y, f = ( - defop(int), - defop(int), - defop(collections.abc.Callable[[int], collections.abc.Callable[[int], int]]), - ) - - with handler(eager_mixed): - f2 = Lam(x, Lam(y, (x() + y()))) - app2 = Lam(f, Lam(x, Lam(y, App(App(f(), x()), y())))) - assert syntactic_eq(App(App(App(app2, f2), 1), 2), 3) - - -def test_lambda_calculus_4(): - x, f, g = ( - defop(int), - defop(collections.abc.Callable[[T], T]), - defop(collections.abc.Callable[[T], T]), - ) - - with handler(eager_mixed): - add1 = Lam(x, (x() + 1)) - compose = Lam(f, Lam(g, Lam(x, App(f(), App(g(), x()))))) - f1_twice = App(App(compose, add1), add1) - assert syntactic_eq(App(f1_twice, 1), 3) - - -def test_lambda_calculus_5(): - x = defop(int) - - with handler(eager_mixed): - e_add1 = Let(x, x(), (x() + 1)) - f_add1 = Lam(x, e_add1) - - assert x in fvsof(e_add1) - assert e_add1.args[0] != x - - assert x not in fvsof(f_add1) - assert f_add1.args[0] != f_add1.args[1].args[0] - - assert syntactic_eq(App(f_add1, 1), 2) - assert syntactic_eq(Let(x, 1, e_add1), 2) - - -def test_arithmetic_1(): - x_, y_ = defop(int), defop(int) - x, y = x_(), y_() - - with handler(eager_mixed): - assert syntactic_eq((1 + 2) + x, x + 3) - assert not syntactic_eq(x + 1, y + 1) - assert syntactic_eq(x + 0, 0 + x) and syntactic_eq(0 + x, x) - - -def test_arithmetic_2(): - x_, y_ = defop(int), defop(int) - x, y = x_(), y_() - - with handler(eager_mixed): - assert syntactic_eq(x + y, y + x) - assert syntactic_eq(3 + x, x + 3) - assert syntactic_eq(1 + (x + 2), x + 3) - assert syntactic_eq((x + 1) + 2, x + 3) - - -def test_arithmetic_3(): - x_, y_ = defop(int), defop(int) - x, y = x_(), y_() - - with handler(eager_mixed): - assert syntactic_eq((1 + (y + 1)) + (1 + (x + 1)), (y + x) + 4) - assert syntactic_eq(1 + ((x + y) + 2), (x + y) + 3) - assert syntactic_eq(1 + ((x + (y + 1)) + 1), (x + y) + 3) - - -def test_arithmetic_4(): - x_, y_ = defop(int), defop(int) - x, y = x_(), y_() - - with handler(eager_mixed): - expr1 = ((x + x) + (x + x)) + ((x + x) + (x + x)) - expr2 = x + (x + (x + (x + (x + (x + (x + x)))))) - expr3 = ((((((x + x) + x) + x) + x) + x) + x) + x - assert syntactic_eq(expr1, expr2) and syntactic_eq(expr2, expr3) - - expr4 = (x + y) + (y + x) - expr5 = (y + (x + x)) + y - expr6 = y + (x + (y + x)) - assert syntactic_eq(expr4, expr5) and syntactic_eq(expr5, expr6) - - -def test_arithmetic_5(): - x, y = defop(int), defop(int) - - with handler(eager_mixed): - assert syntactic_eq(Let(x, x() + 3, x() + 1), x() + 4) - assert syntactic_eq(Let(x, x() + 3, x() + y() + 1), y() + x() + 4) - - assert syntactic_eq(Let(x, x() + 3, Let(x, x() + 4, x() + y())), x() + y() + 7) - - -def test_defun_1(): - x, y = defop(int), defop(int) - - with handler(eager_mixed): - - @trace - def f1(x: int) -> int: - return x + y() + 1 - - assert typeof(f1) is collections.abc.Callable - assert y in fvsof(f1) - assert x not in fvsof(f1) - - assert syntactic_eq(f1(1), y() + 2) - assert syntactic_eq(f1(x()), x() + y() + 1) - - -def test_defun_2(): - with handler(eager_mixed): - - @trace - def f1(x: int, y: int) -> int: - return x + y - - @trace - def f2(x: int, y: int) -> int: - @trace - def f2_inner(y: int) -> int: - return x + y - - return f2_inner(y) - - assert syntactic_eq(f1(1, 2), 3) and syntactic_eq(f2(1, 2), 3) - - -def test_defun_3(): - with handler(eager_mixed): - - @trace - def f2(x: int, y: int) -> int: - return x + y - - @trace - def app2(f: collections.abc.Callable[[int, int], int], x: int, y: int) -> int: - return f(x, y) - - assert syntactic_eq(app2(f2, 1, 2), 3) - - -@pytest.mark.xfail(condition=os.getenv("CI") == "true", reason="Fails on CI") -def test_defun_4(): - x = defop(int) - - with handler(eager_mixed): - - @trace - def compose( - f: collections.abc.Callable[[int], int], - g: collections.abc.Callable[[int], int], - ) -> collections.abc.Callable[[int], int]: - @trace - def fg(x: int) -> int: - assert callable(f), f"f is not callable: {f}" - assert callable(g), f"g is not callable: {g}" - return f(g(x)) - - return fg - - assert callable(compose), f"compose is not callable: {compose}" - - @trace - def add1(x: int) -> int: - return x + 1 - - assert callable(add1), f"add1 is not callable: {add1}" - - @trace - def add1_twice(x: int) -> int: - return compose(add1, add1)(x) - - assert callable(add1_twice), f"add1_twice is not callable: {add1_twice}" - - assert syntactic_eq(add1_twice(1), 3) and syntactic_eq( - compose(add1, add1)(1), 3 - ) - assert syntactic_eq(add1_twice(x()), x() + 2) and syntactic_eq( - compose(add1, add1)(x()), x() + 2 - ) - - -def test_defun_5(): - with pytest.raises(ValueError, match="variadic"): - trace(lambda *xs: None) - - with pytest.raises(ValueError, match="variadic"): - trace(lambda **ys: None) - - with pytest.raises(ValueError, match="variadic"): - trace(lambda y=1, **ys: None) - - with pytest.raises(ValueError, match="variadic"): - trace(lambda x, *xs, y=1, **ys: None) - - -def test_evaluate_2(): - x = defop(int, name="x") - y = defop(int, name="y") - t = x() + y() - assert isinstance(t, Term) - assert t.op.__name__ == "add" - with handler({x: lambda: 1, y: lambda: 3}): - assert evaluate(t) == 4 - - t = x() * y() - assert isinstance(t, Term) - with handler({x: lambda: 2, y: lambda: 3}): - assert evaluate(t) == 6 - - t = x() - y() - assert isinstance(t, Term) - with handler({x: lambda: 2, y: lambda: 3}): - assert evaluate(t) == -1 - - t = x() ^ y() - assert isinstance(t, Term) - with handler({x: lambda: 1, y: lambda: 2}): - assert evaluate(t) == 3 diff --git a/tests/test_ops_semantics.py b/tests/test_ops_semantics.py index 1cfd7461..d07bb749 100644 --- a/tests/test_ops_semantics.py +++ b/tests/test_ops_semantics.py @@ -6,7 +6,6 @@ import pytest -import effectful.handlers.numbers # noqa: F401 from effectful.ops.semantics import ( coproduct, evaluate, diff --git a/tests/test_ops_syntax.py b/tests/test_ops_syntax.py index 6e53087a..6319e687 100644 --- a/tests/test_ops_syntax.py +++ b/tests/test_ops_syntax.py @@ -1,13 +1,17 @@ +import collections import collections.abc import dataclasses import functools import inspect +import logging +import os +import typing from collections.abc import Callable, Iterable, Iterator, Mapping from typing import Annotated, ClassVar import pytest -import effectful.handlers.numbers # noqa: F401 +from docs.source.lambda_ import App, Lam, Let, eager_mixed from effectful.ops.semantics import evaluate, fvsof, handler, typeof from effectful.ops.syntax import ( Scoped, @@ -19,9 +23,15 @@ defterm, iter_, next_, + syntactic_eq, + trace, ) from effectful.ops.types import NotHandled, Operation, Term +logger = logging.getLogger(__name__) + +T = typing.TypeVar("T") + call = defdata.dispatch(collections.abc.Callable).__call__ @@ -202,11 +212,11 @@ def test_term_str(): assert str(x1) == str(x2) == str(x3) == "x" assert repr(x1) != repr(x2) != repr(x3) - assert str(x1() + x2()) == "add(x(), x!1())" - assert str(x1() + x1()) == "add(x(), x())" - assert str(deffn(x1() + x1(), x1)) == "deffn(add(x(), x()), x)" - assert str(deffn(x1() + x1(), x2)) == "deffn(add(x(), x()), x!1)" - assert str(deffn(x1() + x2(), x1)) == "deffn(add(x(), x!1()), x)" + assert str(x1() + x2()) == "__add__(x(), x!1())" + assert str(x1() + x1()) == "__add__(x(), x())" + assert str(deffn(x1() + x1(), x1)) == "deffn(__add__(x(), x()), x)" + assert str(deffn(x1() + x1(), x2)) == "deffn(__add__(x(), x()), x!1)" + assert str(deffn(x1() + x2(), x1)) == "deffn(__add__(x(), x!1()), x)" def test_defop_singledispatch(): @@ -582,3 +592,259 @@ class Lines: origin=Point(3, 4), lines=[Line(Point(3, 4), Point(4, 5))], ) + + +def test_lambda_calculus_1(): + x, y = defop(int), defop(int) + + with handler(eager_mixed): + e1 = x() + 1 + f1 = Lam(x, e1) + + assert syntactic_eq(App(f1, 1), 2) + assert syntactic_eq(Lam(y, f1), f1) + assert syntactic_eq(Lam(x, f1.args[1]), f1.args[1]) + + assert fvsof(e1) == fvsof(x() + 1) + assert fvsof(Lam(x, e1).args[1]) != fvsof(Lam(x, e1).args[1]) + + assert typeof(e1) is int + assert typeof(f1) is collections.abc.Callable + + +def test_lambda_calculus_2(): + x, y = defop(int), defop(int) + + with handler(eager_mixed): + f2 = Lam(x, Lam(y, (x() + y()))) + assert syntactic_eq(App(App(f2, 1), 2), 3) + assert syntactic_eq(Lam(y, f2), f2) + + +def test_lambda_calculus_3(): + x, y, f = ( + defop(int), + defop(int), + defop(collections.abc.Callable[[int], collections.abc.Callable[[int], int]]), + ) + + with handler(eager_mixed): + f2 = Lam(x, Lam(y, (x() + y()))) + app2 = Lam(f, Lam(x, Lam(y, App(App(f(), x()), y())))) + assert syntactic_eq(App(App(App(app2, f2), 1), 2), 3) + + +def test_lambda_calculus_4(): + x, f, g = ( + defop(int), + defop(collections.abc.Callable[[T], T]), + defop(collections.abc.Callable[[T], T]), + ) + + with handler(eager_mixed): + add1 = Lam(x, (x() + 1)) + compose = Lam(f, Lam(g, Lam(x, App(f(), App(g(), x()))))) + f1_twice = App(App(compose, add1), add1) + assert syntactic_eq(App(f1_twice, 1), 3) + + +def test_lambda_calculus_5(): + x = defop(int) + + with handler(eager_mixed): + e_add1 = Let(x, x(), (x() + 1)) + f_add1 = Lam(x, e_add1) + + assert x in fvsof(e_add1) + assert e_add1.args[0] != x + + assert x not in fvsof(f_add1) + assert f_add1.args[0] != f_add1.args[1].args[0] + + assert syntactic_eq(App(f_add1, 1), 2) + assert syntactic_eq(Let(x, 1, e_add1), 2) + + +def test_arithmetic_1(): + x_, y_ = defop(int), defop(int) + x, y = x_(), y_() + + with handler(eager_mixed): + assert syntactic_eq((1 + 2) + x, x + 3) + assert not syntactic_eq(x + 1, y + 1) + assert syntactic_eq(x + 0, 0 + x) and syntactic_eq(0 + x, x) + + +def test_arithmetic_2(): + x_, y_ = defop(int), defop(int) + x, y = x_(), y_() + + with handler(eager_mixed): + assert syntactic_eq(x + y, y + x) + assert syntactic_eq(3 + x, x + 3) + assert syntactic_eq(1 + (x + 2), x + 3) + assert syntactic_eq((x + 1) + 2, x + 3) + + +def test_arithmetic_3(): + x_, y_ = defop(int), defop(int) + x, y = x_(), y_() + + with handler(eager_mixed): + assert syntactic_eq((1 + (y + 1)) + (1 + (x + 1)), (y + x) + 4) + assert syntactic_eq(1 + ((x + y) + 2), (x + y) + 3) + assert syntactic_eq(1 + ((x + (y + 1)) + 1), (x + y) + 3) + + +def test_arithmetic_4(): + x_, y_ = defop(int), defop(int) + x, y = x_(), y_() + + with handler(eager_mixed): + expr1 = ((x + x) + (x + x)) + ((x + x) + (x + x)) + expr2 = x + (x + (x + (x + (x + (x + (x + x)))))) + expr3 = ((((((x + x) + x) + x) + x) + x) + x) + x + assert syntactic_eq(expr1, expr2) and syntactic_eq(expr2, expr3) + + expr4 = (x + y) + (y + x) + expr5 = (y + (x + x)) + y + expr6 = y + (x + (y + x)) + assert syntactic_eq(expr4, expr5) and syntactic_eq(expr5, expr6) + + +def test_arithmetic_5(): + x, y = defop(int), defop(int) + + with handler(eager_mixed): + assert syntactic_eq(Let(x, x() + 3, x() + 1), x() + 4) + assert syntactic_eq(Let(x, x() + 3, x() + y() + 1), y() + x() + 4) + + assert syntactic_eq(Let(x, x() + 3, Let(x, x() + 4, x() + y())), x() + y() + 7) + + +def test_defun_1(): + x, y = defop(int), defop(int) + + with handler(eager_mixed): + + @trace + def f1(x: int) -> int: + return x + y() + 1 + + assert typeof(f1) is collections.abc.Callable + assert y in fvsof(f1) + assert x not in fvsof(f1) + + assert syntactic_eq(f1(1), y() + 2) + assert syntactic_eq(f1(x()), x() + y() + 1) + + +def test_defun_2(): + with handler(eager_mixed): + + @trace + def f1(x: int, y: int) -> int: + return x + y + + @trace + def f2(x: int, y: int) -> int: + @trace + def f2_inner(y: int) -> int: + return x + y + + return f2_inner(y) + + assert syntactic_eq(f1(1, 2), 3) and syntactic_eq(f2(1, 2), 3) + + +def test_defun_3(): + with handler(eager_mixed): + + @trace + def f2(x: int, y: int) -> int: + return x + y + + @trace + def app2(f: collections.abc.Callable[[int, int], int], x: int, y: int) -> int: + return f(x, y) + + assert syntactic_eq(app2(f2, 1, 2), 3) + + +@pytest.mark.xfail(condition=os.getenv("CI") == "true", reason="Fails on CI") +def test_defun_4(): + x = defop(int) + + with handler(eager_mixed): + + @trace + def compose( + f: collections.abc.Callable[[int], int], + g: collections.abc.Callable[[int], int], + ) -> collections.abc.Callable[[int], int]: + @trace + def fg(x: int) -> int: + assert callable(f), f"f is not callable: {f}" + assert callable(g), f"g is not callable: {g}" + return f(g(x)) + + return fg + + assert callable(compose), f"compose is not callable: {compose}" + + @trace + def add1(x: int) -> int: + return x + 1 + + assert callable(add1), f"add1 is not callable: {add1}" + + @trace + def add1_twice(x: int) -> int: + return compose(add1, add1)(x) + + assert callable(add1_twice), f"add1_twice is not callable: {add1_twice}" + + assert syntactic_eq(add1_twice(1), 3) and syntactic_eq( + compose(add1, add1)(1), 3 + ) + assert syntactic_eq(add1_twice(x()), x() + 2) and syntactic_eq( + compose(add1, add1)(x()), x() + 2 + ) + + +def test_defun_5(): + with pytest.raises(ValueError, match="variadic"): + trace(lambda *xs: None) + + with pytest.raises(ValueError, match="variadic"): + trace(lambda **ys: None) + + with pytest.raises(ValueError, match="variadic"): + trace(lambda y=1, **ys: None) + + with pytest.raises(ValueError, match="variadic"): + trace(lambda x, *xs, y=1, **ys: None) + + +def test_evaluate_2(): + x = defop(int, name="x") + y = defop(int, name="y") + t = x() + y() + assert isinstance(t, Term) + with handler({x: lambda: 1, y: lambda: 3}): + assert evaluate(t) == 4 + + t = x() * y() + assert isinstance(t, Term) + with handler({x: lambda: 2, y: lambda: 3}): + assert evaluate(t) == 6 + + t = x() - y() + assert isinstance(t, Term) + with handler({x: lambda: 2, y: lambda: 3}): + assert evaluate(t) == -1 + + t = x() ^ y() + assert isinstance(t, Term) + with handler({x: lambda: 1, y: lambda: 2}): + assert evaluate(t) == 3