33# https://github.com/scikit-learn/scikit-learn/pull/27910#issuecomment-2568023972
44from __future__ import annotations
55
6+ import math
67from collections .abc import Callable , Sequence
78from functools import wraps
89from types import ModuleType
@@ -30,11 +31,19 @@ class P: # pylint: disable=missing-class-docstring
3031 kwargs : dict
3132
3233
34+ class UnknownShapeError (ValueError ):
35+ """
36+ `shape` contains one or more None elements.
37+
38+ This is unsupported when running inside `jax.jit`.
39+ """
40+
41+
3342@overload
3443def apply_numpy_func ( # type: ignore[valid-type]
3544 func : Callable [P , NumPyObject ],
3645 * args : Array ,
37- shape : tuple [int , ...] | None = None ,
46+ shape : tuple [int | None , ...] | None = None ,
3847 dtype : DType | None = None ,
3948 xp : ModuleType | None = None ,
4049 ** kwargs : P .kwargs , # pyright: ignore[reportGeneralTypeIssues]
@@ -45,7 +54,7 @@ def apply_numpy_func( # type: ignore[valid-type]
4554def apply_numpy_func ( # type: ignore[valid-type]
4655 func : Callable [P , Sequence [NumPyObject ]],
4756 * args : Array ,
48- shape : Sequence [tuple [int , ...]],
57+ shape : Sequence [tuple [int | None , ...]],
4958 dtype : Sequence [DType ] | None = None ,
5059 xp : ModuleType | None = None ,
5160 ** kwargs : P .kwargs , # pyright: ignore[reportGeneralTypeIssues]
@@ -55,7 +64,7 @@ def apply_numpy_func( # type: ignore[valid-type]
5564def apply_numpy_func ( # type: ignore[valid-type] # numpydoc ignore=GL07,SA04
5665 func : Callable [P , NumPyObject | Sequence [NumPyObject ]],
5766 * args : Array ,
58- shape : tuple [int , ...] | Sequence [tuple [int , ...]] | None = None ,
67+ shape : tuple [int | None , ...] | Sequence [tuple [int | None , ...]] | None = None ,
5968 dtype : DType | Sequence [DType ] | None = None ,
6069 xp : ModuleType | None = None ,
6170 ** kwargs : P .kwargs , # pyright: ignore[reportGeneralTypeIssues]
@@ -76,7 +85,7 @@ def apply_numpy_func( # type: ignore[valid-type] # numpydoc ignore=GL07,SA04
7685 One or more Array API compliant arrays. You need to be able to apply
7786 :func:`numpy.asarray` to them to convert them to numpy; read notes below about
7887 specific backends.
79- shape : tuple[int, ...] | Sequence[tuple[int, ...]], optional
88+ shape : tuple[int | None , ...] | Sequence[tuple[int, ...]], optional
8089 Output shape or sequence of output shapes, one for each output of `func`.
8190 Default: assume single output and broadcast shapes of the input arrays.
8291 dtype : DType | Sequence[DType], optional
@@ -102,6 +111,8 @@ def apply_numpy_func( # type: ignore[valid-type] # numpydoc ignore=GL07,SA04
102111 JAX
103112 This allows applying eager functions to jitted JAX arrays, which are lazy.
104113 The function won't be applied until the JAX array is materialized.
114+ When running inside `jax.jit`, `shape` must be fully known, i.e. it cannot
115+ contain any `None` elements.
105116
106117 The :doc:`jax:transfer_guard` may prevent arrays on a GPU device from being
107118 transferred back to CPU. This is treated as an implicit transfer.
@@ -135,6 +146,18 @@ def apply_numpy_func( # type: ignore[valid-type] # numpydoc ignore=GL07,SA04
135146 :func:`dask.array.blockwise`, or a native Dask wrapper instead of
136147 `apply_numpy_func`.
137148
149+ Raises
150+ ------
151+ UnknownShapeError
152+ When `shape` is unknown (one or more sizes are None) and this function was
153+ called inside `jax.jit`.
154+
155+ Exception (varies)
156+
157+ - When the backend disallows implicit device to host transfers and the input
158+ arrays are on a device, e.g. on GPU;
159+ - When the backend is sparse and auto-densification is disabled.
160+
138161 See Also
139162 --------
140163 jax.transfer_guard
@@ -147,13 +170,16 @@ def apply_numpy_func( # type: ignore[valid-type] # numpydoc ignore=GL07,SA04
147170 xp = array_namespace (* args )
148171
149172 # Normalize and validate shape and dtype
173+ shapes : list [tuple [int | None , ...]]
174+ dtypes : list [DType ]
150175 multi_output = False
176+
151177 if shape is None :
152178 shapes = [xp .broadcast_shapes (* (arg .shape for arg in args ))]
153- elif isinstance (shape , tuple ) and all (isinstance (s , int ) for s in shape ):
154- shapes = [shape ]
179+ elif isinstance (shape , tuple ) and all (isinstance (s , int | None ) for s in shape ):
180+ shapes = [shape ] # pyright: ignore[reportAssignmentType]
155181 else :
156- shapes = list (shape )
182+ shapes = list (shape ) # type: ignore[arg-type] # pyright: ignore[reportAssignmentType]
157183 multi_output = True
158184
159185 if dtype is None :
@@ -186,13 +212,19 @@ def apply_numpy_func( # type: ignore[valid-type] # numpydoc ignore=GL07,SA04
186212 meta_xp = array_namespace (* metas )
187213
188214 wrapped = dask .delayed (_npfunc_wrapper (func , multi_output , meta_xp ), pure = True )
189- # This finalizes each arg, which is the same as arg.rechunk(-1)
215+ # This finalizes each arg, which is the same as arg.rechunk(-1).
190216 # Please read docstring above for why we're not using
191217 # dask.array.map_blocks or dask.array.blockwise!
192218 delayed_out = wrapped (* args , ** kwargs )
193219
194220 out = tuple (
195- xp .from_delayed (delayed_out [i ], shape = shape , dtype = dtype , meta = metas [0 ])
221+ xp .from_delayed (
222+ delayed_out [i ],
223+ # Dask's unknown shapes diverge from the Array API specification
224+ shape = tuple (math .nan if s is None else s for s in shape ),
225+ dtype = dtype ,
226+ meta = metas [0 ],
227+ )
196228 for i , (shape , dtype ) in enumerate (zip (shapes , dtypes , strict = True ))
197229 )
198230
@@ -205,18 +237,33 @@ def apply_numpy_func( # type: ignore[valid-type] # numpydoc ignore=GL07,SA04
205237 import jax # type: ignore[import-not-found] # pylint: disable=import-outside-toplevel,import-error # pyright: ignore[reportMissingImports]
206238
207239 wrapped = _npfunc_wrapper (func , multi_output , xp )
208- out = cast (
209- tuple [Array , ...],
210- jax .pure_callback (
211- wrapped ,
212- tuple (
213- jax .ShapeDtypeStruct (s , dt ) # pyright: ignore[reportUnknownArgumentType]
214- for s , dt in zip (shapes , dtypes , strict = True )
240+
241+ if any (s is None for shape in shapes for s in shape ):
242+ # Unknown output shape. Won't work with jax.jit, but it
243+ # can work with eager jax.
244+ try :
245+ out = wrapped (* args , ** kwargs )
246+ except jax .errors .TracerArrayConversionError :
247+ msg = (
248+ "jax.jit can't delay application of numpy functions when the shape "
249+ "of the returned array(s) is unknown. "
250+ f"shape={ shapes if multi_output else shapes [0 ]} "
251+ )
252+ raise UnknownShapeError (msg ) from None
253+
254+ else :
255+ out = cast (
256+ tuple [Array , ...],
257+ jax .pure_callback (
258+ wrapped ,
259+ tuple (
260+ jax .ShapeDtypeStruct (shape , dtype ) # pyright: ignore[reportUnknownArgumentType]
261+ for shape , dtype in zip (shapes , dtypes , strict = True )
262+ ),
263+ * args ,
264+ ** kwargs ,
215265 ),
216- * args ,
217- ** kwargs ,
218- ),
219- )
266+ )
220267
221268 else :
222269 # Eager backends
0 commit comments