Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 143 additions & 3 deletions src/gcn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,21 +40,161 @@ runtime_slug(job::CompilerJob{GCNCompilerTarget}) = "gcn-$(job.config.target.dev
const gcn_intrinsics = () # TODO: ("vprintf", "__assertfail", "malloc", "free")
isintrinsic(::CompilerJob{GCNCompilerTarget}, fn::String) = in(fn, gcn_intrinsics)

pass_by_ref(@nospecialize(job::CompilerJob{GCNCompilerTarget})) = true

function finish_module!(@nospecialize(job::CompilerJob{GCNCompilerTarget}),
mod::LLVM.Module, entry::LLVM.Function)
lower_throw_extra!(mod)

if job.config.kernel
# calling convention
callconv!(entry, LLVM.API.LLVMAMDGPUKERNELCallConv)

# work around bad byval codegen (JuliaGPU/GPUCompiler.jl#92)
entry = lower_byval(job, mod, entry)
end

return entry
end

function finish_ir!(
@nospecialize(job::CompilerJob{GCNCompilerTarget}), mod::LLVM.Module,
entry::LLVM.Function
)
if job.config.kernel
entry = add_kernarg_address_spaces!(job, mod, entry)

# optimize after address space rewriting: propagate addrspace(4) through
# the addrspacecast chains, then clean up newly-exposed opportunities
tm = llvm_machine(job.config.target)
@dispose pb=NewPMPassBuilder() begin
add!(pb, NewPMFunctionPassManager()) do fpm
add!(fpm, InferAddressSpacesPass())
add!(fpm, SROAPass())
add!(fpm, InstCombinePass())
add!(fpm, EarlyCSEPass())
add!(fpm, SimplifyCFGPass())
end
run!(pb, mod, tm)
end
end
return entry
end

# Rewrite byref kernel parameters from flat (addrspace 0) to constant (addrspace 4).
#
# On AMDGPU, kernel arguments reside in the constant address space (addrspace 4),
# which is scalar-loadable via s_load. Julia initially emits byref parameters as
# pointers in addrspace(11) (tracked/derived), but RemoveJuliaAddrspacesPass strips
# all non-integral address spaces to flat (addrspace 0) during optimization. This pass
# restores addrspace(4) on byref parameters so that the backend can emit s_load
# instead of flat_load for struct field accesses.
#
# NOTE: must run after optimization, where RemoveJuliaAddrspacesPass has already
# converted Julia's addrspace(11) to flat (addrspace 0) on these parameters.
function add_kernarg_address_spaces!(
@nospecialize(job::CompilerJob), mod::LLVM.Module,
f::LLVM.Function
)
ft = function_type(f)

# find the byref parameters by checking for the byref attribute directly,
# rather than re-classifying arguments (which can fail on typed-pointer LLVM
# due to element type mismatches in classify_arguments assertions).
byref_kind = LLVM.API.LLVMGetEnumAttributeKindForName("byref", 5)
byref_mask = BitVector(undef, length(parameters(ft)))
for i in 1:length(parameters(ft))
attrs = collect(parameter_attributes(f, i))
byref_mask[i] = any(a -> a isa TypeAttribute && kind(a) == byref_kind, attrs)
end

# check if any flat pointer byref params need rewriting
needs_rewrite = false
for (i, param) in enumerate(parameters(ft))
if byref_mask[i] && param isa LLVM.PointerType && addrspace(param) == 0
needs_rewrite = true
break
end
end
needs_rewrite || return f

