Skip to content

Commit 3c61b88

Browse files
committed
Numba CAReduce: respect acc_dtype
Also fix infinity identities for unsigned integers
1 parent a782753 commit 3c61b88

File tree

4 files changed

+193
-83
lines changed

4 files changed

+193
-83
lines changed

pytensor/link/numba/dispatch/elemwise.py

Lines changed: 145 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
from hashlib import sha256
33
from textwrap import dedent, indent
44

5-
import numba
65
import numpy as np
76
from numba.core.extending import overload
87
from numpy.lib.array_utils import normalize_axis_index, normalize_axis_tuple
@@ -14,6 +13,7 @@
1413
)
1514
from pytensor.link.numba.dispatch import basic as numba_basic
1615
from pytensor.link.numba.dispatch.basic import (
16+
create_tuple_string,
1717
numba_funcify_and_cache_key,
1818
register_funcify_and_cache_key,
1919
register_funcify_default_op_cache_key,
@@ -125,10 +125,12 @@ def scalar_in_place_fn_Minimum(op, idx, res, arr):
125125

126126
def create_multiaxis_reducer(
127127
scalar_op,
128+
*,
128129
identity,
129130
axes,
130131
ndim,
131-
dtype,
132+
acc_dtype=None,
133+
out_dtype,
132134
keepdims: bool = False,
133135
):
134136
r"""Construct a function that reduces multiple axes.
@@ -138,17 +140,46 @@ def create_multiaxis_reducer(
138140
.. code-block:: python
139141
140142
def careduce_add(x):
141-
# For x.ndim == 3 and axes == (0, 1) and scalar_op == "Add"
142143
x_shape = x.shape
143-
res_shape = x_shape[2]
144-
res = np.full(res_shape, numba_basic.to_scalar(0.0), dtype=out_dtype)
144+
res_shape = (x_shape[0], x_shape[1])
145+
# identity = 0.0
146+
res = np.full(res_shape, identity, dtype=np.float64)
147+
for i0 in range(x_shape[0]):
148+
for i1 in range(x_shape[1]):
149+
for i2 in range(x_shape[2]):
150+
res[i0, i1] += x[i0, i1, i2]
151+
return res
152+
153+
If accumulation dtype differs from output_dtype
154+
155+
.. code-block:: python
145156
157+
def careduce_add(x):
158+
x_shape = x.shape
159+
res_shape = (x_shape[0], x_shape[1])
160+
# identity = 0.0
161+
res = np.full(res_shape, identity, dtype=np.float64)
146162
for i0 in range(x_shape[0]):
147163
for i1 in range(x_shape[1]):
148164
for i2 in range(x_shape[2]):
149-
res[i2] += x[i0, i1, i2]
165+
res[i0, i1] += x[i0, i1, i2]
166+
return res.astype(np.int32)
167+
168+
Full reductions accumulate on scalars
169+
170+
.. code-block:: python
171+
172+
def careduce_mul(x):
173+
x_shape = x.shape
174+
res_shape = ()
175+
# identity = 1.0
176+
res = identity
177+
for i0 in range(x_shape[0]):
178+
for i1 in range(x_shape[1]):
179+
for i2 in range(x_shape[2]):
180+
res *= x[i0, i1, i2]
181+
return np.array(res, dtype=np.int32)
150182
151-
return res
152183
153184
Parameters
154185
==========
@@ -160,7 +191,9 @@ def careduce_add(x):
160191
The axes to reduce.
161192
ndim:
162193
The number of dimensions of the input variable.
163-
dtype:
194+
acc_dtype: dtype, optional
195+
The data type used during accumulation. Defaults to out_dtype if not provided
196+
out_dtype:
164197
The data type of the result.
165198
keepdims: boolean, default False
166199
Whether to keep the reduced dimensions.
@@ -178,19 +211,23 @@ def careduce_add(x):
178211
"Cannot keep multiple dimensions when reducing multiple axes"
179212
)
180213

214+
out_dtype = np.dtype(out_dtype)
215+
acc_dtype = out_dtype if acc_dtype is None else np.dtype(acc_dtype)
216+
# Numba doesn't allow converting complex to real with a simple `astype`
217+
complex_to_real = acc_dtype.kind == "c" and out_dtype.kind != "c"
218+
out_dtype_str = f"np.{out_dtype.name}"
219+
acc_dtype_str = f"np.{acc_dtype.name}"
181220
careduce_fn_name = f"careduce_{scalar_op}"
182221

183-
identity = str(identity)
184-
if identity == "inf":
185-
identity = "np.inf"
186-
elif identity == "-inf":
187-
identity = "-np.inf"
188-
189-
global_env = {
190-
"np": np,
191-
"numba_basic": numba_basic,
192-
"out_dtype": dtype,
193-
}
222+
if acc_dtype.kind in "ui" and not np.isfinite(identity):
223+
if np.isposinf(identity):
224+
identity = np.iinfo(acc_dtype).max
225+
else:
226+
identity = np.iinfo(acc_dtype).min
227+
228+
# Make sure it has the correct dtype
229+
identity = getattr(np, acc_dtype.name)(identity)
230+
194231
complete_reduction = len(axes) == ndim
195232
kept_axis = tuple(i for i in range(ndim) if i not in axes)
196233

@@ -208,17 +245,23 @@ def careduce_add(x):
208245
scalar_op, res_indices, "res", f"x[{arr_indices}]"
209246
)
210247

211-
res_shape = f"({', '.join(f'x_shape[{i}]' for i in kept_axis)})"
248+
res_shape = create_tuple_string([f"x_shape[{i}]" for i in kept_axis])
212249
if complete_reduction and ndim > 0:
213250
# We accumulate on a scalar, not an array
214-
res_creator = f"np.asarray({identity}).astype(out_dtype).item()"
251+
res_creator = "identity"
215252
inplace_update_stmt = inplace_update_stmt.replace("res[()]", "res")
216-
return_obj = "np.asarray(res)"
253+
if complex_to_real:
254+
return_obj = f"np.array(res).real.astype({out_dtype_str})"
255+
else:
256+
return_obj = f"np.array(res, dtype={out_dtype_str})"
217257
else:
218-
res_creator = (
219-
f"np.full({res_shape}, np.asarray({identity}).item(), dtype=out_dtype)"
220-
)
221-
return_obj = "res"
258+
res_creator = f"np.full(res_shape, identity, dtype={acc_dtype_str})"
259+
if complex_to_real:
260+
return_obj = f"res.real.astype({out_dtype_str})"
261+
else:
262+
return_obj = (
263+
"res" if out_dtype == acc_dtype else f"res.astype({out_dtype_str})"
264+
)
222265

223266
if keepdims:
224267
[axis] = axes
@@ -229,6 +272,7 @@ def careduce_add(x):
229272
def {careduce_fn_name}(x):
230273
x_shape = x.shape
231274
res_shape = {res_shape}
275+
# identity = {identity}
232276
res = {res_creator}
233277
"""
234278
)
@@ -238,13 +282,12 @@ def {careduce_fn_name}(x):
238282
" " * (4 + 4 * axis),
239283
)
240284
careduce_def_src += indent(inplace_update_stmt, " " * (4 + 4 * ndim))
241-
careduce_def_src += "\n\n"
285+
careduce_def_src += "\n"
242286
careduce_def_src += indent(f"return {return_obj}", " " * 4)
243287

