diff --git a/src/metal.jl b/src/metal.jl index 7cf9115d..9f2541f1 100644 --- a/src/metal.jl +++ b/src/metal.jl @@ -171,6 +171,11 @@ function finish_ir!(@nospecialize(job::CompilerJob{MetalCompilerTarget}), mod::L entry = add_parameter_address_spaces!(job, mod, entry) entry = add_global_address_spaces!(job, mod, entry) + # narrow generic pointer parameters whose callers all pass a specific-AS pointer, so + # the constant globals read by out-of-line runtime functions (e.g. the exception + # reporters) load from the constant space rather than crashing Metal's validator. + propagate_argument_address_spaces!(mod) + # propagate specific address spaces through addrspacecast chains introduced # by the rewrites above, so that loads/stores happen in the right address # space (e.g. constant globals in addrspace 2 rather than via a cast to 0, @@ -442,6 +447,219 @@ function add_global_address_spaces!(@nospecialize(job::CompilerJob), mod::LLVM.M end +# interprocedural address-space narrowing +# +# `InferAddressSpaces` rewrites a generic (flat) load/store into a concrete address space +# when it can trace the pointer back to an `addrspacecast` from that space, but only within +# one function. A pointer crossing a call boundary as a generic parameter loses that +# provenance: a constant global passed to an out-of-line runtime function (the exception +# reporters take `Ptr` arguments) arrives generic and is read with a generic-space load, +# which crashes Metal's shader validator. +# +# This pass is the interprocedural complement. When every caller passes the same kind of +# value for a generic pointer parameter, `addrspacecast( -> +# generic)`, it retargets the parameter to that space, drops the casts at the call sites, +# and casts back to generic on entry so the body is unchanged. That only relocates a +# side-effect-free cast across the boundary, so it is trivially correct; the following +# `InferAddressSpaces` run folds the entry cast away. The source need not be a constant +# global; any pointer with a known address space qualifies, so any back-end can run it. +# +# Narrowing one function makes its body forward an `addrspacecast`-from-specific to the +# functions it calls, exposing them in turn. We therefore iterate to a fixed point so a +# constant reaches an arbitrarily deep callee (e.g. an exception reporter that delegates to +# another) regardless of the order functions are visited in. This terminates: each sweep +# that changes anything strictly reduces the number of generic pointer parameters in the +# module, and narrowing never introduces a new one. + +# If `v` is an `addrspacecast` (instruction or constant expression) of a pointer from a +# specific (non-generic) address space to the generic one, return that source pointer; +# otherwise `nothing`. +function addrspacecast_to_generic_source(@nospecialize(v)) + (v isa LLVM.Instruction || v isa LLVM.ConstantExpr) || return nothing + opcode(v) == LLVM.API.LLVMAddrSpaceCast || return nothing + addrspace(value_type(v)) == 0 || return nothing + src = operands(v)[1] + (value_type(src) isa LLVM.PointerType && addrspace(value_type(src)) != 0) || + return nothing + return src +end + +function propagate_argument_address_spaces!(mod::LLVM.Module) + changed = false + while propagate_argument_address_spaces_once!(mod) + changed = true + end + return changed +end + +# a single narrowing sweep over the module; returns whether anything changed. +function propagate_argument_address_spaces_once!(mod::LLVM.Module) + changed = false + for f in collect(functions(mod)) + isempty(blocks(f)) && continue # only functions we can rewrite (have a body) + + # rewriting a signature is only sound with no callers outside the module, so require + # local (internal/private) linkage. by `finish_ir!` the pipeline has internalized + # everything but the kernel entrypoints, so the runtime helpers we target qualify. + linkage(f) in (LLVM.API.LLVMInternalLinkage, LLVM.API.LLVMPrivateLinkage) || continue + + param_types = parameters(function_type(f)) + + # collect call sites; bail unless every use is a direct call we can update + callsites = LLVM.CallInst[] + only_calls = true + for use in uses(f) + v = user(use) + if v isa LLVM.CallInst && called_operand(v) == f + push!(callsites, v) + else + only_calls = false + break + end + end + (only_calls && !isempty(callsites)) || continue + + # for each generic pointer parameter, find the address space its callers agree on + new_addrspaces = fill(-1, length(param_types)) + for (i, pty) in enumerate(param_types) + (pty isa LLVM.PointerType && addrspace(pty) == 0) || continue + as = -1 + for cs in callsites + src = addrspacecast_to_generic_source(arguments(cs)[i]) + if src === nothing + as = -1; break + end + src_as = addrspace(value_type(src)) + as == -1 ? (as = src_as) : (as == src_as || (as = -1; break)) + end + as > 0 && (new_addrspaces[i] = as) + end + any(>=(0), new_addrspaces) || continue + + narrow_pointer_parameters!(mod, f, new_addrspaces, callsites) + changed = true + end + return changed +end + +# copy the call-site attributes (function/return/per-argument) from `src` onto `dst`. the +# narrowing keeps argument positions unchanged, so they map across one-to-one. +function copy_callsite_attributes!(dst::LLVM.CallInst, src::LLVM.CallInst) + for attr in collect(function_attributes(src)) + push!(function_attributes(dst), attr) + end + for attr in collect(return_attributes(src)) + push!(return_attributes(dst), attr) + end + for i in 1:length(arguments(src)) + for attr in collect(argument_attributes(src, i)) + push!(argument_attributes(dst, i), attr) + end + end + return dst +end + +# rewrite a single call so it targets `new_f`/`new_ft`, passing the un-casted source value +# for each retargeted argument (and the original argument otherwise). Preserves calling +# convention, operand bundles and attributes; replaces and erases the old call. +function rewrite_narrowed_call!(builder::IRBuilder, cs::LLVM.CallInst, + new_f::LLVM.Function, new_ft::LLVM.FunctionType, + new_addrspaces::Vector{Int}) + position!(builder, cs) + new_args = LLVM.Value[new_addrspaces[i] >= 0 ? + addrspacecast_to_generic_source(arg) : arg + for (i, arg) in enumerate(arguments(cs))] + new_call = call!(builder, new_ft, new_f, new_args, operand_bundles(cs)) + callconv!(new_call, callconv(cs)) + copy_callsite_attributes!(new_call, cs) + replace_uses!(cs, new_call) + erase!(cs) + return new_call +end + +# Clone `f` with the pointer parameters listed in `new_addrspaces` (index => address space, +# `-1` to leave alone) retargeted to those address spaces, casting each retargeted parameter +# back to generic on entry so the cloned body is unchanged. Rewrite `callsites` to pass the +# un-casted source value for each retargeted argument; recursive self-calls are handled too. +function narrow_pointer_parameters!(mod::LLVM.Module, f::LLVM.Function, + new_addrspaces::Vector{Int}, callsites) + ft = function_type(f) + retarget(pty::LLVM.PointerType, as::Integer) = + supports_typed_pointers(context()) ? LLVM.PointerType(eltype(pty), as) : + LLVM.PointerType(as) + new_types = LLVM.LLVMType[new_addrspaces[i] >= 0 ? + retarget(param_typ::LLVM.PointerType, new_addrspaces[i]) : + param_typ + for (i, param_typ) in enumerate(parameters(ft))] + new_ft = LLVM.FunctionType(return_type(ft), new_types) + + new_f = LLVM.Function(mod, "", new_ft) + linkage!(new_f, linkage(f)) + callconv!(new_f, callconv(f)) + for (old_arg, new_arg) in zip(parameters(f), parameters(new_f)) + LLVM.name!(new_arg, LLVM.name(old_arg)) + end + + # cast each retargeted parameter back to generic so the cloned body keeps using it + # unchanged (InferAddressSpaces folds the cast away afterwards) + @dispose builder=IRBuilder() begin + entry = BasicBlock(new_f, "conversion") + position!(builder, entry) + new_args = LLVM.Value[] + for (i, param_typ) in enumerate(parameters(ft)) + if new_addrspaces[i] >= 0 + push!(new_args, addrspacecast!(builder, parameters(new_f)[i], param_typ)) + else + push!(new_args, parameters(new_f)[i]) + end + end + + value_map = Dict{LLVM.Value, LLVM.Value}( + param => new_args[i] for (i, param) in enumerate(parameters(f))) + value_map[f] = new_f + clone_into!(new_f, f; value_map, + changes=LLVM.API.LLVMCloneFunctionChangeTypeGlobalChanges) + + br!(builder, blocks(new_f)[2]) # fall through to the cloned entry block + end + + # `clone_into!` copies a parameter's attributes only when it maps to a new argument; the + # retargeted ones map to the entry addrspacecast instead, so theirs are dropped. Reattach + # them; they stay valid on the narrowed pointer, and non-retargeted params keep theirs. + for i in 1:length(new_addrspaces) + new_addrspaces[i] >= 0 || continue + for attr in collect(parameter_attributes(f, i)) + push!(parameter_attributes(new_f, i), attr) + end + end + + # a (directly) recursive `f` has self-calls that cloning retargeted to `new_f` but left + # with the old signature; collect them from the clone for rewriting. collect first, since + # the rewritten calls also target `new_f` and must not be revisited. + self_calls = LLVM.CallInst[] + for bb in blocks(new_f), inst in instructions(bb) + inst isa LLVM.CallInst && called_operand(inst) == new_f && push!(self_calls, inst) + end + + # rewrite call sites to pass the un-casted source value for each retargeted argument + @dispose builder=IRBuilder() begin + for cs in callsites + rewrite_narrowed_call!(builder, cs, new_f, new_ft, new_addrspaces) + end + for cs in self_calls + rewrite_narrowed_call!(builder, cs, new_f, new_ft, new_addrspaces) + end + end + + fn = LLVM.name(f) + @assert isempty(uses(f)) # every use was a call site we just rewrote + replace_metadata_uses!(f, new_f) + erase!(f) + LLVM.name!(new_f, fn) + return new_f +end + + # value-to-reference conversion # # Metal doesn't support passing values, so we need to convert those to references instead diff --git a/test/metal.jl b/test/metal.jl index 87539b26..ae837ccd 100644 --- a/test/metal.jl +++ b/test/metal.jl @@ -208,4 +208,177 @@ end end +@testset "argument address-space narrowing" begin + # pointer type in address space `as`, typed- and opaque-pointer compatible + asptr(as) = supports_typed_pointers() ? LLVM.PointerType(LLVM.Int8Type(), as) : + LLVM.PointerType(as) + + # build a module with an internal `callee` that loads through a generic (AS 0) pointer + # parameter, reached from one `caller` per entry in `caller_src_as`, each passing a + # constant global in that address space cast to generic. + function narrowing_module(caller_src_as::Vector{Int}; + callee_linkage=LLVM.API.LLVMInternalLinkage, + recursive=false, address_taken=false) + mod = LLVM.Module("test") + i8 = LLVM.Int8Type() + callee_ft = LLVM.FunctionType(i8, LLVM.LLVMType[asptr(0)]) + callee = LLVM.Function(mod, "callee", callee_ft) + linkage!(callee, callee_linkage) + @dispose builder=IRBuilder() begin + position!(builder, BasicBlock(callee, "entry")) + v = load!(builder, i8, parameters(callee)[1]) + if recursive + # a (would-be infinite) self-call passing a constant global, only to + # exercise the recursion path; not meant to run. + g = GlobalVariable(mod, i8, "gself", caller_src_as[1]) + initializer!(g, ConstantInt(i8, 7)); constant!(g, true) + call!(builder, callee_ft, callee, [const_addrspacecast(g, asptr(0))]) + end + ret!(builder, v) + end + for (n, as) in enumerate(caller_src_as) + g = GlobalVariable(mod, i8, "g$n", as) + initializer!(g, ConstantInt(i8, n)); constant!(g, true) + caller = LLVM.Function(mod, "caller$n", LLVM.FunctionType(i8, LLVM.LLVMType[])) + linkage!(caller, LLVM.API.LLVMInternalLinkage) + @dispose builder=IRBuilder() begin + position!(builder, BasicBlock(caller, "entry")) + ret!(builder, call!(builder, callee_ft, callee, + [const_addrspacecast(g, asptr(0))])) + end + end + if address_taken + # a non-call use of the callee: stash its address in a global + initializer!(GlobalVariable(mod, value_type(callee), "fp"), callee) + end + return mod + end + + callee_param_as(mod) = addrspace(parameters(function_type(functions(mod)["callee"]))[1]) + function calls_to(mod, fname) + f = functions(mod)[fname] + [inst for g in functions(mod) for bb in blocks(g) for inst in instructions(bb) + if inst isa LLVM.CallInst && called_operand(inst) == f] + end + + # all callers agree -> the parameter is narrowed; attributes survive; IR stays valid + Context() do ctx + mod = narrowing_module([2, 2]) + callee = functions(mod)["callee"] + push!(parameter_attributes(callee, 1), EnumAttribute("nonnull", 0)) + push!(function_attributes(callee), EnumAttribute("nounwind", 0)) + + @test GPUCompiler.propagate_argument_address_spaces!(mod) + @test callee_param_as(mod) == 2 + @test all(c -> addrspace(value_type(arguments(c)[1])) == 2, calls_to(mod, "callee")) + + callee = functions(mod)["callee"] + @test kind(EnumAttribute("nonnull", 0)) in kind.(collect(parameter_attributes(callee, 1))) + @test kind(EnumAttribute("nounwind", 0)) in kind.(collect(function_attributes(callee))) + @test (verify(mod); true) + end + + # callers disagree on the source address space -> left alone + Context() do ctx + mod = narrowing_module([2, 1]) + @test !GPUCompiler.propagate_argument_address_spaces!(mod) + @test callee_param_as(mod) == 0 + end + + # the callee's address is taken (a non-call use) -> left alone + Context() do ctx + mod = narrowing_module([2]; address_taken=true) + @test !GPUCompiler.propagate_argument_address_spaces!(mod) + @test callee_param_as(mod) == 0 + end + + # externally-visible callee -> left alone (its signature may be observed elsewhere) + Context() do ctx + mod = narrowing_module([2]; callee_linkage=LLVM.API.LLVMExternalLinkage) + @test !GPUCompiler.propagate_argument_address_spaces!(mod) + @test callee_param_as(mod) == 0 + end + + # a self-recursive callee is narrowed and the self-call rewritten to stay well-typed: + # every call to it (recursive included) must now pass the constant-space pointer + Context() do ctx + mod = narrowing_module([2]; recursive=true) + @test GPUCompiler.propagate_argument_address_spaces!(mod) + @test callee_param_as(mod) == 2 + @test length(calls_to(mod, "callee")) == 2 + @test all(c -> addrspace(value_type(arguments(c)[1])) == 2, calls_to(mod, "callee")) + @test (verify(mod); true) + end + + # the source need not be a global: a device pointer (AS 1) threaded through a helper + # as a generic pointer is narrowed to AS 1 just the same + Context() do ctx + mod = LLVM.Module("test") + i8 = LLVM.Int8Type() + callee_ft = LLVM.FunctionType(i8, LLVM.LLVMType[asptr(0)]) + callee = LLVM.Function(mod, "callee", callee_ft) + linkage!(callee, LLVM.API.LLVMInternalLinkage) + @dispose builder=IRBuilder() begin + position!(builder, BasicBlock(callee, "entry")) + ret!(builder, load!(builder, i8, parameters(callee)[1])) + end + caller = LLVM.Function(mod, "caller", LLVM.FunctionType(i8, LLVM.LLVMType[asptr(1)])) + linkage!(caller, LLVM.API.LLVMInternalLinkage) + @dispose builder=IRBuilder() begin + position!(builder, BasicBlock(caller, "entry")) + gen = addrspacecast!(builder, parameters(caller)[1], asptr(0)) + ret!(builder, call!(builder, callee_ft, callee, [gen])) + end + + @test GPUCompiler.propagate_argument_address_spaces!(mod) + @test callee_param_as(mod) == 1 + @test (verify(mod); true) + end + + # a two-level delegation chain (caller -> mid -> leaf) needs the fixpoint: one sweep + # narrows `mid` (its caller passes a constant global), which only then exposes `leaf`, + # since `mid` now forwards an addrspacecast-from-constant. iterate until both narrow. + Context() do ctx + mod = LLVM.Module("test") + i8 = LLVM.Int8Type() + ft = LLVM.FunctionType(i8, LLVM.LLVMType[asptr(0)]) + param_as(name) = addrspace(parameters(function_type(functions(mod)[name]))[1]) + + # leaf: loads through its generic pointer parameter + leaf = LLVM.Function(mod, "leaf", ft) + linkage!(leaf, LLVM.API.LLVMInternalLinkage) + @dispose builder=IRBuilder() begin + position!(builder, BasicBlock(leaf, "entry")) + ret!(builder, load!(builder, i8, parameters(leaf)[1])) + end + + # mid: forwards its generic pointer parameter to leaf + mid = LLVM.Function(mod, "mid", ft) + linkage!(mid, LLVM.API.LLVMInternalLinkage) + @dispose builder=IRBuilder() begin + position!(builder, BasicBlock(mid, "entry")) + ret!(builder, call!(builder, ft, leaf, [parameters(mid)[1]])) + end + + # caller: passes a constant global (AS 2) cast to generic into mid + g = GlobalVariable(mod, i8, "g", 2) + initializer!(g, ConstantInt(i8, 1)); constant!(g, true) + caller = LLVM.Function(mod, "caller", LLVM.FunctionType(i8, LLVM.LLVMType[])) + linkage!(caller, LLVM.API.LLVMInternalLinkage) + @dispose builder=IRBuilder() begin + position!(builder, BasicBlock(caller, "entry")) + ret!(builder, call!(builder, ft, mid, [const_addrspacecast(g, asptr(0))])) + end + + # a single sweep reaches only `mid`; the fixpoint must then narrow `leaf` too + @test GPUCompiler.propagate_argument_address_spaces_once!(mod) + @test param_as("mid") == 2 + @test param_as("leaf") == 0 + + @test GPUCompiler.propagate_argument_address_spaces!(mod) + @test param_as("leaf") == 2 + @test (verify(mod); true) + end +end + end