Skip to content

Add support for JAX FFI vmap in jax_kernel() and jax_callable() (GH-859)#1203

Draft
nvlukasz wants to merge 5 commits intoNVIDIA:mainfrom
nvlukasz:lwawrzyniak/jax-ffi-vmap
Draft

Add support for JAX FFI vmap in jax_kernel() and jax_callable() (GH-859)#1203
nvlukasz wants to merge 5 commits intoNVIDIA:mainfrom
nvlukasz:lwawrzyniak/jax-ffi-vmap

Conversation

@nvlukasz
Copy link
Contributor

@nvlukasz nvlukasz commented Jan 30, 2026

Implements #859.

Basic example:

import warp as wp
from warp.jax_experimental import jax_kernel

import jax
import jax.numpy as jp

@wp.kernel
def add_kernel(a: wp.array(dtype=float), b: wp.array(dtype=float), output: wp.array(dtype=float)):
    tid = wp.tid()
    output[tid] = a[tid] + b[tid]

jax_add = jax_kernel(add_kernel)

# batched inputs
a = jp.arange(3 * 4, dtype=jp.float32).reshape((3, 4))
b = jp.ones(3 * 4, dtype=jp.float32).reshape((3, 4))

(output,) = jax.jit(jax.vmap(jax_add))(a, b)
print(output)

Example with in-out arguments:

import warp as wp
from warp.jax_experimental import jax_kernel

import jax
import jax.numpy as jp

@wp.kernel
def rowsum_kernel(matrix: wp.array2d(dtype=float), sums: wp.array1d(dtype=float)):
    i, j = wp.tid()
    wp.atomic_add(sums, i, matrix[i, j])

jax_rowsum = jax_kernel(rowsum_kernel, in_out_argnames=["sums"])

# batched input with shape (2, 3, 4)
matrices = jp.arange(2 * 3 * 4, dtype=jp.float32).reshape((2, 3, 4))
print(matrices)
print()

# vmap with batch dim 0: input 2 matrices with shape (3, 4), output shape (2, 3)
sums = jp.zeros((2, 3), dtype=jp.float32)
(output,) = jax.jit(jax.vmap(jax_rowsum, in_axes=(0, 0)))(matrices, sums)
print(output)
print()

# vmap with batch dim 1: input 3 matrices with shape (2, 4), output shape (3, 2)
sums = jp.zeros((3, 2), dtype=jp.float32)
(output,) = jax.jit(jax.vmap(jax_rowsum, in_axes=(1, 0)))(matrices, sums)
print(output)
print()

# vmap with batch dim 2: input 4 matrices with shape (2, 3), output shape (4, 2)
sums = jp.zeros((4, 2), dtype=jp.float32)
(output,) = jax.jit(jax.vmap(jax_rowsum, in_axes=(2, 0)))(matrices, sums)
print(output)

Example with custom launch/output dimensions:

@wp.kernel
def lookup_kernel(table: wp.array(dtype=float), indices: wp.array(dtype=int), output: wp.array(dtype=float)):
    i = wp.tid()
    output[i] = table[indices[i]]

jax_lookup = jax_kernel(lookup_kernel)

# lookup table (not batched)
N = 100
table = jp.arange(N, dtype=jp.float32)

# batched indices to look up
key = jax.random.key(42)
indices = jax.random.randint(key, (20, 50), 0, N, dtype=jp.int32)

# vmap with batch dim 0: input 20 sets of 50 indices each, output shape (20, 50)
(output,) = jax.jit(jax.vmap(partial(jax_lookup, launch_dims=50), in_axes=(None, 0)))(
    table, indices
)
print(output)
print()

# vmap with batch dim 1: input 50 sets of 20 indices each, output shape (50, 20)
(output,) = jax.jit(jax.vmap(partial(jax_lookup, launch_dims=20), in_axes=(None, 1)))(
    table, indices
)
print(output)

Description

