@@ -7,56 +7,96 @@ this, each :class:`Op` in an PyTensor graph must have an equivalent JAX/Numba im
77This tutorial will explain how JAX and Numba implementations are created for an :class: `Op `. It will
88focus specifically on the JAX case, but the same mechanisms are used for Numba as well.
99
10- Step 1: Identify the PyTensor :class: `Op ` you’ d like to implement in JAX
10+ Step 1: Identify the PyTensor :class: `Op ` you' d like to implement in JAX
1111------------------------------------------------------------------------
1212
13- Find the source for the PyTensor :class: `Op ` you’ d like to be supported in JAX, and
14- identify the function signature and return values. These can be determined by
15- looking at the :meth: `Op.make_node ` implementation. In general, one needs to be familiar
13+ Find the source for the PyTensor :class: `Op ` you' d like to be supported in JAX, and
14+ identify the function signature and return values. These can be determined by
15+ looking at the :meth: `Op.make_node ` implementation. In general, one needs to be familiar
1616with PyTensor :class: `Op `\s in order to provide a conversion implementation, so first read
1717:ref: `creating_an_op ` if you are not familiar.
1818
19- For example, the :class: ` Eye ` \ :class: ` Op ` current has an :meth: ` Op.make_node ` as follows :
19+ For example, you want to extend support for :class: ` CumsumOp ` \ :
2020
2121.. code :: python
2222
23- def make_node (self , n , m , k ):
24- n = as_tensor_variable(n)
25- m = as_tensor_variable(m)
26- k = as_tensor_variable(k)
27- assert n.ndim == 0
28- assert m.ndim == 0
29- assert k.ndim == 0
30- return Apply(
31- self ,
32- [n, m, k],
33- [TensorType(dtype = self .dtype, shape = (None , None ))()],
34- )
23+ class CumsumOp (Op ):
24+ __props__ = (" axis" ,)
25+
26+ def __new__ (typ , * args , ** kwargs ):
27+ obj = object .__new__ (CumOp, * args, ** kwargs)
28+ obj.mode = " add"
29+ return obj
30+
31+
32+ :class: `CumsumOp ` turns out to be a variant of :class: `CumOp `\ :class: `Op `
33+ which currently has an :meth: `Op.make_node ` as follows:
34+
35+ .. code :: python
3536
37+ def make_node (self , x ):
38+ x = ptb.as_tensor_variable(x)
39+ out_type = x.type()
40+
41+ if self .axis is None :
42+ out_type = vector(dtype = x.dtype) # Flatten
43+ elif self .axis >= x.ndim or self .axis < - x.ndim:
44+ raise ValueError (f " axis(= { self .axis} ) out of bounds " )
45+
46+ return Apply(self , [x], [out_type])
3647
3748 The :class: `Apply ` instance that's returned specifies the exact types of inputs that
3849our JAX implementation will receive and the exact types of outputs it's expected to
39- return--both in terms of their data types and number of dimensions.
50+ return--both in terms of their data types and number of dimensions/shapes .
4051The actual inputs our implementation will receive are necessarily numeric values
4152or NumPy :class: `ndarray `\s ; all that :meth: `Op.make_node ` tells us is the
4253general signature of the underlying computation.
4354
44- More specifically, the :class: `Apply ` implies that the inputs come from values that are
45- automatically converted to PyTensor variables via :func: `as_tensor_variable `, and
46- the ``assert ``\s that follow imply that they must be scalars. According to this
47- logic, the inputs could have any data type (e.g. floats, ints), so our JAX
48- implementation must be able to handle all the possible data types.
55+ More specifically, the :class: `Apply ` implies that there is one input that is
56+ automatically converted to PyTensor variables via :func: `as_tensor_variable `.
57+ There is another parameter, `axis `, that is used to determine the direction
58+ of the operation, hence shape of the output. The check that follows imply that
59+ `axis ` must refer to a dimension in the input tensor. The input's elements
60+ could also have any data type (e.g. floats, ints), so our JAX implementation
61+ must be able to handle all the possible data types.
4962
5063It also tells us that there's only one return value, that it has a data type
51- determined by :attr: `Eye.dtype `, and that it has two non-broadcastable
52- dimensions. The latter implies that the result is necessarily a matrix. The
53- former implies that our JAX implementation will need to access the :attr: `dtype `
54- attribute of the PyTensor :class: `Eye `\ :class: `Op ` it's converting.
64+ determined by :meth: `x.type() ` i.e., the data type of the original tensor.
65+ This implies that the result is necessarily a matrix.
5566
56- Next, we can look at the :meth: `Op.perform ` implementation to see exactly
57- how the inputs and outputs are used to compute the outputs for an :class: `Op `
58- in Python. This method is effectively what needs to be implemented in JAX.
67+ Some class may have a more complex behavior. For example, the :class: `CumOp `\ :class: `Op `
68+ also has another variant :class: `CumprodOp `\ :class: `Op ` with the exact signature
69+ as :class: `CumsumOp `\ :class: `Op `. The difference lies in that the `mode ` attribute in
70+ :class: `CumOp ` definition:
71+
72+ .. code :: python
5973
74+ class CumOp (COp ):
75+ # See function cumsum/cumprod for docstring
76+
77+ __props__ = (" axis" , " mode" )
78+ check_input = False
79+ params_type = ParamsType(
80+ c_axis = int_t, mode = EnumList((" MODE_ADD" , " add" ), (" MODE_MUL" , " mul" ))
81+ )
82+
83+ def __init__ (self , axis : int | None = None , mode = " add" ):
84+ if mode not in (" add" , " mul" ):
85+ raise ValueError (f ' { type (self ).__name__ } : Unknown mode " { mode} " ' )
86+ self .axis = axis
87+ self .mode = mode
88+
89+ c_axis = property (lambda self : np.MAXDIMS if self .axis is None else self .axis)
90+
91+ `__props__ ` is used to parametrize the general behavior of the :class: `Op `. One need to
92+ pay attention to this to decide whether the JAX implementation should support all variants
93+ or raise an explicit NotImplementedError for cases that are not supported e.g., when
94+ :class: `CumsumOp ` of :class: `CumOp("add") ` is supported but not :class: `CumprodOp ` of
95+ :class: `CumOp("mul") `.
96+
97+ Next, we look at the :meth: `Op.perform ` implementation to see exactly
98+ how the inputs and outputs are used to compute the outputs for an :class: `Op `
99+ in Python. This method is effectively what needs to be implemented in JAX.
60100
61101Step 2: Find the relevant JAX method (or something close)
62102---------------------------------------------------------
@@ -82,47 +122,83 @@ Here's an example for :class:`IfElse`:
82122 )
83123 return res if n_outs > 1 else res[0 ]
84124
125+ In this case, :class: `CumOp ` is implemented with NumPy's :func: `numpy.cumsum `
126+ and :func: `numpy.cumprod `, which have JAX equivalents: :func: `jax.numpy.cumsum `
127+ and :func: `jax.numpy.cumprod `.
128+
129+ .. code :: python
130+
131+ def perform (self , node , inputs , output_storage ):
132+ x = inputs[0 ]
133+ z = output_storage[0 ]
134+ if self .mode == " add" :
135+ z[0 ] = np.cumsum(x, axis = self .axis)
136+ else :
137+ z[0 ] = np.cumprod(x, axis = self .axis)
85138
86139 Step 3: Register the function with the `jax_funcify ` dispatcher
87140---------------------------------------------------------------
88141
89- With the PyTensor `Op ` replicated in JAX, we’ ll need to register the
142+ With the PyTensor `Op ` replicated in JAX, we' ll need to register the
90143function with the PyTensor JAX `Linker `. This is done through the use of
91144`singledispatch `. If you don't know how `singledispatch ` works, see the
92145`Python documentation <https://docs.python.org/3/library/functools.html#functools.singledispatch >`_.
93146
94147The relevant dispatch functions created by `singledispatch ` are :func: `pytensor.link.numba.dispatch.numba_funcify ` and
95148:func: `pytensor.link.jax.dispatch.jax_funcify `.
96149
97- Here’ s an example for the `Eye `\ `Op `:
150+ Here' s an example for the `CumOp `\ `Op `:
98151
99152.. code :: python
100153
101154 import jax.numpy as jnp
102155
103- from pytensor.tensor.basic import Eye
156+ from pytensor.tensor.extra_ops import CumOp
104157 from pytensor.link.jax.dispatch import jax_funcify
105158
106159
107- @jax_funcify.register (Eye)
108- def jax_funcify_Eye (op ):
160+ @jax_funcify.register (CumOp)
161+ def jax_funcify_CumOp (op , ** kwargs ):
162+ axis = op.axis
163+ mode = op.mode
109164
110- # Obtain necessary "static" attributes from the Op being converted
111- dtype = op.dtype
165+ def cumop (x , axis = axis, mode = mode):
166+ if mode == " add" :
167+ return jnp.cumsum(x, axis = axis)
168+ else :
169+ return jnp.cumprod(x, axis = axis)
112170
113- # Create a JAX jit-able function that implements the Op
114- def eye (N , M , k ):
115- return jnp.eye(N, M, k, dtype = dtype)
171+ return cumop
116172
117- return eye
173+ Suppose `jnp.cumprod ` does not exist, we will need to register the function as follows:
174+
175+ .. code :: python
176+
177+ import jax.numpy as jnp
178+
179+ from pytensor.tensor.extra_ops import CumOp
180+ from pytensor.link.jax.dispatch import jax_funcify
118181
119182
183+ @jax_funcify.register (CumOp)
184+ def jax_funcify_CumOp (op , ** kwargs ):
185+ axis = op.axis
186+ mode = op.mode
187+
188+ def cumop (x , axis = axis, mode = mode):
189+ if mode == " add" :
190+ return jnp.cumsum(x, axis = axis)
191+ else :
192+ raise NotImplementedError (" JAX does not support cumprod function at the moment." )
193+
194+ return cumop
195+
120196 Step 4: Write tests
121197-------------------
122198
123199Test that your registered `Op ` is working correctly by adding tests to the
124- appropriate test suites in PyTensor (e.g. in ``tests.link.test_jax `` and one of
125- the modules in ``tests.link.numba.dispatch ``). The tests should ensure that your implementation can
200+ appropriate test suites in PyTensor (e.g. in ``tests.link.jax `` and one of
201+ the modules in ``tests.link.numba ``). The tests should ensure that your implementation can
126202handle the appropriate types of inputs and produce outputs equivalent to `Op.perform `.
127203Check the existing tests for the general outline of these kinds of tests. In
128204most cases, a helper function can be used to easily verify the correspondence
@@ -131,23 +207,79 @@ between a JAX/Numba implementation and its `Op`.
131207For example, the :func: `compare_jax_and_py ` function streamlines the steps
132208involved in making comparisons with `Op.perform `.
133209
134- Here's a small example of a test for :class: `Eye ` :
210+ Here's a small example of a test for :class: `CumOp ` above :
135211
136212.. code :: python
213+
214+ import numpy as np
215+ import pytensor.tensor as pt
216+ from pytensor.configdefaults import config
217+ from tests.link.jax.test_basic import compare_jax_and_py
218+ from pytensor.graph import FunctionGraph
219+ from pytensor.graph.op import get_test_value
220+
221+ def test_jax_CumOp ():
222+ """ Test JAX conversion of the `CumOp` `Op`."""
223+
224+ # Create a symbolic input for the first input of `CumOp`
225+ a = pt.matrix(" a" )
137226
138- import pytensor.tensor as pt
227+ # Create test value tag for a
228+ a.tag.test_value = np.arange(9 , dtype = config.floatX).reshape((3 , 3 ))
229+
230+ # Create the output variable
231+ out = pt.cumsum(a, axis = 0 )
232+
233+ # Create a PyTensor `FunctionGraph`
234+ fgraph = FunctionGraph([a], [out])
235+
236+ # Pass the graph and inputs to the testing function
237+ compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
238+
239+ # For the second mode of CumOp
240+ out = pt.cumprod(a, axis = 1 )
241+ fgraph = FunctionGraph([a], [out])
242+ compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
243+
244+ If the variant :class: `CumprodOp ` is not implemented, we can add a test for it as follows:
245+
246+ .. code :: python
247+
248+ import pytest
249+
250+ def test_jax_CumOp ():
251+ """ Test JAX conversion of the `CumOp` `Op`."""
252+ a = pt.matrix(" a" )
253+ a.tag.test_value = np.arange(9 , dtype = config.floatX).reshape((3 , 3 ))
254+
255+ with pytest.raises(NotImplementedError ):
256+ out = pt.cumprod(a, axis = 1 )
257+ fgraph = FunctionGraph([a], [out])
258+ compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
259+
260+ Note
261+ ----
262+ In out previous example of extending JAX, :class: `Eye `\ :class: `Op ` was used with the test function as follows:
263+
264+ .. code :: python
265+ def test_jax_Eye ():
266+ """ Test JAX conversion of the `Eye` `Op`."""
139267
140- def test_jax_Eye ():
141- """ Test JAX conversion of the `Eye` `Op`. """
268+ # Create a symbolic input for `Eye`
269+ x_at = pt.scalar()
142270
143- # Create a symbolic input for `Eye`
144- x_at = pt.scalar( )
271+ # Create a variable that is the output of an `Eye` `Op `
272+ eye_var = pt.eye(x_at )
145273
146- # Create a variable that is the output of an `Eye` `Op `
147- eye_var = pt.eye(x_at )
274+ # Create an PyTensor `FunctionGraph `
275+ out_fg = FunctionGraph( outputs = [eye_var] )
148276
149- # Create an PyTensor `FunctionGraph`
150- out_fg = FunctionGraph( outputs = [eye_var ])
277+ # Pass the graph and any inputs to the testing function
278+ compare_jax_and_py(out_fg, [ 3 ])
151279
152- # Pass the graph and any inputs to the testing function
153- compare_jax_and_py(out_fg, [3 ])
280+ This one nowadays leads to a test failure due to new restrictions in JAX + JIT,
281+ as reported in issue `#654 <https://github.com/pymc-devs/pytensor/issues/654 >`_.
282+ All jitted functions now must have constant shape, which means a graph like the
283+ one of :class: `Eye ` can never be translated to JAX, since it's fundamentally a
284+ function with dynamic shapes. In other words, only PyTensor graphs with static shapes
285+ can be translated to JAX at the moment.
0 commit comments