Add support for JAX FFI vmap in jax_kernel() and jax_callable() (GH-859)#1203
Add support for JAX FFI vmap in jax_kernel() and jax_callable() (GH-859)#1203nvlukasz wants to merge 5 commits intoNVIDIA:mainfrom
Conversation
📝 WalkthroughWalkthroughThis pull request adds JAX vmap support to Warp's FFI functions (jax_kernel and jax_callable). It includes batch dimension handling utilities, hashable normalization of launch/output dimensions for caching, comprehensive documentation examples, and new test coverage for various vmap batching scenarios. Changes
Sequence DiagramsequenceDiagram
participant JAX as JAX vmap
participant FFI as jax_kernel/jax_callable
participant BatchUtil as Batch Utilities
participant Callback as FFI Callback
participant Kernel as Warp Kernel
JAX->>FFI: Call with batched inputs<br/>(batch_ndim > 0)
FFI->>BatchUtil: collapse_batch_dims(input_shape)
BatchUtil-->>FFI: Collapsed shape
FFI->>BatchUtil: compute_batch_size(shape, batch_ndim)
BatchUtil-->>FFI: Batch size value
FFI->>FFI: Normalize launch_dims/output_dims<br/>to hashable tuples
FFI->>FFI: Check/populate cache<br/>with hashable keys
alt Cache hit
FFI-->>JAX: Return cached result
else Cache miss
FFI->>Callback: Create FFI callback with<br/>batch-aware buffer handling
Callback->>Callback: Reshape inputs using<br/>collapse_batch_dims
Callback->>Callback: Infer launch_dims if None
Callback->>Kernel: Execute with adjusted<br/>launch dimensions
Kernel-->>Callback: Kernel results
Callback->>Callback: Reshape outputs using<br/>batch dimensions
Callback-->>FFI: Processed outputs
FFI->>FFI: Cache result
FFI-->>JAX: Return output
end
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes 🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 3
🤖 Fix all issues with AI agents
In `@CHANGELOG.md`:
- Line 49: Update the CHANGELOG entry "Fix hashing errors when creating
`jax_kernel()` and `jax_callable()`." to include the GitHub issue reference in
the required format — append something like
`([GH-XXX](https://github.com/NVIDIA/warp/issues/XXX))` to that line while
keeping the imperative present tense; target the exact entry mentioning
`jax_kernel()` and `jax_callable()` and replace XXX with the correct issue
number.
In `@docs/user_guide/interoperability.rst`:
- Around line 1290-1305: The snippet uses partial, jax, and jp but doesn't
import them; add the missing imports at the top (e.g., import jax, import
jax.numpy as jp, and from functools import partial) so symbols used by
jax_kernel/jax_lookup, jax.random, and jp.arange/jp.float32 resolve when the
example is copied standalone.
In `@warp/_src/jax_experimental/ffi.py`:
- Around line 1174-1179: When computing hashable_output_dims, ensure any
list-like shape values are converted to tuples before forming the hashable
structure: in the dict branch (where output_dims is a dict) build
hashable_output_dims as tuple(sorted((k, tuple(v) if hasattr(v, "__len__") and
not isinstance(v, (str, bytes, tuple)) else v) for k, v in
output_dims.items())), and in the iterable branch convert list-like output_dims
to tuple (e.g., if hasattr(output_dims, "__len__") and not
isinstance(output_dims, tuple): hashable_output_dims = tuple(output_dims)).
Update the code that sets hashable_output_dims (the existing output_dims ->
hashable_output_dims logic) accordingly.
| - Fix JAX FFI multi-gpu graph caching ([GH-1181](https://github.com/NVIDIA/warp/pull/1181)). | ||
| - Fix tile * constant multiplication when one operand is a vector or matrix type ([GH-1175](https://github.com/NVIDIA/warp/issues/1175)). | ||
| - Fix kernel symbol resolution accepting invalid namespace paths like `wp.foo.bar.tid()` ([GH-1198](https://github.com/NVIDIA/warp/issues/1198)). | ||
| - Fix hashing errors when creating `jax_kernel()` and `jax_callable()`. |
There was a problem hiding this comment.
Add a GH issue reference for the hashing fix entry.
This line is missing the required issue link in the specified format.
Suggested update
- Fix hashing errors when creating `jax_kernel()` and `jax_callable()`.
+ Fix hashing errors when creating `jax_kernel()` and `jax_callable()`
+ ([GH-XXX](https://github.com/NVIDIA/warp/issues/XXX)).As per coding guidelines, use imperative present tense in CHANGELOG.md entries ('Add X', not 'Added X' or 'This adds X') and include issue references as ([GH-XXX](https://github.com/NVIDIA/warp/issues/XXX)).
🤖 Prompt for AI Agents
In `@CHANGELOG.md` at line 49, Update the CHANGELOG entry "Fix hashing errors when
creating `jax_kernel()` and `jax_callable()`." to include the GitHub issue
reference in the required format — append something like
`([GH-XXX](https://github.com/NVIDIA/warp/issues/XXX))` to that line while
keeping the imperative present tense; target the exact entry mentioning
`jax_kernel()` and `jax_callable()` and replace XXX with the correct issue
number.
| .. code-block:: python | ||
|
|
||
| jax_lookup = jax_kernel(lookup_kernel) | ||
|
|
||
| # lookup table (not batched) | ||
| N = 100 | ||
| table = jp.arange(N, dtype=jp.float32) | ||
|
|
||
| # batched indices to look up | ||
| key = jax.random.key(42) | ||
| indices = jax.random.randint(key, (20, 50), 0, N, dtype=jp.int32) | ||
|
|
||
| # vmap with batch dim 0: input 20 sets of 50 indices each, output shape (20, 50) | ||
| (output,) = jax.jit(jax.vmap(partial(jax_lookup, launch_dims=50), in_axes=(None, 0)))( | ||
| table, indices | ||
| ) |
There was a problem hiding this comment.
Add the missing imports in the custom launch VMAP example.
partial (and jax) are used but not imported in this snippet, which makes it fail when copied standalone.
Suggested update
.. code-block:: python
+ from functools import partial
+ import jax
+
jax_lookup = jax_kernel(lookup_kernel)
# lookup table (not batched)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| .. code-block:: python | |
| jax_lookup = jax_kernel(lookup_kernel) | |
| # lookup table (not batched) | |
| N = 100 | |
| table = jp.arange(N, dtype=jp.float32) | |
| # batched indices to look up | |
| key = jax.random.key(42) | |
| indices = jax.random.randint(key, (20, 50), 0, N, dtype=jp.int32) | |
| # vmap with batch dim 0: input 20 sets of 50 indices each, output shape (20, 50) | |
| (output,) = jax.jit(jax.vmap(partial(jax_lookup, launch_dims=50), in_axes=(None, 0)))( | |
| table, indices | |
| ) | |
| .. code-block:: python | |
| from functools import partial | |
| import jax | |
| jax_lookup = jax_kernel(lookup_kernel) | |
| # lookup table (not batched) | |
| N = 100 | |
| table = jp.arange(N, dtype=jp.float32) | |
| # batched indices to look up | |
| key = jax.random.key(42) | |
| indices = jax.random.randint(key, (20, 50), 0, N, dtype=jp.int32) | |
| # vmap with batch dim 0: input 20 sets of 50 indices each, output shape (20, 50) | |
| (output,) = jax.jit(jax.vmap(partial(jax_lookup, launch_dims=50), in_axes=(None, 0)))( | |
| table, indices | |
| ) |
🤖 Prompt for AI Agents
In `@docs/user_guide/interoperability.rst` around lines 1290 - 1305, The snippet
uses partial, jax, and jp but doesn't import them; add the missing imports at
the top (e.g., import jax, import jax.numpy as jp, and from functools import
partial) so symbols used by jax_kernel/jax_lookup, jax.random, and
jp.arange/jp.float32 resolve when the example is copied standalone.
| if isinstance(output_dims, dict): | ||
| hashable_output_dims = tuple(sorted(output_dims.items())) | ||
| elif hasattr(output_dims, "__len__"): | ||
| hashable_output_dims = tuple(output_dims) | ||
| else: | ||
| hashable_output_dims = output_dims |
There was a problem hiding this comment.
Normalize dict output_dims values before hashing.
Lists are valid shapes (per docs), but tuple(sorted(output_dims.items())) still embeds lists and remains unhashable, causing TypeError for valid calls.
Suggested update
- if isinstance(output_dims, dict):
- hashable_output_dims = tuple(sorted(output_dims.items()))
+ if isinstance(output_dims, dict):
+ hashable_output_dims = tuple(
+ sorted(
+ (k, tuple(v) if hasattr(v, "__len__") else v)
+ for k, v in output_dims.items()
+ )
+ )
elif hasattr(output_dims, "__len__"):
hashable_output_dims = tuple(output_dims)
else:
hashable_output_dims = output_dims- if isinstance(output_dims, dict):
- hashable_output_dims = tuple(sorted(output_dims.items()))
+ if isinstance(output_dims, dict):
+ hashable_output_dims = tuple(
+ sorted(
+ (k, tuple(v) if hasattr(v, "__len__") else v)
+ for k, v in output_dims.items()
+ )
+ )
elif hasattr(output_dims, "__len__"):
hashable_output_dims = tuple(output_dims)
else:
hashable_output_dims = output_dimsAlso applies to: 1496-1501
🤖 Prompt for AI Agents
In `@warp/_src/jax_experimental/ffi.py` around lines 1174 - 1179, When computing
hashable_output_dims, ensure any list-like shape values are converted to tuples
before forming the hashable structure: in the dict branch (where output_dims is
a dict) build hashable_output_dims as tuple(sorted((k, tuple(v) if hasattr(v,
"__len__") and not isinstance(v, (str, bytes, tuple)) else v) for k, v in
output_dims.items())), and in the iterable branch convert list-like output_dims
to tuple (e.g., if hasattr(output_dims, "__len__") and not
isinstance(output_dims, tuple): hashable_output_dims = tuple(output_dims)).
Update the code that sets hashable_output_dims (the existing output_dims ->
hashable_output_dims logic) accordingly.
Implements #859.
Basic example:
Example with in-out arguments:
Example with custom launch/output dimensions:
Description
Before your PR is "Ready for review"
__init__.pyi,docs/api_reference/,docs/language_reference/)pre-commit run -aSummary by CodeRabbit
New Features
jax_kernel()andjax_callable()to enable batched computations with Warp kernels.Bug Fixes
jax_kernel()andjax_callable()creation.Documentation
Tests
✏️ Tip: You can customize this high-level summary in your review settings.