Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 0 additions & 10 deletions src/gcn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -255,14 +255,4 @@ function lower_throw_extra!(mod::LLVM.Module)
return changed
end

function emit_trap!(job::CompilerJob{GCNCompilerTarget}, builder, mod, inst)
trap_ft = LLVM.FunctionType(LLVM.VoidType())
trap = if haskey(functions(mod), "llvm.trap")
functions(mod)["llvm.trap"]
else
LLVM.Function(mod, "llvm.trap", trap_ft)
end
call!(builder, trap_ft, trap)
end

can_vectorize(job::CompilerJob{GCNCompilerTarget}) = true
261 changes: 257 additions & 4 deletions src/irgen.jl
Original file line number Diff line number Diff line change
Expand Up @@ -231,9 +231,19 @@ end

# report an exception in a GPU-compatible manner
#
# the exact behavior depends on the debug level. in all cases, a `trap` will be emitted, On
# debug level 1, the exception name will be printed, and on debug level 2 the individual
# stack frames (as recovered from the LLVM debug information) will be printed as well.
# the exact behavior depends on the debug level. in all cases, a `trap` is emitted. on debug
# level 1, the exception name is printed, and on debug level 2 the individual stack frames (as
# recovered from the LLVM debug information) are printed as well.
#
# the `trap` here is *not* the final lowering of the exception: some targets cannot tolerate a
# hardware trap (on Apple M1 compute a `trap` wedges the whole GPU, JuliaGPU/Metal.jl#433; and
# SPIR-V/PoCL have no abort), so those backends strip it post-optimization in
# `lower_unreachable_control_flow!` and let the lane exit via a clean `ret`. the trap must
# nonetheless survive through `optimize!`: it is `noreturn`, and that is what stops InstCombine's
# `removeInstructionsBeforeUnreachable` (which erases instructions preceding an `unreachable`
# while `!mayThrow() && willReturn()`) from deleting the `signal_exception` call below and
# folding away the guarding bounds-check branch. so the trap is the optimizer-correctness guard;
# do not move its removal earlier than post-`optimize!`.
function emit_exception!(builder, name, inst)
job = current_job::CompilerJob
bb = position(builder)
Expand Down Expand Up @@ -264,7 +274,8 @@ function emit_exception!(builder, name, inst)
end
end

# signal the exception
# signal the exception to the host (backend-specific: writes a `KernelState` mailbox).
# the host reads this mailbox after synchronizing.
call!(builder, Runtime.get(:signal_exception))

emit_trap!(job, builder, mod, inst)
Expand All @@ -281,6 +292,248 @@ function emit_trap!(@nospecialize(job::CompilerJob), builder, mod, inst)
end


## unreachable control flow handling

# check if a function contains unreachable control flow
# (`unreachable` terminator or `trap` call)
function has_unreachable_control_flow(f::LLVM.Function)
for bb in blocks(f), inst in instructions(bb)
if isa(inst, LLVM.UnreachableInst)
return true
end
if isa(inst, LLVM.CallInst)
callee = called_operand(inst)
if isa(callee, LLVM.Function) && name(callee) == "llvm.trap"
return true
end
end
end
return false
end

# force-inline every function with unreachable control flow into kernels, so that
# `lower_unreachable_control_flow!` can rewrite it into a `ret` soundly.
#
# this is a fixpoint iteration based on `has_unreachable_control_flow`: each round marks the
# functions that currently contain unreachable control flow and inlines them, which exposes it
# in their callers, until it has all been hoisted up into the kernels. this naturally handles
# the `kernel → A → B` case where only `B` traps: `A` is marked once `B` is inlined into it,
# without us having to reason about call-graph paths.
function inline_unreachable_control_flow!(@nospecialize(job::CompilerJob), mod::LLVM.Module)
changed = false
alwaysinline_attr = EnumAttribute("alwaysinline", 0)
noinline_attr = EnumAttribute("noinline", 0)
kernel_fns = kernels(mod)

