Skip to content

Commit ad960d6

Browse files
committed
Numba CAReduce: respect acc_dtype
Also fix infinity identities for unsigned integers
1 parent 3ff7603 commit ad960d6

File tree

4 files changed

+193
-83
lines changed

4 files changed

+193
-83
lines changed

pytensor/link/numba/dispatch/elemwise.py

Lines changed: 145 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
from hashlib import sha256
33
from textwrap import dedent, indent
44

5-
import numba
65
import numpy as np
76
from numba.core.extending import overload
87
from numpy.lib.array_utils import normalize_axis_index, normalize_axis_tuple
@@ -14,6 +13,7 @@
1413
)
1514
from pytensor.link.numba.dispatch import basic as numba_basic
1615
from pytensor.link.numba.dispatch.basic import (
16+
create_tuple_string,
1717
numba_funcify_and_cache_key,
1818
register_funcify_and_cache_key,
1919
register_funcify_default_op_cache_key,
@@ -125,10 +125,12 @@ def scalar_in_place_fn_Minimum(op, idx, res, arr):
125125

126126
def create_multiaxis_reducer(
127127
scalar_op,
128+
*,
128129
identity,
129130
axes,
130131
ndim,
131-
dtype,
132+
acc_dtype=None,
133+
out_dtype,
132134
keepdims: bool = False,
133135
):
134136
r"""Construct a function that reduces multiple axes.
@@ -138,17 +140,46 @@ def create_multiaxis_reducer(
138140
.. code-block:: python
139141
140142
def careduce_add(x):
141-
# For x.ndim == 3 and axes == (0, 1) and scalar_op == "Add"
142143
x_shape = x.shape
143-
res_shape = x_shape[2]
144-
res = np.full(res_shape, numba_basic.to_scalar(0.0), dtype=out_dtype)
144+
res_shape = (x_shape[0], x_shape[1])
145+
# identity = 0.0
146+
res = np.full(res_shape, identity, dtype=np.float64)
147+
for i0 in range(x_shape[0]):
148+
for i1 in range(x_shape[1]):
149+
for i2 in range(x_shape[2]):
150+
res[i0, i1] += x[i0, i1, i2]
151+
return res
152+
153+
If accumulation dtype differs from output_dtype
154+
155+
.. code-block:: python
145156
157+
def careduce_add(x):
158+
x_shape = x.shape
159+
res_shape = (x_shape[0], x_shape[1])
160+
# identity = 0.0
161+
res = np.full(res_shape, identity, dtype=np.float64)
146162
for i0 in range(x_shape[0]):
147163
for i1 in range(x_shape[1]):
148164
for i2 in range(x_shape[2]):
149-
res[i2] += x[i0, i1, i2]
165+
res[i0, i1] += x[i0, i1, i2]
166+
return res.astype(np.int32)
167+
168+
Full reductions accumulate on scalars
169+
170+
.. code-block:: python
171+
172+
def careduce_mul(x):
173+
x_shape = x.shape
174+
res_shape = ()
175+
# identity = 1.0
176+
res = identity
177+
for i0 in range(x_shape[0]):
178+
for i1 in range(x_shape[1]):
179+
for i2 in range(x_shape[2]):
180+
res *= x[i0, i1, i2]
181+
return np.array(res, dtype=np.int32)
150182
151-
return res
152183
153184
Parameters
154185
==========
@@ -160,7 +191,9 @@ def careduce_add(x):
160191
The axes to reduce.
161192
ndim:
162193
The number of dimensions of the input variable.
163-
dtype:
194+
acc_dtype: dtype, optional
195+
The data type used during accumulation. Defaults to out_dtype if not provided
196+
out_dtype:
164197
The data type of the result.
165198
keepdims: boolean, default False
166199
Whether to keep the reduced dimensions.
@@ -178,19 +211,23 @@ def careduce_add(x):
178211
"Cannot keep multiple dimensions when reducing multiple axes"
179212
)
180213

214+
out_dtype = np.dtype(out_dtype)
215+
acc_dtype = out_dtype if acc_dtype is None else np.dtype(acc_dtype)
216+
# Numba doesn't allow converting complex to real with a simple `astype`
217+
complex_to_real = acc_dtype.kind == "c" and out_dtype.kind != "c"
218+
out_dtype_str = f"np.{out_dtype.name}"
219+
acc_dtype_str = f"np.{acc_dtype.name}"
181220
careduce_fn_name = f"careduce_{scalar_op}"
182221

183-
identity = str(identity)
184-
if identity == "inf":
185-
identity = "np.inf"
186-
elif identity == "-inf":
187-
identity = "-np.inf"
188-
189-
global_env = {
190-
"np": np,
191-
"numba_basic": numba_basic,
192-
"out_dtype": dtype,
193-
}
222+
if acc_dtype.kind in "ui" and not np.isfinite(identity):
223+
if np.isposinf(identity):
224+
identity = np.iinfo(acc_dtype).max
225+
else:
226+
identity = np.iinfo(acc_dtype).min
227+
228+
# Make sure it has the correct dtype
229+
identity = getattr(np, acc_dtype.name)(identity)
230+
194231
complete_reduction = len(axes) == ndim
195232
kept_axis = tuple(i for i in range(ndim) if i not in axes)
196233

@@ -208,17 +245,23 @@ def careduce_add(x):
208245
scalar_op, res_indices, "res", f"x[{arr_indices}]"
209246
)
210247

211-
res_shape = f"({', '.join(f'x_shape[{i}]' for i in kept_axis)})"
248+
res_shape = create_tuple_string([f"x_shape[{i}]" for i in kept_axis])
212249
if complete_reduction and ndim > 0:
213250
# We accumulate on a scalar, not an array
214-
res_creator = f"np.asarray({identity}).astype(out_dtype).item()"
251+
res_creator = "identity"
215252
inplace_update_stmt = inplace_update_stmt.replace("res[()]", "res")
216-
return_obj = "np.asarray(res)"
253+
if complex_to_real:
254+
return_obj = f"np.array(res).real.astype({out_dtype_str})"
255+
else:
256+
return_obj = f"np.array(res, dtype={out_dtype_str})"
217257
else:
218-
res_creator = (
219-
f"np.full({res_shape}, np.asarray({identity}).item(), dtype=out_dtype)"
220-
)
221-
return_obj = "res"
258+
res_creator = f"np.full(res_shape, identity, dtype={acc_dtype_str})"
259+
if complex_to_real:
260+
return_obj = f"res.real.astype({out_dtype_str})"
261+
else:
262+
return_obj = (
263+
"res" if out_dtype == acc_dtype else f"res.astype({out_dtype_str})"
264+
)
222265

223266
if keepdims:
224267
[axis] = axes
@@ -229,6 +272,7 @@ def careduce_add(x):
229272
def {careduce_fn_name}(x):
230273
x_shape = x.shape
231274
res_shape = {res_shape}
275+
# identity = {identity}
232276
res = {res_creator}
233277
"""
234278
)
@@ -238,13 +282,12 @@ def {careduce_fn_name}(x):
238282
" " * (4 + 4 * axis),
239283
)
240284
careduce_def_src += indent(inplace_update_stmt, " " * (4 + 4 * ndim))
241-
careduce_def_src += "\n\n"
285+
careduce_def_src += "\n"
242286
careduce_def_src += indent(f"return {return_obj}", " " * 4)
243287

