22from hashlib import sha256
33from textwrap import dedent , indent
44
5- import numba
65import numpy as np
76from numba .core .extending import overload
87from numpy .lib .array_utils import normalize_axis_index , normalize_axis_tuple
1413)
1514from pytensor .link .numba .dispatch import basic as numba_basic
1615from pytensor .link .numba .dispatch .basic import (
16+ create_tuple_string ,
1717 numba_funcify_and_cache_key ,
1818 register_funcify_and_cache_key ,
1919 register_funcify_default_op_cache_key ,
@@ -125,10 +125,12 @@ def scalar_in_place_fn_Minimum(op, idx, res, arr):
125125
126126def create_multiaxis_reducer (
127127 scalar_op ,
128+ * ,
128129 identity ,
129130 axes ,
130131 ndim ,
131- dtype ,
132+ acc_dtype = None ,
133+ out_dtype ,
132134 keepdims : bool = False ,
133135):
134136 r"""Construct a function that reduces multiple axes.
@@ -138,17 +140,46 @@ def create_multiaxis_reducer(
138140 .. code-block:: python
139141
140142 def careduce_add(x):
141- # For x.ndim == 3 and axes == (0, 1) and scalar_op == "Add"
142143 x_shape = x.shape
143- res_shape = x_shape[2]
144- res = np.full(res_shape, numba_basic.to_scalar(0.0), dtype=out_dtype)
144+ res_shape = (x_shape[0], x_shape[1])
145+ # identity = 0.0
146+ res = np.full(res_shape, identity, dtype=np.float64)
147+ for i0 in range(x_shape[0]):
148+ for i1 in range(x_shape[1]):
149+ for i2 in range(x_shape[2]):
150+ res[i0, i1] += x[i0, i1, i2]
151+ return res
152+
153+ If accumulation dtype differs from output_dtype
154+
155+ .. code-block:: python
145156
157+ def careduce_add(x):
158+ x_shape = x.shape
159+ res_shape = (x_shape[0], x_shape[1])
160+ # identity = 0.0
161+ res = np.full(res_shape, identity, dtype=np.float64)
146162 for i0 in range(x_shape[0]):
147163 for i1 in range(x_shape[1]):
148164 for i2 in range(x_shape[2]):
149- res[i2] += x[i0, i1, i2]
165+ res[i0, i1] += x[i0, i1, i2]
166+ return res.astype(np.int32)
167+
168+ Full reductions accumulate on scalars
169+
170+ .. code-block:: python
171+
172+ def careduce_mul(x):
173+ x_shape = x.shape
174+ res_shape = ()
175+ # identity = 1.0
176+ res = identity
177+ for i0 in range(x_shape[0]):
178+ for i1 in range(x_shape[1]):
179+ for i2 in range(x_shape[2]):
180+ res *= x[i0, i1, i2]
181+ return np.array(res, dtype=np.int32)
150182
151- return res
152183
153184 Parameters
154185 ==========
@@ -160,7 +191,9 @@ def careduce_add(x):
160191 The axes to reduce.
161192 ndim:
162193 The number of dimensions of the input variable.
163- dtype:
194+ acc_dtype: dtype, optional
195+ The data type used during accumulation. Defaults to out_dtype if not provided
196+ out_dtype:
164197 The data type of the result.
165198 keepdims: boolean, default False
166199 Whether to keep the reduced dimensions.
@@ -178,19 +211,23 @@ def careduce_add(x):
178211 "Cannot keep multiple dimensions when reducing multiple axes"
179212 )
180213
214+ out_dtype = np .dtype (out_dtype )
215+ acc_dtype = out_dtype if acc_dtype is None else np .dtype (acc_dtype )
216+ # Numba doesn't allow converting complex to real with a simple `astype`
217+ complex_to_real = acc_dtype .kind == "c" and out_dtype .kind != "c"
218+ out_dtype_str = f"np.{ out_dtype .name } "
219+ acc_dtype_str = f"np.{ acc_dtype .name } "
181220 careduce_fn_name = f"careduce_{ scalar_op } "
182221
183- identity = str (identity )
184- if identity == "inf" :
185- identity = "np.inf"
186- elif identity == "-inf" :
187- identity = "-np.inf"
188-
189- global_env = {
190- "np" : np ,
191- "numba_basic" : numba_basic ,
192- "out_dtype" : dtype ,
193- }
222+ if acc_dtype .kind in "ui" and not np .isfinite (identity ):
223+ if np .isposinf (identity ):
224+ identity = np .iinfo (acc_dtype ).max
225+ else :
226+ identity = np .iinfo (acc_dtype ).min
227+
228+ # Make sure it has the correct dtype
229+ identity = getattr (np , acc_dtype .name )(identity )
230+
194231 complete_reduction = len (axes ) == ndim
195232 kept_axis = tuple (i for i in range (ndim ) if i not in axes )
196233
@@ -208,17 +245,23 @@ def careduce_add(x):
208245 scalar_op , res_indices , "res" , f"x[{ arr_indices } ]"
209246 )
210247
211- res_shape = f"( { ', ' . join ( f' x_shape[{ i } ]' for i in kept_axis ) } )"
248+ res_shape = create_tuple_string ([ f" x_shape[{ i } ]" for i in kept_axis ])
212249 if complete_reduction and ndim > 0 :
213250 # We accumulate on a scalar, not an array
214- res_creator = f"np.asarray( { identity } ).astype(out_dtype).item() "
251+ res_creator = " identity"
215252 inplace_update_stmt = inplace_update_stmt .replace ("res[()]" , "res" )
216- return_obj = "np.asarray(res)"
253+ if complex_to_real :
254+ return_obj = f"np.array(res).real.astype({ out_dtype_str } )"
255+ else :
256+ return_obj = f"np.array(res, dtype={ out_dtype_str } )"
217257 else :
218- res_creator = (
219- f"np.full({ res_shape } , np.asarray({ identity } ).item(), dtype=out_dtype)"
220- )
221- return_obj = "res"
258+ res_creator = f"np.full(res_shape, identity, dtype={ acc_dtype_str } )"
259+ if complex_to_real :
260+ return_obj = f"res.real.astype({ out_dtype_str } )"
261+ else :
262+ return_obj = (
263+ "res" if out_dtype == acc_dtype else f"res.astype({ out_dtype_str } )"
264+ )
222265
223266 if keepdims :
224267 [axis ] = axes
@@ -229,6 +272,7 @@ def careduce_add(x):
229272 def { careduce_fn_name } (x):
230273 x_shape = x.shape
231274 res_shape = { res_shape }
275+ # identity = { identity }
232276 res = { res_creator }
233277 """
234278 )
@@ -238,13 +282,12 @@ def {careduce_fn_name}(x):
238282 " " * (4 + 4 * axis ),
239283 )
240284 careduce_def_src += indent (inplace_update_stmt , " " * (4 + 4 * ndim ))
241- careduce_def_src += "\n \n "
285+ careduce_def_src += "\n "
242286 careduce_def_src += indent (f"return { return_obj } " , " " * 4 )
243287
244288 careduce_fn = compile_numba_function_src (
245- careduce_def_src , careduce_fn_name , { ** globals (), ** global_env }
289+ careduce_def_src , careduce_fn_name , globals () | { "np" : np , "identity" : identity }
246290 )
247-
248291 return careduce_fn
249292
250293
@@ -356,41 +399,45 @@ def numba_funcify_CAReduce(op, node, **kwargs):
356399 acc_dtype = op .acc_dtype
357400 else :
358401 acc_dtype = node .outputs [0 ].type .dtype
359- np_acc_dtype = np .dtype (acc_dtype )
360-
361- scalar_op_identity = op .scalar_op .identity
362- if np_acc_dtype .kind == "i" and not np .isfinite (scalar_op_identity ):
363- if np .isposinf (scalar_op_identity ):
364- scalar_op_identity = np .iinfo (np_acc_dtype ).max
365- else :
366- scalar_op_identity = np .iinfo (np_acc_dtype ).min
367- # Make sure it has the correct dtype
368- scalar_op_identity = np .array (scalar_op_identity , dtype = np_acc_dtype )
369402
370403 out_dtype = np .dtype (node .outputs [0 ].type .dtype )
371404
372- if isinstance (op , Sum ) and node .inputs [0 ].ndim == len (axes ):
405+ if (
406+ isinstance (op , Sum )
407+ and node .inputs [0 ].ndim == len (axes )
408+ and out_dtype == acc_dtype
409+ ):
373410 # Slightly faster for this case
374411 @numba_basic .numba_njit
375412 def impl_sum (array ):
376- return np .asarray (array .sum (), dtype = np_acc_dtype ). astype ( out_dtype )
413+ return np .array (array .sum ())
377414
378415 careduce_fn = impl_sum # Some tests look for this name
379416
380417 else :
381418 ndim = node .inputs [0 ].ndim
382419 careduce_py_fn = create_multiaxis_reducer (
383420 op .scalar_op ,
384- scalar_op_identity ,
385- axes ,
386- ndim ,
387- out_dtype ,
421+ identity = op .scalar_op .identity ,
422+ axes = axes ,
423+ ndim = ndim ,
424+ acc_dtype = acc_dtype ,
425+ out_dtype = out_dtype ,
388426 )
389427 careduce_fn = numba_basic .numba_njit (careduce_py_fn , boundscheck = False )
390428
429+ cache_version = 1
391430 careduce_key = sha256 (
392431 str (
393- (type (op ), type (op .scalar_op ), axes , acc_dtype , scalar_op_identity .item ())
432+ (
433+ type (op ),
434+ type (op .scalar_op ),
435+ axes ,
436+ out_dtype ,
437+ acc_dtype ,
438+ op .scalar_op .identity ,
439+ cache_version ,
440+ )
394441 ).encode ()
395442 ).hexdigest ()
396443 return careduce_fn , careduce_key
@@ -449,18 +496,26 @@ def dimshuffle(x):
449496
450497@register_funcify_default_op_cache_key (Softmax )
451498def numba_funcify_Softmax (op , node , ** kwargs ):
452- x_at = node .inputs [0 ]
453- x_dtype = x_at .type .numpy_dtype
454- x_dtype = numba .np .numpy_support .from_dtype (x_dtype )
499+ ndim = node .inputs [0 ].type .ndim
500+ inp_dtype = node .inputs [0 ].type .numpy_dtype
455501 axis = op .axis
456502
457- if axis is not None :
458- axis = normalize_axis_index (axis , x_at .ndim )
503+ if ndim > 1 and axis is not None :
459504 reduce_max_py = create_multiaxis_reducer (
460- maximum , - np .inf , axis , x_at .ndim , x_dtype , keepdims = True
505+ maximum ,
506+ identity = - np .inf ,
507+ axes = (axis ,),
508+ ndim = ndim ,
509+ out_dtype = inp_dtype ,
510+ keepdims = True ,
461511 )
462512 reduce_sum_py = create_multiaxis_reducer (
463- add_as , 0.0 , (axis ,), x_at .ndim , x_dtype , keepdims = True
513+ add_as ,
514+ identity = 0.0 ,
515+ axes = (axis ,),
516+ ndim = ndim ,
517+ out_dtype = inp_dtype ,
518+ keepdims = True ,
464519 )
465520
466521 jit_fn = numba_basic .numba_njit (boundscheck = False )
@@ -470,66 +525,72 @@ def numba_funcify_Softmax(op, node, **kwargs):
470525 reduce_max = np .max
471526 reduce_sum = np .sum
472527
473- def softmax_py_fn (x ):
528+ @numba_basic .numba_njit (boundscheck = False )
529+ def softmax (x ):
474530 z = reduce_max (x )
475531 e_x = np .exp (x - z )
476532 w = reduce_sum (e_x )
477533 sm = e_x / w
478534 return sm
479535
480- softmax = numba_basic .numba_njit (softmax_py_fn , boundscheck = False )
481-
482- return softmax
536+ cache_version = 1
537+ return softmax , cache_version
483538
484539
485540@register_funcify_default_op_cache_key (SoftmaxGrad )
486541def numba_funcify_SoftmaxGrad (op , node , ** kwargs ):
487- sm_at = node .inputs [1 ]
488- sm_dtype = sm_at .type .numpy_dtype
489- sm_dtype = numba .np .numpy_support .from_dtype (sm_dtype )
542+ ndim = node .inputs [0 ].type .ndim
543+ inp_dtype = node .inputs [0 ].type .numpy_dtype
490544
491545 axis = op .axis
492- if axis is not None :
493- axis = normalize_axis_index (axis , sm_at .ndim )
546+ if ndim > 1 and axis is not None :
494547 reduce_sum_py = create_multiaxis_reducer (
495- add_as , 0.0 , (axis ,), sm_at .ndim , sm_dtype , keepdims = True
548+ add_as ,
549+ identity = 0.0 ,
550+ axes = (axis ,),
551+ ndim = ndim ,
552+ out_dtype = inp_dtype ,
553+ keepdims = True ,
496554 )
497555
498556 jit_fn = numba_basic .numba_njit (boundscheck = False )
499557 reduce_sum = jit_fn (reduce_sum_py )
500558 else :
501559 reduce_sum = np .sum
502560
503- def softmax_grad_py_fn (dy , sm ):
561+ @numba_basic .numba_njit (boundscheck = False )
562+ def softmax_grad (dy , sm ):
504563 dy_times_sm = dy * sm
505564 sum_dy_times_sm = reduce_sum (dy_times_sm )
506565 dx = dy_times_sm - sum_dy_times_sm * sm
507566 return dx
508567
509- softmax_grad = numba_basic .numba_njit (softmax_grad_py_fn , boundscheck = False )
510-
511- return softmax_grad
568+ cache_version = 1
569+ return softmax_grad , cache_version
512570
513571
514572@register_funcify_default_op_cache_key (LogSoftmax )
515573def numba_funcify_LogSoftmax (op , node , ** kwargs ):
516- x_at = node .inputs [0 ]
517- x_dtype = x_at .type .numpy_dtype
518- x_dtype = numba .np .numpy_support .from_dtype (x_dtype )
574+ ndim = node .inputs [0 ].type .ndim
575+ inp_dtype = node .inputs [0 ].type .numpy_dtype
519576 axis = op .axis
520577
521- if axis is not None :
522- axis = normalize_axis_index (axis , x_at .ndim )
578+ if ndim > 1 and axis is not None :
523579 reduce_max_py = create_multiaxis_reducer (
524580 maximum ,
525- - np .inf ,
526- (axis ,),
527- x_at . ndim ,
528- x_dtype ,
581+ identity = - np .inf ,
582+ axes = (axis ,),
583+ ndim = ndim ,
584+ out_dtype = inp_dtype ,
529585 keepdims = True ,
530586 )
531587 reduce_sum_py = create_multiaxis_reducer (
532- add_as , 0.0 , (axis ,), x_at .ndim , x_dtype , keepdims = True
588+ add_as ,
589+ identity = 0.0 ,
590+ axes = (axis ,),
591+ ndim = ndim ,
592+ out_dtype = inp_dtype ,
593+ keepdims = True ,
533594 )
534595
535596 jit_fn = numba_basic .numba_njit (boundscheck = False )
@@ -539,13 +600,14 @@ def numba_funcify_LogSoftmax(op, node, **kwargs):
539600 reduce_max = np .max
540601 reduce_sum = np .sum
541602
542- def log_softmax_py_fn (x ):
603+ @numba_basic .numba_njit (boundscheck = False )
604+ def log_softmax (x ):
543605 xdev = x - reduce_max (x )
544606 lsm = xdev - np .log (reduce_sum (np .exp (xdev )))
545607 return lsm
546608
547- log_softmax = numba_basic . numba_njit ( log_softmax_py_fn , boundscheck = False )
548- return log_softmax
609+ cache_version = 1
610+ return log_softmax , cache_version
549611
550612
551613@register_funcify_default_op_cache_key (Argmax )
0 commit comments