# generate the new function type with constant address space on byref params
new_types = LLVMType[]
for (i, param) in enumerate(parameters(ft))
if byref_mask[i] && param isa LLVM.PointerType && addrspace(param) == 0
if supports_typed_pointers(context())
push!(new_types, LLVM.PointerType(eltype(param), #=constant=# 4))
else
push!(new_types, LLVM.PointerType(#=constant=# 4))
end
else
push!(new_types, param)
end
end
new_ft = LLVM.FunctionType(return_type(ft), new_types)
new_f = LLVM.Function(mod, "", new_ft)
linkage!(new_f, linkage(f))
for (arg, new_arg) in zip(parameters(f), parameters(new_f))
LLVM.name!(new_arg, LLVM.name(arg))
end

# insert addrspacecasts from kernarg (4) back to flat (0) so that the cloned IR
# (which expects flat pointers) continues to work. The AMDGPU backend's
# AMDGPULowerKernelArguments traces these casts and produces s_load.
new_args = LLVM.Value[]
@dispose builder=IRBuilder() begin
entry_bb = BasicBlock(new_f, "conversion")
position!(builder, entry_bb)

for (i, param) in enumerate(parameters(ft))
if byref_mask[i] && param isa LLVM.PointerType && addrspace(param) == 0
cast = addrspacecast!(builder, parameters(new_f)[i], param)
push!(new_args, cast)
else
push!(new_args, parameters(new_f)[i])
end
end

# clone the original function body
value_map = Dict{LLVM.Value, LLVM.Value}(
param => new_args[i] for (i, param) in enumerate(parameters(f))
)
value_map[f] = new_f
clone_into!(
new_f, f; value_map,
changes = LLVM.API.LLVMCloneFunctionChangeTypeGlobalChanges
)

# fall through from conversion block to cloned entry
br!(builder, blocks(new_f)[2])
end

# copy parameter attributes AFTER clone_into!, because CloneFunctionInto
# overwrites all attributes via setAttributes. For byref params, the VMap
# maps old args to addrspacecast instructions (not Arguments), so LLVM's
# attribute remapping silently drops them. We must re-add them here.
for i in 1:length(parameters(ft))
for attr in collect(parameter_attributes(f, i))
push!(parameter_attributes(new_f, i), attr)
end
end

# replace the old function
fn = LLVM.name(f)
prune_constexpr_uses!(f)
@assert isempty(uses(f))
replace_metadata_uses!(f, new_f)
erase!(f)
LLVM.name!(new_f, fn)

# clean up the extra conversion block
@dispose pb=NewPMPassBuilder() begin
add!(pb, NewPMFunctionPassManager()) do fpm
add!(fpm, SimplifyCFGPass())
end
run!(pb, mod)
end

return functions(mod)[fn]
end


## LLVM passes

Expand Down
6 changes: 6 additions & 0 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,12 @@ kernel_state_type(@nospecialize(job::CompilerJob)) = Nothing
# Does the target need to pass kernel arguments by value?
pass_by_value(@nospecialize(job::CompilerJob)) = true

# Should the target use byref instead of byval+lower_byval for kernel arguments?
# When true, aggregate arguments are passed as pointers with the byref attribute,
# allowing the backend to load fields directly from the argument memory (e.g. kernarg
# segment on AMDGPU) instead of materializing the entire struct via first-class aggregates.
pass_by_ref(@nospecialize(job::CompilerJob)) = false

# whether pointer is a valid call target
valid_function_pointer(@nospecialize(job::CompilerJob), ptr::Ptr{Cvoid}) = false

Expand Down
6 changes: 5 additions & 1 deletion src/irgen.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,11 @@ function irgen(@nospecialize(job::CompilerJob))
for arg in args
if arg.cc == BITS_REF
llvm_typ = convert(LLVMType, arg.typ)
attr = TypeAttribute("byval", llvm_typ)
if pass_by_ref(job)
attr = TypeAttribute("byref", llvm_typ)
else
attr = TypeAttribute("byval", llvm_typ)
end
push!(parameter_attributes(entry, arg.idx), attr)
end
end
Expand Down
158 changes: 158 additions & 0 deletions test/gcn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,122 @@ end
end
end

@testset "kernarg address space for byref parameters" begin
mod = @eval module $(gensym())
struct MyStruct
x::Float64
y::Float64
end

function kernel(s::MyStruct)
s.x + s.y
return
end
end

# byref struct params should be ptr addrspace(4) in kernel IR
@test @filecheck begin
check"CHECK: define amdgpu_kernel void @_Z6kernel8MyStruct(ptr addrspace(4)"
GCN.code_llvm(mod.kernel, Tuple{mod.MyStruct}; dump_module=true, kernel=true)
end

# non-kernel should NOT have addrspace(4)
@test @filecheck begin
check"CHECK-NOT: addrspace(4)"
GCN.code_llvm(mod.kernel, Tuple{mod.MyStruct}; dump_module=true, kernel=false)
end
end

@testset "byref attribute preserved on kernarg parameters" begin
mod = @eval module $(gensym())
struct LargeStruct
a::Float64
b::Float64
c::Float64
d::Float64
end

function kernel(s::LargeStruct, out::Ptr{Float64})
unsafe_store!(out, s.a + s.b + s.c + s.d)
return
end
end

# the byref attribute must survive the addrspace rewrite (clone_into! can drop it)
@test @filecheck begin
check"CHECK: byref"
check"CHECK: addrspace(4)"
GCN.code_llvm(mod.kernel, Tuple{mod.LargeStruct, Ptr{Float64}};
dump_module=true, kernel=true)
end
end

@testset "mixed byref and scalar kernel parameters" begin
mod = @eval module $(gensym())
struct Params
x::Float64
y::Float64
end

function kernel(a::Float64, s::Params, out::Ptr{Float64})
unsafe_store!(out, a + s.x + s.y)
return
end
end

# scalar Float64 should NOT be in addrspace(4),
# only the struct byref param should be.
# NOTE: Ptr{Float64} is lowered to i64 on Julia ≤1.11 and ptr on Julia 1.12+.
@test @filecheck begin
check"CHECK: define amdgpu_kernel void"
check"CHECK-SAME: double"
check"CHECK-SAME: ptr addrspace(4)"
check"CHECK-SAME: {{(i64|ptr)}}"
GCN.code_llvm(mod.kernel, Tuple{Float64, mod.Params, Ptr{Float64}};
dump_module=true, kernel=true)
end
end

@testset "add_kernarg_address_spaces! rewrites IR correctly" begin
mod = @eval module $(gensym())
struct KernelArgs
x::Float64
y::Float64
z::Float64
end

function kernel(s::KernelArgs, scale::Float64, out::Ptr{Float64})
unsafe_store!(out, (s.x + s.y + s.z) * scale)
return
end
end

job, _ = GCN.create_job(mod.kernel, Tuple{mod.KernelArgs, Float64, Ptr{Float64}};
kernel=true)
JuliaContext() do ctx
ir, meta = GPUCompiler.compile(:llvm, job)

entry = meta.entry
ft = function_type(entry)
params = parameters(ft)

# the struct byref param should be ptr addrspace(4)
has_as4 = any(p -> p isa LLVM.PointerType && addrspace(p) == 4, params)
@test has_as4

# non-struct params (double, and i64/ptr for Ptr{Float64}) should NOT
# be in addrspace(4). Ptr{Float64} is i64 on Julia ≤1.11, ptr on 1.12+.
non_byref = filter(p -> !(p isa LLVM.PointerType && addrspace(p) == 4), params)
@test !isempty(non_byref) # double (and i64 or ptr) params

# byref attribute must be present
ir_str = string(ir)
@test occursin("byref", ir_str)

dispose(ir)
end
end

@testset "https://github.com/JuliaGPU/AMDGPU.jl/issues/846" begin
ir, rt = GCN.code_typed((Tuple{Tuple{Val{4}}, Tuple{Float32}},); always_inline=true) do t
t[1]
Expand All @@ -49,6 +165,48 @@ end
############################################################################################
@testset "assembly" begin

@testset "s_load for kernarg struct access" begin
mod = @eval module $(gensym())
struct MyStruct
x::Float64
y::Float64
end

function kernel(s::MyStruct, out::Ptr{Float64})
unsafe_store!(out, s.x + s.y)
return
end
end

# struct field loads from kernarg should use s_load, not flat_load
@test @filecheck begin
check"CHECK: s_load_dwordx"
check"CHECK-NOT: flat_load"
GCN.code_native(mod.kernel, Tuple{mod.MyStruct, Ptr{Float64}}; kernel=true)
end
end

@testset "no scratch spills for small struct kernarg" begin
mod = @eval module $(gensym())
struct SmallStruct
x::Float64
y::Float64
end

function kernel(s::SmallStruct, out::Ptr{Float64})
unsafe_store!(out, s.x + s.y)
return
end
end

# a small struct kernel should not need scratch memory
@test @filecheck begin
check"CHECK: .private_segment_fixed_size: 0"
GCN.code_native(mod.kernel, Tuple{mod.SmallStruct, Ptr{Float64}};
dump_module=true, kernel=true)
end
end

@testset "skip scalar trap" begin
mod = @eval module $(gensym())
workitem_idx_x() = ccall("llvm.amdgcn.workitem.id.x", llvmcall, Int32, ())
Expand Down
Loading