Skip to content

[Feature] Add JAX integration for FlyDSL kernels#257

Open
wenchenvincent wants to merge 4 commits intoROCm:mainfrom
wenchenvincent:feat/jax-integration
Open

[Feature] Add JAX integration for FlyDSL kernels#257
wenchenvincent wants to merge 4 commits intoROCm:mainfrom
wenchenvincent:feat/jax-integration

Conversation

@wenchenvincent
Copy link

Motivation

FlyDSL currently only supports PyTorch tensors. This PR adds JAX support with two levels of integration:

  • Eager mode (from_jax): wrap JAX arrays and pass them directly to @flyc.jit functions
  • jax.jit mode (jax_kernel): call FlyDSL kernels inside jax.jit via XLA custom calls

PyTorch is also made an optional dependency — FlyDSL can now be imported and used without torch.

Technical Details

New package python/flydsl/jax/:

  • adapter.pyJaxTensorAdaptor wrapping jax.Array via DLPack, all dtypes including float8
  • primitive.py — JAX primitive with abstract eval, eager impl, and StableHLO CustomCallOp lowering
  • ffi_bridge.py — Compiles kernels and registers them as XLA custom-call targets
  • _xla_bridge.c — Thread-safe C trampoline bridging XLA's GPU calling convention to FlyDSL's bare-pointer convention, with baked scalar argument support

Modified files:

  • compiler/jit_argument.py — torch import guarded behind try/except
  • compiler/jit_function.py — Added get_last_artifact() public API for external integrations

Examples: JAX versions of vectorAdd, tiledCopy, and tiledMma (both eager and jax.jit)

Test Plan

  • 30 unit tests covering adapter, primitive, C trampoline, and registration dedup
  • End-to-end integration test (102K-element vectorized add)
  • All 3 examples verified with both eager and jax.jit paths
  • Existing unit tests unaffected (34 passed, 2 pre-existing skips)

Test Result

All tests pass on MI300X, JAX 0.8.2. All examples produce correct results (max diff: 0.00e+00).

Submission Checklist

Copilot AI review requested due to automatic review settings March 21, 2026 08:34
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds a first-class JAX integration layer for FlyDSL, enabling both eager execution on jax.Array inputs and jax.jit execution via XLA custom calls, while also making PyTorch an optional dependency for importing FlyDSL.

Changes:

  • Introduces python/flydsl/jax/ (adapter, primitive lowering, FFI bridge, and C trampoline) to run FlyDSL kernels from JAX (eager + jax.jit).
  • Makes PyTorch optional by guarding torch-specific registrations in jit_argument.py.
  • Exposes JitFunction.get_last_artifact() to allow external integrations (JAX bridge) to retrieve compiled artifacts.

Reviewed changes

Copilot reviewed 13 out of 13 changed files in this pull request and generated 8 comments.

