Bug Description
I'm trying to use Warp's JAX FFI feature, I see that the jax_callback feature supports vmap. However, when I try to use it, it complains about dimensions. Below is a minimal example:
import jax
import jax.numpy as jnp
import warp as wp
from warp.jax_experimental.ffi import jax_callable
@wp.kernel
def scale_kernel(a: wp.array(dtype=float), s: float, output: wp.array(dtype=float)):
tid = wp.tid()
output[tid] = a[tid] * s
def in_out_func(
a: wp.array(dtype=float), # input only
c: wp.array(dtype=float), # output only
):
wp.launch(scale_kernel, dim=a.size, inputs=[a, 2.0], outputs=[c])
jax_func = jax_callable(in_out_func, vmap_method='broadcast_all', num_outputs=1)
f = jax.jit(jax_func)
a = jnp.ones(100, dtype=jnp.float32).reshape((10, 10))
c = jax.vmap(f, in_axes=0, out_axes=0)(a)
c = f(a)
print(c)
The error I'm getting is:
RuntimeError: Error launching kernel 'scale_kernel', argument 'a' expects an array with 1 dimension(s) but the passed array has 2 dimension(s).
E0721 21:11:26.423976 3931630 pjrt_stream_executor_client.cc:3077] Execution of replica 0 failed: UNKNOWN: FFI callback error: RuntimeError: Error launching kernel 'scale_kernel', argument 'a' expects an array with 1 dimension(s) but the passed array has 2 dimension(s).
System Information
I'm using warp 1.7.0
Bug Description
I'm trying to use Warp's JAX FFI feature, I see that the
jax_callbackfeature supports vmap. However, when I try to use it, it complains about dimensions. Below is a minimal example:The error I'm getting is:
System Information
I'm using warp 1.7.0