diff --git a/src/gcn.jl b/src/gcn.jl index 146d9a33..e310b5c5 100644 --- a/src/gcn.jl +++ b/src/gcn.jl @@ -40,6 +40,8 @@ runtime_slug(job::CompilerJob{GCNCompilerTarget}) = "gcn-$(job.config.target.dev const gcn_intrinsics = () # TODO: ("vprintf", "__assertfail", "malloc", "free") isintrinsic(::CompilerJob{GCNCompilerTarget}, fn::String) = in(fn, gcn_intrinsics) +pass_by_ref(@nospecialize(job::CompilerJob{GCNCompilerTarget})) = true + function finish_module!(@nospecialize(job::CompilerJob{GCNCompilerTarget}), mod::LLVM.Module, entry::LLVM.Function) lower_throw_extra!(mod) @@ -47,14 +49,152 @@ function finish_module!(@nospecialize(job::CompilerJob{GCNCompilerTarget}), if job.config.kernel # calling convention callconv!(entry, LLVM.API.LLVMAMDGPUKERNELCallConv) - - # work around bad byval codegen (JuliaGPU/GPUCompiler.jl#92) - entry = lower_byval(job, mod, entry) end return entry end +function finish_ir!( + @nospecialize(job::CompilerJob{GCNCompilerTarget}), mod::LLVM.Module, + entry::LLVM.Function + ) + if job.config.kernel + entry = add_kernarg_address_spaces!(job, mod, entry) + + # optimize after address space rewriting: propagate addrspace(4) through + # the addrspacecast chains, then clean up newly-exposed opportunities + tm = llvm_machine(job.config.target) + @dispose pb=NewPMPassBuilder() begin + add!(pb, NewPMFunctionPassManager()) do fpm + add!(fpm, InferAddressSpacesPass()) + add!(fpm, SROAPass()) + add!(fpm, InstCombinePass()) + add!(fpm, EarlyCSEPass()) + add!(fpm, SimplifyCFGPass()) + end + run!(pb, mod, tm) + end + end + return entry +end + +# Rewrite byref kernel parameters from flat (addrspace 0) to constant (addrspace 4). +# +# On AMDGPU, kernel arguments reside in the constant address space (addrspace 4), +# which is scalar-loadable via s_load. Julia initially emits byref parameters as +# pointers in addrspace(11) (tracked/derived), but RemoveJuliaAddrspacesPass strips +# all non-integral address spaces to flat (addrspace 0) during optimization. This pass +# restores addrspace(4) on byref parameters so that the backend can emit s_load +# instead of flat_load for struct field accesses. +# +# NOTE: must run after optimization, where RemoveJuliaAddrspacesPass has already +# converted Julia's addrspace(11) to flat (addrspace 0) on these parameters. +function add_kernarg_address_spaces!( + @nospecialize(job::CompilerJob), mod::LLVM.Module, + f::LLVM.Function + ) + ft = function_type(f) + + # find the byref parameters by checking for the byref attribute directly, + # rather than re-classifying arguments (which can fail on typed-pointer LLVM + # due to element type mismatches in classify_arguments assertions). + byref_kind = LLVM.API.LLVMGetEnumAttributeKindForName("byref", 5) + byref_mask = BitVector(undef, length(parameters(ft))) + for i in 1:length(parameters(ft)) + attrs = collect(parameter_attributes(f, i)) + byref_mask[i] = any(a -> a isa TypeAttribute && kind(a) == byref_kind, attrs) + end + + # check if any flat pointer byref params need rewriting + needs_rewrite = false + for (i, param) in enumerate(parameters(ft)) + if byref_mask[i] && param isa LLVM.PointerType && addrspace(param) == 0 + needs_rewrite = true + break + end + end + needs_rewrite || return f + + # generate the new function type with constant address space on byref params + new_types = LLVMType[] + for (i, param) in enumerate(parameters(ft)) + if byref_mask[i] && param isa LLVM.PointerType && addrspace(param) == 0 + if supports_typed_pointers(context()) + push!(new_types, LLVM.PointerType(eltype(param), #=constant=# 4)) + else + push!(new_types, LLVM.PointerType(#=constant=# 4)) + end + else + push!(new_types, param) + end + end + new_ft = LLVM.FunctionType(return_type(ft), new_types) + new_f = LLVM.Function(mod, "", new_ft) + linkage!(new_f, linkage(f)) + for (arg, new_arg) in zip(parameters(f), parameters(new_f)) + LLVM.name!(new_arg, LLVM.name(arg)) + end + + # insert addrspacecasts from kernarg (4) back to flat (0) so that the cloned IR + # (which expects flat pointers) continues to work. The AMDGPU backend's + # AMDGPULowerKernelArguments traces these casts and produces s_load. + new_args = LLVM.Value[] + @dispose builder=IRBuilder() begin + entry_bb = BasicBlock(new_f, "conversion") + position!(builder, entry_bb) + + for (i, param) in enumerate(parameters(ft)) + if byref_mask[i] && param isa LLVM.PointerType && addrspace(param) == 0 + cast = addrspacecast!(builder, parameters(new_f)[i], param) + push!(new_args, cast) + else + push!(new_args, parameters(new_f)[i]) + end + end + + # clone the original function body + value_map = Dict{LLVM.Value, LLVM.Value}( + param => new_args[i] for (i, param) in enumerate(parameters(f)) + ) + value_map[f] = new_f + clone_into!( + new_f, f; value_map, + changes = LLVM.API.LLVMCloneFunctionChangeTypeGlobalChanges + ) + + # fall through from conversion block to cloned entry + br!(builder, blocks(new_f)[2]) + end + + # copy parameter attributes AFTER clone_into!, because CloneFunctionInto + # overwrites all attributes via setAttributes. For byref params, the VMap + # maps old args to addrspacecast instructions (not Arguments), so LLVM's + # attribute remapping silently drops them. We must re-add them here. + for i in 1:length(parameters(ft)) + for attr in collect(parameter_attributes(f, i)) + push!(parameter_attributes(new_f, i), attr) + end + end + + # replace the old function + fn = LLVM.name(f) + prune_constexpr_uses!(f) + @assert isempty(uses(f)) + replace_metadata_uses!(f, new_f) + erase!(f) + LLVM.name!(new_f, fn) + + # clean up the extra conversion block + @dispose pb=NewPMPassBuilder() begin + add!(pb, NewPMFunctionPassManager()) do fpm + add!(fpm, SimplifyCFGPass()) + end + run!(pb, mod) + end + + return functions(mod)[fn] +end + ## LLVM passes diff --git a/src/interface.jl b/src/interface.jl index 21ddcf57..fc65c888 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -272,6 +272,12 @@ kernel_state_type(@nospecialize(job::CompilerJob)) = Nothing # Does the target need to pass kernel arguments by value? pass_by_value(@nospecialize(job::CompilerJob)) = true +# Should the target use byref instead of byval+lower_byval for kernel arguments? +# When true, aggregate arguments are passed as pointers with the byref attribute, +# allowing the backend to load fields directly from the argument memory (e.g. kernarg +# segment on AMDGPU) instead of materializing the entire struct via first-class aggregates. +pass_by_ref(@nospecialize(job::CompilerJob)) = false + # whether pointer is a valid call target valid_function_pointer(@nospecialize(job::CompilerJob), ptr::Ptr{Cvoid}) = false diff --git a/src/irgen.jl b/src/irgen.jl index 5149e9f0..744904e5 100644 --- a/src/irgen.jl +++ b/src/irgen.jl @@ -94,7 +94,11 @@ function irgen(@nospecialize(job::CompilerJob)) for arg in args if arg.cc == BITS_REF llvm_typ = convert(LLVMType, arg.typ) - attr = TypeAttribute("byval", llvm_typ) + if pass_by_ref(job) + attr = TypeAttribute("byref", llvm_typ) + else + attr = TypeAttribute("byval", llvm_typ) + end push!(parameter_attributes(entry, arg.idx), attr) end end diff --git a/test/gcn.jl b/test/gcn.jl index 95641a44..5b49cf59 100644 --- a/test/gcn.jl +++ b/test/gcn.jl @@ -37,6 +37,122 @@ end end end +@testset "kernarg address space for byref parameters" begin + mod = @eval module $(gensym()) + struct MyStruct + x::Float64 + y::Float64 + end + + function kernel(s::MyStruct) + s.x + s.y + return + end + end + + # byref struct params should be ptr addrspace(4) in kernel IR + @test @filecheck begin + check"CHECK: define amdgpu_kernel void @_Z6kernel8MyStruct(ptr addrspace(4)" + GCN.code_llvm(mod.kernel, Tuple{mod.MyStruct}; dump_module=true, kernel=true) + end + + # non-kernel should NOT have addrspace(4) + @test @filecheck begin + check"CHECK-NOT: addrspace(4)" + GCN.code_llvm(mod.kernel, Tuple{mod.MyStruct}; dump_module=true, kernel=false) + end +end + +@testset "byref attribute preserved on kernarg parameters" begin + mod = @eval module $(gensym()) + struct LargeStruct + a::Float64 + b::Float64 + c::Float64 + d::Float64 + end + + function kernel(s::LargeStruct, out::Ptr{Float64}) + unsafe_store!(out, s.a + s.b + s.c + s.d) + return + end + end + + # the byref attribute must survive the addrspace rewrite (clone_into! can drop it) + @test @filecheck begin + check"CHECK: byref" + check"CHECK: addrspace(4)" + GCN.code_llvm(mod.kernel, Tuple{mod.LargeStruct, Ptr{Float64}}; + dump_module=true, kernel=true) + end +end + +@testset "mixed byref and scalar kernel parameters" begin + mod = @eval module $(gensym()) + struct Params + x::Float64 + y::Float64 + end + + function kernel(a::Float64, s::Params, out::Ptr{Float64}) + unsafe_store!(out, a + s.x + s.y) + return + end + end + + # scalar Float64 should NOT be in addrspace(4), + # only the struct byref param should be. + # NOTE: Ptr{Float64} is lowered to i64 on Julia ≤1.11 and ptr on Julia 1.12+. + @test @filecheck begin + check"CHECK: define amdgpu_kernel void" + check"CHECK-SAME: double" + check"CHECK-SAME: ptr addrspace(4)" + check"CHECK-SAME: {{(i64|ptr)}}" + GCN.code_llvm(mod.kernel, Tuple{Float64, mod.Params, Ptr{Float64}}; + dump_module=true, kernel=true) + end +end + +@testset "add_kernarg_address_spaces! rewrites IR correctly" begin + mod = @eval module $(gensym()) + struct KernelArgs + x::Float64 + y::Float64 + z::Float64 + end + + function kernel(s::KernelArgs, scale::Float64, out::Ptr{Float64}) + unsafe_store!(out, (s.x + s.y + s.z) * scale) + return + end + end + + job, _ = GCN.create_job(mod.kernel, Tuple{mod.KernelArgs, Float64, Ptr{Float64}}; + kernel=true) + JuliaContext() do ctx + ir, meta = GPUCompiler.compile(:llvm, job) + + entry = meta.entry + ft = function_type(entry) + params = parameters(ft) + + # the struct byref param should be ptr addrspace(4) + has_as4 = any(p -> p isa LLVM.PointerType && addrspace(p) == 4, params) + @test has_as4 + + # non-struct params (double, and i64/ptr for Ptr{Float64}) should NOT + # be in addrspace(4). Ptr{Float64} is i64 on Julia ≤1.11, ptr on 1.12+. + non_byref = filter(p -> !(p isa LLVM.PointerType && addrspace(p) == 4), params) + @test !isempty(non_byref) # double (and i64 or ptr) params + + # byref attribute must be present + ir_str = string(ir) + @test occursin("byref", ir_str) + + dispose(ir) + end +end + @testset "https://github.com/JuliaGPU/AMDGPU.jl/issues/846" begin ir, rt = GCN.code_typed((Tuple{Tuple{Val{4}}, Tuple{Float32}},); always_inline=true) do t t[1] @@ -49,6 +165,48 @@ end ############################################################################################ @testset "assembly" begin +@testset "s_load for kernarg struct access" begin + mod = @eval module $(gensym()) + struct MyStruct + x::Float64 + y::Float64 + end + + function kernel(s::MyStruct, out::Ptr{Float64}) + unsafe_store!(out, s.x + s.y) + return + end + end + + # struct field loads from kernarg should use s_load, not flat_load + @test @filecheck begin + check"CHECK: s_load_dwordx" + check"CHECK-NOT: flat_load" + GCN.code_native(mod.kernel, Tuple{mod.MyStruct, Ptr{Float64}}; kernel=true) + end +end + +@testset "no scratch spills for small struct kernarg" begin + mod = @eval module $(gensym()) + struct SmallStruct + x::Float64 + y::Float64 + end + + function kernel(s::SmallStruct, out::Ptr{Float64}) + unsafe_store!(out, s.x + s.y) + return + end + end + + # a small struct kernel should not need scratch memory + @test @filecheck begin + check"CHECK: .private_segment_fixed_size: 0" + GCN.code_native(mod.kernel, Tuple{mod.SmallStruct, Ptr{Float64}}; + dump_module=true, kernel=true) + end +end + @testset "skip scalar trap" begin mod = @eval module $(gensym()) workitem_idx_x() = ccall("llvm.amdgcn.workitem.id.x", llvmcall, Int32, ())