Before your PR is "Ready for review"

  • All commits are signed-off to indicate that your contribution adheres to the Developer Certificate of Origin requirements
  • Necessary tests have been added
  • Documentation is up-to-date
  • Auto-generated files modified by compiling Warp and building the documentation have been updated (e.g. __init__.pyi, docs/api_reference/, docs/language_reference/)
  • Code passes formatting and linting checks with pre-commit run -a

Summary by CodeRabbit

  • New Features

    • Added JAX vmap support for jax_kernel() and jax_callable() to enable batched computations with Warp kernels.
  • Bug Fixes

    • Fixed hashing errors in jax_kernel() and jax_callable() creation.
  • Documentation

    • Added comprehensive VMAP integration guide with examples covering batched operations, in-out arguments, and custom launch configurations.
  • Tests

    • Added new vmap interoperability test coverage.

✏️ Tip: You can customize this high-level summary in your review settings.

@copy-pr-bot
Copy link

copy-pr-bot bot commented Jan 30, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@coderabbitai
Copy link

coderabbitai bot commented Jan 30, 2026

📝 Walkthrough

Walkthrough

This pull request adds JAX vmap support to Warp's FFI functions (jax_kernel and jax_callable). It includes batch dimension handling utilities, hashable normalization of launch/output dimensions for caching, comprehensive documentation examples, and new test coverage for various vmap batching scenarios.

Changes

Cohort / File(s) Summary
Documentation Updates
CHANGELOG.md, docs/user_guide/interoperability.rst
Added changelog entry and three detailed documentation sections demonstrating JAX vmap usage with jax_kernel/jax_callable, including basic batching, in-out arguments, and custom launch/output dimensions.
Core FFI Implementation
warp/_src/jax_experimental/ffi.py
Introduced collapse_batch_dims() and compute_batch_size() utilities for batch dimension handling. Normalized launch_dims/output_dims to hashable tuples for stable caching. Enhanced both jax_kernel and jax_callable paths with batch-aware dimension inference, batch size computation, and proper handling of batched inputs/outputs in callbacks and kernel execution.
Test Coverage
warp/tests/interop/test_jax.py
Added add2d_kernel for 2D array operations and three comprehensive vmap tests (test_ffi_vmap_add, test_ffi_vmap_rowsum, test_ffi_vmap_lookup) covering various vmap configurations, in_axes/out_axes combinations, in_out argument handling, and custom launch dimensions.

Sequence Diagram

sequenceDiagram
    participant JAX as JAX vmap
    participant FFI as jax_kernel/jax_callable
    participant BatchUtil as Batch Utilities
    participant Callback as FFI Callback
    participant Kernel as Warp Kernel
    
    JAX->>FFI: Call with batched inputs<br/>(batch_ndim > 0)
    FFI->>BatchUtil: collapse_batch_dims(input_shape)
    BatchUtil-->>FFI: Collapsed shape
    FFI->>BatchUtil: compute_batch_size(shape, batch_ndim)
    BatchUtil-->>FFI: Batch size value
    FFI->>FFI: Normalize launch_dims/output_dims<br/>to hashable tuples
    FFI->>FFI: Check/populate cache<br/>with hashable keys
    alt Cache hit
        FFI-->>JAX: Return cached result
    else Cache miss
        FFI->>Callback: Create FFI callback with<br/>batch-aware buffer handling
        Callback->>Callback: Reshape inputs using<br/>collapse_batch_dims
        Callback->>Callback: Infer launch_dims if None
        Callback->>Kernel: Execute with adjusted<br/>launch dimensions
        Kernel-->>Callback: Kernel results
        Callback->>Callback: Reshape outputs using<br/>batch dimensions
        Callback-->>FFI: Processed outputs
        FFI->>FFI: Cache result
        FFI-->>JAX: Return output
    end
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