244288
careduce_fn = compile_numba_function_src(
245-
careduce_def_src, careduce_fn_name, {**globals(), **global_env}
289+
careduce_def_src, careduce_fn_name, globals() | {"np": np, "identity": identity}
246290
)
247-
248291
return careduce_fn
249292

250293

@@ -356,41 +399,45 @@ def numba_funcify_CAReduce(op, node, **kwargs):
356399
acc_dtype = op.acc_dtype
357400
else:
358401
acc_dtype = node.outputs[0].type.dtype
359-
np_acc_dtype = np.dtype(acc_dtype)
360-
361-
scalar_op_identity = op.scalar_op.identity
362-
if np_acc_dtype.kind == "i" and not np.isfinite(scalar_op_identity):
363-
if np.isposinf(scalar_op_identity):
364-
scalar_op_identity = np.iinfo(np_acc_dtype).max
365-
else:
366-
scalar_op_identity = np.iinfo(np_acc_dtype).min
367-
# Make sure it has the correct dtype
368-
scalar_op_identity = np.array(scalar_op_identity, dtype=np_acc_dtype)
369402

370403
out_dtype = np.dtype(node.outputs[0].type.dtype)
371404

372-
if isinstance(op, Sum) and node.inputs[0].ndim == len(axes):
405+
if (
406+
isinstance(op, Sum)
407+
and node.inputs[0].ndim == len(axes)
408+
and out_dtype == acc_dtype
409+
):
373410
# Slightly faster for this case
374411
@numba_basic.numba_njit
375412
def impl_sum(array):
376-
return np.asarray(array.sum(), dtype=np_acc_dtype).astype(out_dtype)
413+
return np.array(array.sum())
377414

