Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 27 additions & 18 deletions pytensor/link/numba/dispatch/compile_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from pytensor.compile.mode import NUMBA
from pytensor.compile.ops import DeepCopyOp, TypeCastingOp
from pytensor.ifelse import IfElse
from pytensor.link.numba.cache import compile_numba_function_src
from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.basic import (
numba_funcify_and_cache_key,
Expand Down Expand Up @@ -103,30 +104,38 @@ def deepcopy(x):
@register_funcify_default_op_cache_key(IfElse)
def numba_funcify_IfElse(op, **kwargs):
n_outs = op.n_outs
as_view = op.as_view

if n_outs > 1:

@numba_basic.numba_njit
def ifelse(cond, *args):
if cond:
res = args[:n_outs]
else:
res = args[n_outs:]

return res
true_names = [f"t{i}" for i in range(n_outs)]
false_names = [f"f{i}" for i in range(n_outs)]
arg_list = ", ".join(true_names + false_names)

# Build return expressions
if as_view:
true_returns = ", ".join(true_names)
false_returns = ", ".join(false_names)
else:
true_returns = ", ".join(f"{name}.copy()" for name in true_names)
false_returns = ", ".join(f"{name}.copy()" for name in false_names)

# Build the code for the function
func_src = f"""
def ifelse_codegen(cond, {arg_list}):
if cond:
return ({true_returns})
else:
return ({false_returns})
"""

# Compile the generated source code into a Python function
ifelse_py = compile_numba_function_src(func_src, "ifelse_codegen", globals())

@numba_basic.numba_njit
def ifelse(cond, *args):
if cond:
res = args[:n_outs]
else:
res = args[n_outs:]
# JIT-compile using numba
ifelse_numba = numba_basic.numba_njit(ifelse_py)

return res[0]
cache_version = 1

return ifelse
return ifelse_numba, cache_version


@register_funcify_and_cache_key(CheckAndRaise)
Expand Down
74 changes: 73 additions & 1 deletion tests/link/numba/test_compile_ops.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import numpy as np
import pytest

from pytensor import OpFromGraph, config, function, ifelse
from pytensor import Mode, OpFromGraph, config, function, ifelse
from pytensor import tensor as pt
from pytensor.compile import ViewOp
from pytensor.ifelse import IfElse
from pytensor.raise_op import assert_op
from tests.link.numba.test_basic import compare_numba_and_py

Expand Down Expand Up @@ -153,3 +154,74 @@ def test_check_and_raise():
out = assert_op(x.sum(), np.array(True))

compare_numba_and_py([x], out, [x_test_value])


@pytest.mark.parametrize("as_view", [True, False])
def test_ifelse_single_output(as_view):
x = pt.vector("x")

op = IfElse(as_view=as_view, n_outs=1)
out = op(x.sum() > 0, [x], [x])[0] # returns tuple/list

fn = function([x], out, mode=Mode("numba", optimizer=None), accept_inplace=True)

# FALSE branch
a = np.zeros(3)
res_false = fn(a)

assert np.array_equal(res_false, a)
if as_view:
assert res_false is a
else:
assert res_false is not a

# TRUE branch
b = np.ones(3)
res_true = fn(b)

assert np.array_equal(res_true, b)
if as_view:
assert res_true is b
else:
assert res_true is not b


@pytest.mark.parametrize("as_view", [True, False])
def test_ifelse_multiple_outputs(as_view):
x = pt.vector("x")
y = pt.vector("y")

op = IfElse(as_view=as_view, n_outs=2)
out1, out2 = op(x.sum() > 0, x, y, y, x)

fn = function(
[x, y], [out1, out2], mode=Mode("numba", optimizer=None), accept_inplace=True
)

# TRUE branch
a = np.ones(3)
b = np.zeros(3)
r1_true, r2_true = fn(a, b)

assert np.array_equal(r1_true, a)
assert np.array_equal(r2_true, b)
if as_view:
assert r1_true is a
assert r2_true is b
else:
assert r1_true is not a
assert r2_true is not b

# FALSE branch
a2 = np.zeros(3)
b2 = np.arange(3)
r1_false, r2_false = fn(a2, b2)

assert np.array_equal(r1_false, b2)
assert np.array_equal(r2_false, a2)
if as_view:
assert r1_false is b2
assert r2_false is a2
else:
assert r1_false is not b2
assert r2_false is not a2
Loading