From 690015f685aa7da99dab672a334aab035b63f19d Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 5 Dec 2025 10:42:05 +0100 Subject: [PATCH] Numba CAReduce: respect acc_dtype Also fix infinity identities for unsigned integers --- pytensor/link/numba/dispatch/elemwise.py | 222 +++++++++++++++-------- pytensor/tensor/elemwise.py | 5 +- pytensor/tensor/math.py | 5 +- tests/link/numba/test_elemwise.py | 41 ++++- 4 files changed, 190 insertions(+), 83 deletions(-) diff --git a/pytensor/link/numba/dispatch/elemwise.py b/pytensor/link/numba/dispatch/elemwise.py index 9b2c9f514c..ad0832bdd9 100644 --- a/pytensor/link/numba/dispatch/elemwise.py +++ b/pytensor/link/numba/dispatch/elemwise.py @@ -2,7 +2,6 @@ from hashlib import sha256 from textwrap import dedent, indent -import numba import numpy as np from numba.core.extending import overload from numpy.lib.array_utils import normalize_axis_index, normalize_axis_tuple @@ -14,6 +13,7 @@ ) from pytensor.link.numba.dispatch import basic as numba_basic from pytensor.link.numba.dispatch.basic import ( + create_tuple_string, numba_funcify_and_cache_key, register_funcify_and_cache_key, register_funcify_default_op_cache_key, @@ -125,10 +125,12 @@ def scalar_in_place_fn_Minimum(op, idx, res, arr): def create_multiaxis_reducer( scalar_op, + *, identity, axes, ndim, - dtype, + acc_dtype=None, + out_dtype, keepdims: bool = False, ): r"""Construct a function that reduces multiple axes. @@ -138,17 +140,46 @@ def create_multiaxis_reducer( .. code-block:: python def careduce_add(x): - # For x.ndim == 3 and axes == (0, 1) and scalar_op == "Add" x_shape = x.shape - res_shape = x_shape[2] - res = np.full(res_shape, numba_basic.to_scalar(0.0), dtype=out_dtype) + res_shape = (x_shape[0], x_shape[1]) + # identity = 0.0 + res = np.full(res_shape, identity, dtype=np.float64) + for i0 in range(x_shape[0]): + for i1 in range(x_shape[1]): + for i2 in range(x_shape[2]): + res[i0, i1] += x[i0, i1, i2] + return res + + If accumulation dtype differs from output_dtype + + .. code-block:: python + def careduce_add(x): + x_shape = x.shape + res_shape = (x_shape[0], x_shape[1]) + # identity = 0.0 + res = np.full(res_shape, identity, dtype=np.float64) for i0 in range(x_shape[0]): for i1 in range(x_shape[1]): for i2 in range(x_shape[2]): - res[i2] += x[i0, i1, i2] + res[i0, i1] += x[i0, i1, i2] + return res.astype(np.int32) + + Full reductions accumulate on scalars + + .. code-block:: python + + def careduce_mul(x): + x_shape = x.shape + res_shape = () + # identity = 1.0 + res = identity + for i0 in range(x_shape[0]): + for i1 in range(x_shape[1]): + for i2 in range(x_shape[2]): + res *= x[i0, i1, i2] + return np.array(res, dtype=np.int32) - return res Parameters ========== @@ -160,7 +191,9 @@ def careduce_add(x): The axes to reduce. ndim: The number of dimensions of the input variable. - dtype: + acc_dtype: dtype, optional + The data type used during accumulation. Defaults to out_dtype if not provided + out_dtype: The data type of the result. keepdims: boolean, default False Whether to keep the reduced dimensions. @@ -178,19 +211,23 @@ def careduce_add(x): "Cannot keep multiple dimensions when reducing multiple axes" ) + out_dtype = np.dtype(out_dtype) + acc_dtype = out_dtype if acc_dtype is None else np.dtype(acc_dtype) + # Numba doesn't allow converting complex to real with a simple `astype` + complex_to_real = acc_dtype.kind == "c" and out_dtype.kind != "c" + out_dtype_str = f"np.{out_dtype.name}" + acc_dtype_str = f"np.{acc_dtype.name}" careduce_fn_name = f"careduce_{scalar_op}" - identity = str(identity) - if identity == "inf": - identity = "np.inf" - elif identity == "-inf": - identity = "-np.inf" - - global_env = { - "np": np, - "numba_basic": numba_basic, - "out_dtype": dtype, - } + if acc_dtype.kind in "ui" and not np.isfinite(identity): + if np.isposinf(identity): + identity = np.iinfo(acc_dtype).max + else: + identity = np.iinfo(acc_dtype).min + + # Make sure it has the correct dtype + identity = getattr(np, acc_dtype.name)(identity) + complete_reduction = len(axes) == ndim kept_axis = tuple(i for i in range(ndim) if i not in axes) @@ -208,17 +245,23 @@ def careduce_add(x): scalar_op, res_indices, "res", f"x[{arr_indices}]" ) - res_shape = f"({', '.join(f'x_shape[{i}]' for i in kept_axis)})" + res_shape = create_tuple_string([f"x_shape[{i}]" for i in kept_axis]) if complete_reduction and ndim > 0: # We accumulate on a scalar, not an array - res_creator = f"np.asarray({identity}).astype(out_dtype).item()" + res_creator = "identity" inplace_update_stmt = inplace_update_stmt.replace("res[()]", "res") - return_obj = "np.asarray(res)" + if complex_to_real: + return_obj = f"np.array(res).real.astype({out_dtype_str})" + else: + return_obj = f"np.array(res, dtype={out_dtype_str})" else: - res_creator = ( - f"np.full({res_shape}, np.asarray({identity}).item(), dtype=out_dtype)" - ) - return_obj = "res" + res_creator = f"np.full(res_shape, identity, dtype={acc_dtype_str})" + if complex_to_real: + return_obj = f"res.real.astype({out_dtype_str})" + else: + return_obj = ( + "res" if out_dtype == acc_dtype else f"res.astype({out_dtype_str})" + ) if keepdims: [axis] = axes @@ -229,6 +272,7 @@ def careduce_add(x): def {careduce_fn_name}(x): x_shape = x.shape res_shape = {res_shape} + # identity = {identity} res = {res_creator} """ ) @@ -238,13 +282,12 @@ def {careduce_fn_name}(x): " " * (4 + 4 * axis), ) careduce_def_src += indent(inplace_update_stmt, " " * (4 + 4 * ndim)) - careduce_def_src += "\n\n" + careduce_def_src += "\n" careduce_def_src += indent(f"return {return_obj}", " " * 4) careduce_fn = compile_numba_function_src( - careduce_def_src, careduce_fn_name, {**globals(), **global_env} + careduce_def_src, careduce_fn_name, globals() | {"np": np, "identity": identity} ) - return careduce_fn @@ -356,24 +399,18 @@ def numba_funcify_CAReduce(op, node, **kwargs): acc_dtype = op.acc_dtype else: acc_dtype = node.outputs[0].type.dtype - np_acc_dtype = np.dtype(acc_dtype) - - scalar_op_identity = op.scalar_op.identity - if np_acc_dtype.kind == "i" and not np.isfinite(scalar_op_identity): - if np.isposinf(scalar_op_identity): - scalar_op_identity = np.iinfo(np_acc_dtype).max - else: - scalar_op_identity = np.iinfo(np_acc_dtype).min - # Make sure it has the correct dtype - scalar_op_identity = np.array(scalar_op_identity, dtype=np_acc_dtype) out_dtype = np.dtype(node.outputs[0].type.dtype) - if isinstance(op, Sum) and node.inputs[0].ndim == len(axes): + if ( + isinstance(op, Sum) + and node.inputs[0].ndim == len(axes) + and out_dtype == acc_dtype + ): # Slightly faster for this case @numba_basic.numba_njit def impl_sum(array): - return np.asarray(array.sum(), dtype=np_acc_dtype).astype(out_dtype) + return np.array(array.sum()) careduce_fn = impl_sum # Some tests look for this name @@ -381,16 +418,26 @@ def impl_sum(array): ndim = node.inputs[0].ndim careduce_py_fn = create_multiaxis_reducer( op.scalar_op, - scalar_op_identity, - axes, - ndim, - out_dtype, + identity=op.scalar_op.identity, + axes=axes, + ndim=ndim, + acc_dtype=acc_dtype, + out_dtype=out_dtype, ) careduce_fn = numba_basic.numba_njit(careduce_py_fn, boundscheck=False) + cache_version = 1 careduce_key = sha256( str( - (type(op), type(op.scalar_op), axes, acc_dtype, scalar_op_identity.item()) + ( + type(op), + type(op.scalar_op), + axes, + out_dtype, + acc_dtype, + op.scalar_op.identity, + cache_version, + ) ).encode() ).hexdigest() return careduce_fn, careduce_key @@ -449,18 +496,26 @@ def dimshuffle(x): @register_funcify_default_op_cache_key(Softmax) def numba_funcify_Softmax(op, node, **kwargs): - x_at = node.inputs[0] - x_dtype = x_at.type.numpy_dtype - x_dtype = numba.np.numpy_support.from_dtype(x_dtype) + ndim = node.inputs[0].type.ndim + inp_dtype = node.inputs[0].type.numpy_dtype axis = op.axis - if axis is not None: - axis = normalize_axis_index(axis, x_at.ndim) + if ndim > 1 and axis is not None: reduce_max_py = create_multiaxis_reducer( - maximum, -np.inf, axis, x_at.ndim, x_dtype, keepdims=True + maximum, + identity=-np.inf, + axes=(axis,), + ndim=ndim, + out_dtype=inp_dtype, + keepdims=True, ) reduce_sum_py = create_multiaxis_reducer( - add_as, 0.0, (axis,), x_at.ndim, x_dtype, keepdims=True + add_as, + identity=0.0, + axes=(axis,), + ndim=ndim, + out_dtype=inp_dtype, + keepdims=True, ) jit_fn = numba_basic.numba_njit(boundscheck=False) @@ -470,29 +525,32 @@ def numba_funcify_Softmax(op, node, **kwargs): reduce_max = np.max reduce_sum = np.sum - def softmax_py_fn(x): + @numba_basic.numba_njit(boundscheck=False) + def softmax(x): z = reduce_max(x) e_x = np.exp(x - z) w = reduce_sum(e_x) sm = e_x / w return sm - softmax = numba_basic.numba_njit(softmax_py_fn, boundscheck=False) - - return softmax + cache_version = 1 + return softmax, cache_version @register_funcify_default_op_cache_key(SoftmaxGrad) def numba_funcify_SoftmaxGrad(op, node, **kwargs): - sm_at = node.inputs[1] - sm_dtype = sm_at.type.numpy_dtype - sm_dtype = numba.np.numpy_support.from_dtype(sm_dtype) + ndim = node.inputs[0].type.ndim + inp_dtype = node.inputs[0].type.numpy_dtype axis = op.axis - if axis is not None: - axis = normalize_axis_index(axis, sm_at.ndim) + if ndim > 1 and axis is not None: reduce_sum_py = create_multiaxis_reducer( - add_as, 0.0, (axis,), sm_at.ndim, sm_dtype, keepdims=True + add_as, + identity=0.0, + axes=(axis,), + ndim=ndim, + out_dtype=inp_dtype, + keepdims=True, ) jit_fn = numba_basic.numba_njit(boundscheck=False) @@ -500,36 +558,39 @@ def numba_funcify_SoftmaxGrad(op, node, **kwargs): else: reduce_sum = np.sum - def softmax_grad_py_fn(dy, sm): + @numba_basic.numba_njit(boundscheck=False) + def softmax_grad(dy, sm): dy_times_sm = dy * sm sum_dy_times_sm = reduce_sum(dy_times_sm) dx = dy_times_sm - sum_dy_times_sm * sm return dx - softmax_grad = numba_basic.numba_njit(softmax_grad_py_fn, boundscheck=False) - - return softmax_grad + cache_version = 1 + return softmax_grad, cache_version @register_funcify_default_op_cache_key(LogSoftmax) def numba_funcify_LogSoftmax(op, node, **kwargs): - x_at = node.inputs[0] - x_dtype = x_at.type.numpy_dtype - x_dtype = numba.np.numpy_support.from_dtype(x_dtype) + ndim = node.inputs[0].type.ndim + inp_dtype = node.inputs[0].type.numpy_dtype axis = op.axis - if axis is not None: - axis = normalize_axis_index(axis, x_at.ndim) + if ndim > 1 and axis is not None: reduce_max_py = create_multiaxis_reducer( maximum, - -np.inf, - (axis,), - x_at.ndim, - x_dtype, + identity=-np.inf, + axes=(axis,), + ndim=ndim, + out_dtype=inp_dtype, keepdims=True, ) reduce_sum_py = create_multiaxis_reducer( - add_as, 0.0, (axis,), x_at.ndim, x_dtype, keepdims=True + add_as, + identity=0.0, + axes=(axis,), + ndim=ndim, + out_dtype=inp_dtype, + keepdims=True, ) jit_fn = numba_basic.numba_njit(boundscheck=False) @@ -539,13 +600,14 @@ def numba_funcify_LogSoftmax(op, node, **kwargs): reduce_max = np.max reduce_sum = np.sum - def log_softmax_py_fn(x): + @numba_basic.numba_njit(boundscheck=False) + def log_softmax(x): xdev = x - reduce_max(x) lsm = xdev - np.log(reduce_sum(np.exp(xdev))) return lsm - log_softmax = numba_basic.numba_njit(log_softmax_py_fn, boundscheck=False) - return log_softmax + cache_version = 1 + return log_softmax, cache_version @register_funcify_default_op_cache_key(Argmax) diff --git a/pytensor/tensor/elemwise.py b/pytensor/tensor/elemwise.py index f1d8bc09df..1616666a63 100644 --- a/pytensor/tensor/elemwise.py +++ b/pytensor/tensor/elemwise.py @@ -1391,7 +1391,10 @@ def _axis_str(self): return f"axes={list(axis)}" def __str__(self): - return f"{type(self).__name__}{{{self.scalar_op}, {self._axis_str()}}}" + if self.acc_dtype != self.dtype: + return f"{type(self).__name__}{{{self.scalar_op}, {self._axis_str()}, acc={self.acc_dtype}}}" + else: + return f"{type(self).__name__}{{{self.scalar_op}, {self._axis_str()}}}" def perform(self, node, inp, out): (input,) = inp diff --git a/pytensor/tensor/math.py b/pytensor/tensor/math.py index ac858e5107..fa424e4679 100644 --- a/pytensor/tensor/math.py +++ b/pytensor/tensor/math.py @@ -357,7 +357,10 @@ def max_and_argmax(a, axis=None, keepdims=False): class FixedOpCAReduce(CAReduce): def __str__(self): - return f"{type(self).__name__}{{{self._axis_str()}}}" + if self.dtype != self.acc_dtype: + return f"{type(self).__name__}{{{self._axis_str()}, acc={self.acc_dtype}}}" + else: + return f"{type(self).__name__}{{{self._axis_str()}}}" class NonZeroDimsCAReduce(FixedOpCAReduce): diff --git a/tests/link/numba/test_elemwise.py b/tests/link/numba/test_elemwise.py index 614d5a092e..3dea6c4d39 100644 --- a/tests/link/numba/test_elemwise.py +++ b/tests/link/numba/test_elemwise.py @@ -13,7 +13,7 @@ from pytensor.gradient import grad from pytensor.scalar import Composite, float64 from pytensor.scalar import add as scalar_add -from pytensor.tensor import blas, tensor +from pytensor.tensor import blas, matrix, tensor, tensor3 from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise from pytensor.tensor.math import All, Any, Max, Min, Prod, ProdWithoutZeros, Sum from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad @@ -366,6 +366,45 @@ def test_CAReduce(careduce_fn, axis, v): assert isinstance(node.op, CAReduce) +@pytest.mark.parametrize("axis", (-1, (0, -1), None)) +def test_CAReduce_respects_acc_dtype(axis): + x = tensor3("x", dtype="int8") + out = x.sum(dtype="int8", acc_dtype="int64", axis=axis) + # Choose values that would overflow if accumulated internally in int8 + max_int8 = np.iinfo(np.int8).max + test_x = np.array([max_int8, 5, max_int8, -max_int8, 5, -max_int8], dtype=np.int8) + test_x = np.broadcast_to(test_x, (6, 2, 6)).copy() + _, [res] = compare_numba_and_py( + [x], + [out], + [test_x], + ) + if axis == -1: + assert np.all(res == 10) + elif axis == (0, -1): + assert np.all(res == 60) + elif axis is None: + assert res == 120 + + +@pytest.mark.parametrize("axis", (1, None)) +def test_CAReduce_acc_complex_out_float(axis): + x = matrix("x", dtype="complex128") + out = x.sum(dtype="float64", axis=axis) + test_x = np.array([[1 + 0.5j, 2 - 0.5j], [3 + 0.5j, 4 - 0.5j]], dtype="complex128") + compare_numba_and_py([x], [out], [test_x]) + + +@pytest.mark.parametrize("axis", (-1, (0, -1), None)) +def test_CAReduce_discrete_infinity_identity(axis): + rng = np.random.default_rng(337) + x = tensor3("x", dtype="int8") + out = x.max(axis) + compare_numba_and_py( + [x], [out], [rng.integers(-127, 127, size=(6, 6, 6)).astype("int8")] + ) + + def test_scalar_Elemwise_Clip(): a = pt.scalar("a") b = pt.scalar("b")