From 48a86d6269f6c4a1387b037d2c3164e506424f2d Mon Sep 17 00:00:00 2001 From: Simeon David Schaub Date: Thu, 7 Aug 2025 13:42:16 +0200 Subject: [PATCH 1/4] optimized `mapreduce` using sub group shuffle ref #352 Unfortunately, I don't really see any performance improvements with this, any ideas why? I expected this to be quite a bit faster. --- lib/intrinsics/src/work_item.jl | 44 ++++++++++++++++++++ src/mapreduce.jl | 72 ++++++++++++++++++++++++++++++++- 2 files changed, 115 insertions(+), 1 deletion(-) diff --git a/lib/intrinsics/src/work_item.jl b/lib/intrinsics/src/work_item.jl index bbe85adb..bd30ae9a 100644 --- a/lib/intrinsics/src/work_item.jl +++ b/lib/intrinsics/src/work_item.jl @@ -34,6 +34,50 @@ for (julia_name, (spirv_name, julia_type, offset)) in [ end end + +# Sub-group shuffle intrinsics using a loop and @eval, matching the style of the 1D/3D value loops above +export sub_group_shuffle, sub_group_shuffle_xor + +for (jltype, llvmtype, julia_type_str) in [ + (Int8, "i8", :Int8), + (UInt8, "i8", :UInt8), + (Int16, "i16", :Int16), + (UInt16, "i16", :UInt16), + (Int32, "i32", :Int32), + (UInt32, "i32", :UInt32), + (Int64, "i64", :Int64), + (UInt64, "i64", :UInt64), + (Float16, "half", :Float16), + (Float32, "float", :Float32), + (Float64, "double",:Float64) + ] + @eval begin + export sub_group_shuffle, sub_group_shuffle_xor + function sub_group_shuffle(x::$jltype, idx::Integer) + Base.llvmcall( + $(""" + declare $llvmtype @__spirv_GroupNonUniformShuffle(i32, $llvmtype, i32) + define $llvmtype @entry($llvmtype %val, i32 %idx) #0 { + %res = call $llvmtype @__spirv_GroupNonUniformShuffle(i32 3, $llvmtype %val, i32 %idx) + ret $llvmtype %res + } + attributes #0 = { alwaysinline } + """, "entry"), $julia_type_str, Tuple{$julia_type_str, Int32}, x, idx % Int32 - 1i32) + end + function sub_group_shuffle_xor(x::$jltype, mask::Integer) + Base.llvmcall( + $(""" + declare $llvmtype @__spirv_GroupNonUniformShuffleXor(i32, $llvmtype, i32) + define $llvmtype @entry($llvmtype %val, i32 %mask) #0 { + %res = call $llvmtype @__spirv_GroupNonUniformShuffleXor(i32 3, $llvmtype %val, i32 %mask) + ret $llvmtype %res + } + attributes #0 = { alwaysinline } + """, "entry"), $julia_type_str, Tuple{$julia_type_str, Int32}, x, mask % UInt32) + end + end +end + # 3D values for (julia_name, (spirv_name, offset)) in [ # indices diff --git a/src/mapreduce.jl b/src/mapreduce.jl index e9a3f979..feb33016 100644 --- a/src/mapreduce.jl +++ b/src/mapreduce.jl @@ -5,8 +5,78 @@ # - group-stride loop to delay need for second kernel launch # - let the driver choose the local size +function shuffle_expr(::Type{T}) where {T} + if T in SPIRVIntrinsics.generic_integer_types || T in SPIRVIntrinsics.generic_types + return :(sub_group_shuffle(val, i)) + elseif Base.isstructtype(T) + ex = Expr(:new, T) + for f in fieldnames(T) + ex_f = shuffle_expr(fieldtype(T, f)) + ex_f === nothing && return nothing + push!(ex.args, :(let val = getfield(val, $(QuoteNode(f))) + $ex_f + end)) + end + return ex + else + return nothing + end +end + +@inline @generated function reduce_group(op, val::T, neutral, ::Val{maxitems}) where {T, maxitems} + ex = shuffle_expr(T) + if ex === nothing + return :(reduce_group_fallback(op, val, neutral, Val(maxitems))) + end + + quote + # Subgroup shuffle-based warp reduction + lane = get_sub_group_local_id() + width = get_sub_group_size() + + offset = 1 + while offset < width + if lane > offset + i = lane - offset + other = $ex + val = op(val, other) + end + offset <<= 1 + end + + items = get_num_sub_groups() + item = get_sub_group_id() + + shared = CLLocalArray(T, (maxitems,)) + if items > 1 && lane == 1 + @inbounds shared[item] = val + + d = 1 + while d < items + work_group_barrier(LOCAL_MEM_FENCE) + index = 2 * d * (item-1) + 1 + @inbounds if index <= items + other_val = if index + d <= items + shared[index+d] + else + neutral + end + shared[index] = op(shared[index], other_val) + end + d *= 2 + end + + if item == 1 + val = @inbounds shared[item] + end + end + + return val + end +end + # Reduce a value across a group, using local memory for communication -@inline function reduce_group(op, val::T, neutral, ::Val{maxitems}) where {T, maxitems} +@inline function reduce_group_fallback(op, val::T, neutral, ::Val{maxitems}) where {T, maxitems} items = get_local_size() item = get_local_id() From b9a0c43e824dd3740b1e035e0d542a5e1ae80fe1 Mon Sep 17 00:00:00 2001 From: Simeon David Schaub Date: Mon, 13 Oct 2025 11:07:57 +0200 Subject: [PATCH 2/4] fix non-power of 2 sub group size --- src/mapreduce.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/mapreduce.jl b/src/mapreduce.jl index feb33016..04e56ebe 100644 --- a/src/mapreduce.jl +++ b/src/mapreduce.jl @@ -36,9 +36,9 @@ end offset = 1 while offset < width - if lane > offset - i = lane - offset - other = $ex + i = lane + offset + other = $ex + if i <= width val = op(val, other) end offset <<= 1 From ccacd0a9e2b6e6c71bc445f33d1d2f75a4215bde Mon Sep 17 00:00:00 2001 From: Simeon David Schaub Date: Mon, 13 Oct 2025 11:42:32 +0200 Subject: [PATCH 3/4] properly mangle sub group shuffle builtin --- lib/intrinsics/src/SPIRVIntrinsics.jl | 1 + lib/intrinsics/src/shuffle.jl | 12 ++++++++ lib/intrinsics/src/utils.jl | 2 ++ lib/intrinsics/src/work_item.jl | 44 --------------------------- 4 files changed, 15 insertions(+), 44 deletions(-) create mode 100644 lib/intrinsics/src/shuffle.jl diff --git a/lib/intrinsics/src/SPIRVIntrinsics.jl b/lib/intrinsics/src/SPIRVIntrinsics.jl index bd15fdd9..b2bca59d 100644 --- a/lib/intrinsics/src/SPIRVIntrinsics.jl +++ b/lib/intrinsics/src/SPIRVIntrinsics.jl @@ -23,6 +23,7 @@ include("printf.jl") include("math.jl") include("integer.jl") include("atomic.jl") +include("shuffle.jl") # helper macro to import all names from this package, even non-exported ones. macro import_all() diff --git a/lib/intrinsics/src/shuffle.jl b/lib/intrinsics/src/shuffle.jl new file mode 100644 index 00000000..804b0e18 --- /dev/null +++ b/lib/intrinsics/src/shuffle.jl @@ -0,0 +1,12 @@ +export sub_group_shuffle, sub_group_shuffle_xor + +const gentypes = [Int8, UInt8, Int16, UInt16, Int32, UInt32, Int64, UInt64, Float16, Float32, Float64] + +for gentype in gentypes +@eval begin + +@device_function sub_group_shuffle(x::$gentype, i::Integer) = @builtin_ccall("sub_group_shuffle", $gentype, ($gentype, Int32), x, i % Int32 - 1i32) +@device_function sub_group_shuffle_xor(x::$gentype, mask::Integer) = @builtin_ccall("sub_group_shuffle_xor", $gentype, ($gentype, UInt32), x, mask % UInt32) + +end +end diff --git a/lib/intrinsics/src/utils.jl b/lib/intrinsics/src/utils.jl index e1a5a939..2c12db8a 100644 --- a/lib/intrinsics/src/utils.jl +++ b/lib/intrinsics/src/utils.jl @@ -26,6 +26,8 @@ macro builtin_ccall(name, ret, argtypes, args...) "c" elseif T == UInt8 "h" + elseif T == Float16 + "Dh" elseif T == Float32 "f" elseif T == Float64 diff --git a/lib/intrinsics/src/work_item.jl b/lib/intrinsics/src/work_item.jl index bd30ae9a..bbe85adb 100644 --- a/lib/intrinsics/src/work_item.jl +++ b/lib/intrinsics/src/work_item.jl @@ -34,50 +34,6 @@ for (julia_name, (spirv_name, julia_type, offset)) in [ end end - -# Sub-group shuffle intrinsics using a loop and @eval, matching the style of the 1D/3D value loops above -export sub_group_shuffle, sub_group_shuffle_xor - -for (jltype, llvmtype, julia_type_str) in [ - (Int8, "i8", :Int8), - (UInt8, "i8", :UInt8), - (Int16, "i16", :Int16), - (UInt16, "i16", :UInt16), - (Int32, "i32", :Int32), - (UInt32, "i32", :UInt32), - (Int64, "i64", :Int64), - (UInt64, "i64", :UInt64), - (Float16, "half", :Float16), - (Float32, "float", :Float32), - (Float64, "double",:Float64) - ] - @eval begin - export sub_group_shuffle, sub_group_shuffle_xor - function sub_group_shuffle(x::$jltype, idx::Integer) - Base.llvmcall( - $(""" - declare $llvmtype @__spirv_GroupNonUniformShuffle(i32, $llvmtype, i32) - define $llvmtype @entry($llvmtype %val, i32 %idx) #0 { - %res = call $llvmtype @__spirv_GroupNonUniformShuffle(i32 3, $llvmtype %val, i32 %idx) - ret $llvmtype %res - } - attributes #0 = { alwaysinline } - """, "entry"), $julia_type_str, Tuple{$julia_type_str, Int32}, x, idx % Int32 - 1i32) - end - function sub_group_shuffle_xor(x::$jltype, mask::Integer) - Base.llvmcall( - $(""" - declare $llvmtype @__spirv_GroupNonUniformShuffleXor(i32, $llvmtype, i32) - define $llvmtype @entry($llvmtype %val, i32 %mask) #0 { - %res = call $llvmtype @__spirv_GroupNonUniformShuffleXor(i32 3, $llvmtype %val, i32 %mask) - ret $llvmtype %res - } - attributes #0 = { alwaysinline } - """, "entry"), $julia_type_str, Tuple{$julia_type_str, Int32}, x, mask % UInt32) - end - end -end - # 3D values for (julia_name, (spirv_name, offset)) in [ # indices From 96192c21f879fa1d5678f01c596c3ac9ce721179 Mon Sep 17 00:00:00 2001 From: Simeon David Schaub Date: Mon, 13 Oct 2025 15:53:09 +0200 Subject: [PATCH 4/4] wip --- src/mapreduce.jl | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/mapreduce.jl b/src/mapreduce.jl index 04e56ebe..60064d06 100644 --- a/src/mapreduce.jl +++ b/src/mapreduce.jl @@ -6,7 +6,7 @@ # - let the driver choose the local size function shuffle_expr(::Type{T}) where {T} - if T in SPIRVIntrinsics.generic_integer_types || T in SPIRVIntrinsics.generic_types + if T in SPIRVIntrinsics.gentypes return :(sub_group_shuffle(val, i)) elseif Base.isstructtype(T) ex = Expr(:new, T) @@ -115,12 +115,13 @@ Base.@propagate_inbounds _map_getindex(args::Tuple{}, I) = () # Reduce an array across the grid. All elements to be processed can be addressed by the # product of the two iterators `Rreduce` and `Rother`, where the latter iterator will have # singleton entries for the dimensions that should be reduced (and vice versa). -function partial_mapreduce_device(f, op, neutral, maxitems, Rreduce, Rother, R, As...) +function partial_mapreduce_device(f, op, neutral, maxitems, Rreduce, Rother, R, A) + As = (A,) # decompose the 1D hardware indices into separate ones for reduction (across items # and possibly groups if it doesn't fit) and other elements (remaining groups) localIdx_reduce = get_local_id() localDim_reduce = get_local_size() - groupIdx_reduce, groupIdx_other = fldmod1(get_group_id(), length(Rother)) + groupIdx_reduce, groupIdx_other = @inline fldmod1(get_group_id(), length(Rother)) groupDim_reduce = get_num_groups() รท length(Rother) # group-based indexing into the values outside of the reduction dimension @@ -137,7 +138,7 @@ function partial_mapreduce_device(f, op, neutral, maxitems, Rreduce, Rother, R, neutral end - val = op(neutral, neutral) + val = neutral # reduce serially across chunks of input vector that don't fit in a group ireduce = localIdx_reduce + (groupIdx_reduce - 1) * localDim_reduce