Skip to content

Commit a5fb911

Browse files
authored
Fix error formatting in C-Elemwise (#1749)
1 parent e274e0d commit a5fb911

File tree

3 files changed

+24
-4
lines changed

3 files changed

+24
-4
lines changed

pytensor/tensor/elemwise.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1099,7 +1099,7 @@ def c_support_code_apply(self, node, nodename):
10991099
return support_code
11001100

11011101
def c_code_cache_version_apply(self, node):
1102-
version = [16] # the version corresponding to the c code in this Op
1102+
version = [17] # the version corresponding to the c code in this Op
11031103

11041104
# now we insert versions for the ops on which we depend...
11051105
scalar_node = Apply(

pytensor/tensor/elemwise_cgen.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def make_checks(loop_orders, dtypes, sub, compute_stride_jump=True):
8585
runtime_broadcast_error_msg = (
8686
"Runtime broadcasting not allowed. "
8787
"One input had a distinct dimension length of 1, but was not marked as broadcastable: "
88-
"(input[%%i].shape[%%i] = %%lld, input[%%i].shape[%%i] = %%lld). "
88+
"(input[%i].shape[%i] = %lld, input[%i].shape[%i] = %lld). "
8989
"If broadcasting was intended, use `specify_broadcastable` on the relevant input."
9090
)
9191

@@ -113,7 +113,7 @@ def make_checks(loop_orders, dtypes, sub, compute_stride_jump=True):
113113
(long long int) {sub[f"lv{j}"]}_n{x}
114114
);
115115
}} else {{
116-
PyErr_Format(PyExc_ValueError, "Input dimension mismatch: (input[%%i].shape[%%i] = %%lld, input[%%i].shape[%%i] = %%lld)",
116+
PyErr_Format(PyExc_ValueError, "Input dimension mismatch: (input[%i].shape[%i] = %lld, input[%i].shape[%i] = %lld)",
117117
{j0},
118118
{x0},
119119
(long long int) {sub[f"lv{j0}"]}_n{x0},

tests/tensor/test_elemwise.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
bmatrix,
3333
bscalar,
3434
discrete_dtypes,
35+
dmatrix,
3536
lscalar,
3637
matrix,
3738
scalar,
@@ -832,7 +833,26 @@ def test_runtime_broadcast_python(self):
832833
reason="G++ not available, so we need to skip this test.",
833834
)
834835
def test_runtime_broadcast_c(self):
835-
check_elemwise_runtime_broadcast(Mode(linker="c"))
836+
c_mode = Mode(linker="cvm")
837+
check_elemwise_runtime_broadcast(c_mode)
838+
839+
# Test C-backend specific error formatting
840+
x = dmatrix("x")
841+
y = dmatrix("y")
842+
fn = function([x, y], x * y, mode=c_mode)
843+
with pytest.raises(
844+
ValueError,
845+
match=r"Runtime broadcasting not allowed.*\(input\[0\]\.shape\[1\] = 4, input\[1\]\.shape\[1\] = 1\)",
846+
):
847+
fn(np.zeros((5, 4)), np.zeros((5, 1)))
848+
849+
with pytest.raises(
850+
ValueError,
851+
match=re.escape(
852+
"Input dimension mismatch: (input[0].shape[1] = 4, input[1].shape[1] = 3)"
853+
),
854+
):
855+
fn(np.zeros((5, 4)), np.zeros((5, 3)))
836856

837857
def test_str(self):
838858
op = Elemwise(ps.add, inplace_pattern={0: 0}, name=None)

0 commit comments

Comments
 (0)