Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 32 additions & 6 deletions pytensor/link/jax/dispatch/subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,18 @@
"""


@jax_funcify.register(AdvancedSubtensor1)
def jax_funcify_AdvancedSubtensor1(op, node, **kwargs):
def advanced_subtensor1(x, ilist):
return x[ilist]

return advanced_subtensor1


@jax_funcify.register(Subtensor)
@jax_funcify.register(AdvancedSubtensor)
@jax_funcify.register(AdvancedSubtensor1)
def jax_funcify_Subtensor(op, node, **kwargs):
idx_list = getattr(op, "idx_list", None)
idx_list = op.idx_list

def subtensor(x, *ilists):
indices = indices_from_subtensor(ilists, idx_list)
Expand All @@ -47,10 +54,24 @@ def subtensor(x, *ilists):
return subtensor


@jax_funcify.register(IncSubtensor)
@jax_funcify.register(AdvancedIncSubtensor1)
def jax_funcify_AdvancedIncSubtensor1(op, node, **kwargs):
if getattr(op, "set_instead_of_inc", False):

def jax_fn(x, y, ilist):
return x.at[ilist].set(y)

else:

def jax_fn(x, y, ilist):
return x.at[ilist].add(y)

return jax_fn


@jax_funcify.register(IncSubtensor)
def jax_funcify_IncSubtensor(op, node, **kwargs):
idx_list = getattr(op, "idx_list", None)
idx_list = op.idx_list

if getattr(op, "set_instead_of_inc", False):

Expand All @@ -77,6 +98,8 @@ def incsubtensor(x, y, *ilist, jax_fn=jax_fn, idx_list=idx_list):

@jax_funcify.register(AdvancedIncSubtensor)
def jax_funcify_AdvancedIncSubtensor(op, node, **kwargs):
idx_list = op.idx_list

if getattr(op, "set_instead_of_inc", False):

def jax_fn(x, indices, y):
Expand All @@ -87,8 +110,11 @@ def jax_fn(x, indices, y):
def jax_fn(x, indices, y):
return x.at[indices].add(y)

def advancedincsubtensor(x, y, *ilist, jax_fn=jax_fn):
return jax_fn(x, ilist, y)
def advancedincsubtensor(x, y, *ilist, jax_fn=jax_fn, idx_list=idx_list):
indices = indices_from_subtensor(ilist, idx_list)
if len(indices) == 1:
indices = indices[0]
return jax_fn(x, indices, y)

return advancedincsubtensor

Expand Down
49 changes: 26 additions & 23 deletions pytensor/link/numba/dispatch/subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
)
from pytensor.link.numba.dispatch.compile_ops import numba_deepcopy
from pytensor.tensor import TensorType
from pytensor.tensor.rewriting.subtensor import is_full_slice
from pytensor.tensor.subtensor import (
AdvancedIncSubtensor,
AdvancedIncSubtensor1,
Expand All @@ -29,7 +28,7 @@
IncSubtensor,
Subtensor,
)
from pytensor.tensor.type_other import MakeSlice, NoneTypeT, SliceType
from pytensor.tensor.type_other import MakeSlice


def slice_new(self, start, stop, step):
Expand Down Expand Up @@ -239,28 +238,32 @@ def {function_name}({", ".join(input_names)}):
@register_funcify_and_cache_key(AdvancedIncSubtensor)
def numba_funcify_AdvancedSubtensor(op, node, **kwargs):
if isinstance(op, AdvancedSubtensor):
_x, _y, idxs = node.inputs[0], None, node.inputs[1:]
tensor_inputs = node.inputs[1:]
else:
_x, _y, *idxs = node.inputs

basic_idxs = [
idx
for idx in idxs
if (
isinstance(idx.type, NoneTypeT)
or (isinstance(idx.type, SliceType) and not is_full_slice(idx))
)
]
adv_idxs = [
{
"axis": i,
"dtype": idx.type.dtype,
"bcast": idx.type.broadcastable,
"ndim": idx.type.ndim,
}
for i, idx in enumerate(idxs)
if isinstance(idx.type, TensorType)
]
tensor_inputs = node.inputs[2:]

# Reconstruct indexing information from idx_list and tensor inputs
basic_idxs = []
adv_idxs = []
input_idx = 0

for i, entry in enumerate(op.idx_list):
if isinstance(entry, slice):
# Basic slice index
basic_idxs.append(entry)
elif isinstance(entry, Type):
# Advanced tensor index
if input_idx < len(tensor_inputs):
idx_input = tensor_inputs[input_idx]
adv_idxs.append(
{
"axis": i,
"dtype": idx_input.type.dtype,
"bcast": idx_input.type.broadcastable,
"ndim": idx_input.type.ndim,
}
)
input_idx += 1

# Special implementation for consecutive integer vector indices
if (
Expand Down
25 changes: 18 additions & 7 deletions pytensor/link/pytorch/dispatch/subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
Subtensor,
indices_from_subtensor,
)
from pytensor.tensor.type_other import MakeSlice, SliceType
from pytensor.tensor.type_other import MakeSlice


def check_negative_steps(indices):
Expand Down Expand Up @@ -63,7 +63,10 @@ def makeslice(start, stop, step):
@pytorch_funcify.register(AdvancedSubtensor1)
@pytorch_funcify.register(AdvancedSubtensor)
def pytorch_funcify_AdvSubtensor(op, node, **kwargs):
def advsubtensor(x, *indices):
idx_list = op.idx_list

def advsubtensor(x, *flattened_indices):
indices = indices_from_subtensor(flattened_indices, idx_list)
check_negative_steps(indices)
return x[indices]

Expand Down Expand Up @@ -102,12 +105,14 @@ def inc_subtensor(x, y, *flattened_indices):
@pytorch_funcify.register(AdvancedIncSubtensor)
@pytorch_funcify.register(AdvancedIncSubtensor1)
def pytorch_funcify_AdvancedIncSubtensor(op, node, **kwargs):
idx_list = op.idx_list
inplace = op.inplace
ignore_duplicates = getattr(op, "ignore_duplicates", False)

if op.set_instead_of_inc:

def adv_set_subtensor(x, y, *indices):
def adv_set_subtensor(x, y, *flattened_indices):
indices = indices_from_subtensor(flattened_indices, idx_list)
check_negative_steps(indices)
if isinstance(op, AdvancedIncSubtensor1):
op._check_runtime_broadcasting(node, x, y, indices)
Expand All @@ -120,7 +125,8 @@ def adv_set_subtensor(x, y, *indices):

elif ignore_duplicates:

def adv_inc_subtensor_no_duplicates(x, y, *indices):
def adv_inc_subtensor_no_duplicates(x, y, *flattened_indices):
indices = indices_from_subtensor(flattened_indices, idx_list)
check_negative_steps(indices)
if isinstance(op, AdvancedIncSubtensor1):
op._check_runtime_broadcasting(node, x, y, indices)
Expand All @@ -132,13 +138,18 @@ def adv_inc_subtensor_no_duplicates(x, y, *indices):
return adv_inc_subtensor_no_duplicates

else:
if any(isinstance(idx.type, SliceType) for idx in node.inputs[2:]):
# Check if we have slice indexing in idx_list
has_slice_indexing = (
any(isinstance(entry, slice) for entry in idx_list) if idx_list else False
)
if has_slice_indexing:
raise NotImplementedError(
"IncSubtensor with potential duplicates indexes and slice indexing not implemented in PyTorch"
)

def adv_inc_subtensor(x, y, *indices):
# Not needed because slices aren't supported
def adv_inc_subtensor(x, y, *flattened_indices):
indices = indices_from_subtensor(flattened_indices, idx_list)
# Not needed because slices aren't supported in this path
# check_negative_steps(indices)
if not inplace:
x = x.clone()
Expand Down
124 changes: 120 additions & 4 deletions pytensor/sparse/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1317,8 +1317,6 @@ def perform(self, node, inputs, outputs):
z[0] = y

def grad(self, inputs, gout):
from pytensor.sparse.math import sp_sum

(x, s) = inputs
(gz,) = gout
return [col_scale(gz, s), sp_sum(x * gz, axis=0)]
Expand Down Expand Up @@ -1368,8 +1366,6 @@ def perform(self, node, inputs, outputs):
z[0] = scipy.sparse.csc_matrix((y_data, indices, indptr), (M, N))

def grad(self, inputs, gout):
from pytensor.sparse.math import sp_sum

(x, s) = inputs
(gz,) = gout
return [row_scale(gz, s), sp_sum(x * gz, axis=1)]
Expand Down Expand Up @@ -1435,6 +1431,126 @@ def row_scale(x, s):
return col_scale(x.T, s).T


class SpSum(Op):
"""