244288
careduce_fn = compile_numba_function_src(
245-
careduce_def_src, careduce_fn_name, {**globals(), **global_env}
289+
careduce_def_src, careduce_fn_name, globals() | {"np": np, "identity": identity}
246290
)
247-
248291
return careduce_fn
249292

250293

@@ -356,41 +399,45 @@ def numba_funcify_CAReduce(op, node, **kwargs):
356399
acc_dtype = op.acc_dtype
357400
else:
358401
acc_dtype = node.outputs[0].type.dtype
359-
np_acc_dtype = np.dtype(acc_dtype)
360-
361-
scalar_op_identity = op.scalar_op.identity
362-
if np_acc_dtype.kind == "i" and not np.isfinite(scalar_op_identity):
363-
if np.isposinf(scalar_op_identity):
364-
scalar_op_identity = np.iinfo(np_acc_dtype).max
365-
else:
366-
scalar_op_identity = np.iinfo(np_acc_dtype).min
367-
# Make sure it has the correct dtype
368-
scalar_op_identity = np.array(scalar_op_identity, dtype=np_acc_dtype)
369402

370403
out_dtype = np.dtype(node.outputs[0].type.dtype)
371404

372-
if isinstance(op, Sum) and node.inputs[0].ndim == len(axes):
405+
if (
406+
isinstance(op, Sum)
407+
and node.inputs[0].ndim == len(axes)
408+
and out_dtype == acc_dtype
409+
):
373410
# Slightly faster for this case
374411
@numba_basic.numba_njit
375412
def impl_sum(array):
376-
return np.asarray(array.sum(), dtype=np_acc_dtype).astype(out_dtype)
413+
return np.array(array.sum())
377414