Show a summary per file
File Description
tests/unit/test_jax_integration.py Adds unit tests for adapter/primitive/bridge behavior (skips if JAX missing).
tests/test_jax_vecadd.py Adds a JAX vector-add integration script/test (currently unguarded JAX import).
python/flydsl/jax/init.py Public entrypoints for from_jax, jax_kernel, and lazy wrappers.
python/flydsl/jax/adapter.py Implements JaxTensorAdaptor via DLPack for eager-mode @flyc.jit calls.
python/flydsl/jax/primitive.py Implements a JAX primitive + StableHLO CustomCallOp lowering for jax.jit.
python/flydsl/jax/ffi_bridge.py Compiles/registers FlyDSL kernels as XLA custom call targets + loads/builds trampoline.
python/flydsl/jax/_xla_bridge.c C trampoline bridging XLA GPU custom-call ABI to FlyDSL ptr-packing ABI.
python/flydsl/compiler/jit_function.py Adds get_last_artifact() and tracks last compilation result.
python/flydsl/compiler/jit_argument.py Makes torch optional; keeps PyTorch tensor support when installed.
python/flydsl/compiler/init.py Convenience re-export flydsl.compiler.from_jax.
examples/04-vectorAdd-jax.py JAX version of vector add (eager + jax.jit).
examples/05-tiledCopy-jax.py JAX version of tiled copy (eager + jax.jit).
examples/06-tiledMma-jax.py JAX version of tiled MMA (eager + jax.jit).

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +203 to +214
# 0 = API_VERSION_ORIGINAL (CPU: fn(out, ins))
# 1 = API_VERSION_STATUS_RETURNING (CPU: fn(out, ins, status))
# 2 = API_VERSION_STATUS_RETURNING_UNIFIED (GPU: fn(stream, buffers, opaque, opaque_len))
# 4 = API_VERSION_TYPED_FFI
# We use 2 for GPU custom calls with the old untyped convention.
# backend_config carries the opaque bytes (slot index for the C trampoline).
i32_type = jax_ir.IntegerType.get_signless(32)
call = stablehlo.CustomCallOp(
result_types,
list(args),
call_target_name=target_name,
api_version=jax_ir.IntegerAttr.get(i32_type, 2),
Copy link

Copilot AI Mar 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

stablehlo.CustomCallOp is emitted with api_version=2, but the target is registered in ffi_bridge._register_with_xla() using api_version=0. XLA/JAX expects these API versions to match; otherwise the runtime will call the bridge with a different calling convention than the C trampoline implements. Align the api_version used in both places (and update the docstring/comments accordingly).

Suggested change
# 0 = API_VERSION_ORIGINAL (CPU: fn(out, ins))
# 1 = API_VERSION_STATUS_RETURNING (CPU: fn(out, ins, status))
# 2 = API_VERSION_STATUS_RETURNING_UNIFIED (GPU: fn(stream, buffers, opaque, opaque_len))
# 4 = API_VERSION_TYPED_FFI
# We use 2 for GPU custom calls with the old untyped convention.
# backend_config carries the opaque bytes (slot index for the C trampoline).
i32_type = jax_ir.IntegerType.get_signless(32)
call = stablehlo.CustomCallOp(
result_types,
list(args),
call_target_name=target_name,
api_version=jax_ir.IntegerAttr.get(i32_type, 2),
# 0 = API_VERSION_ORIGINAL (fn(out, ins))
# 1 = API_VERSION_STATUS_RETURNING (fn(out, ins, status))
# 2 = API_VERSION_STATUS_RETURNING_UNIFIED (GPU: fn(stream, buffers, opaque, opaque_len))
# 4 = API_VERSION_TYPED_FFI
# We use 0 here to match the api_version used in ffi_bridge._register_with_xla().
# backend_config carries the opaque bytes (slot index for the C trampoline).
i32_type = jax_ir.IntegerType.get_signless(32)
call = stablehlo.CustomCallOp(
result_types,
list(args),
call_target_name=target_name,
api_version=jax_ir.IntegerAttr.get(i32_type, 0),

Copilot uses AI. Check for mistakes.
Comment on lines +262 to +266
# api_version=0: old custom-call convention
# void fn(stream, void** buffers, const char* opaque, size_t opaque_len)
_xla_client.register_custom_call_target(
target.name, capsule, xla_platform_name, api_version=0,
)
Copy link

Copilot AI Mar 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_register_with_xla() registers the custom call target with api_version=0, but the lowering in primitive.py sets api_version=2 on the StableHLO CustomCallOp. This mismatch will typically lead to the wrong runtime ABI being used for the callback. Update the registration to use the same API version as the emitted CustomCallOp (and keep the inline comments/docstring consistent).

Copilot uses AI. Check for mistakes.
Comment on lines +59 to +71
def _ensure_bridge_lib() -> ctypes.CDLL:
"""Load ``_xla_bridge.so``, compiling from source if necessary."""
if not _BRIDGE_SO.exists():
if not _BRIDGE_C.exists():
raise FileNotFoundError(
f"Cannot find XLA bridge source: {_BRIDGE_C}\n"
f"Please rebuild or reinstall flydsl."
)
subprocess.check_call(
["gcc", "-shared", "-fPIC", "-O2", "-o", str(_BRIDGE_SO), str(_BRIDGE_C)],
cwd=str(_THIS_DIR),
)
lib = ctypes.CDLL(str(_BRIDGE_SO))
Copy link

Copilot AI Mar 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Building _xla_bridge.so at import time via a hard-coded gcc invocation is brittle operationally (requires a compiler at runtime, write permissions to the installed package dir, and fails on non-GNU toolchains). If you do keep this fallback, add the proper thread library flags (-pthread / -lpthread) since _xla_bridge.c uses pthreads, and consider moving compilation to the package build step or caching in a writable user cache dir instead of the source tree.

Copilot uses AI. Check for mistakes.
Comment on lines +172 to +175
with _lock:
if target_name in _registered_targets:
return target_name

Copy link

Copilot AI Mar 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

compile_and_register() does a membership check under _lock and then releases the lock before compilation/registration. Two threads lowering the same shapes can race and both compile + register the same target_name, potentially causing duplicate registrations or nondeterministic behavior. Consider holding the lock through the whole register path, or store an “in-progress” sentinel (e.g., a threading.Event/Future) so only one thread compiles while others wait.

Copilot uses AI. Check for mistakes.
Comment on lines +51 to +57
pthread_mutex_unlock(&g_lock);

g_targets[idx].func = (flydsl_func_t)func_ptr;
g_targets[idx].n_buffers = n_buffers;
g_targets[idx].n_scalars = n_scalars;
for (int i = 0; i < n_scalars; i++)
g_targets[idx].scalar_vals[i] = scalar_vals ? scalar_vals[i] : 0;
Copy link

Copilot AI Mar 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

flydsl_xla_register() increments g_n_targets under the mutex, then releases the lock and only afterwards populates g_targets[idx]. xla_bridge_dispatch() can run concurrently, observe idx < g_n_targets, and read a partially initialized slot (including a NULL/garbage func pointer). Populate the slot while still holding g_lock (or use an atomic publish pattern) to make registration + dispatch thread-safe.

Suggested change
pthread_mutex_unlock(&g_lock);
g_targets[idx].func = (flydsl_func_t)func_ptr;
g_targets[idx].n_buffers = n_buffers;
g_targets[idx].n_scalars = n_scalars;
for (int i = 0; i < n_scalars; i++)
g_targets[idx].scalar_vals[i] = scalar_vals ? scalar_vals[i] : 0;
g_targets[idx].func = (flydsl_func_t)func_ptr;
g_targets[idx].n_buffers = n_buffers;
g_targets[idx].n_scalars = n_scalars;
for (int i = 0; i < n_scalars; i++)
g_targets[idx].scalar_vals[i] = scalar_vals ? scalar_vals[i] : 0;
pthread_mutex_unlock(&g_lock);

Copilot uses AI. Check for mistakes.
Comment on lines +63 to +66
int idx = 0;
if (opaque_len >= sizeof(int))
memcpy(&idx, opaque, sizeof(int));

Copy link

Copilot AI Mar 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In xla_bridge_dispatch(), if opaque_len < sizeof(int) you leave idx as 0 and will dispatch slot 0 even when the opaque payload is missing/invalid. This can silently call the wrong kernel. Consider treating short/invalid opaque as an error and returning early (and ideally validate opaque_len == sizeof(int) since you control the encoding).

Suggested change
int idx = 0;
if (opaque_len >= sizeof(int))
memcpy(&idx, opaque, sizeof(int));
/* Expect opaque to contain exactly one int index. Treat mismatched
* or missing opaque as an error to avoid dispatching the wrong slot.
*/
if (opaque == NULL || opaque_len != sizeof(int))
return;
int idx;
memcpy(&idx, opaque, sizeof(int));

Copilot uses AI. Check for mistakes.
Comment on lines +84 to +87
for (int i = 0; i < ns; i++) {
storage[nb + i] = (void*)t->scalar_vals[i];
packed[nb + i] = &storage[nb + i];
}
Copy link

Copilot AI Mar 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Scalar packing currently does storage[nb + i] = (void*)t->scalar_vals[i]; which relies on casting integers to pointers and on the callee interpreting the in-memory representation of a void* as an integer scalar. This is non-portable/undefined behavior and can break for large/negative values or non-integer scalars. Prefer storing scalars in a dedicated int64_t scalar_storage[] array on the stack (or similar) and set packed[...] to point at the scalar bytes (matching how FlyDSL packs scalar args via ctypes).

Copilot uses AI. Check for mistakes.
Comment on lines +12 to +20
import jax
import jax.numpy as jnp
import numpy as np

import flydsl.compiler as flyc
import flydsl.expr as fx
from flydsl.jax import from_jax


Copy link

Copilot AI Mar 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file imports jax unconditionally at module import time. Because it’s under tests/ and named test_*.py, pytest will import it during collection, causing the whole test run to error in environments where JAX isn’t installed (even if JAX support is meant to be optional). Add a try/except ImportError + pytest.skip(..., allow_module_level=True) guard (like tests/unit/test_jax_integration.py), or move/rename it so it’s not collected by pytest when JAX is absent.

Suggested change
import jax
import jax.numpy as jnp
import numpy as np
import flydsl.compiler as flyc
import flydsl.expr as fx
from flydsl.jax import from_jax
import numpy as np
import flydsl.compiler as flyc
import flydsl.expr as fx
try:
import jax
import jax.numpy as jnp
from flydsl.jax import from_jax
except ImportError:
import pytest
pytest.skip("JAX not installed; skipping JAX vecAdd test.", allow_module_level=True)

Copilot uses AI. Check for mistakes.
Add two levels of JAX integration so FlyDSL GPU kernels can be called
from JAX code on AMD GPUs:

Level 1 -- Eager mode (from_jax):
  Wrap jax.Array objects via DLPack so they can be passed directly to
  @flyc.jit functions. Works with all dtypes including float8 variants.

Level 2 -- jax.jit integration (jax_kernel):
  Register compiled FlyDSL kernels as XLA custom-call targets via a C
  trampoline that bridges XLA's GPU calling convention (stream, buffers,
  opaque) to FlyDSL's bare-pointer convention. Supports tensor buffers,
  baked scalar arguments, and compile-time constants.

New package python/flydsl/jax/:
  - adapter.py: JaxTensorAdaptor with DLPack, float8, layout dynamism
  - primitive.py: JAX primitive with abstract eval, eager impl, and
    StableHLO CustomCallOp lowering for rocm/gpu platforms
  - ffi_bridge.py: Compilation, XLA target registration, opaque encoding
  - _xla_bridge.c: Thread-safe C trampoline with scalar insertion

Also makes PyTorch an optional dependency by guarding the torch import
in jit_argument.py behind try/except. The JitArgumentRegistry, compiler
decorators, and JAX integration all work without torch installed.

Adds JitFunction.get_last_artifact() public API for external integrations
to retrieve compiled kernels without accessing private attributes.

Tested on MI300X with JAX 0.8.2, all examples produce correct results.

Signed-off-by: Wen Chen <Wen.Chen@amd.com>
JAX equivalents of examples 01-03, demonstrating both eager mode
(from_jax + @flyc.jit) and jax.jit mode (jax_kernel + XLA custom call):

- 04-vectorAdd-jax.py: Vector addition with layout algebra
- 05-tiledCopy-jax.py: Tiled copy with partitioned tensors
- 06-tiledMma-jax.py: Single-tile GEMM (64x64x8) using MFMA instructions

All produce correct results on MI300X with JAX 0.8.2.

Signed-off-by: Wen Chen <Wen.Chen@amd.com>
30 unit tests in tests/unit/test_jax_integration.py:
- Adapter: 18 tests (dtypes, shapes, fly_types, fly_ptrs, cache
  signature, layout dynamic, alignment, error cases, float8)
- Primitive: 5 tests (eager single/multi output, constexpr/scalar
  forwarding, abstract eval)
- C trampoline: 5 tests (dispatch, scalar insertion, slot allocation,
  bounds rejection)
- Registration dedup: 2 tests (name hashing stability)

End-to-end test in tests/test_jax_vecadd.py:
- 102,400-element vectorized add using JAX arrays, max error = 0.00

All 30 unit tests pass in ~4s on JAX 0.8.2 without torch.

Signed-off-by: Wen Chen <Wen.Chen@amd.com>
1. _xla_bridge.c: Populate slot data inside the mutex-protected section
   so concurrent dispatchers never see uninitialized data (Comment 5).
   Add strict opaque validation -- return early if NULL or wrong size
   instead of silently defaulting to slot 0 (Comment 6). Use dedicated
   int64_t scalar_storage[] array instead of casting integers to void*
   pointers, avoiding undefined behavior (Comment 7).

2. ffi_bridge.py: Hold the lock through the full compile+register path
   to prevent concurrent threads from duplicating work (Comment 4).
   Use 'cc' instead of 'gcc' for portability, add -lpthread flag
   (Comment 3). Clarify api_version cross-references between the XLA
   registration API (0=untyped) and StableHLO CustomCallOp
   (2=STATUS_RETURNING_UNIFIED) which refer to the same calling
   convention (Comments 1+2).

3. test_jax_vecadd.py: Add try/except ImportError guard with
   pytest.skip so the test suite doesn't break when JAX is absent
   (Comment 8).

Signed-off-by: Wen Chen <Wen.Chen@amd.com>
Signed-off-by: Wen Chen <Wen.Chen@amd.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants