Skip to content

Commit 53adf9a

Browse files
CopilotricardoV94
authored andcommitted
Update dispatch functions and rewrite rules for new AdvancedSubtensor interface, store expected_inputs_len
Co-authored-by: ricardoV94 <28983449+ricardoV94@users.noreply.github.com>
1 parent a3634dd commit 53adf9a

File tree

5 files changed

+127
-45
lines changed

5 files changed

+127
-45
lines changed

pytensor/link/jax/dispatch/subtensor.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,8 @@ def incsubtensor(x, y, *ilist, jax_fn=jax_fn, idx_list=idx_list):
7777

7878
@jax_funcify.register(AdvancedIncSubtensor)
7979
def jax_funcify_AdvancedIncSubtensor(op, node, **kwargs):
80+
idx_list = getattr(op, "idx_list", None)
81+
8082
if getattr(op, "set_instead_of_inc", False):
8183

8284
def jax_fn(x, indices, y):
@@ -87,8 +89,11 @@ def jax_fn(x, indices, y):
8789
def jax_fn(x, indices, y):
8890
return x.at[indices].add(y)
8991

90-
def advancedincsubtensor(x, y, *ilist, jax_fn=jax_fn):
91-
return jax_fn(x, ilist, y)
92+
def advancedincsubtensor(x, y, *ilist, jax_fn=jax_fn, idx_list=idx_list):
93+
indices = indices_from_subtensor(ilist, idx_list)
94+
if len(indices) == 1:
95+
indices = indices[0]
96+
return jax_fn(x, indices, y)
9297

9398
return advancedincsubtensor
9499

pytensor/link/numba/dispatch/subtensor.py

Lines changed: 23 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -239,28 +239,30 @@ def {function_name}({", ".join(input_names)}):
239239
@register_funcify_and_cache_key(AdvancedIncSubtensor)
240240
def numba_funcify_AdvancedSubtensor(op, node, **kwargs):
241241
if isinstance(op, AdvancedSubtensor):
242-
_x, _y, idxs = node.inputs[0], None, node.inputs[1:]
242+
x, y, tensor_inputs = node.inputs[0], None, node.inputs[1:]
243243
else:
244-
_x, _y, *idxs = node.inputs
245-
246-
basic_idxs = [
247-
idx
248-
for idx in idxs
249-
if (
250-
isinstance(idx.type, NoneTypeT)
251-
or (isinstance(idx.type, SliceType) and not is_full_slice(idx))
252-
)
253-
]
254-
adv_idxs = [
255-
{
256-
"axis": i,
257-
"dtype": idx.type.dtype,
258-
"bcast": idx.type.broadcastable,
259-
"ndim": idx.type.ndim,
260-
}
261-
for i, idx in enumerate(idxs)
262-
if isinstance(idx.type, TensorType)
263-
]
244+
x, y, *tensor_inputs = node.inputs
245+
246+
# Reconstruct indexing information from idx_list and tensor inputs
247+
basic_idxs = []
248+
adv_idxs = []
249+
input_idx = 0
250+
251+
for i, entry in enumerate(op.idx_list):
252+
if isinstance(entry, slice):
253+
# Basic slice index
254+
basic_idxs.append(entry)
255+
elif isinstance(entry, Type):
256+
# Advanced tensor index
257+
if input_idx < len(tensor_inputs):
258+
idx_input = tensor_inputs[input_idx]
259+
adv_idxs.append({
260+
"axis": i,
261+
"dtype": idx_input.type.dtype,
262+
"bcast": idx_input.type.broadcastable,
263+
"ndim": idx_input.type.ndim,
264+
})
265+
input_idx += 1
264266