🚥 Pre-merge checks | ✅ 2 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 20.83% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly and specifically summarizes the main change: adding JAX FFI vmap support to jax_kernel() and jax_callable(), which aligns with the primary objective of the pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🤖 Fix all issues with AI agents
In `@CHANGELOG.md`:
- Line 49: Update the CHANGELOG entry "Fix hashing errors when creating
`jax_kernel()` and `jax_callable()`." to include the GitHub issue reference in
the required format — append something like
`([GH-XXX](https://github.com/NVIDIA/warp/issues/XXX))` to that line while
keeping the imperative present tense; target the exact entry mentioning
`jax_kernel()` and `jax_callable()` and replace XXX with the correct issue
number.

In `@docs/user_guide/interoperability.rst`:
- Around line 1290-1305: The snippet uses partial, jax, and jp but doesn't
import them; add the missing imports at the top (e.g., import jax, import
jax.numpy as jp, and from functools import partial) so symbols used by
jax_kernel/jax_lookup, jax.random, and jp.arange/jp.float32 resolve when the
example is copied standalone.

In `@warp/_src/jax_experimental/ffi.py`:
- Around line 1174-1179: When computing hashable_output_dims, ensure any
list-like shape values are converted to tuples before forming the hashable
structure: in the dict branch (where output_dims is a dict) build
hashable_output_dims as tuple(sorted((k, tuple(v) if hasattr(v, "__len__") and
not isinstance(v, (str, bytes, tuple)) else v) for k, v in
output_dims.items())), and in the iterable branch convert list-like output_dims
to tuple (e.g., if hasattr(output_dims, "__len__") and not
isinstance(output_dims, tuple): hashable_output_dims = tuple(output_dims)).
Update the code that sets hashable_output_dims (the existing output_dims ->
hashable_output_dims logic) accordingly.

