Skip to content
218 changes: 218 additions & 0 deletions src/metal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,11 @@ function finish_ir!(@nospecialize(job::CompilerJob{MetalCompilerTarget}), mod::L
entry = add_parameter_address_spaces!(job, mod, entry)
entry = add_global_address_spaces!(job, mod, entry)

# narrow generic pointer parameters whose callers all pass a specific-AS pointer, so
# the constant globals read by out-of-line runtime functions (e.g. the exception
# reporters) load from the constant space rather than crashing Metal's validator.
propagate_argument_address_spaces!(mod)

# propagate specific address spaces through addrspacecast chains introduced
# by the rewrites above, so that loads/stores happen in the right address
# space (e.g. constant globals in addrspace 2 rather than via a cast to 0,
Expand Down Expand Up @@ -442,6 +447,219 @@ function add_global_address_spaces!(@nospecialize(job::CompilerJob), mod::LLVM.M
end


# interprocedural address-space narrowing
#
# `InferAddressSpaces` rewrites a generic (flat) load/store into a concrete address space
# when it can trace the pointer back to an `addrspacecast` from that space, but only within
# one function. A pointer crossing a call boundary as a generic parameter loses that
# provenance: a constant global passed to an out-of-line runtime function (the exception
# reporters take `Ptr` arguments) arrives generic and is read with a generic-space load,
# which crashes Metal's shader validator.
#
# This pass is the interprocedural complement. When every caller passes the same kind of
# value for a generic pointer parameter, `addrspacecast(<ptr in a specific space> ->
# generic)`, it retargets the parameter to that space, drops the casts at the call sites,
# and casts back to generic on entry so the body is unchanged. That only relocates a
# side-effect-free cast across the boundary, so it is trivially correct; the following
# `InferAddressSpaces` run folds the entry cast away. The source need not be a constant
# global; any pointer with a known address space qualifies, so any back-end can run it.
#
# Narrowing one function makes its body forward an `addrspacecast`-from-specific to the
# functions it calls, exposing them in turn. We therefore iterate to a fixed point so a
# constant reaches an arbitrarily deep callee (e.g. an exception reporter that delegates to
# another) regardless of the order functions are visited in. This terminates: each sweep
# that changes anything strictly reduces the number of generic pointer parameters in the
# module, and narrowing never introduces a new one.

# If `v` is an `addrspacecast` (instruction or constant expression) of a pointer from a
# specific (non-generic) address space to the generic one, return that source pointer;
# otherwise `nothing`.
function addrspacecast_to_generic_source(@nospecialize(v))
(v isa LLVM.Instruction || v isa LLVM.ConstantExpr) || return nothing
opcode(v) == LLVM.API.LLVMAddrSpaceCast || return nothing
addrspace(value_type(v)) == 0 || return nothing
src = operands(v)[1]
(value_type(src) isa LLVM.PointerType && addrspace(value_type(src)) != 0) ||
return nothing
return src
end

function propagate_argument_address_spaces!(mod::LLVM.Module)
changed = false
while propagate_argument_address_spaces_once!(mod)
changed = true
end
return changed
end

# a single narrowing sweep over the module; returns whether anything changed.
function propagate_argument_address_spaces_once!(mod::LLVM.Module)
changed = false
for f in collect(functions(mod))
isempty(blocks(f)) && continue # only functions we can rewrite (have a body)

# rewriting a signature is only sound with no callers outside the module, so require
# local (internal/private) linkage. by `finish_ir!` the pipeline has internalized
# everything but the kernel entrypoints, so the runtime helpers we target qualify.
linkage(f) in (LLVM.API.LLVMInternalLinkage, LLVM.API.LLVMPrivateLinkage) || continue

param_types = parameters(function_type(f))

# collect call sites; bail unless every use is a direct call we can update
callsites = LLVM.CallInst[]
only_calls = true
for use in uses(f)
v = user(use)
if v isa LLVM.CallInst && called_operand(v) == f
push!(callsites, v)
else
only_calls = false
break
end
end
(only_calls && !isempty(callsites)) || continue

# for each generic pointer parameter, find the address space its callers agree on
new_addrspaces = fill(-1, length(param_types))
for (i, pty) in enumerate(param_types)
(pty isa LLVM.PointerType && addrspace(pty) == 0) || continue
as = -1
for cs in callsites
src = addrspacecast_to_generic_source(arguments(cs)[i])
if src === nothing
as = -1; break
end
src_as = addrspace(value_type(src))
as == -1 ? (as = src_as) : (as == src_as || (as = -1; break))
end
as > 0 && (new_addrspaces[i] = as)
end
any(>=(0), new_addrspaces) || continue

narrow_pointer_parameters!(mod, f, new_addrspaces, callsites)
changed = true
end
return changed
end

# copy the call-site attributes (function/return/per-argument) from `src` onto `dst`. the
# narrowing keeps argument positions unchanged, so they map across one-to-one.
function copy_callsite_attributes!(dst::LLVM.CallInst, src::LLVM.CallInst)
for attr in collect(function_attributes(src))
push!(function_attributes(dst), attr)
end
for attr in collect(return_attributes(src))
push!(return_attributes(dst), attr)
end
for i in 1:length(arguments(src))
for attr in collect(argument_attributes(src, i))
push!(argument_attributes(dst, i), attr)
end
end
return dst
end

# rewrite a single call so it targets `new_f`/`new_ft`, passing the un-casted source value
# for each retargeted argument (and the original argument otherwise). Preserves calling
# convention, operand bundles and attributes; replaces and erases the old call.
function rewrite_narrowed_call!(builder::IRBuilder, cs::LLVM.CallInst,
new_f::LLVM.Function, new_ft::LLVM.FunctionType,
new_addrspaces::Vector{Int})
position!(builder, cs)
new_args = LLVM.Value[new_addrspaces[i] >= 0 ?
addrspacecast_to_generic_source(arg) : arg
for (i, arg) in enumerate(arguments(cs))]
new_call = call!(builder, new_ft, new_f, new_args, operand_bundles(cs))
callconv!(new_call, callconv(cs))
copy_callsite_attributes!(new_call, cs)
replace_uses!(cs, new_call)
erase!(cs)
return new_call
end

# Clone `f` with the pointer parameters listed in `new_addrspaces` (index => address space,
# `-1` to leave alone) retargeted to those address spaces, casting each retargeted parameter
# back to generic on entry so the cloned body is unchanged. Rewrite `callsites` to pass the
# un-casted source value for each retargeted argument; recursive self-calls are handled too.
function narrow_pointer_parameters!(mod::LLVM.Module, f::LLVM.Function,
new_addrspaces::Vector{Int}, callsites)
ft = function_type(f)
retarget(pty::LLVM.PointerType, as::Integer) =
supports_typed_pointers(context()) ? LLVM.PointerType(eltype(pty), as) :
LLVM.PointerType(as)
new_types = LLVM.LLVMType[new_addrspaces[i] >= 0 ?
retarget(param_typ::LLVM.PointerType, new_addrspaces[i]) :
param_typ
for (i, param_typ) in enumerate(parameters(ft))]
new_ft = LLVM.FunctionType(return_type(ft), new_types)

new_f = LLVM.Function(mod, "", new_ft)
linkage!(new_f, linkage(f))
callconv!(new_f, callconv(f))
for (old_arg, new_arg) in zip(parameters(f), parameters(new_f))
LLVM.name!(new_arg, LLVM.name(old_arg))
end

# cast each retargeted parameter back to generic so the cloned body keeps using it
# unchanged (InferAddressSpaces folds the cast away afterwards)
@dispose builder=IRBuilder() begin
entry = BasicBlock(new_f, "conversion")
position!(builder, entry)
new_args = LLVM.Value[]
for (i, param_typ) in enumerate(parameters(ft))
if new_addrspaces[i] >= 0
push!(new_args, addrspacecast!(builder, parameters(new_f)[i], param_typ))
else
push!(new_args, parameters(new_f)[i])
end
end

value_map = Dict{LLVM.Value, LLVM.Value}(
param => new_args[i] for (i, param) in enumerate(parameters(f)))
value_map[f] = new_f
clone_into!(new_f, f; value_map,
changes=LLVM.API.LLVMCloneFunctionChangeTypeGlobalChanges)

br!(builder, blocks(new_f)[2]) # fall through to the cloned entry block
end

# `clone_into!` copies a parameter's attributes only when it maps to a new argument; the
# retargeted ones map to the entry addrspacecast instead, so theirs are dropped. Reattach
# them; they stay valid on the narrowed pointer, and non-retargeted params keep theirs.
for i in 1:length(new_addrspaces)
new_addrspaces[i] >= 0 || continue
for attr in collect(parameter_attributes(f, i))
push!(parameter_attributes(new_f, i), attr)
end
end

# a (directly) recursive `f` has self-calls that cloning retargeted to `new_f` but left
# with the old signature; collect them from the clone for rewriting. collect first, since
# the rewritten calls also target `new_f` and must not be revisited.
self_calls = LLVM.CallInst[]
for bb in blocks(new_f), inst in instructions(bb)
inst isa LLVM.CallInst && called_operand(inst) == new_f && push!(self_calls, inst)
end

# rewrite call sites to pass the un-casted source value for each retargeted argument
@dispose builder=IRBuilder() begin
for cs in callsites
rewrite_narrowed_call!(builder, cs, new_f, new_ft, new_addrspaces)
end
for cs in self_calls
rewrite_narrowed_call!(builder, cs, new_f, new_ft, new_addrspaces)
end
end

fn = LLVM.name(f)
@assert isempty(uses(f)) # every use was a call site we just rewrote
replace_metadata_uses!(f, new_f)
erase!(f)
LLVM.name!(new_f, fn)
return new_f
end


# value-to-reference conversion
#
# Metal doesn't support passing values, so we need to convert those to references instead
Expand Down
Loading
Loading