@tracepoint "inline unreachable control flow" begin
while true
marked = false
for f in functions(mod)
isdeclaration(f) && continue
# never inline a kernel, and don't bother marking a function with no call sites
# (the inliner can't inline it anyway).
(f in kernel_fns || isempty(uses(f))) && continue
attrs = function_attributes(f)
alwaysinline_attr in collect(attrs) && continue
has_unreachable_control_flow(f) || continue

delete!(attrs, noinline_attr)
push!(attrs, alwaysinline_attr)
marked = true
end
marked || break

@dispose pb=NewPMPassBuilder() begin
add!(pb, AlwaysInlinerPass())
run!(pb, mod, llvm_machine(job.config.target))
end
changed = true
end
end

return changed
end

# lower `trap` to a clean return to get rid of `unreachable` and `noreturn`
#
# this is for compatibility with back-ends that don't support (SPIR-V) or have
# problems with `trap` (Metal on Apple M1 and M2). note that the rewrite is not
# entirely correct: barriers may deadlock if a participating lane has exited.
# however, it's generally not possible to do better without hardware support.
function lower_unreachable_control_flow!(@nospecialize(job::CompilerJob), mod::LLVM.Module)
changed = false
@tracepoint "lower unreachable control flow" begin

# the rewrite below only makes sense in a kernel: a function whose `ret` exits to the host rather
# than to a caller. kernels are the only thing we emit, and their top-level `ret` is what we rely
# on here; everything else is a callee that the inlining below folds into its kernel(s).

# hoist every throwing function up into its kernel(s) first, so that each `unreachable` we
# rewrite below belongs to a kernel whose `ret` is a genuine exit (see the comment above).
changed |= inline_unreachable_control_flow!(job, mod)

# defensively drop any dead leftovers before the back-end sees them. `AlwaysInlinerPass` already
# erases the throwing helpers it fully inlines (they are `internal`, hence discardable), so in
# practice this is a no-op; it is here only to catch dead remnants of partial inlining, since the
# regular `cleanup` DCE ran before `finish_ir!` and won't see anything produced above.
@dispose pb=NewPMPassBuilder() begin
add!(pb, GlobalDCEPass())
run!(pb, mod, llvm_machine(job.config.target))
end

# lower the unreachable control flow, but *only* in the kernels: there, turning an `unreachable`
# into a `ret` is a genuine exit. we deliberately do not touch any other function: one that still
# contains `unreachable`/`trap` after the inlining above is one we couldn't hoist into a kernel
# (recursive or address-taken throwing code), and rewriting its `unreachable` into a `ret` would
# silently resume execution in the caller instead of exiting. we leave it as-is — keeping its
# `trap`/`unreachable`, which the back-end may reject, but that honestly surfaces an unsupported
# construct instead of quietly miscompiling it — and warn.
kernel_fns = kernels(mod)
for f in functions(mod)
isdeclaration(f) && continue
if f in kernel_fns
changed |= lower_unreachable_control_flow!(f)
elseif has_unreachable_control_flow(f) && !isempty(uses(f))
@safe_warn "Cannot lower unreachable control flow in '$(name(f))': it has callers but could not be inlined into a kernel (it is likely recursive or address-taken). Leaving its trap/unreachable in place; this may not be supported by the back-end."
end
end

# scrub every `noreturn` attribute (functions *and* call sites), module-wide. after the rewrite
# above the entry points no longer trap or run off into `unreachable`, but `noreturn` is a
# cached fact that outlives the instructions it was derived from — and a stale `noreturn` lets
# a trusting back-end (Metal's AIR optimizer, the SPIR-V translator) re-derive an
# `unreachable`/`OpUnreachable`/trap right after the call and undo our work. we do this here,
# not per-function, to also reach functions the rewrite skipped: `noreturn` declarations the
# kernel calls, and genuinely-`noreturn` functions (e.g. infinite loops) we left out-of-line.
# dropping it is always safe — it only relaxes an optimization hint; the back-end may re-infer
# it on a function that really never returns, but with no trap to reconstruct that is harmless.
noreturn_attr = EnumAttribute("noreturn", 0)
for f in functions(mod)
delete!(function_attributes(f), noreturn_attr)
for bb in blocks(f), inst in instructions(bb)
isa(inst, LLVM.CallInst) && delete!(function_attributes(inst), noreturn_attr)
end
end