- Fix JAX FFI multi-gpu graph caching ([GH-1181](https://github.com/NVIDIA/warp/pull/1181)).
- Fix tile * constant multiplication when one operand is a vector or matrix type ([GH-1175](https://github.com/NVIDIA/warp/issues/1175)).
- Fix kernel symbol resolution accepting invalid namespace paths like `wp.foo.bar.tid()` ([GH-1198](https://github.com/NVIDIA/warp/issues/1198)).
- Fix hashing errors when creating `jax_kernel()` and `jax_callable()`.
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Add a GH issue reference for the hashing fix entry.

This line is missing the required issue link in the specified format.

Suggested update
- Fix hashing errors when creating `jax_kernel()` and `jax_callable()`.
+ Fix hashing errors when creating `jax_kernel()` and `jax_callable()`
+   ([GH-XXX](https://github.com/NVIDIA/warp/issues/XXX)).

As per coding guidelines, use imperative present tense in CHANGELOG.md entries ('Add X', not 'Added X' or 'This adds X') and include issue references as ([GH-XXX](https://github.com/NVIDIA/warp/issues/XXX)).

🤖 Prompt for AI Agents
In `@CHANGELOG.md` at line 49, Update the CHANGELOG entry "Fix hashing errors when
creating `jax_kernel()` and `jax_callable()`." to include the GitHub issue
reference in the required format — append something like
`([GH-XXX](https://github.com/NVIDIA/warp/issues/XXX))` to that line while
keeping the imperative present tense; target the exact entry mentioning
`jax_kernel()` and `jax_callable()` and replace XXX with the correct issue
number.

Comment on lines +1290 to +1305
.. code-block:: python

jax_lookup = jax_kernel(lookup_kernel)

# lookup table (not batched)
N = 100
table = jp.arange(N, dtype=jp.float32)

# batched indices to look up
key = jax.random.key(42)
indices = jax.random.randint(key, (20, 50), 0, N, dtype=jp.int32)

# vmap with batch dim 0: input 20 sets of 50 indices each, output shape (20, 50)
(output,) = jax.jit(jax.vmap(partial(jax_lookup, launch_dims=50), in_axes=(None, 0)))(
table, indices
)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Add the missing imports in the custom launch VMAP example.

partial (and jax) are used but not imported in this snippet, which makes it fail when copied standalone.

Suggested update
 .. code-block:: python
 
+    from functools import partial
+    import jax
+
     jax_lookup = jax_kernel(lookup_kernel)
 
     # lookup table (not batched)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
.. code-block:: python
jax_lookup = jax_kernel(lookup_kernel)
# lookup table (not batched)
N = 100
table = jp.arange(N, dtype=jp.float32)
# batched indices to look up
key = jax.random.key(42)
indices = jax.random.randint(key, (20, 50), 0, N, dtype=jp.int32)
# vmap with batch dim 0: input 20 sets of 50 indices each, output shape (20, 50)
(output,) = jax.jit(jax.vmap(partial(jax_lookup, launch_dims=50), in_axes=(None, 0)))(
table, indices
)
.. code-block:: python
from functools import partial
import jax
jax_lookup = jax_kernel(lookup_kernel)
# lookup table (not batched)
N = 100
table = jp.arange(N, dtype=jp.float32)
# batched indices to look up
key = jax.random.key(42)
indices = jax.random.randint(key, (20, 50), 0, N, dtype=jp.int32)
# vmap with batch dim 0: input 20 sets of 50 indices each, output shape (20, 50)
(output,) = jax.jit(jax.vmap(partial(jax_lookup, launch_dims=50), in_axes=(None, 0)))(
table, indices
)
🤖 Prompt for AI Agents
In `@docs/user_guide/interoperability.rst` around lines 1290 - 1305, The snippet
uses partial, jax, and jp but doesn't import them; add the missing imports at
the top (e.g., import jax, import jax.numpy as jp, and from functools import
partial) so symbols used by jax_kernel/jax_lookup, jax.random, and
jp.arange/jp.float32 resolve when the example is copied standalone.

Comment on lines +1174 to +1179
if isinstance(output_dims, dict):
hashable_output_dims = tuple(sorted(output_dims.items()))
elif hasattr(output_dims, "__len__"):
hashable_output_dims = tuple(output_dims)
else:
hashable_output_dims = output_dims
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Normalize dict output_dims values before hashing.

Lists are valid shapes (per docs), but tuple(sorted(output_dims.items())) still embeds lists and remains unhashable, causing TypeError for valid calls.

Suggested update
-    if isinstance(output_dims, dict):
-        hashable_output_dims = tuple(sorted(output_dims.items()))
+    if isinstance(output_dims, dict):
+        hashable_output_dims = tuple(
+            sorted(
+                (k, tuple(v) if hasattr(v, "__len__") else v)
+                for k, v in output_dims.items()
+            )
+        )
     elif hasattr(output_dims, "__len__"):
         hashable_output_dims = tuple(output_dims)
     else:
         hashable_output_dims = output_dims
-    if isinstance(output_dims, dict):
-        hashable_output_dims = tuple(sorted(output_dims.items()))
+    if isinstance(output_dims, dict):
+        hashable_output_dims = tuple(
+            sorted(
+                (k, tuple(v) if hasattr(v, "__len__") else v)
+                for k, v in output_dims.items()
+            )
+        )
     elif hasattr(output_dims, "__len__"):
         hashable_output_dims = tuple(output_dims)
     else:
         hashable_output_dims = output_dims

Also applies to: 1496-1501

🤖 Prompt for AI Agents
In `@warp/_src/jax_experimental/ffi.py` around lines 1174 - 1179, When computing
hashable_output_dims, ensure any list-like shape values are converted to tuples
before forming the hashable structure: in the dict branch (where output_dims is
a dict) build hashable_output_dims as tuple(sorted((k, tuple(v) if hasattr(v,
"__len__") and not isinstance(v, (str, bytes, tuple)) else v) for k, v in
output_dims.items())), and in the iterable branch convert list-like output_dims
to tuple (e.g., if hasattr(output_dims, "__len__") and not
isinstance(output_dims, tuple): hashable_output_dims = tuple(output_dims)).
Update the code that sets hashable_output_dims (the existing output_dims ->
hashable_output_dims logic) accordingly.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant