Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 73 additions & 0 deletions docs/src/tutorials/kernels.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Kernels
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a random thought, but maybe the tutorial can be called something like "Compute Kernels" or "GPU Kernels" ? IMHO, kernel has a lot of different meanings in computing.


Suppose your codebase contains custom GPU kernels, typically those defined with [KernelAbstractions.jl](https://github.com/JuliaGPU/KernelAbstractions.jl).

## Example

```@example kernels
using KernelAbstractions

@kernel function square_kernel!(y, @Const(x))
i = @index(Global)
@inbounds y[i] = x[i] * x[i]
end

function square(x)
y = similar(x)
backend = KernelAbstractions.get_backend(x)
kernel! = square_kernel!(backend)
kernel!(y, x; ndrange=length(x))
return y
end
```

```jldoctest kernels
x = float.(1:5)
y = square(x)

# output

5-element Vector{Float64}:
1.0
4.0
9.0
16.0
25.0
```

## Kernel compilation

To compile such kernels with Reactant, you need to pass the option `raise=true` to the `@compile` or `@jit` macro.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you don't need to perform raise=true to compile with reactant, they'll just run natively as the existing kernel. I you want to convert it to a tensor form to enable linear algebra optimizations, raise needs to be set to true

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm on a Mac M3 and when I try kernel compilation without raising I get

julia> y = @jit square(x)
ERROR: CUDA driver not found
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:35
  [2] functional
    @ ~/.julia/packages/CUDA/x8d2s/src/initialization.jl:24 [inlined]
  [3] task_local_state!()
    @ CUDA ~/.julia/packages/CUDA/x8d2s/lib/cudadrv/state.jl:77
  [4] active_state
    @ ~/.julia/packages/CUDA/x8d2s/lib/cudadrv/state.jl:110 [inlined]
  [5] cufunction(f::typeof(gpu_square_kernel!), tt::Type{Tuple{…}}; kwargs::@Kwargs{})
    @ CUDA ~/.julia/packages/CUDA/x8d2s/src/compiler/execution.jl:366
  [6] cufunction(f::typeof(gpu_square_kernel!), tt::Type{Tuple{…}})
    @ CUDA ~/.julia/packages/CUDA/x8d2s/src/compiler/execution.jl:365
  [7] launch_configuration(f::ReactantCUDAExt.LLVMFunc{…}; shmem::Int64, max_threads::Int64)
    @ ReactantCUDAExt ~/.julia/packages/Reactant/zlIsO/ext/ReactantCUDAExt.jl:619
  [8] ka_with_reactant
    @ ~/.julia/packages/Reactant/zlIsO/ext/ReactantCUDAExt.jl:525 [inlined]
  [9] (::Nothing)(none::typeof(Reactant.ka_with_reactant), none::Int64, none::Nothing, none::KernelAbstractions.Kernel{…}, none::Tuple{…})
    @ Reactant ./<missing>:0
 [10] launch_config
    @ ~/.julia/packages/Reactant/zlIsO/ext/ReactantKernelAbstractionsExt.jl:68 [inlined]
 [11] ka_with_reactant
    @ ~/.julia/packages/Reactant/zlIsO/ext/ReactantCUDAExt.jl:504 [inlined]
 [12] call_with_reactant(::typeof(Reactant.ka_with_reactant), ::Int64, ::Nothing, ::KernelAbstractions.Kernel{…}, ::Reactant.TracedRArray{…}, ::Reactant.TracedRArray{…})
    @ Reactant ~/.julia/packages/Reactant/zlIsO/src/utils.jl:0
 [13] (::KernelAbstractions.Kernel{…})(::Any, ::Vararg{…}; ndrange::Int64, workgroupsize::Nothing)
    @ ReactantKernelAbstractionsExt ~/.julia/packages/Reactant/zlIsO/ext/ReactantKernelAbstractionsExt.jl:118
 [14] #kwcall
    @ ~/.julia/packages/Reactant/zlIsO/ext/ReactantKernelAbstractionsExt.jl:113 [inlined]
 [15] (::Nothing)(none::typeof(Core.kwcall), none::@NamedTuple{}, none::KernelAbstractions.Kernel{…}, none::Tuple{…})
    @ Reactant ./<missing>:0
 [16] call_with_reactant(::typeof(Core.kwcall), ::@NamedTuple{}, ::KernelAbstractions.Kernel{…}, ::Reactant.TracedRArray{…}, ::Reactant.TracedRArray{…})
    @ Reactant ~/.julia/packages/Reactant/zlIsO/src/utils.jl:523
 [17] square
    @ ~/Documents/GitHub/Julia/Reactant.jl/test/playground.jl:15 [inlined]
 [18] (::Nothing)(none::typeof(square), none::Reactant.TracedRArray{Float64, 1})
    @ Reactant ./<missing>:0
 [19] getproperty
    @ ./Base.jl:49 [inlined]
 [20] size
    @ ~/.julia/packages/Reactant/zlIsO/src/TracedRArray.jl:259 [inlined]
 [21] axes
    @ ./abstractarray.jl:98 [inlined]
 [22] similar
    @ ./abstractarray.jl:821 [inlined]
 [23] similar
    @ ./abstractarray.jl:820 [inlined]
 [24] square
    @ ~/Documents/GitHub/Julia/Reactant.jl/test/playground.jl:12 [inlined]
 [25] call_with_reactant(::typeof(square), ::Reactant.TracedRArray{Float64, 1})
    @ Reactant ~/.julia/packages/Reactant/zlIsO/src/utils.jl:0
 [26] make_mlir_fn(f::typeof(square), args::Tuple{…}, kwargs::@NamedTuple{}, name::String, concretein::Bool; toscalar::Bool, return_dialect::Symbol, args_in_result::Symbol, construct_function_without_args::Bool, do_transpose::Bool, input_shardings::Nothing, output_shardings::Nothing, runtime::Val{…}, verify_arg_names::Nothing, argprefix::Symbol, resprefix::Symbol, resargprefix::Symbol, num_replicas::Int64, optimize_then_pad::Bool)
    @ Reactant.TracedUtils ~/.julia/packages/Reactant/zlIsO/src/TracedUtils.jl:345
 [27] make_mlir_fn
    @ ~/.julia/packages/Reactant/zlIsO/src/TracedUtils.jl:275 [inlined]
 [28] compile_mlir!(mod::Reactant.MLIR.IR.Module, f::typeof(square), args::Tuple{…}, compile_options::CompileOptions, callcache::Dict{…}, sdycache::Dict{…}, sdygroupidcache::Tuple{…}; fn_kwargs::@NamedTuple{}, backend::String, runtime::Val{…}, legalize_stablehlo_to_mhlo::Bool, client::Reactant.XLA.PJRT.Client, kwargs::@Kwargs{})
    @ Reactant.Compiler ~/.julia/packages/Reactant/zlIsO/src/Compiler.jl:1614
 [29] compile_mlir!
    @ ~/.julia/packages/Reactant/zlIsO/src/Compiler.jl:1576 [inlined]
 [30] compile_xla(f::Function, args::Tuple{…}; before_xla_optimizations::Bool, client::Nothing, serializable::Bool, kwargs::@Kwargs{})
    @ Reactant.Compiler ~/.julia/packages/Reactant/zlIsO/src/Compiler.jl:3524
 [31] compile_xla
    @ ~/.julia/packages/Reactant/zlIsO/src/Compiler.jl:3496 [inlined]
 [32] compile(f::Function, args::Tuple{…}; kwargs::@Kwargs{})
    @ Reactant.Compiler ~/.julia/packages/Reactant/zlIsO/src/Compiler.jl:3600
 [33] top-level scope
    @ ~/.julia/packages/Reactant/zlIsO/src/Compiler.jl:2669
Some type information was truncated. Use `show(err)` to see complete types.

Furthermore, the [CUDA.jl](https://github.com/JuliaGPU/CUDA.jl) package needs to be loaded (even on non-NVIDIA hardware).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hm this is a bug, in principle we should force load this for you [though in the bg it will load cuda.jl]

Copy link
Contributor Author

@gdalle gdalle Nov 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without loading CUDA first, I get the error I mentioned on Discourse:

julia> y = @jit raise=true square(x)
ERROR: MethodError: no method matching ka_with_reactant(::Int64, ::Nothing, ::KernelAbstractions.Kernel{…}, ::Reactant.TracedRArray{…}, ::Reactant.TracedRArray{…})
The function `ka_with_reactant` exists, but no method is defined for this combination of argument types.
Attempted to raise a KernelAbstractions kernel with Reactant but CUDA.jl is not loaded.
Load CUDA.jl using `using CUDA`. You might need to restart the Julia process (even if Revise.jl is loaded).
Stacktrace:
  [1] macro expansion
    @ ~/.julia/packages/Reactant/zlIsO/src/utils.jl:0 [inlined]
  [2] call_with_reactant(::typeof(Reactant.ka_with_reactant), ::Int64, ::Nothing, ::KernelAbstractions.Kernel{…}, ::Reactant.TracedRArray{…}, ::Reactant.TracedRArray{…})
    @ Reactant ~/.julia/packages/Reactant/zlIsO/src/utils.jl:944
  [3] (::KernelAbstractions.Kernel{…})(::Any, ::Vararg{…}; ndrange::Int64, workgroupsize::Nothing)
    @ ReactantKernelAbstractionsExt ~/.julia/packages/Reactant/zlIsO/ext/ReactantKernelAbstractionsExt.jl:118
...

```jldoctest kernels
import CUDA
using Reactant

xr = ConcreteRArray(x)
yr = @jit raise=true square(xr)

# output

5-element ConcretePJRTArray{Float64,1}:
1.0
4.0
9.0
16.0
25.0
```

## Differentiated kernel

In addition, if you want to compute derivatives of your kernel with [Enzyme.jl](https://github.com/EnzymeAD/Enzyme.jl), the option `raise_first=true` also becomes necessary.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you explain more what raising is. I think this doc would be better written as a tutorial on raising rather than kernels

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

re this specifically, currently we only support differentiation rules for the raised tensors [the internal kernel representaiton derivatives are in progress cc @Pangoraw @avik-pal etc].

For now raising will enable differentiation to succeed, but also raising must be performed prior to differentiation, hence raisefirst

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you explain more what raising is. I think this doc would be better written as a tutorial on raising rather than kernels

No I can't. I don't know the first thing about how raising works (besides what @yolhan83 has taught me on Discourse), I only know that it seems to be necessary on my laptop for handwritten kernels to get compiled.

The goals of this PR are:

  1. Showing that interoperability with custom kernels is an important aspect for Reactant users, that needs to be documented.
  2. Getting the page started so that people who actually know what they're talking about can finish writing proper documentation.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wrote a small section about raising, if you want to include or i can open a PR after this one is merged: 5fe6e01


```jldoctest kernels
import Enzyme

sumsquare(x) = sum(square(x))
gr = @jit raise=true raise_first=true Enzyme.gradient(Enzyme.Reverse, sumsquare, xr)

# output

(ConcretePJRTArray{Float64, 1, 1}([2.0, 4.0, 6.0, 8.0, 10.0]),)
```
Loading