378415
careduce_fn = impl_sum # Some tests look for this name
379416

380417
else:
381418
ndim = node.inputs[0].ndim
382419
careduce_py_fn = create_multiaxis_reducer(
383420
op.scalar_op,
384-
scalar_op_identity,
385-
axes,
386-
ndim,
387-
out_dtype,
421+
identity=op.scalar_op.identity,
422+
axes=axes,
423+
ndim=ndim,
424+
acc_dtype=acc_dtype,
425+
out_dtype=out_dtype,
388426
)
389427
careduce_fn = numba_basic.numba_njit(careduce_py_fn, boundscheck=False)
390428

429+
cache_version = 1
391430
careduce_key = sha256(
392431
str(
393-
(type(op), type(op.scalar_op), axes, acc_dtype, scalar_op_identity.item())
432+
(
433+
type(op),
434+
type(op.scalar_op),
435+
axes,
436+
out_dtype,
437+
acc_dtype,
438+
op.scalar_op.identity,
439+
cache_version,
440+
)
394441
).encode()
395442
).hexdigest()
396443
return careduce_fn, careduce_key
@@ -449,18 +496,27 @@ def dimshuffle(x):
449496

450497
@register_funcify_default_op_cache_key(Softmax)
451498
def numba_funcify_Softmax(op, node, **kwargs):
452-
x_at = node.inputs[0]
453-
x_dtype = x_at.type.numpy_dtype
454-
x_dtype = numba.np.numpy_support.from_dtype(x_dtype)
499+
ndim = node.inputs[0].type.ndim
500+
inp_dtype = node.inputs[0].type.numpy_dtype
455501
axis = op.axis
456502

457-
if axis is not None:
458-
axis = normalize_axis_index(axis, x_at.ndim)
503+
if ndim > 1 and axis is not None:
504+
axis = normalize_axis_index(axis, ndim)
459505
reduce_max_py = create_multiaxis_reducer(
460-
maximum, -np.inf, axis, x_at.ndim, x_dtype, keepdims=True
506+
maximum,
507+
identity=-np.inf,
508+
axes=axis,
509+
ndim=ndim,
510+
out_dtype=inp_dtype,
511+
keepdims=True,
461512
)
462513
reduce_sum_py = create_multiaxis_reducer(
463-
add_as, 0.0, (axis,), x_at.ndim, x_dtype, keepdims=True
514+
add_as,
515+
identity=0.0,
516+
axes=(axis,),
517+
ndim=ndim,
518+
out_dtype=inp_dtype,
519+
keepdims=True,
464520
)
465521

466522
jit_fn = numba_basic.numba_njit(boundscheck=False)
@@ -470,66 +526,74 @@ def numba_funcify_Softmax(op, node, **kwargs):
470526
reduce_max = np.max
471527
reduce_sum = np.sum
472528

473-
def softmax_py_fn(x):
529+
@numba_basic.numba_njit(boundscheck=False)
530+
def softmax(x):
474531
z = reduce_max(x)
475532
e_x = np.exp(x - z)
476533
w = reduce_sum(e_x)
477534
sm = e_x / w
478535
return sm
479536

480-
softmax = numba_basic.numba_njit(softmax_py_fn, boundscheck=False)
481-
482-
return softmax
537+
cache_version = 1
538+
return softmax, cache_version
483539

484540

