Skip to content

[BUG] Error when using jax_callable with vmap #859

@HaoliangWang

Description

@HaoliangWang

Bug Description

I'm trying to use Warp's JAX FFI feature, I see that the jax_callback feature supports vmap. However, when I try to use it, it complains about dimensions. Below is a minimal example:

import jax
import jax.numpy as jnp

import warp as wp
from warp.jax_experimental.ffi import jax_callable


@wp.kernel
def scale_kernel(a: wp.array(dtype=float), s: float, output: wp.array(dtype=float)):
    tid = wp.tid()
    output[tid] = a[tid] * s


def in_out_func(
    a: wp.array(dtype=float),  # input only
    c: wp.array(dtype=float),  # output only
):
    wp.launch(scale_kernel, dim=a.size, inputs=[a, 2.0], outputs=[c])


jax_func = jax_callable(in_out_func, vmap_method='broadcast_all', num_outputs=1)

f = jax.jit(jax_func)

a = jnp.ones(100, dtype=jnp.float32).reshape((10, 10))

c = jax.vmap(f, in_axes=0, out_axes=0)(a)
c = f(a)
print(c)

The error I'm getting is:

RuntimeError: Error launching kernel 'scale_kernel', argument 'a' expects an array with 1 dimension(s) but the passed array has 2 dimension(s).

E0721 21:11:26.423976 3931630 pjrt_stream_executor_client.cc:3077] Execution of replica 0 failed: UNKNOWN: FFI callback error: RuntimeError: Error launching kernel 'scale_kernel', argument 'a' expects an array with 1 dimension(s) but the passed array has 2 dimension(s).

System Information

I'm using warp 1.7.0

Metadata

Metadata

Assignees

Labels

bugSomething isn't workinginteropInteroperability of Warp with other libraries

Type

Projects

No projects

Milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions