Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 7 additions & 9 deletions src/compiler/compilation.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
## gpucompiler interface

Base.@kwdef struct OpenCLCompilerParams <: AbstractCompilerParams
sub_group_size::Int # Some devices support multiple sizes. This is used to force one when needed
# request a fixed sub-group width via `intel_reqd_sub_group_size`
sub_group_size::Union{Nothing,Int} = nothing
end

const OpenCLCompilerConfig = CompilerConfig{SPIRVCompilerTarget, OpenCLCompilerParams}
Expand Down Expand Up @@ -32,9 +33,8 @@ function GPUCompiler.finish_module!(@nospecialize(job::OpenCLCompilerJob),
Tuple{CompilerJob{SPIRVCompilerTarget}, LLVM.Module, LLVM.Function},
job, mod, entry)

# Set the subgroup size if supported
sg_size = job.config.params.sub_group_size
if sg_size >= 0
if sg_size !== nothing
metadata(entry)["intel_reqd_sub_group_size"] = MDNode([ConstantInt(Int32(sg_size))])
end

Expand Down Expand Up @@ -136,15 +136,13 @@ function compiler_config(dev::cl.Device; kwargs...)
end
return config
end
@noinline function _compiler_config(dev; kernel=true, name=nothing, always_inline=false, kwargs...)
@noinline function _compiler_config(dev; kernel=true, name=nothing, always_inline=false,
sub_group_size::Union{Nothing,Int}=nothing, kwargs...)
supports_fp16 = "cl_khr_fp16" in dev.extensions
supports_fp64 = "cl_khr_fp64" in dev.extensions

# Set to -1 if specifying a subgroup size is not supported
sub_group_size = if "cl_intel_required_subgroup_size" in dev.extensions
cl.sub_group_size(dev)
else
-1
if sub_group_size !== nothing && !("cl_intel_required_subgroup_size" in dev.extensions)
error("Device does not support cl_intel_required_subgroup_size")
end

# create GPUCompiler objects
Expand Down
2 changes: 1 addition & 1 deletion src/compiler/execution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ export @opencl, clfunction
## high-level @opencl interface

const MACRO_KWARGS = [:launch]
const COMPILER_KWARGS = [:kernel, :name, :always_inline, :extensions, :backend, :validate]
const COMPILER_KWARGS = [:kernel, :name, :always_inline, :extensions, :backend, :validate, :sub_group_size]
const LAUNCH_KWARGS = [:global_size, :local_size, :queue]

macro opencl(ex...)
Expand Down
6 changes: 3 additions & 3 deletions test/intrinsics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ end
N = local_size * numworkgroups

results = CLVector{SubgroupData}(undef, N)
kernel = @opencl launch = false test_subgroup_kernel(results)
kernel = @opencl launch = false sub_group_size = sg_size test_subgroup_kernel(results)

kernel(results; local_size, global_size=N)

Expand Down Expand Up @@ -248,7 +248,7 @@ end
@testset for T in cl.sub_group_shuffle_supported_types(cl.device())
a = rand(T, sg_size)
d_a = CLArray(a)
@opencl local_size = sg_size global_size = sg_size shfl_idx_kernel(d_a)
@opencl local_size = sg_size global_size = sg_size sub_group_size = sg_size shfl_idx_kernel(d_a)
@test Array(d_a) == reverse(a)
end
end
Expand All @@ -267,7 +267,7 @@ end
in = rand(T, sg_size)
idxs = xor.(0:(sg_size - 1), 1) .+ 1
d_in = CLArray(in)
@opencl local_size = sg_size global_size = sg_size shfl_xor_kernel(d_in)
@opencl local_size = sg_size global_size = sg_size sub_group_size = sg_size shfl_xor_kernel(d_in)
@test Array(d_in) == in[idxs]
end
end
Expand Down
Loading