3131
3232import operator
3333from functools import partial , reduce
34- from typing import TYPE_CHECKING , cast
34+ from typing import TYPE_CHECKING , cast , overload
3535from warnings import warn
3636
3737import numpy as np
4949from arraycontext .fake_numpy import BaseFakeNumpyLinalgNamespace
5050from arraycontext .impl .pyopencl .taggable_cl_array import TaggableCLArray
5151from arraycontext .loopy import LoopyBasedFakeNumpyNamespace
52- from arraycontext .typing import OrderCF , is_scalar_like
52+ from arraycontext .typing import ArrayOrContainer , OrderCF , ScalarLike , is_scalar_like
5353
5454
5555if TYPE_CHECKING :
@@ -341,7 +341,25 @@ def inner(ary: ArrayOrScalar) -> ArrayOrScalar:
341341
342342 # {{{ mathematical functions
343343
344- def sum (self , a , axis = None , dtype = None ):
344+ @overload
345+ def sum (self ,
346+ a : ArrayOrContainer ,
347+ axis : int | tuple [int , ...] | None = None ,
348+ dtype : DTypeLike = None ,
349+ ) -> Array : ...
350+ @overload
351+ def sum (self ,
352+ a : ScalarLike ,
353+ axis : int | tuple [int , ...] | None = None ,
354+ dtype : DTypeLike = None ,
355+ ) -> ScalarLike : ...
356+
357+ @override
358+ def sum (self ,
359+ a : ArrayOrContainerOrScalar ,
360+ axis : int | tuple [int , ...] | None = None ,
361+ dtype : DTypeLike = None ,
362+ ) -> ArrayOrScalar :
345363 if isinstance (axis , int ):
346364 axis = axis ,
347365
@@ -358,6 +376,17 @@ def maximum(self, x, y):
358376 partial (cl_array .maximum , queue = self ._array_context .queue ),
359377 x , y )
360378
379+ @overload
380+ def max (self ,
381+ a : ArrayOrContainer ,
382+ axis : int | tuple [int , ...] | None = None ,
383+ ) -> Array : ...
384+ @overload
385+ def max (self ,
386+ a : ScalarLike ,
387+ axis : int | tuple [int , ...] | None = None ,
388+ ) -> ScalarLike : ...
389+
361390 @override
362391 def max (self ,
363392 a : ArrayOrContainerOrScalar ,
@@ -379,13 +408,24 @@ def _rec_max(ary):
379408 _rec_max ,
380409 a )
381410
382- amax = max
411+ amax = max # pyright: ignore[reportAssignmentType, reportDeprecated]
383412
384413 def minimum (self , x , y ):
385414 return rec_multimap_array_container (
386415 partial (cl_array .minimum , queue = self ._array_context .queue ),
387416 x , y )
388417
418+ @overload
419+ def min (self ,
420+ a : ArrayOrContainer ,
421+ axis : int | tuple [int , ...] | None = None ,
422+ ) -> Array : ...
423+ @overload
424+ def min (self ,
425+ a : ScalarLike ,
426+ axis : int | tuple [int , ...] | None = None ,
427+ ) -> ScalarLike : ...
428+
389429 @override
390430 def min (self ,
391431 a : ArrayOrContainerOrScalar ,
@@ -406,7 +446,7 @@ def _rec_min(ary):
406446 _rec_min ,
407447 a )
408448
409- amin = min
449+ amin = min # pyright: ignore[reportAssignmentType, reportDeprecated]
410450
411451 def absolute (self , a ):
412452 return self .abs (a )
0 commit comments