485541
@register_funcify_default_op_cache_key(SoftmaxGrad)
486542
def numba_funcify_SoftmaxGrad(op, node, **kwargs):
487-
sm_at = node.inputs[1]
488-
sm_dtype = sm_at.type.numpy_dtype
489-
sm_dtype = numba.np.numpy_support.from_dtype(sm_dtype)
543+
ndim = node.inputs[0].type.ndim
544+
inp_dtype = node.inputs[0].type.numpy_dtype
490545

491546
axis = op.axis
492-
if axis is not None:
493-
axis = normalize_axis_index(axis, sm_at.ndim)
547+
if ndim > 1 and axis is not None:
548+
axis = normalize_axis_index(axis, ndim)
494549
reduce_sum_py = create_multiaxis_reducer(
495-
add_as, 0.0, (axis,), sm_at.ndim, sm_dtype, keepdims=True
550+
add_as,
551+
identity=0.0,
552+
axes=(axis,),
553+
ndim=ndim,
554+
out_dtype=inp_dtype,
555+
keepdims=True,
496556
)
497557

498558
jit_fn = numba_basic.numba_njit(boundscheck=False)
499559
reduce_sum = jit_fn(reduce_sum_py)
500560
else:
501561
reduce_sum = np.sum
502562

503-
def softmax_grad_py_fn(dy, sm):
563+
@numba_basic.numba_njit(boundscheck=False)
564+
def softmax_grad(dy, sm):
504565
dy_times_sm = dy * sm
505566
sum_dy_times_sm = reduce_sum(dy_times_sm)
506567
dx = dy_times_sm - sum_dy_times_sm * sm
507568
return dx
508569

509-
softmax_grad = numba_basic.numba_njit(softmax_grad_py_fn, boundscheck=False)
510-
511-
return softmax_grad
570+
cache_version = 1
571+
return softmax_grad, cache_version
512572

513573

514574
@register_funcify_default_op_cache_key(LogSoftmax)
515575
def numba_funcify_LogSoftmax(op, node, **kwargs):
516-
x_at = node.inputs[0]
517-
x_dtype = x_at.type.numpy_dtype
518-
x_dtype = numba.np.numpy_support.from_dtype(x_dtype)
576+
ndim = node.inputs[0].type.ndim
577+
inp_dtype = node.inputs[0].type.numpy_dtype
519578
axis = op.axis
520579

521-
if axis is not None:
522-
axis = normalize_axis_index(axis, x_at.ndim)
580+
if ndim > 1 and axis is not None:
581+
axis = normalize_axis_index(axis, ndim)
523582
reduce_max_py = create_multiaxis_reducer(
524583
maximum,
525-
-np.inf,
526-
(axis,),
527-
x_at.ndim,
528-
x_dtype,
584+
identity=-np.inf,
585+
axes=(axis,),
586+
ndim=ndim,
587+
out_dtype=inp_dtype,
529588
keepdims=True,
530589
)
531590
reduce_sum_py = create_multiaxis_reducer(
532-
add_as, 0.0, (axis,), x_at.ndim, x_dtype, keepdims=True
591+
add_as,
592+
identity=0.0,
593+
axes=(axis,),
594+
ndim=ndim,
595+
out_dtype=inp_dtype,
596+
keepdims=True,
533597
)
534598

535599
jit_fn = numba_basic.numba_njit(boundscheck=False)
@@ -539,13 +603,14 @@ def numba_funcify_LogSoftmax(op, node, **kwargs):
539603
reduce_max = np.max
540604
reduce_sum = np.sum
541605

542-
def log_softmax_py_fn(x):
606+
@numba_basic.numba_njit(boundscheck=False)
607+
def log_softmax(x):
543608
xdev = x - reduce_max(x)
544609
lsm = xdev - np.log(reduce_sum(np.exp(xdev)))
545610
return lsm
546611

547-
log_softmax = numba_basic.numba_njit(log_softmax_py_fn, boundscheck=False)
548-
return log_softmax
612+
cache_version = 1
613+
return log_softmax, cache_version
549614

550615

551616
@register_funcify_default_op_cache_key(Argmax)

0 commit comments

Comments
 (0)