From dbb46fb88713608f792650a963b75d5505d0aac2 Mon Sep 17 00:00:00 2001 From: "Michael J. Sullivan" Date: Fri, 30 Jan 2026 14:01:52 -0800 Subject: [PATCH] Add the nplike test, testing broadcasting --- pep.rst | 44 ++++++++++++++++++ scripts/update-examples.sh | 3 ++ tests/test_nplike.py | 95 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 142 insertions(+) create mode 100644 tests/test_nplike.py diff --git a/pep.rst b/pep.rst index c2f2e8d..0c8301d 100644 --- a/pep.rst +++ b/pep.rst @@ -1080,6 +1080,50 @@ dataclasses-style method generation ] +.. _numpy-impl: + +NumPy-style broadcasting +------------------------ + +:: + + class Array[DType, *Shape]: + def __add__[*Shape2]( + self, other: Array[DType, *Shape2] + ) -> Array[DType, *Broadcast[tuple[*Shape], tuple[*Shape2]]]: + raise BaseException + + type AppendTuple[A, B] = tuple[ + *[x for x in typing.Iter[A]], + B, + ] + + type MergeOne[T, S] = ( + T + if typing.Matches[T, S] or typing.Matches[S, Literal[1]] + else S + if typing.Matches[T, Literal[1]] + else typing.RaiseError[Literal["Broadcast mismatch"], T, S] + ) + + type DropLast[T] = typing.Slice[T, Literal[0], Literal[-1]] + type Last[T] = typing.GetArg[T, tuple, Literal[-1]] + + # Matching on Never here is intentional; it prevents infinite + # recursions when T is not a tuple. + type Empty[T] = typing.IsSub[typing.Length[T], Literal[0]] + + type Broadcast[T, S] = ( + S + if typing.Bool[Empty[T]] + else T + if typing.Bool[Empty[S]] + else AppendTuple[ + Broadcast[DropLast[T], DropLast[S]], MergeOne[Last[T], Last[S]] + ] + ) + + Rationale ========= diff --git a/scripts/update-examples.sh b/scripts/update-examples.sh index 954b3d9..a94c9bd 100755 --- a/scripts/update-examples.sh +++ b/scripts/update-examples.sh @@ -9,3 +9,6 @@ scripts/py2rst.py tests/test_fastapilike_2.py --start "Begin PEP section: datacl scripts/py2rst.py tests/test_fastapilike_2.py --start "Begin PEP section: Automatically deriving FastAPI CRUD models" --end "End PEP section" \ | scripts/rst_replace_section.py pep.rst fastapi-impl -i + +scripts/py2rst.py tests/test_nplike.py --start "Begin PEP section" --end "End PEP section" \ + | scripts/rst_replace_section.py pep.rst numpy-impl -i diff --git a/tests/test_nplike.py b/tests/test_nplike.py new file mode 100644 index 0000000..0b72ecb --- /dev/null +++ b/tests/test_nplike.py @@ -0,0 +1,95 @@ +from typing import Literal + +from typemap import typing + +import pytest + +# Begin PEP section + + +class Array[DType, *Shape]: + def __add__[*Shape2]( + self, other: Array[DType, *Shape2] + ) -> Array[DType, *Broadcast[tuple[*Shape], tuple[*Shape2]]]: + raise BaseException + + +type AppendTuple[A, B] = tuple[ + *[x for x in typing.Iter[A]], + B, +] + +type MergeOne[T, S] = ( + T + if typing.Matches[T, S] or typing.Matches[S, Literal[1]] + else S + if typing.Matches[T, Literal[1]] + else typing.RaiseError[Literal["Broadcast mismatch"], T, S] +) + +type DropLast[T] = typing.Slice[T, Literal[0], Literal[-1]] +type Last[T] = typing.GetArg[T, tuple, Literal[-1]] + +# Matching on Never here is intentional; it prevents infinite +# recursions when T is not a tuple. +type Empty[T] = typing.IsSub[typing.Length[T], Literal[0]] + +type Broadcast[T, S] = ( + S + if typing.Bool[Empty[T]] + else T + if typing.Bool[Empty[S]] + else AppendTuple[ + Broadcast[DropLast[T], DropLast[S]], MergeOne[Last[T], Last[S]] + ] +) + +# End PEP section + +type GetElem[T] = typing.GetArg[T, Array, Literal[0]] +type GetShape[T] = typing.Slice[typing.GetArgs[T, Array], Literal[1], None] + +# type Apply[T, S] = Array[GetElem[T], *Broadcast[GetShape[T], GetShape[S]]] +type Apply[T, S] = Array[ + GetElem[T], + *[x for x in typing.Iter[Broadcast[GetShape[T], GetShape[S]]]], +] + +###### +from typemap.type_eval import eval_typing, TypeMapError + +from typing import Literal as L + + +def test_nplike_1(): + a1 = Array[float, L[4], L[1]] + a2 = Array[float, L[3]] + res = eval_typing(Apply[a1, a2]) + + assert res == Array[float, L[4], L[3]] + + +def test_nplike_2(): + b1 = Array[float, int, int] + b2 = Array[float, int] + res = eval_typing(Apply[b1, b2]) + + assert res == Array[float, int, int] + + +def test_nplike_3(): + c1 = Array[float, L[4], L[1], L[5]] + c2 = Array[float, L[4], L[3], L[1]] + res = eval_typing(Apply[c1, c2]) + + assert res == Array[float, L[4], L[3], L[5]] + + +def test_nplike_4(): + err1 = Array[float, L[4], L[2]] + err2 = Array[float, L[3]] + + with pytest.raises( + TypeMapError, match=r"Broadcast mismatch:.*Literal\[2\].*Literal\[3\]" + ): + eval_typing(Apply[err1, err2])