# erase the now-unused `llvm.trap` declaration. guarded by `isempty(uses(...))` so we only
# ever drop it when the calls above are gone (other backends create their own `llvm.trap`
# and never invoke this pass, so theirs is untouched).
if haskey(functions(mod), "llvm.trap")
trap = functions(mod)["llvm.trap"]
if isempty(uses(trap))
erase!(trap)
changed = true
end
end

end
return changed
end

function lower_unreachable_control_flow!(f::LLVM.Function)
changed = false

# Pass 1: strip every `llvm.trap` call, regardless of shape.
for bb in blocks(f), inst in collect(instructions(bb))
if isa(inst, LLVM.CallInst)
callee = called_operand(inst)
if isa(callee, LLVM.Function) && name(callee) == "llvm.trap"
erase!(inst)
changed = true
end
end
end

# Pass 2: lower every `unreachable` terminator to a branch to a return
# block. this also covers `unreachable` not preceded by a trap.
unreachables = Instruction[]
exit_blocks = BasicBlock[]
for bb in blocks(f), inst in instructions(bb)
if isa(inst, LLVM.UnreachableInst)
push!(unreachables, inst)
end
if isa(inst, LLVM.RetInst)
push!(exit_blocks, bb)
end
end
isempty(unreachables) && return changed

@dispose builder=IRBuilder() begin
local return_block
if isempty(exit_blocks)
# the function has no normal return (e.g. a kernel whose only path is a `throw`).
# synthesize a return block so we can turn the `unreachable` into a clean return.
return_block = BasicBlock(f, "ret")
position!(builder, return_block)
rt = return_type(function_type(f))
if rt == LLVM.VoidType()
ret!(builder)
else
ret!(builder, UndefValue(rt))
end
else
# if we have multiple exit blocks, take the last one, which is hopefully the least
# divergent (assuming divergent control flow is the root of the problem here).
exit_block = last(exit_blocks)
ret = terminator(exit_block)

# create a return block with only the return instruction, so that we only have to
# care about any values returned, and not about any other SSA value in the block.
if first(instructions(exit_block)) == ret
# we can reuse the exit block if it only contains the return
return_block = exit_block
else
# split the exit block right before the ret
return_block = BasicBlock(f, "ret")
move_after(return_block, exit_block)

# emit a branch
position!(builder, ret)
br!(builder, return_block)

# move the return
remove!(ret)
position!(builder, return_block)
insert!(builder, ret)
end

# when returning a value, add a phi node to the return block, so that we can later
# add incoming undef values when branching from `unreachable` blocks
if !isempty(operands(ret))
position!(builder, ret)
# XXX: support aggregate returns?
val = only(operands(ret))
phi = phi!(builder, value_type(val))
for pred in predecessors(return_block)
push!(incoming(phi), (val, pred))
end
operands(ret)[1] = phi
end
end

# replace the unreachable with a branch to the return block
for unreachable in unreachables
bb = LLVM.parent(unreachable)

position!(builder, unreachable)
br!(builder, return_block)
erase!(unreachable)

# patch up any phi nodes in the return block
for inst in instructions(return_block)
if isa(inst, LLVM.PHIInst)
undef = UndefValue(value_type(inst))
vals = incoming(inst)
push!(vals, (undef, bb))
end
end
end
end

return true
end


## kernel promotion

@enum ArgumentCC begin
Expand Down
Loading
Loading