From bfdd9fc7cef0d69a04dabfee93ad421693077e5a Mon Sep 17 00:00:00 2001 From: emekaokoli19 Date: Wed, 3 Dec 2025 00:06:51 +0100 Subject: [PATCH 1/3] fix-make copies of inputs in numba --- pytensor/link/numba/dispatch/compile_ops.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/pytensor/link/numba/dispatch/compile_ops.py b/pytensor/link/numba/dispatch/compile_ops.py index 8eb73d0111..676c5a1219 100644 --- a/pytensor/link/numba/dispatch/compile_ops.py +++ b/pytensor/link/numba/dispatch/compile_ops.py @@ -109,22 +109,28 @@ def numba_funcify_IfElse(op, **kwargs): @numba_basic.numba_njit def ifelse(cond, *args): if cond: - res = args[:n_outs] + selected = args[:n_outs] else: - res = args[n_outs:] + selected = args[n_outs:] - return res + # Return a tuple of copies + out = [None] * n_outs + for i in range(n_outs): + out[i] = selected[i].copy() + + return tuple(out) else: @numba_basic.numba_njit def ifelse(cond, *args): if cond: - res = args[:n_outs] + arr = args[0] else: - res = args[n_outs:] + arr = args[1] - return res[0] + # Return a copy + return arr.copy() return ifelse From a26bc32de527a54cd986ba998717871cac84d1b9 Mon Sep 17 00:00:00 2001 From: emekaokoli19 Date: Thu, 4 Dec 2025 11:57:23 +0100 Subject: [PATCH 2/3] added codegen and small fixes --- pytensor/link/numba/dispatch/compile_ops.py | 53 +++++++++++++-------- tests/link/numba/test_compile_ops.py | 51 +++++++++++++++++++- 2 files changed, 83 insertions(+), 21 deletions(-) diff --git a/pytensor/link/numba/dispatch/compile_ops.py b/pytensor/link/numba/dispatch/compile_ops.py index 676c5a1219..b1db046250 100644 --- a/pytensor/link/numba/dispatch/compile_ops.py +++ b/pytensor/link/numba/dispatch/compile_ops.py @@ -10,6 +10,7 @@ from pytensor.compile.mode import NUMBA from pytensor.compile.ops import DeepCopyOp, TypeCastingOp from pytensor.ifelse import IfElse +from pytensor.link.numba.cache import compile_numba_function_src from pytensor.link.numba.dispatch import basic as numba_basic from pytensor.link.numba.dispatch.basic import ( numba_funcify_and_cache_key, @@ -103,36 +104,48 @@ def deepcopy(x): @register_funcify_default_op_cache_key(IfElse) def numba_funcify_IfElse(op, **kwargs): n_outs = op.n_outs + as_view = op.as_view - if n_outs > 1: + if n_outs == 1: @numba_basic.numba_njit - def ifelse(cond, *args): - if cond: - selected = args[:n_outs] - else: - selected = args[n_outs:] + def ifelse(cond, x_true, x_false): + arr = x_true if cond else x_false + return arr if as_view else arr.copy() - # Return a tuple of copies - out = [None] * n_outs - for i in range(n_outs): - out[i] = selected[i].copy() + cache_version = 3 + return ifelse, cache_version - return tuple(out) + true_names = [f"t{i}" for i in range(n_outs)] + false_names = [f"f{i}" for i in range(n_outs)] + arg_list = ", ".join(true_names + false_names) + # Build return expressions + if as_view: + true_returns = ", ".join(true_names) + false_returns = ", ".join(false_names) else: + true_returns = ", ".join(f"{name}.copy()" for name in true_names) + false_returns = ", ".join(f"{name}.copy()" for name in false_names) + + # Build the code for the function + func_src = f""" +def ifelse_codegen(cond, {arg_list}): + if cond: + return ({true_returns}) + else: + return ({false_returns}) +""" - @numba_basic.numba_njit - def ifelse(cond, *args): - if cond: - arr = args[0] - else: - arr = args[1] + # Compile the generated source code into a Python function + ifelse_py = compile_numba_function_src(func_src, "ifelse_codegen", globals()) + + # JIT-compile using numba + ifelse_numba = numba_basic.numba_njit(ifelse_py) - # Return a copy - return arr.copy() + cache_version = 3 - return ifelse + return ifelse_numba, cache_version @register_funcify_and_cache_key(CheckAndRaise) diff --git a/tests/link/numba/test_compile_ops.py b/tests/link/numba/test_compile_ops.py index b51b359a08..5579c11933 100644 --- a/tests/link/numba/test_compile_ops.py +++ b/tests/link/numba/test_compile_ops.py @@ -1,7 +1,7 @@ import numpy as np import pytest -from pytensor import OpFromGraph, config, function, ifelse +from pytensor import Mode, OpFromGraph, config, function, ifelse from pytensor import tensor as pt from pytensor.compile import ViewOp from pytensor.raise_op import assert_op @@ -153,3 +153,52 @@ def test_check_and_raise(): out = assert_op(x.sum(), np.array(True)) compare_numba_and_py([x], out, [x_test_value]) + + +def test_ifelse_single_output(): + x = pt.vector("x") + out = ifelse(x.sum() > 0, x, x) + + fn = function([x], out, mode=Mode("numba", optimizer=None)) + + x_test = np.zeros((5,)) + res = fn(x_test) + + # Returned array should not be the input (must be a copy) + assert res is not x_test + assert np.array_equal(res, x_test) + + +def test_ifelse_multiple_outputs(): + x = pt.vector("x") + y = pt.vector("y") + out1, out2 = ifelse(x.sum() > 0, (x, y), (y, x)) + + fn = function([x, y], [out1, out2], mode=Mode("numba", optimizer=None)) + + a = np.ones(3) + b = np.zeros(3) + + r1, r2 = fn(a, b) + + assert np.array_equal(r1, a) + assert np.array_equal(r2, b) + assert r1 is not a + assert r2 is not b + + +def test_ifelse_false_branch(): + x = pt.vector("x") + y = pt.vector("y") + + out = ifelse(x.sum() > 0, x, y) + + fn = function([x, y], out, mode=Mode("numba", optimizer=None)) + + a = np.zeros(3) + b = np.arange(3) + + res = fn(a, b) + + assert np.array_equal(res, b) + assert res is not b From 29bcbaf9119a528e6bf2144c02302d03b939a2eb Mon Sep 17 00:00:00 2001 From: emekaokoli19 Date: Thu, 4 Dec 2025 13:19:20 +0100 Subject: [PATCH 3/3] merged tests --- pytensor/link/numba/dispatch/compile_ops.py | 12 +-- tests/link/numba/test_compile_ops.py | 89 +++++++++++++-------- 2 files changed, 57 insertions(+), 44 deletions(-) diff --git a/pytensor/link/numba/dispatch/compile_ops.py b/pytensor/link/numba/dispatch/compile_ops.py index b1db046250..f869e49f8a 100644 --- a/pytensor/link/numba/dispatch/compile_ops.py +++ b/pytensor/link/numba/dispatch/compile_ops.py @@ -106,16 +106,6 @@ def numba_funcify_IfElse(op, **kwargs): n_outs = op.n_outs as_view = op.as_view - if n_outs == 1: - - @numba_basic.numba_njit - def ifelse(cond, x_true, x_false): - arr = x_true if cond else x_false - return arr if as_view else arr.copy() - - cache_version = 3 - return ifelse, cache_version - true_names = [f"t{i}" for i in range(n_outs)] false_names = [f"f{i}" for i in range(n_outs)] arg_list = ", ".join(true_names + false_names) @@ -143,7 +133,7 @@ def ifelse_codegen(cond, {arg_list}): # JIT-compile using numba ifelse_numba = numba_basic.numba_njit(ifelse_py) - cache_version = 3 + cache_version = 1 return ifelse_numba, cache_version diff --git a/tests/link/numba/test_compile_ops.py b/tests/link/numba/test_compile_ops.py index 5579c11933..7216cfb0bc 100644 --- a/tests/link/numba/test_compile_ops.py +++ b/tests/link/numba/test_compile_ops.py @@ -4,6 +4,7 @@ from pytensor import Mode, OpFromGraph, config, function, ifelse from pytensor import tensor as pt from pytensor.compile import ViewOp +from pytensor.ifelse import IfElse from pytensor.raise_op import assert_op from tests.link.numba.test_basic import compare_numba_and_py @@ -155,50 +156,72 @@ def test_check_and_raise(): compare_numba_and_py([x], out, [x_test_value]) -def test_ifelse_single_output(): +@pytest.mark.parametrize("as_view", [True, False]) +def test_ifelse_single_output(as_view): x = pt.vector("x") - out = ifelse(x.sum() > 0, x, x) - fn = function([x], out, mode=Mode("numba", optimizer=None)) + op = IfElse(as_view=as_view, n_outs=1) + out = op(x.sum() > 0, [x], [x])[0] # returns tuple/list - x_test = np.zeros((5,)) - res = fn(x_test) + fn = function([x], out, mode=Mode("numba", optimizer=None), accept_inplace=True) - # Returned array should not be the input (must be a copy) - assert res is not x_test - assert np.array_equal(res, x_test) - - -def test_ifelse_multiple_outputs(): - x = pt.vector("x") - y = pt.vector("y") - out1, out2 = ifelse(x.sum() > 0, (x, y), (y, x)) - - fn = function([x, y], [out1, out2], mode=Mode("numba", optimizer=None)) + # FALSE branch + a = np.zeros(3) + res_false = fn(a) - a = np.ones(3) - b = np.zeros(3) + assert np.array_equal(res_false, a) + if as_view: + assert res_false is a + else: + assert res_false is not a - r1, r2 = fn(a, b) + # TRUE branch + b = np.ones(3) + res_true = fn(b) - assert np.array_equal(r1, a) - assert np.array_equal(r2, b) - assert r1 is not a - assert r2 is not b + assert np.array_equal(res_true, b) + if as_view: + assert res_true is b + else: + assert res_true is not b -def test_ifelse_false_branch(): +@pytest.mark.parametrize("as_view", [True, False]) +def test_ifelse_multiple_outputs(as_view): x = pt.vector("x") y = pt.vector("y") - out = ifelse(x.sum() > 0, x, y) - - fn = function([x, y], out, mode=Mode("numba", optimizer=None)) - - a = np.zeros(3) - b = np.arange(3) + op = IfElse(as_view=as_view, n_outs=2) + out1, out2 = op(x.sum() > 0, x, y, y, x) - res = fn(a, b) + fn = function( + [x, y], [out1, out2], mode=Mode("numba", optimizer=None), accept_inplace=True + ) - assert np.array_equal(res, b) - assert res is not b + # TRUE branch + a = np.ones(3) + b = np.zeros(3) + r1_true, r2_true = fn(a, b) + + assert np.array_equal(r1_true, a) + assert np.array_equal(r2_true, b) + if as_view: + assert r1_true is a + assert r2_true is b + else: + assert r1_true is not a + assert r2_true is not b + + # FALSE branch + a2 = np.zeros(3) + b2 = np.arange(3) + r1_false, r2_false = fn(a2, b2) + + assert np.array_equal(r1_false, b2) + assert np.array_equal(r2_false, a2) + if as_view: + assert r1_false is b2 + assert r2_false is a2 + else: + assert r1_false is not b2 + assert r2_false is not a2