WARNING: judgement call...
We are not using the structured in the comparison or hashing
because it doesn't change the perform method therefore, we
*do* want Sums with different structured values to be merged
by the merge optimization and this requires them to compare equal.
"""

__props__ = ("axis",)

def __init__(self, axis=None, sparse_grad=True):
super().__init__()
self.axis = axis
self.structured = sparse_grad
if self.axis not in (None, 0, 1):
raise ValueError("Illegal value for self.axis.")

def make_node(self, x):
x = as_sparse_variable(x)
assert x.format in ("csr", "csc")

if self.axis is not None:
out_shape = (None,)
else:
out_shape = ()

z = TensorType(dtype=x.dtype, shape=out_shape)()
return Apply(self, [x], [z])

def perform(self, node, inputs, outputs):
(x,) = inputs
(z,) = outputs
if self.axis is None:
z[0] = np.asarray(x.sum())
else:
z[0] = np.asarray(x.sum(self.axis)).ravel()

def grad(self, inputs, gout):
(x,) = inputs
(gz,) = gout
if x.dtype not in continuous_dtypes:
return [x.zeros_like(dtype=config.floatX)]
if self.structured:
if self.axis is None:
r = gz * sp_ones_like(x)
elif self.axis == 0:
r = col_scale(sp_ones_like(x), gz)
elif self.axis == 1:
r = row_scale(sp_ones_like(x), gz)
else:
raise ValueError("Illegal value for self.axis.")
else:
o_format = x.format
x = dense_from_sparse(x)
if _is_sparse_variable(gz):
gz = dense_from_sparse(gz)
if self.axis is None:
r = ptb.second(x, gz)
else:
ones = ptb.ones_like(x)
if self.axis == 0:
r = specify_broadcastable(gz.dimshuffle("x", 0), 0) * ones
elif self.axis == 1:
r = specify_broadcastable(gz.dimshuffle(0, "x"), 1) * ones
else:
raise ValueError("Illegal value for self.axis.")
r = SparseFromDense(o_format)(r)
return [r]

def infer_shape(self, fgraph, node, shapes):
r = None
if self.axis is None:
r = [()]
elif self.axis == 0:
r = [(shapes[0][1],)]
else:
r = [(shapes[0][0],)]
return r

def __str__(self):
return f"{self.__class__.__name__}{{axis={self.axis}}}"


def sp_sum(x, axis=None, sparse_grad=False):
"""
Calculate the sum of a sparse matrix along the specified axis.

It operates a reduction along the specified axis. When `axis` is `None`,
it is applied along all axes.

Parameters
----------
x
Sparse matrix.
axis
Axis along which the sum is applied. Integer or `None`.
sparse_grad : bool
`True` to have a structured grad.

Returns
-------
object
The sum of `x` in a dense format.

Notes
-----
The grad implementation is controlled with the `sparse_grad` parameter.
`True` will provide a structured grad and `False` will provide a regular
grad. For both choices, the grad returns a sparse matrix having the same
format as `x`.

This op does not return a sparse matrix, but a dense tensor matrix.

"""

return SpSum(axis, sparse_grad)(x)


class Diag(Op):
"""Extract the diagonal of a square sparse matrix as a dense vector.

Expand Down
Loading