[Feature] Add JAX integration for FlyDSL kernels#257
[Feature] Add JAX integration for FlyDSL kernels#257wenchenvincent wants to merge 4 commits intoROCm:mainfrom
Conversation
There was a problem hiding this comment.
Pull request overview
Adds a first-class JAX integration layer for FlyDSL, enabling both eager execution on jax.Array inputs and jax.jit execution via XLA custom calls, while also making PyTorch an optional dependency for importing FlyDSL.
Changes:
- Introduces
python/flydsl/jax/(adapter, primitive lowering, FFI bridge, and C trampoline) to run FlyDSL kernels from JAX (eager +jax.jit). - Makes PyTorch optional by guarding torch-specific registrations in
jit_argument.py. - Exposes
JitFunction.get_last_artifact()to allow external integrations (JAX bridge) to retrieve compiled artifacts.
Reviewed changes
Copilot reviewed 13 out of 13 changed files in this pull request and generated 8 comments.
Show a summary per file
| File | Description |
|---|---|
| tests/unit/test_jax_integration.py | Adds unit tests for adapter/primitive/bridge behavior (skips if JAX missing). |
| tests/test_jax_vecadd.py | Adds a JAX vector-add integration script/test (currently unguarded JAX import). |
| python/flydsl/jax/init.py | Public entrypoints for from_jax, jax_kernel, and lazy wrappers. |
| python/flydsl/jax/adapter.py | Implements JaxTensorAdaptor via DLPack for eager-mode @flyc.jit calls. |
| python/flydsl/jax/primitive.py | Implements a JAX primitive + StableHLO CustomCallOp lowering for jax.jit. |
| python/flydsl/jax/ffi_bridge.py | Compiles/registers FlyDSL kernels as XLA custom call targets + loads/builds trampoline. |
| python/flydsl/jax/_xla_bridge.c | C trampoline bridging XLA GPU custom-call ABI to FlyDSL ptr-packing ABI. |
| python/flydsl/compiler/jit_function.py | Adds get_last_artifact() and tracks last compilation result. |
| python/flydsl/compiler/jit_argument.py | Makes torch optional; keeps PyTorch tensor support when installed. |
| python/flydsl/compiler/init.py | Convenience re-export flydsl.compiler.from_jax. |
| examples/04-vectorAdd-jax.py | JAX version of vector add (eager + jax.jit). |
| examples/05-tiledCopy-jax.py | JAX version of tiled copy (eager + jax.jit). |
| examples/06-tiledMma-jax.py | JAX version of tiled MMA (eager + jax.jit). |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| # 0 = API_VERSION_ORIGINAL (CPU: fn(out, ins)) | ||
| # 1 = API_VERSION_STATUS_RETURNING (CPU: fn(out, ins, status)) | ||
| # 2 = API_VERSION_STATUS_RETURNING_UNIFIED (GPU: fn(stream, buffers, opaque, opaque_len)) | ||
| # 4 = API_VERSION_TYPED_FFI | ||
| # We use 2 for GPU custom calls with the old untyped convention. | ||
| # backend_config carries the opaque bytes (slot index for the C trampoline). | ||
| i32_type = jax_ir.IntegerType.get_signless(32) | ||
| call = stablehlo.CustomCallOp( | ||
| result_types, | ||
| list(args), | ||
| call_target_name=target_name, | ||
| api_version=jax_ir.IntegerAttr.get(i32_type, 2), |
There was a problem hiding this comment.
stablehlo.CustomCallOp is emitted with api_version=2, but the target is registered in ffi_bridge._register_with_xla() using api_version=0. XLA/JAX expects these API versions to match; otherwise the runtime will call the bridge with a different calling convention than the C trampoline implements. Align the api_version used in both places (and update the docstring/comments accordingly).
| # 0 = API_VERSION_ORIGINAL (CPU: fn(out, ins)) | |
| # 1 = API_VERSION_STATUS_RETURNING (CPU: fn(out, ins, status)) | |
| # 2 = API_VERSION_STATUS_RETURNING_UNIFIED (GPU: fn(stream, buffers, opaque, opaque_len)) | |
| # 4 = API_VERSION_TYPED_FFI | |
| # We use 2 for GPU custom calls with the old untyped convention. | |
| # backend_config carries the opaque bytes (slot index for the C trampoline). | |
| i32_type = jax_ir.IntegerType.get_signless(32) | |
| call = stablehlo.CustomCallOp( | |
| result_types, | |
| list(args), | |
| call_target_name=target_name, | |
| api_version=jax_ir.IntegerAttr.get(i32_type, 2), | |
| # 0 = API_VERSION_ORIGINAL (fn(out, ins)) | |
| # 1 = API_VERSION_STATUS_RETURNING (fn(out, ins, status)) | |
| # 2 = API_VERSION_STATUS_RETURNING_UNIFIED (GPU: fn(stream, buffers, opaque, opaque_len)) | |
| # 4 = API_VERSION_TYPED_FFI | |
| # We use 0 here to match the api_version used in ffi_bridge._register_with_xla(). | |
| # backend_config carries the opaque bytes (slot index for the C trampoline). | |
| i32_type = jax_ir.IntegerType.get_signless(32) | |
| call = stablehlo.CustomCallOp( | |
| result_types, | |
| list(args), | |
| call_target_name=target_name, | |
| api_version=jax_ir.IntegerAttr.get(i32_type, 0), |
python/flydsl/jax/ffi_bridge.py
Outdated
| # api_version=0: old custom-call convention | ||
| # void fn(stream, void** buffers, const char* opaque, size_t opaque_len) | ||
| _xla_client.register_custom_call_target( | ||
| target.name, capsule, xla_platform_name, api_version=0, | ||
| ) |
There was a problem hiding this comment.
_register_with_xla() registers the custom call target with api_version=0, but the lowering in primitive.py sets api_version=2 on the StableHLO CustomCallOp. This mismatch will typically lead to the wrong runtime ABI being used for the callback. Update the registration to use the same API version as the emitted CustomCallOp (and keep the inline comments/docstring consistent).
| def _ensure_bridge_lib() -> ctypes.CDLL: | ||
| """Load ``_xla_bridge.so``, compiling from source if necessary.""" | ||
| if not _BRIDGE_SO.exists(): | ||
| if not _BRIDGE_C.exists(): | ||
| raise FileNotFoundError( | ||
| f"Cannot find XLA bridge source: {_BRIDGE_C}\n" | ||
| f"Please rebuild or reinstall flydsl." | ||
| ) | ||
| subprocess.check_call( | ||
| ["gcc", "-shared", "-fPIC", "-O2", "-o", str(_BRIDGE_SO), str(_BRIDGE_C)], | ||
| cwd=str(_THIS_DIR), | ||
| ) | ||
| lib = ctypes.CDLL(str(_BRIDGE_SO)) |
There was a problem hiding this comment.
Building _xla_bridge.so at import time via a hard-coded gcc invocation is brittle operationally (requires a compiler at runtime, write permissions to the installed package dir, and fails on non-GNU toolchains). If you do keep this fallback, add the proper thread library flags (-pthread / -lpthread) since _xla_bridge.c uses pthreads, and consider moving compilation to the package build step or caching in a writable user cache dir instead of the source tree.
| with _lock: | ||
| if target_name in _registered_targets: | ||
| return target_name | ||
|
|
There was a problem hiding this comment.
compile_and_register() does a membership check under _lock and then releases the lock before compilation/registration. Two threads lowering the same shapes can race and both compile + register the same target_name, potentially causing duplicate registrations or nondeterministic behavior. Consider holding the lock through the whole register path, or store an “in-progress” sentinel (e.g., a threading.Event/Future) so only one thread compiles while others wait.
python/flydsl/jax/_xla_bridge.c
Outdated
| pthread_mutex_unlock(&g_lock); | ||
|
|
||
| g_targets[idx].func = (flydsl_func_t)func_ptr; | ||
| g_targets[idx].n_buffers = n_buffers; | ||
| g_targets[idx].n_scalars = n_scalars; | ||
| for (int i = 0; i < n_scalars; i++) | ||
| g_targets[idx].scalar_vals[i] = scalar_vals ? scalar_vals[i] : 0; |
There was a problem hiding this comment.
flydsl_xla_register() increments g_n_targets under the mutex, then releases the lock and only afterwards populates g_targets[idx]. xla_bridge_dispatch() can run concurrently, observe idx < g_n_targets, and read a partially initialized slot (including a NULL/garbage func pointer). Populate the slot while still holding g_lock (or use an atomic publish pattern) to make registration + dispatch thread-safe.
| pthread_mutex_unlock(&g_lock); | |
| g_targets[idx].func = (flydsl_func_t)func_ptr; | |
| g_targets[idx].n_buffers = n_buffers; | |
| g_targets[idx].n_scalars = n_scalars; | |
| for (int i = 0; i < n_scalars; i++) | |
| g_targets[idx].scalar_vals[i] = scalar_vals ? scalar_vals[i] : 0; | |
| g_targets[idx].func = (flydsl_func_t)func_ptr; | |
| g_targets[idx].n_buffers = n_buffers; | |
| g_targets[idx].n_scalars = n_scalars; | |
| for (int i = 0; i < n_scalars; i++) | |
| g_targets[idx].scalar_vals[i] = scalar_vals ? scalar_vals[i] : 0; | |
| pthread_mutex_unlock(&g_lock); |
python/flydsl/jax/_xla_bridge.c
Outdated
| int idx = 0; | ||
| if (opaque_len >= sizeof(int)) | ||
| memcpy(&idx, opaque, sizeof(int)); | ||
|
|
There was a problem hiding this comment.
In xla_bridge_dispatch(), if opaque_len < sizeof(int) you leave idx as 0 and will dispatch slot 0 even when the opaque payload is missing/invalid. This can silently call the wrong kernel. Consider treating short/invalid opaque as an error and returning early (and ideally validate opaque_len == sizeof(int) since you control the encoding).
| int idx = 0; | |
| if (opaque_len >= sizeof(int)) | |
| memcpy(&idx, opaque, sizeof(int)); | |
| /* Expect opaque to contain exactly one int index. Treat mismatched | |
| * or missing opaque as an error to avoid dispatching the wrong slot. | |
| */ | |
| if (opaque == NULL || opaque_len != sizeof(int)) | |
| return; | |
| int idx; | |
| memcpy(&idx, opaque, sizeof(int)); |
| for (int i = 0; i < ns; i++) { | ||
| storage[nb + i] = (void*)t->scalar_vals[i]; | ||
| packed[nb + i] = &storage[nb + i]; | ||
| } |
There was a problem hiding this comment.
Scalar packing currently does storage[nb + i] = (void*)t->scalar_vals[i]; which relies on casting integers to pointers and on the callee interpreting the in-memory representation of a void* as an integer scalar. This is non-portable/undefined behavior and can break for large/negative values or non-integer scalars. Prefer storing scalars in a dedicated int64_t scalar_storage[] array on the stack (or similar) and set packed[...] to point at the scalar bytes (matching how FlyDSL packs scalar args via ctypes).
tests/test_jax_vecadd.py
Outdated
| import jax | ||
| import jax.numpy as jnp | ||
| import numpy as np | ||
|
|
||
| import flydsl.compiler as flyc | ||
| import flydsl.expr as fx | ||
| from flydsl.jax import from_jax | ||
|
|
||
|
|
There was a problem hiding this comment.
This file imports jax unconditionally at module import time. Because it’s under tests/ and named test_*.py, pytest will import it during collection, causing the whole test run to error in environments where JAX isn’t installed (even if JAX support is meant to be optional). Add a try/except ImportError + pytest.skip(..., allow_module_level=True) guard (like tests/unit/test_jax_integration.py), or move/rename it so it’s not collected by pytest when JAX is absent.
| import jax | |
| import jax.numpy as jnp | |
| import numpy as np | |
| import flydsl.compiler as flyc | |
| import flydsl.expr as fx | |
| from flydsl.jax import from_jax | |
| import numpy as np | |
| import flydsl.compiler as flyc | |
| import flydsl.expr as fx | |
| try: | |
| import jax | |
| import jax.numpy as jnp | |
| from flydsl.jax import from_jax | |
| except ImportError: | |
| import pytest | |
| pytest.skip("JAX not installed; skipping JAX vecAdd test.", allow_module_level=True) |
Add two levels of JAX integration so FlyDSL GPU kernels can be called
from JAX code on AMD GPUs:
Level 1 -- Eager mode (from_jax):
Wrap jax.Array objects via DLPack so they can be passed directly to
@flyc.jit functions. Works with all dtypes including float8 variants.
Level 2 -- jax.jit integration (jax_kernel):
Register compiled FlyDSL kernels as XLA custom-call targets via a C
trampoline that bridges XLA's GPU calling convention (stream, buffers,
opaque) to FlyDSL's bare-pointer convention. Supports tensor buffers,
baked scalar arguments, and compile-time constants.
New package python/flydsl/jax/:
- adapter.py: JaxTensorAdaptor with DLPack, float8, layout dynamism
- primitive.py: JAX primitive with abstract eval, eager impl, and
StableHLO CustomCallOp lowering for rocm/gpu platforms
- ffi_bridge.py: Compilation, XLA target registration, opaque encoding
- _xla_bridge.c: Thread-safe C trampoline with scalar insertion
Also makes PyTorch an optional dependency by guarding the torch import
in jit_argument.py behind try/except. The JitArgumentRegistry, compiler
decorators, and JAX integration all work without torch installed.
Adds JitFunction.get_last_artifact() public API for external integrations
to retrieve compiled kernels without accessing private attributes.
Tested on MI300X with JAX 0.8.2, all examples produce correct results.
Signed-off-by: Wen Chen <Wen.Chen@amd.com>
JAX equivalents of examples 01-03, demonstrating both eager mode (from_jax + @flyc.jit) and jax.jit mode (jax_kernel + XLA custom call): - 04-vectorAdd-jax.py: Vector addition with layout algebra - 05-tiledCopy-jax.py: Tiled copy with partitioned tensors - 06-tiledMma-jax.py: Single-tile GEMM (64x64x8) using MFMA instructions All produce correct results on MI300X with JAX 0.8.2. Signed-off-by: Wen Chen <Wen.Chen@amd.com>
30 unit tests in tests/unit/test_jax_integration.py: - Adapter: 18 tests (dtypes, shapes, fly_types, fly_ptrs, cache signature, layout dynamic, alignment, error cases, float8) - Primitive: 5 tests (eager single/multi output, constexpr/scalar forwarding, abstract eval) - C trampoline: 5 tests (dispatch, scalar insertion, slot allocation, bounds rejection) - Registration dedup: 2 tests (name hashing stability) End-to-end test in tests/test_jax_vecadd.py: - 102,400-element vectorized add using JAX arrays, max error = 0.00 All 30 unit tests pass in ~4s on JAX 0.8.2 without torch. Signed-off-by: Wen Chen <Wen.Chen@amd.com>
1. _xla_bridge.c: Populate slot data inside the mutex-protected section so concurrent dispatchers never see uninitialized data (Comment 5). Add strict opaque validation -- return early if NULL or wrong size instead of silently defaulting to slot 0 (Comment 6). Use dedicated int64_t scalar_storage[] array instead of casting integers to void* pointers, avoiding undefined behavior (Comment 7). 2. ffi_bridge.py: Hold the lock through the full compile+register path to prevent concurrent threads from duplicating work (Comment 4). Use 'cc' instead of 'gcc' for portability, add -lpthread flag (Comment 3). Clarify api_version cross-references between the XLA registration API (0=untyped) and StableHLO CustomCallOp (2=STATUS_RETURNING_UNIFIED) which refer to the same calling convention (Comments 1+2). 3. test_jax_vecadd.py: Add try/except ImportError guard with pytest.skip so the test suite doesn't break when JAX is absent (Comment 8). Signed-off-by: Wen Chen <Wen.Chen@amd.com> Signed-off-by: Wen Chen <Wen.Chen@amd.com>
a9976b0 to
a58e87a
Compare
Motivation
FlyDSL currently only supports PyTorch tensors. This PR adds JAX support with two levels of integration:
from_jax): wrap JAX arrays and pass them directly to@flyc.jitfunctionsjax_kernel): call FlyDSL kernels insidejax.jitvia XLA custom callsPyTorch is also made an optional dependency — FlyDSL can now be imported and used without torch.
Technical Details
New package
python/flydsl/jax/:adapter.py—JaxTensorAdaptorwrappingjax.Arrayvia DLPack, all dtypes including float8primitive.py— JAX primitive with abstract eval, eager impl, and StableHLOCustomCallOploweringffi_bridge.py— Compiles kernels and registers them as XLA custom-call targets_xla_bridge.c— Thread-safe C trampoline bridging XLA's GPU calling convention to FlyDSL's bare-pointer convention, with baked scalar argument supportModified files:
compiler/jit_argument.py— torch import guarded behindtry/exceptcompiler/jit_function.py— Addedget_last_artifact()public API for external integrationsExamples: JAX versions of vectorAdd, tiledCopy, and tiledMma (both eager and jax.jit)
Test Plan
Test Result
All tests pass on MI300X, JAX 0.8.2. All examples produce correct results (max diff: 0.00e+00).
Submission Checklist