378415
careduce_fn = impl_sum # Some tests look for this name
379416

380417
else:
381418
ndim = node.inputs[0].ndim
382419
careduce_py_fn = create_multiaxis_reducer(
383420
op.scalar_op,
384-
scalar_op_identity,
385-
axes,
386-
ndim,
387-
out_dtype,
421+
identity=op.scalar_op.identity,
422+
axes=axes,
423+
ndim=ndim,
424+
acc_dtype=acc_dtype,
425+
out_dtype=out_dtype,
388426
)
389427
careduce_fn = numba_basic.numba_njit(careduce_py_fn, boundscheck=False)
390428

429+
cache_version = 1
391430
careduce_key = sha256(
392431
str(
393-
(type(op), type(op.scalar_op), axes, acc_dtype, scalar_op_identity.item())
432+
(
433+
type(op),
434+
type(op.scalar_op),
435+
axes,
436+
out_dtype,
437+
acc_dtype,
438+
op.scalar_op.identity,
439+
cache_version,
440+
)
394441
).encode()
395442
).hexdigest()
396443
return careduce_fn, careduce_key
@@ -436,18 +483,27 @@ def dimshuffle(x):
436483

437484
@register_funcify_default_op_cache_key(Softmax)
438485
def numba_funcify_Softmax(op, node, **kwargs):
439-
x_at = node.inputs[0]
440-
x_dtype = x_at.type.numpy_dtype
441-
x_dtype = numba.np.numpy_support.from_dtype(x_dtype)
486+
ndim = node.inputs[0].type.ndim
487+
inp_dtype = node.inputs[0].type.numpy_dtype
442488
axis = op.axis
443489

444-
if axis is not None:
445-
axis = normalize_axis_index(axis, x_at.ndim)
490+
if ndim > 1 and axis is not None:
491+
axis = normalize_axis_index(axis, ndim)
446492
reduce_max_py = create_multiaxis_reducer(
447-
maximum, -np.inf, axis, x_at.ndim, x_dtype, keepdims=True
493+
maximum,
494+
identity=-np.inf,
495+
axes=axis,
496+
ndim=ndim,
497+
out_dtype=inp_dtype,
498+
keepdims=True,
448499
)
449500
reduce_sum_py = create_multiaxis_reducer(
450-
add_as, 0.0, (axis,), x_at.ndim, x_dtype, keepdims=True
501+
add_as,
502+
identity=0.0,
503+
axes=(axis,),
504+
ndim=ndim,
505+
out_dtype=inp_dtype,
506+
keepdims=True,
451507
)
452508

453509
jit_fn = numba_basic.numba_njit(boundscheck=False)
@@ -457,66 +513,74 @@ def numba_funcify_Softmax(op, node, **kwargs):
457513
reduce_max = np.max
458514
reduce_sum = np.sum
459515

460-
def softmax_py_fn(x):
516+
@numba_basic.numba_njit(boundscheck=False)
517+
def softmax(x):
461518
z = reduce_max(x)
462519
e_x = np.exp(x - z)
463520
w = reduce_sum(e_x)
464521
sm = e_x / w
465522
return sm
466523

467-
softmax = numba_basic.numba_njit(softmax_py_fn, boundscheck=False)
468-
469-
return softmax
524+
cache_version = 1
525+
return softmax, cache_version
470526

471527

