Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions pytensor/compile/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from collections.abc import Callable, Sequence
from copy import copy
from functools import partial
from itertools import chain
from typing import Union, cast

from pytensor.compile.function import function
Expand Down Expand Up @@ -47,11 +48,15 @@ def infer_shape(outs, inputs, input_shapes):
assert len(inp_shp) == inp.type.ndim

shape_feature = ShapeFeature()
shape_feature.on_attach(FunctionGraph([], []))
fgraph = FunctionGraph([], [], features=[shape_feature])
for v in chain.from_iterable(s for s in input_shapes if s is not None):
# Import input_shape nodes, as for some graphs ShapeFeature assumes these were seen before
if (node := v.owner) is not None:
fgraph.import_node(node, import_missing=True)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ShapeFeature is quite picky about having seen some variables. It may even count as bug but I don't want to open that can of worms.

Anyway OpFromGraph uses this helper for infer_shape, and the way Blockwise triggers it when it has an OFG inside, would end up triggering an error.


# Initialize shape_of with the input shapes
for inp, inp_shp in zip(inputs, input_shapes, strict=True):
shape_feature.set_shape(inp, inp_shp)
shape_feature.set_shape(inp, inp_shp, override=True)

def local_traverse(out):
"""
Expand Down
1 change: 0 additions & 1 deletion pytensor/link/numba/dispatch/blockwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ def numba_funcify_Blockwise(op: BlockwiseWithCoreShape, node, **kwargs):
core_op_fn, core_op_key = numba_funcify_and_cache_key(
core_op,
node=core_node,
parent_node=node,
**kwargs,
)
core_op_fn = store_core_outputs(core_op_fn, nin=nin, nout=nout)
Expand Down
1 change: 0 additions & 1 deletion pytensor/link/numba/dispatch/elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,6 @@ def numba_funcify_Elemwise(op, node, **kwargs):
scalar_op_fn, scalar_cache_key = numba_funcify_and_cache_key(
op.scalar_op,
node=scalar_node,
parent_node=node,
**kwargs,
)

Expand Down
8 changes: 5 additions & 3 deletions pytensor/link/numba/dispatch/tensor_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,24 +74,26 @@ def numba_funcify_Alloc(op, node, **kwargs):
f'if val.shape[{-i - 1}] == 1 and scalar_shape[{-i - 1}] != 1: raise ValueError("{Alloc._runtime_broadcast_error_msg}")'
)
check_runtime_broadcast_src = indent("\n".join(check_runtime_broadcast), " " * 4)

dtype = node.inputs[0].type.dtype
alloc_def_src = f"""
def alloc(val, {", ".join(shape_var_names)}):
{shapes_to_items_src}
scalar_shape = {create_tuple_string(shape_var_item_names)}
{check_runtime_broadcast_src}
res = np.empty(scalar_shape, dtype=val.dtype)
res = np.empty(scalar_shape, dtype=np.{dtype})
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would fail for scalar valued alloc due to numba/numba#10358

res[...] = val
return res
"""
alloc_fn = compile_numba_function_src(
alloc_def_src,
"alloc",
globals() | {"np": np},
write_to_disk=True,
)

cache_version = -1
cache_key = sha256(
str((type(op), node.inputs[0].type.broadcastable)).encode()
str((type(op), node.inputs[0].type.broadcastable, cache_version)).encode()
).hexdigest()
return numba_basic.numba_njit(alloc_fn), cache_key

Expand Down
14 changes: 12 additions & 2 deletions tests/link/numba/test_blockwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
import pytest

from pytensor import function
from pytensor.tensor import tensor, tensor3
from pytensor.tensor.basic import ARange
from pytensor.tensor import lvector, tensor, tensor3
from pytensor.tensor.basic import Alloc, ARange, constant
from pytensor.tensor.blockwise import Blockwise, BlockwiseWithCoreShape
from pytensor.tensor.nlinalg import SVD, Det
from pytensor.tensor.slinalg import Cholesky, cholesky
Expand Down Expand Up @@ -70,3 +70,13 @@ def test_repeated_args():
final_node = fn.maker.fgraph.outputs[0].owner
assert isinstance(final_node.op, BlockwiseWithCoreShape)
assert final_node.inputs[0] is final_node.inputs[1]


def test_blockwise_alloc():
val = lvector("val")
out = Blockwise(Alloc(), signature="(),(),()->(2,3)")(
val, constant(2, dtype="int64"), constant(3, dtype="int64")
)
assert out.type.ndim == 3

compare_numba_and_py([val], [out], [np.arange(5)], eval_obj_mode=False)
21 changes: 20 additions & 1 deletion tests/link/numba/test_compile_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
from pytensor import OpFromGraph, config, function, ifelse
from pytensor import tensor as pt
from pytensor.compile import ViewOp
from pytensor.graph import vectorize_graph
from pytensor.raise_op import assert_op
from pytensor.scalar import Add
from pytensor.tensor import matrix
from pytensor.tensor import dmatrix, dtensor3, matrix
from pytensor.tensor.elemwise import Elemwise
from tests.link.numba.test_basic import compare_numba_and_py

Expand Down Expand Up @@ -171,6 +172,24 @@ def test_ofg_aliased_outputs():
np.testing.assert_allclose(res, np.ones((2, 2)))


def test_ofg_elemwise_regression():
# Regression bug for https://github.com/pymc-devs/pytensor/issues/1507
x = dmatrix("x", shape=(None, None))
z = OpFromGraph(
inputs=[x],
outputs=[x + 1],
)(x)

x_batched = dtensor3("X_batched", shape=(None, None, None))
z_batched = vectorize_graph(z, {x: x_batched})
compare_numba_and_py(
[x_batched],
[z_batched],
[np.random.normal(size=(3, 2, 4))],
eval_obj_mode=False,
)


def test_check_and_raise():
x = pt.vector()
x_test_value = np.array([1.0, 2.0], dtype=config.floatX)
Expand Down
Loading