|
14 | 14 | in2out, |
15 | 15 | node_rewriter, |
16 | 16 | ) |
| 17 | +from pytensor.graph.type import Type |
17 | 18 | from pytensor.raise_op import Assert |
18 | 19 | from pytensor.scalar import Add, ScalarConstant, ScalarType |
19 | 20 | from pytensor.scalar import constant as scalar_constant |
@@ -229,7 +230,7 @@ def local_replace_AdvancedSubtensor(fgraph, node): |
229 | 230 |
|
230 | 231 | indexed_var = node.inputs[0] |
231 | 232 | tensor_inputs = node.inputs[1:] |
232 | | - |
| 233 | + |
233 | 234 | # Reconstruct indices from idx_list and tensor inputs |
234 | 235 | indices = [] |
235 | 236 | input_idx = 0 |
@@ -267,7 +268,7 @@ def local_AdvancedIncSubtensor_to_AdvancedIncSubtensor1(fgraph, node): |
267 | 268 | res = node.inputs[0] |
268 | 269 | val = node.inputs[1] |
269 | 270 | tensor_inputs = node.inputs[2:] |
270 | | - |
| 271 | + |
271 | 272 | # Reconstruct indices from idx_list and tensor inputs |
272 | 273 | indices = [] |
273 | 274 | input_idx = 0 |
@@ -1112,6 +1113,7 @@ def local_inplace_AdvancedIncSubtensor1(fgraph, node): |
1112 | 1113 | def local_inplace_AdvancedIncSubtensor(fgraph, node): |
1113 | 1114 | if isinstance(node.op, AdvancedIncSubtensor) and not node.op.inplace: |
1114 | 1115 | new_op = type(node.op)( |
| 1116 | + node.op.idx_list, |
1115 | 1117 | inplace=True, |
1116 | 1118 | set_instead_of_inc=node.op.set_instead_of_inc, |
1117 | 1119 | ignore_duplicates=node.op.ignore_duplicates, |
@@ -1376,6 +1378,7 @@ def local_useless_inc_subtensor_alloc(fgraph, node): |
1376 | 1378 | z_broad[k] |
1377 | 1379 | and not same_shape(xi, y, dim_x=k, dim_y=k) |
1378 | 1380 | and shape_of[y][k] != 1 |
| 1381 | + and shape_of[xi][k] == 1 |
1379 | 1382 | ) |
1380 | 1383 | ] |
1381 | 1384 |
|
@@ -1778,7 +1781,7 @@ def ravel_multidimensional_bool_idx(fgraph, node): |
1778 | 1781 | else: |
1779 | 1782 | x, y = node.inputs[0], node.inputs[1] |
1780 | 1783 | tensor_inputs = node.inputs[2:] |
1781 | | - |
| 1784 | + |
1782 | 1785 | # Reconstruct indices from idx_list and tensor inputs |
1783 | 1786 | idxs = [] |
1784 | 1787 | input_idx = 0 |
@@ -1829,36 +1832,36 @@ def ravel_multidimensional_bool_idx(fgraph, node): |
1829 | 1832 | # Create new AdvancedSubtensor with updated idx_list |
1830 | 1833 | new_idx_list = list(node.op.idx_list) |
1831 | 1834 | new_tensor_inputs = list(tensor_inputs) |
1832 | | - |
| 1835 | + |
1833 | 1836 | # Update the idx_list and tensor_inputs for the raveled boolean index |
1834 | 1837 | input_idx = 0 |
1835 | 1838 | for i, entry in enumerate(node.op.idx_list): |
1836 | 1839 | if isinstance(entry, Type): |
1837 | 1840 | if input_idx == bool_idx_pos: |
1838 | 1841 | new_tensor_inputs[input_idx] = raveled_bool_idx |
1839 | 1842 | input_idx += 1 |
1840 | | - |
| 1843 | + |
1841 | 1844 | new_out = AdvancedSubtensor(new_idx_list)(raveled_x, *new_tensor_inputs) |
1842 | 1845 | else: |
1843 | 1846 | # Create new AdvancedIncSubtensor with updated idx_list |
1844 | 1847 | new_idx_list = list(node.op.idx_list) |
1845 | 1848 | new_tensor_inputs = list(tensor_inputs) |
1846 | | - |
| 1849 | + |
1847 | 1850 | # Update the tensor_inputs for the raveled boolean index |
1848 | 1851 | input_idx = 0 |
1849 | 1852 | for i, entry in enumerate(node.op.idx_list): |
1850 | 1853 | if isinstance(entry, Type): |
1851 | 1854 | if input_idx == bool_idx_pos: |
1852 | 1855 | new_tensor_inputs[input_idx] = raveled_bool_idx |
1853 | 1856 | input_idx += 1 |
1854 | | - |
| 1857 | + |
1855 | 1858 | # The dimensions of y that correspond to the boolean indices |
1856 | 1859 | # must already be raveled in the original graph, so we don't need to do anything to it |
1857 | 1860 | new_out = AdvancedIncSubtensor( |
1858 | 1861 | new_idx_list, |
1859 | 1862 | inplace=node.op.inplace, |
1860 | 1863 | set_instead_of_inc=node.op.set_instead_of_inc, |
1861 | | - ignore_duplicates=node.op.ignore_duplicates |
| 1864 | + ignore_duplicates=node.op.ignore_duplicates, |
1862 | 1865 | )(raveled_x, y, *new_tensor_inputs) |
1863 | 1866 | # But we must reshape the output to match the original shape |
1864 | 1867 | new_out = new_out.reshape(x_shape) |
|
0 commit comments