472528
@register_funcify_default_op_cache_key(SoftmaxGrad)
473529
def numba_funcify_SoftmaxGrad(op, node, **kwargs):
474-
sm_at = node.inputs[1]
475-
sm_dtype = sm_at.type.numpy_dtype
476-
sm_dtype = numba.np.numpy_support.from_dtype(sm_dtype)
530+
ndim = node.inputs[0].type.ndim
531+
inp_dtype = node.inputs[0].type.numpy_dtype
477532

478533
axis = op.axis
479-
if axis is not None:
480-
axis = normalize_axis_index(axis, sm_at.ndim)
534+
if ndim > 1 and axis is not None:
535+
axis = normalize_axis_index(axis, ndim)
481536
reduce_sum_py = create_multiaxis_reducer(
482-
add_as, 0.0, (axis,), sm_at.ndim, sm_dtype, keepdims=True
537+
add_as,
538+
identity=0.0,
539+
axes=(axis,),
540+
ndim=ndim,
541+
out_dtype=inp_dtype,
542+
keepdims=True,
483543
)
484544

485545
jit_fn = numba_basic.numba_njit(boundscheck=False)
486546
reduce_sum = jit_fn(reduce_sum_py)
487547
else:
488548
reduce_sum = np.sum
489549

490-
def softmax_grad_py_fn(dy, sm):
550+
@numba_basic.numba_njit(boundscheck=False)
551+
def softmax_grad(dy, sm):
491552
dy_times_sm = dy * sm
492553
sum_dy_times_sm = reduce_sum(dy_times_sm)
493554
dx = dy_times_sm - sum_dy_times_sm * sm
494555
return dx
495556

496-
softmax_grad = numba_basic.numba_njit(softmax_grad_py_fn, boundscheck=False)
497-
498-
return softmax_grad
557+
cache_version = 1
558+
return softmax_grad, cache_version
499559

500560

501561
@register_funcify_default_op_cache_key(LogSoftmax)
502562
def numba_funcify_LogSoftmax(op, node, **kwargs):
503-
x_at = node.inputs[0]
504-
x_dtype = x_at.type.numpy_dtype
505-
x_dtype = numba.np.numpy_support.from_dtype(x_dtype)
563+
ndim = node.inputs[0].type.ndim
564+
inp_dtype = node.inputs[0].type.numpy_dtype
506565
axis = op.axis
507566

508-
if axis is not None:
509-
axis = normalize_axis_index(axis, x_at.ndim)
567+
if ndim > 1 and axis is not None:
568+
axis = normalize_axis_index(axis, ndim)
510569
reduce_max_py = create_multiaxis_reducer(
511570
maximum,
512-
-np.inf,
513-
(axis,),
514-
x_at.ndim,
515-
x_dtype,
571+
identity=-np.inf,
572+
axes=(axis,),
573+
ndim=ndim,
574+
out_dtype=inp_dtype,
516575
keepdims=True,
517576
)
518577
reduce_sum_py = create_multiaxis_reducer(
519-
add_as, 0.0, (axis,), x_at.ndim, x_dtype, keepdims=True
578+
add_as,
579+
identity=0.0,
580+
axes=(axis,),
581+
ndim=ndim,
582+
out_dtype=inp_dtype,
583+
keepdims=True,
520584
)
521585

522586
jit_fn = numba_basic.numba_njit(boundscheck=False)
@@ -526,13 +590,14 @@ def numba_funcify_LogSoftmax(op, node, **kwargs):
526590
reduce_max = np.max
527591
reduce_sum = np.sum
528592

529-
def log_softmax_py_fn(x):
593+
@numba_basic.numba_njit(boundscheck=False)
594+
def log_softmax(x):
530595
xdev = x - reduce_max(x)
531596
lsm = xdev - np.log(reduce_sum(np.exp(xdev)))
532597
return lsm
533598

534-
log_softmax = numba_basic.numba_njit(log_softmax_py_fn, boundscheck=False)
535-
return log_softmax
599+
cache_version = 1
600+
return log_softmax, cache_version
536601

537602

538603
@register_funcify_default_op_cache_key(Argmax)

0 commit comments

Comments
 (0)