22from hashlib import sha256
33from textwrap import dedent , indent
44
5- import numba
65import numpy as np
76from numba .core .extending import overload
87from numpy .lib .array_utils import normalize_axis_index , normalize_axis_tuple
1413)
1514from pytensor .link .numba .dispatch import basic as numba_basic
1615from pytensor .link .numba .dispatch .basic import (
16+ create_tuple_string ,
1717 numba_funcify_and_cache_key ,
1818 register_funcify_and_cache_key ,
1919 register_funcify_default_op_cache_key ,
@@ -125,10 +125,12 @@ def scalar_in_place_fn_Minimum(op, idx, res, arr):
125125
126126def create_multiaxis_reducer (
127127 scalar_op ,
128+ * ,
128129 identity ,
129130 axes ,
130131 ndim ,
131- dtype ,
132+ acc_dtype = None ,
133+ out_dtype ,
132134 keepdims : bool = False ,
133135):
134136 r"""Construct a function that reduces multiple axes.
@@ -138,17 +140,46 @@ def create_multiaxis_reducer(
138140 .. code-block:: python
139141
140142 def careduce_add(x):
141- # For x.ndim == 3 and axes == (0, 1) and scalar_op == "Add"
142143 x_shape = x.shape
143- res_shape = x_shape[2]
144- res = np.full(res_shape, numba_basic.to_scalar(0.0), dtype=out_dtype)
144+ res_shape = (x_shape[0], x_shape[1])
145+ # identity = 0.0
146+ res = np.full(res_shape, identity, dtype=np.float64)
147+ for i0 in range(x_shape[0]):
148+ for i1 in range(x_shape[1]):
149+ for i2 in range(x_shape[2]):
150+ res[i0, i1] += x[i0, i1, i2]
151+ return res
152+
153+ If accumulation dtype differs from output_dtype
154+
155+ .. code-block:: python
145156
157+ def careduce_add(x):
158+ x_shape = x.shape
159+ res_shape = (x_shape[0], x_shape[1])
160+ # identity = 0.0
161+ res = np.full(res_shape, identity, dtype=np.float64)
146162 for i0 in range(x_shape[0]):
147163 for i1 in range(x_shape[1]):
148164 for i2 in range(x_shape[2]):
149- res[i2] += x[i0, i1, i2]
165+ res[i0, i1] += x[i0, i1, i2]
166+ return res.astype(np.int32)
167+
168+ Full reductions accumulate on scalars
169+
170+ .. code-block:: python
171+
172+ def careduce_mul(x):
173+ x_shape = x.shape
174+ res_shape = ()
175+ # identity = 1.0
176+ res = identity
177+ for i0 in range(x_shape[0]):
178+ for i1 in range(x_shape[1]):
179+ for i2 in range(x_shape[2]):
180+ res *= x[i0, i1, i2]
181+ return np.array(res, dtype=np.int32)
150182
151- return res
152183
153184 Parameters
154185 ==========
@@ -160,7 +191,9 @@ def careduce_add(x):
160191 The axes to reduce.
161192 ndim:
162193 The number of dimensions of the input variable.
163- dtype:
194+ acc_dtype: dtype, optional
195+ The data type used during accumulation. Defaults to out_dtype if not provided
196+ out_dtype:
164197 The data type of the result.
165198 keepdims: boolean, default False
166199 Whether to keep the reduced dimensions.
@@ -178,19 +211,23 @@ def careduce_add(x):
178211 "Cannot keep multiple dimensions when reducing multiple axes"
179212 )
180213
214+ out_dtype = np .dtype (out_dtype )
215+ acc_dtype = out_dtype if acc_dtype is None else np .dtype (acc_dtype )
216+ # Numba doesn't allow converting complex to real with a simple `astype`
217+ complex_to_real = acc_dtype .kind == "c" and out_dtype .kind != "c"
218+ out_dtype_str = f"np.{ out_dtype .name } "
219+ acc_dtype_str = f"np.{ acc_dtype .name } "
181220 careduce_fn_name = f"careduce_{ scalar_op } "
182221
183- identity = str (identity )
184- if identity == "inf" :
185- identity = "np.inf"
186- elif identity == "-inf" :
187- identity = "-np.inf"
188-
189- global_env = {
190- "np" : np ,
191- "numba_basic" : numba_basic ,
192- "out_dtype" : dtype ,
193- }
222+ if acc_dtype .kind in "ui" and not np .isfinite (identity ):
223+ if np .isposinf (identity ):
224+ identity = np .iinfo (acc_dtype ).max
225+ else :
226+ identity = np .iinfo (acc_dtype ).min
227+
228+ # Make sure it has the correct dtype
229+ identity = getattr (np , acc_dtype .name )(identity )
230+
194231 complete_reduction = len (axes ) == ndim
195232 kept_axis = tuple (i for i in range (ndim ) if i not in axes )
196233
@@ -208,17 +245,23 @@ def careduce_add(x):
208245 scalar_op , res_indices , "res" , f"x[{ arr_indices } ]"
209246 )
210247
211- res_shape = f"( { ', ' . join ( f' x_shape[{ i } ]' for i in kept_axis ) } )"
248+ res_shape = create_tuple_string ([ f" x_shape[{ i } ]" for i in kept_axis ])
212249 if complete_reduction and ndim > 0 :
213250 # We accumulate on a scalar, not an array
214- res_creator = f"np.asarray( { identity } ).astype(out_dtype).item() "
251+ res_creator = " identity"
215252 inplace_update_stmt = inplace_update_stmt .replace ("res[()]" , "res" )
216- return_obj = "np.asarray(res)"
253+ if complex_to_real :
254+ return_obj = f"np.array(res).real.astype({ out_dtype_str } )"
255+ else :
256+ return_obj = f"np.array(res, dtype={ out_dtype_str } )"
217257 else :
218- res_creator = (
219- f"np.full({ res_shape } , np.asarray({ identity } ).item(), dtype=out_dtype)"
220- )
221- return_obj = "res"
258+ res_creator = f"np.full(res_shape, identity, dtype={ acc_dtype_str } )"
259+ if complex_to_real :
260+ return_obj = f"res.real.astype({ out_dtype_str } )"
261+ else :
262+ return_obj = (
263+ "res" if out_dtype == acc_dtype else f"res.astype({ out_dtype_str } )"
264+ )
222265
223266 if keepdims :
224267 [axis ] = axes
@@ -229,6 +272,7 @@ def careduce_add(x):
229272 def { careduce_fn_name } (x):
230273 x_shape = x.shape
231274 res_shape = { res_shape }
275+ # identity = { identity }
232276 res = { res_creator }
233277 """
234278 )
@@ -238,13 +282,12 @@ def {careduce_fn_name}(x):
238282 " " * (4 + 4 * axis ),
239283 )
240284 careduce_def_src += indent (inplace_update_stmt , " " * (4 + 4 * ndim ))
241- careduce_def_src += "\n \n "
285+ careduce_def_src += "\n "
242286 careduce_def_src += indent (f"return { return_obj } " , " " * 4 )
243287
244288 careduce_fn = compile_numba_function_src (
245- careduce_def_src , careduce_fn_name , { ** globals (), ** global_env }
289+ careduce_def_src , careduce_fn_name , globals () | { "np" : np , "identity" : identity }
246290 )
247-
248291 return careduce_fn
249292
250293
@@ -356,41 +399,45 @@ def numba_funcify_CAReduce(op, node, **kwargs):
356399 acc_dtype = op .acc_dtype
357400 else :
358401 acc_dtype = node .outputs [0 ].type .dtype
359- np_acc_dtype = np .dtype (acc_dtype )
360-
361- scalar_op_identity = op .scalar_op .identity
362- if np_acc_dtype .kind == "i" and not np .isfinite (scalar_op_identity ):
363- if np .isposinf (scalar_op_identity ):
364- scalar_op_identity = np .iinfo (np_acc_dtype ).max
365- else :
366- scalar_op_identity = np .iinfo (np_acc_dtype ).min
367- # Make sure it has the correct dtype
368- scalar_op_identity = np .array (scalar_op_identity , dtype = np_acc_dtype )
369402
370403 out_dtype = np .dtype (node .outputs [0 ].type .dtype )
371404
372- if isinstance (op , Sum ) and node .inputs [0 ].ndim == len (axes ):
405+ if (
406+ isinstance (op , Sum )
407+ and node .inputs [0 ].ndim == len (axes )
408+ and out_dtype == acc_dtype
409+ ):
373410 # Slightly faster for this case
374411 @numba_basic .numba_njit
375412 def impl_sum (array ):
376- return np .asarray (array .sum (), dtype = np_acc_dtype ). astype ( out_dtype )
413+ return np .array (array .sum ())
377414
378415 careduce_fn = impl_sum # Some tests look for this name
379416
380417 else :
381418 ndim = node .inputs [0 ].ndim
382419 careduce_py_fn = create_multiaxis_reducer (
383420 op .scalar_op ,
384- scalar_op_identity ,
385- axes ,
386- ndim ,
387- out_dtype ,
421+ identity = op .scalar_op .identity ,
422+ axes = axes ,
423+ ndim = ndim ,
424+ acc_dtype = acc_dtype ,
425+ out_dtype = out_dtype ,
388426 )
389427 careduce_fn = numba_basic .numba_njit (careduce_py_fn , boundscheck = False )
390428
429+ cache_version = 1
391430 careduce_key = sha256 (
392431 str (
393- (type (op ), type (op .scalar_op ), axes , acc_dtype , scalar_op_identity .item ())
432+ (
433+ type (op ),
434+ type (op .scalar_op ),
435+ axes ,
436+ out_dtype ,
437+ acc_dtype ,
438+ op .scalar_op .identity ,
439+ cache_version ,
440+ )
394441 ).encode ()
395442 ).hexdigest ()
396443 return careduce_fn , careduce_key
@@ -436,18 +483,27 @@ def dimshuffle(x):
436483
437484@register_funcify_default_op_cache_key (Softmax )
438485def numba_funcify_Softmax (op , node , ** kwargs ):
439- x_at = node .inputs [0 ]
440- x_dtype = x_at .type .numpy_dtype
441- x_dtype = numba .np .numpy_support .from_dtype (x_dtype )
486+ ndim = node .inputs [0 ].type .ndim
487+ inp_dtype = node .inputs [0 ].type .numpy_dtype
442488 axis = op .axis
443489
444- if axis is not None :
445- axis = normalize_axis_index (axis , x_at . ndim )
490+ if ndim > 1 and axis is not None :
491+ axis = normalize_axis_index (axis , ndim )
446492 reduce_max_py = create_multiaxis_reducer (
447- maximum , - np .inf , axis , x_at .ndim , x_dtype , keepdims = True
493+ maximum ,
494+ identity = - np .inf ,
495+ axes = axis ,
496+ ndim = ndim ,
497+ out_dtype = inp_dtype ,
498+ keepdims = True ,
448499 )
449500 reduce_sum_py = create_multiaxis_reducer (
450- add_as , 0.0 , (axis ,), x_at .ndim , x_dtype , keepdims = True
501+ add_as ,
502+ identity = 0.0 ,
503+ axes = (axis ,),
504+ ndim = ndim ,
505+ out_dtype = inp_dtype ,
506+ keepdims = True ,
451507 )
452508
453509 jit_fn = numba_basic .numba_njit (boundscheck = False )
@@ -457,66 +513,74 @@ def numba_funcify_Softmax(op, node, **kwargs):
457513 reduce_max = np .max
458514 reduce_sum = np .sum
459515
460- def softmax_py_fn (x ):
516+ @numba_basic .numba_njit (boundscheck = False )
517+ def softmax (x ):
461518 z = reduce_max (x )
462519 e_x = np .exp (x - z )
463520 w = reduce_sum (e_x )
464521 sm = e_x / w
465522 return sm
466523
467- softmax = numba_basic .numba_njit (softmax_py_fn , boundscheck = False )
468-
469- return softmax
524+ cache_version = 1
525+ return softmax , cache_version
470526
471527
472528@register_funcify_default_op_cache_key (SoftmaxGrad )
473529def numba_funcify_SoftmaxGrad (op , node , ** kwargs ):
474- sm_at = node .inputs [1 ]
475- sm_dtype = sm_at .type .numpy_dtype
476- sm_dtype = numba .np .numpy_support .from_dtype (sm_dtype )
530+ ndim = node .inputs [0 ].type .ndim
531+ inp_dtype = node .inputs [0 ].type .numpy_dtype
477532
478533 axis = op .axis
479- if axis is not None :
480- axis = normalize_axis_index (axis , sm_at . ndim )
534+ if ndim > 1 and axis is not None :
535+ axis = normalize_axis_index (axis , ndim )
481536 reduce_sum_py = create_multiaxis_reducer (
482- add_as , 0.0 , (axis ,), sm_at .ndim , sm_dtype , keepdims = True
537+ add_as ,
538+ identity = 0.0 ,
539+ axes = (axis ,),
540+ ndim = ndim ,
541+ out_dtype = inp_dtype ,
542+ keepdims = True ,
483543 )
484544
485545 jit_fn = numba_basic .numba_njit (boundscheck = False )
486546 reduce_sum = jit_fn (reduce_sum_py )
487547 else :
488548 reduce_sum = np .sum
489549
490- def softmax_grad_py_fn (dy , sm ):
550+ @numba_basic .numba_njit (boundscheck = False )
551+ def softmax_grad (dy , sm ):
491552 dy_times_sm = dy * sm
492553 sum_dy_times_sm = reduce_sum (dy_times_sm )
493554 dx = dy_times_sm - sum_dy_times_sm * sm
494555 return dx
495556
496- softmax_grad = numba_basic .numba_njit (softmax_grad_py_fn , boundscheck = False )
497-
498- return softmax_grad
557+ cache_version = 1
558+ return softmax_grad , cache_version
499559
500560
501561@register_funcify_default_op_cache_key (LogSoftmax )
502562def numba_funcify_LogSoftmax (op , node , ** kwargs ):
503- x_at = node .inputs [0 ]
504- x_dtype = x_at .type .numpy_dtype
505- x_dtype = numba .np .numpy_support .from_dtype (x_dtype )
563+ ndim = node .inputs [0 ].type .ndim
564+ inp_dtype = node .inputs [0 ].type .numpy_dtype
506565 axis = op .axis
507566
508- if axis is not None :
509- axis = normalize_axis_index (axis , x_at . ndim )
567+ if ndim > 1 and axis is not None :
568+ axis = normalize_axis_index (axis , ndim )
510569 reduce_max_py = create_multiaxis_reducer (
511570 maximum ,
512- - np .inf ,
513- (axis ,),
514- x_at . ndim ,
515- x_dtype ,
571+ identity = - np .inf ,
572+ axes = (axis ,),
573+ ndim = ndim ,
574+ out_dtype = inp_dtype ,
516575 keepdims = True ,
517576 )
518577 reduce_sum_py = create_multiaxis_reducer (
519- add_as , 0.0 , (axis ,), x_at .ndim , x_dtype , keepdims = True
578+ add_as ,
579+ identity = 0.0 ,
580+ axes = (axis ,),
581+ ndim = ndim ,
582+ out_dtype = inp_dtype ,
583+ keepdims = True ,
520584 )
521585
522586 jit_fn = numba_basic .numba_njit (boundscheck = False )
@@ -526,13 +590,14 @@ def numba_funcify_LogSoftmax(op, node, **kwargs):
526590 reduce_max = np .max
527591 reduce_sum = np .sum
528592
529- def log_softmax_py_fn (x ):
593+ @numba_basic .numba_njit (boundscheck = False )
594+ def log_softmax (x ):
530595 xdev = x - reduce_max (x )
531596 lsm = xdev - np .log (reduce_sum (np .exp (xdev )))
532597 return lsm
533598
534- log_softmax = numba_basic . numba_njit ( log_softmax_py_fn , boundscheck = False )
535- return log_softmax
599+ cache_version = 1
600+ return log_softmax , cache_version
536601
537602
538603@register_funcify_default_op_cache_key (Argmax )
0 commit comments