Skip to content

Commit ea26794

Browse files
committed
Numba uint: fix Sigmoid and Softplus with uint inputs
1 parent f6ebb5c commit ea26794

File tree

3 files changed

+45
-17
lines changed

3 files changed

+45
-17
lines changed

pytensor/link/numba/dispatch/scalar.py

Lines changed: 37 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,10 @@
3131
from pytensor.scalar.math import Erf, Erfc, GammaLn, Log1mexp, Sigmoid, Softplus
3232

3333

34-
def scalar_op_cache_key(op):
34+
def scalar_op_cache_key(op, **extra_fields):
3535
# Scalar Ops don't have _props, because of their weird outputs_types_preference function
3636
# So we create hash differently
37-
return sha256(str(type(op)).encode()).hexdigest()
37+
return sha256(str((type(op), tuple(extra_fields.items()))).encode()).hexdigest()
3838

3939

4040
@register_funcify_and_cache_key(ScalarOp)
@@ -267,11 +267,28 @@ def reciprocal(x):
267267

268268
@register_funcify_and_cache_key(Sigmoid)
269269
def numba_funcify_Sigmoid(op, node, **kwargs):
270-
@numba_basic.numba_njit
271-
def sigmoid(x):
272-
return 1 / (1 + np.exp(-x))
270+
inp_dtype = node.inputs[0].type.dtype
271+
if inp_dtype.startswith("uint"):
272+
upcast_uint_dtype = {
273+
"uint8": np.float32, # numpy uses float16, but not Numba
274+
"uint16": np.float32,
275+
"uint32": np.float64,
276+
"uint64": np.float64,
277+
}[inp_dtype]
278+
279+
@numba_basic.numba_njit
280+
def sigmoid(x):
281+
# Can't negate uint
282+
float_x = numba_basic.direct_cast(x, upcast_uint_dtype)
283+
return 1 / (1 + np.exp(-float_x))
284+
285+
else:
273286

274-
return sigmoid, scalar_op_cache_key(op)
287+
@numba_basic.numba_njit
288+
def sigmoid(x):
289+
return 1 / (1 + np.exp(-x))
290+
291+
return sigmoid, scalar_op_cache_key(op, cache_version=1)
275292

276293

277294
@register_funcify_and_cache_key(GammaLn)
@@ -319,6 +336,16 @@ def erfc(x):
319336

320337
@register_funcify_and_cache_key(Softplus)
321338
def numba_funcify_Softplus(op, node, **kwargs):
339+
inp_dtype = node.inputs[0].type.dtype
340+
if inp_dtype.startswith("uint"):
341+
upcast_uint_dtype = {
342+
"uint8": np.float32, # numpy uses float16, but not Numba
343+
"uint16": np.float32,
344+
"uint32": np.float64,
345+
"uint64": np.float64,
346+
}[inp_dtype]
347+
else:
348+
upcast_uint_dtype = None
322349
out_dtype = np.dtype(node.outputs[0].type.dtype)
323350

324351
@numba_basic.numba_njit
@@ -328,9 +355,12 @@ def softplus(x):
328355
elif x < 18.0:
329356
value = np.log1p(np.exp(x))
330357
elif x < 33.3:
358+
if upcast_uint_dtype is not None:
359+
# Can't negate uint
360+
x = numba_basic.direct_cast(x, upcast_uint_dtype)
331361
value = x + np.exp(-x)
332362
else:
333363
value = x
334364
return numba_basic.direct_cast(value, out_dtype)
335365

336-
return softplus, scalar_op_cache_key(op)
366+
return softplus, scalar_op_cache_key(op, cache_version=1)

pytensor/scalar/math.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1259,14 +1259,18 @@ class Softplus(UnaryScalarOp):
12591259
def impl(self, x):
12601260
# If x is an int8 or uint8, numpy.exp will compute the result in
12611261
# half-precision (float16), where we want float32.
1262-
not_int8 = str(getattr(x, "dtype", "")) not in ("int8", "uint8")
1262+
x_dtype = getattr(x, "dtype", None)
1263+
not_int8 = x_dtype is None or x_dtype.itemsize > 1
12631264
if x < -37.0:
12641265
return np.exp(x) if not_int8 else np.exp(x, signature="f")
12651266
elif x < 18.0:
12661267
return (
12671268
np.log1p(np.exp(x)) if not_int8 else np.log1p(np.exp(x, signature="f"))
12681269
)
12691270
elif x < 33.3:
1271+
if x_dtype is not None and x_dtype.kind == "u":
1272+
# Negate uint will not do what we want
1273+
x = x.astype("float32" if x_dtype.itemsize <= 2 else "float64")
12701274
return x + np.exp(-x) if not_int8 else x + np.exp(-x, signature="f")
12711275
else:
12721276
return x

tests/link/numba/test_scalar.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -158,15 +158,9 @@ def test_isnan(composite):
158158
@pytest.mark.parametrize(
159159
"dtype",
160160
[
161-
pytest.param(
162-
"float32",
163-
marks=pytest.mark.xfail(reason="Scalar downcasting not supported in numba"),
164-
),
161+
"float32",
165162
"float64",
166-
pytest.param(
167-
"int16",
168-
marks=pytest.mark.xfail(reason="Scalar downcasting not supported in numba"),
169-
),
163+
"int16",
170164
"int64",
171165
"uint32",
172166
],
@@ -183,7 +177,7 @@ def test_Softplus(dtype):
183177
test_x = np.dtype(dtype).type(value)
184178
np.testing.assert_allclose(
185179
py_fn(test_x),
186-
numba_fn(test_x),
180+
getattr(np, g.dtype)(numba_fn(test_x)),
187181
strict=True,
188182
err_msg=f"Failed for value {value}",
189183
)

0 commit comments

Comments
 (0)