265267
# Special implementation for consecutive integer vector indices
266268
if (

pytensor/link/pytorch/dispatch/subtensor.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,10 @@ def makeslice(start, stop, step):
6363
@pytorch_funcify.register(AdvancedSubtensor1)
6464
@pytorch_funcify.register(AdvancedSubtensor)
6565
def pytorch_funcify_AdvSubtensor(op, node, **kwargs):
66-
def advsubtensor(x, *indices):
66+
idx_list = getattr(op, "idx_list", None)
67+
68+
def advsubtensor(x, *flattened_indices):
69+
indices = indices_from_subtensor(flattened_indices, idx_list)
6770
check_negative_steps(indices)
6871
return x[indices]
6972

@@ -102,12 +105,14 @@ def inc_subtensor(x, y, *flattened_indices):
102105
@pytorch_funcify.register(AdvancedIncSubtensor)
103106
@pytorch_funcify.register(AdvancedIncSubtensor1)
104107
def pytorch_funcify_AdvancedIncSubtensor(op, node, **kwargs):
108+
idx_list = getattr(op, "idx_list", None)
105109
inplace = op.inplace
106110
ignore_duplicates = getattr(op, "ignore_duplicates", False)
107111

108112
if op.set_instead_of_inc:
109113

110-
def adv_set_subtensor(x, y, *indices):
114+
def adv_set_subtensor(x, y, *flattened_indices):
115+
indices = indices_from_subtensor(flattened_indices, idx_list)
111116
check_negative_steps(indices)
112117
if isinstance(op, AdvancedIncSubtensor1):
113118
op._check_runtime_broadcasting(node, x, y, indices)
@@ -120,7 +125,8 @@ def adv_set_subtensor(x, y, *indices):
120125

121126
elif ignore_duplicates:
122127

123-
def adv_inc_subtensor_no_duplicates(x, y, *indices):
128+
def adv_inc_subtensor_no_duplicates(x, y, *flattened_indices):
129+
indices = indices_from_subtensor(flattened_indices, idx_list)
124130
check_negative_steps(indices)
125131
if isinstance(op, AdvancedIncSubtensor1):
126132
op._check_runtime_broadcasting(node, x, y, indices)
@@ -132,13 +138,16 @@ def adv_inc_subtensor_no_duplicates(x, y, *indices):
132138
return adv_inc_subtensor_no_duplicates
133139

134140
else:
135-
if any(isinstance(idx.type, SliceType) for idx in node.inputs[2:]):
141+
# Check if we have slice indexing in idx_list
142+
has_slice_indexing = any(isinstance(entry, slice) for entry in idx_list) if idx_list else False
143+
if has_slice_indexing:
136144
raise NotImplementedError(
137145
"IncSubtensor with potential duplicates indexes and slice indexing not implemented in PyTorch"
138146
)
139147

140-
def adv_inc_subtensor(x, y, *indices):
141-
# Not needed because slices aren't supported
148+
def adv_inc_subtensor(x, y, *flattened_indices):
149+
indices = indices_from_subtensor(flattened_indices, idx_list)
150+
# Not needed because slices aren't supported in this path
142151
# check_negative_steps(indices)
143152
if not inplace:
144153
x = x.clone()

pytensor/tensor/rewriting/subtensor.py

Lines changed: 71 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,18 @@ def local_replace_AdvancedSubtensor(fgraph, node):
228228
return
229229

230230
indexed_var = node.inputs[0]
231-
indices = node.inputs[1:]
231+
tensor_inputs = node.inputs[1:]
232+
233+
# Reconstruct indices from idx_list and tensor inputs
234+
indices = []
235+
input_idx = 0
236+
for entry in node.op.idx_list:
237+
if isinstance(entry, slice):
238+
indices.append(entry)
239+
elif isinstance(entry, Type):
240+
if input_idx < len(tensor_inputs):
241+
indices.append(tensor_inputs[input_idx])
242+
input_idx += 1
232243

233244
axis = get_advsubtensor_axis(indices)
234245

@@ -255,7 +266,18 @@ def local_AdvancedIncSubtensor_to_AdvancedIncSubtensor1(fgraph, node):
255266

256267
res = node.inputs[0]
257268
val = node.inputs[1]
258-
indices = node.inputs[2:]
269+
tensor_inputs = node.inputs[2:]
270+
271+
# Reconstruct indices from idx_list and tensor inputs
272+
indices = []
273+
input_idx = 0
274+
for entry in node.op.idx_list:
275+
if isinstance(entry, slice):
276+
indices.append(entry)
277+
elif isinstance(entry, Type):
278+
if input_idx < len(tensor_inputs):
279+
indices.append(tensor_inputs[input_idx])
280+
input_idx += 1
259281

260282
axis = get_advsubtensor_axis(indices)
261283

@@ -1751,9 +1773,22 @@ def ravel_multidimensional_bool_idx(fgraph, node):
17511773
x[eye(3, dtype=bool)].set(y) -> x.ravel()[eye(3).ravel()].set(y).reshape(x.shape)
17521774
"""
17531775
if isinstance(node.op, AdvancedSubtensor):
1754-
x, *idxs = node.inputs
1776+
x = node.inputs[0]
1777+
tensor_inputs = node.inputs[1:]
17551778
else:
1756-
x, y, *idxs = node.inputs
1779+
x, y = node.inputs[0], node.inputs[1]
1780+
tensor_inputs = node.inputs[2:]
1781+
1782+
# Reconstruct indices from idx_list and tensor inputs
1783+
idxs = []
1784+
input_idx = 0
1785+
for entry in node.op.idx_list:
1786+
if isinstance(entry, slice):
1787+
idxs.append(entry)
1788+
elif isinstance(entry, Type):
1789+
if input_idx < len(tensor_inputs):
1790+
idxs.append(tensor_inputs[input_idx])
1791+
input_idx += 1
17571792

17581793
if any(
17591794
(
@@ -1791,12 +1826,41 @@ def ravel_multidimensional_bool_idx(fgraph, node):
17911826
new_idxs[bool_idx_pos] = raveled_bool_idx
17921827

17931828
if isinstance(node.op, AdvancedSubtensor):
1794-
new_out = node.op(raveled_x, *new_idxs)
1829+
# Create new AdvancedSubtensor with updated idx_list
1830+
new_idx_list = list(node.op.idx_list)
1831+
new_tensor_inputs = list(tensor_inputs)
1832+
1833+
# Update the idx_list and tensor_inputs for the raveled boolean index
1834+
input_idx = 0
1835+
for i, entry in enumerate(node.op.idx_list):
1836+
if isinstance(entry, Type):
1837+
if input_idx == bool_idx_pos:
1838+
new_tensor_inputs[input_idx] = raveled_bool_idx
1839+
input_idx += 1
1840+
1841+
new_out = AdvancedSubtensor(new_idx_list)(raveled_x, *new_tensor_inputs)
17951842
else:
1843+
# Create new AdvancedIncSubtensor with updated idx_list
1844+
new_idx_list = list(node.op.idx_list)
1845+
new_tensor_inputs = list(tensor_inputs)
1846+
1847+
# Update the tensor_inputs for the raveled boolean index
1848+
input_idx = 0
1849+
for i, entry in enumerate(node.op.idx_list):
1850+
if isinstance(entry, Type):
1851+
if input_idx == bool_idx_pos:
1852+
new_tensor_inputs[input_idx] = raveled_bool_idx
1853+
input_idx += 1
1854+
17961855
# The dimensions of y that correspond to the boolean indices
17971856
# must already be raveled in the original graph, so we don't need to do anything to it
1798-
new_out = node.op(raveled_x, y, *new_idxs)
1799-
# But we must reshape the output to math the original shape
1857+
new_out = AdvancedIncSubtensor(
1858+
new_idx_list,
1859+
inplace=node.op.inplace,
1860+
set_instead_of_inc=node.op.set_instead_of_inc,
1861+
ignore_duplicates=node.op.ignore_duplicates
1862+
)(raveled_x, y, *new_tensor_inputs)
1863+
# But we must reshape the output to match the original shape
18001864
new_out = new_out.reshape(x_shape)
18011865

18021866
return [copy_stack_trace(node.outputs[0], new_out)]

pytensor/tensor/subtensor.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2585,10 +2585,12 @@ def __init__(self, idx_list):
25852585
Parameters
25862586
----------
25872587
idx_list : tuple
2588-
List of indices where slices and newaxis are stored as-is,
2588+
List of indices where slices are stored as-is,
25892589
and numerical indices are replaced by their types.
25902590
"""
25912591
self.idx_list = tuple(map(index_vars_to_types, idx_list))
2592+
# Store expected number of tensor inputs for validation
2593+
self.expected_inputs_len = len(get_slice_elements(self.idx_list, lambda entry: isinstance(entry, Type)))
25922594

25932595
def make_node(self, x, *inputs):
25942596
"""
@@ -2604,15 +2606,14 @@ def make_node(self, x, *inputs):
26042606
inputs = tuple(as_tensor_variable(a) for a in inputs)
26052607

26062608
idx_list = list(self.idx_list)
2607-
if (len([entry for entry in idx_list if entry is not np.newaxis]) > x.type.ndim):
2609+
if len(idx_list) > x.type.ndim:
26082610
raise IndexError("too many indices for array")
26092611

26102612
# Validate input count matches expected from idx_list
2611-
expected_inputs = get_slice_elements(idx_list, lambda entry: isinstance(entry, Type))
2612-
if len(inputs) != len(expected_inputs):
2613-
raise ValueError(f"Expected {len(expected_inputs)} inputs but got {len(inputs)}")
2613+
if len(inputs) != self.expected_inputs_len:
2614+
raise ValueError(f"Expected {self.expected_inputs_len} inputs but got {len(inputs)}")
26142615

2615-
# Build explicit_indices for shape inference (newaxis handled by __getitem__)
2616+
# Build explicit_indices for shape inference
26162617
explicit_indices = []
26172618
input_idx = 0
26182619

@@ -2982,6 +2983,8 @@ def __init__(
29822983
self, idx_list, inplace=False, set_instead_of_inc=False, ignore_duplicates=False
29832984
):
29842985
self.idx_list = tuple(map(index_vars_to_types, idx_list))
2986+
# Store expected number of tensor inputs for validation
2987+
self.expected_inputs_len = len(get_slice_elements(self.idx_list, lambda entry: isinstance(entry, Type)))
29852988
self.set_instead_of_inc = set_instead_of_inc
29862989
self.inplace = inplace
29872990
if inplace:
@@ -3000,9 +3003,8 @@ def make_node(self, x, y, *inputs):
30003003
y = as_tensor_variable(y)
30013004

30023005
# Validate that we have the right number of tensor inputs for our idx_list
3003-
expected_tensor_inputs = sum(1 for entry in self.idx_list if isinstance(entry, Type))
3004-
if len(inputs) != expected_tensor_inputs:
3005-
raise ValueError(f"Expected {expected_tensor_inputs} tensor inputs but got {len(inputs)}")
3006+
if len(inputs) != self.expected_inputs_len:
3007+
raise ValueError(f"Expected {self.expected_inputs_len} tensor inputs but got {len(inputs)}")
30063008

30073009
new_inputs = []
30083010
for inp in inputs:

0 commit comments

Comments
 (0)