diff --git a/src/gcn.jl b/src/gcn.jl index 8cc0ef56..52e52a2f 100644 --- a/src/gcn.jl +++ b/src/gcn.jl @@ -255,14 +255,4 @@ function lower_throw_extra!(mod::LLVM.Module) return changed end -function emit_trap!(job::CompilerJob{GCNCompilerTarget}, builder, mod, inst) - trap_ft = LLVM.FunctionType(LLVM.VoidType()) - trap = if haskey(functions(mod), "llvm.trap") - functions(mod)["llvm.trap"] - else - LLVM.Function(mod, "llvm.trap", trap_ft) - end - call!(builder, trap_ft, trap) -end - can_vectorize(job::CompilerJob{GCNCompilerTarget}) = true diff --git a/src/irgen.jl b/src/irgen.jl index 2c2eff96..ec74ce6a 100644 --- a/src/irgen.jl +++ b/src/irgen.jl @@ -231,9 +231,19 @@ end # report an exception in a GPU-compatible manner # -# the exact behavior depends on the debug level. in all cases, a `trap` will be emitted, On -# debug level 1, the exception name will be printed, and on debug level 2 the individual -# stack frames (as recovered from the LLVM debug information) will be printed as well. +# the exact behavior depends on the debug level. in all cases, a `trap` is emitted. on debug +# level 1, the exception name is printed, and on debug level 2 the individual stack frames (as +# recovered from the LLVM debug information) are printed as well. +# +# the `trap` here is *not* the final lowering of the exception: some targets cannot tolerate a +# hardware trap (on Apple M1 compute a `trap` wedges the whole GPU, JuliaGPU/Metal.jl#433; and +# SPIR-V/PoCL have no abort), so those backends strip it post-optimization in +# `lower_unreachable_control_flow!` and let the lane exit via a clean `ret`. the trap must +# nonetheless survive through `optimize!`: it is `noreturn`, and that is what stops InstCombine's +# `removeInstructionsBeforeUnreachable` (which erases instructions preceding an `unreachable` +# while `!mayThrow() && willReturn()`) from deleting the `signal_exception` call below and +# folding away the guarding bounds-check branch. so the trap is the optimizer-correctness guard; +# do not move its removal earlier than post-`optimize!`. function emit_exception!(builder, name, inst) job = current_job::CompilerJob bb = position(builder) @@ -264,7 +274,8 @@ function emit_exception!(builder, name, inst) end end - # signal the exception + # signal the exception to the host (backend-specific: writes a `KernelState` mailbox). + # the host reads this mailbox after synchronizing. call!(builder, Runtime.get(:signal_exception)) emit_trap!(job, builder, mod, inst) @@ -281,6 +292,248 @@ function emit_trap!(@nospecialize(job::CompilerJob), builder, mod, inst) end +## unreachable control flow handling + +# check if a function contains unreachable control flow +# (`unreachable` terminator or `trap` call) +function has_unreachable_control_flow(f::LLVM.Function) + for bb in blocks(f), inst in instructions(bb) + if isa(inst, LLVM.UnreachableInst) + return true + end + if isa(inst, LLVM.CallInst) + callee = called_operand(inst) + if isa(callee, LLVM.Function) && name(callee) == "llvm.trap" + return true + end + end + end + return false +end + +# force-inline every function with unreachable control flow into kernels, so that +# `lower_unreachable_control_flow!` can rewrite it into a `ret` soundly. +# +# this is a fixpoint iteration based on `has_unreachable_control_flow`: each round marks the +# functions that currently contain unreachable control flow and inlines them, which exposes it +# in their callers, until it has all been hoisted up into the kernels. this naturally handles +# the `kernel → A → B` case where only `B` traps: `A` is marked once `B` is inlined into it, +# without us having to reason about call-graph paths. +function inline_unreachable_control_flow!(@nospecialize(job::CompilerJob), mod::LLVM.Module) + changed = false + alwaysinline_attr = EnumAttribute("alwaysinline", 0) + noinline_attr = EnumAttribute("noinline", 0) + kernel_fns = kernels(mod) + + @tracepoint "inline unreachable control flow" begin + while true + marked = false + for f in functions(mod) + isdeclaration(f) && continue + # never inline a kernel, and don't bother marking a function with no call sites + # (the inliner can't inline it anyway). + (f in kernel_fns || isempty(uses(f))) && continue + attrs = function_attributes(f) + alwaysinline_attr in collect(attrs) && continue + has_unreachable_control_flow(f) || continue + + delete!(attrs, noinline_attr) + push!(attrs, alwaysinline_attr) + marked = true + end + marked || break + + @dispose pb=NewPMPassBuilder() begin + add!(pb, AlwaysInlinerPass()) + run!(pb, mod, llvm_machine(job.config.target)) + end + changed = true + end + end + + return changed +end + +# lower `trap` to a clean return to get rid of `unreachable` and `noreturn` +# +# this is for compatibility with back-ends that don't support (SPIR-V) or have +# problems with `trap` (Metal on Apple M1 and M2). note that the rewrite is not +# entirely correct: barriers may deadlock if a participating lane has exited. +# however, it's generally not possible to do better without hardware support. +function lower_unreachable_control_flow!(@nospecialize(job::CompilerJob), mod::LLVM.Module) + changed = false + @tracepoint "lower unreachable control flow" begin + + # the rewrite below only makes sense in a kernel: a function whose `ret` exits to the host rather + # than to a caller. kernels are the only thing we emit, and their top-level `ret` is what we rely + # on here; everything else is a callee that the inlining below folds into its kernel(s). + + # hoist every throwing function up into its kernel(s) first, so that each `unreachable` we + # rewrite below belongs to a kernel whose `ret` is a genuine exit (see the comment above). + changed |= inline_unreachable_control_flow!(job, mod) + + # defensively drop any dead leftovers before the back-end sees them. `AlwaysInlinerPass` already + # erases the throwing helpers it fully inlines (they are `internal`, hence discardable), so in + # practice this is a no-op; it is here only to catch dead remnants of partial inlining, since the + # regular `cleanup` DCE ran before `finish_ir!` and won't see anything produced above. + @dispose pb=NewPMPassBuilder() begin + add!(pb, GlobalDCEPass()) + run!(pb, mod, llvm_machine(job.config.target)) + end + + # lower the unreachable control flow, but *only* in the kernels: there, turning an `unreachable` + # into a `ret` is a genuine exit. we deliberately do not touch any other function: one that still + # contains `unreachable`/`trap` after the inlining above is one we couldn't hoist into a kernel + # (recursive or address-taken throwing code), and rewriting its `unreachable` into a `ret` would + # silently resume execution in the caller instead of exiting. we leave it as-is — keeping its + # `trap`/`unreachable`, which the back-end may reject, but that honestly surfaces an unsupported + # construct instead of quietly miscompiling it — and warn. + kernel_fns = kernels(mod) + for f in functions(mod) + isdeclaration(f) && continue + if f in kernel_fns + changed |= lower_unreachable_control_flow!(f) + elseif has_unreachable_control_flow(f) && !isempty(uses(f)) + @safe_warn "Cannot lower unreachable control flow in '$(name(f))': it has callers but could not be inlined into a kernel (it is likely recursive or address-taken). Leaving its trap/unreachable in place; this may not be supported by the back-end." + end + end + + # scrub every `noreturn` attribute (functions *and* call sites), module-wide. after the rewrite + # above the entry points no longer trap or run off into `unreachable`, but `noreturn` is a + # cached fact that outlives the instructions it was derived from — and a stale `noreturn` lets + # a trusting back-end (Metal's AIR optimizer, the SPIR-V translator) re-derive an + # `unreachable`/`OpUnreachable`/trap right after the call and undo our work. we do this here, + # not per-function, to also reach functions the rewrite skipped: `noreturn` declarations the + # kernel calls, and genuinely-`noreturn` functions (e.g. infinite loops) we left out-of-line. + # dropping it is always safe — it only relaxes an optimization hint; the back-end may re-infer + # it on a function that really never returns, but with no trap to reconstruct that is harmless. + noreturn_attr = EnumAttribute("noreturn", 0) + for f in functions(mod) + delete!(function_attributes(f), noreturn_attr) + for bb in blocks(f), inst in instructions(bb) + isa(inst, LLVM.CallInst) && delete!(function_attributes(inst), noreturn_attr) + end + end + + # erase the now-unused `llvm.trap` declaration. guarded by `isempty(uses(...))` so we only + # ever drop it when the calls above are gone (other backends create their own `llvm.trap` + # and never invoke this pass, so theirs is untouched). + if haskey(functions(mod), "llvm.trap") + trap = functions(mod)["llvm.trap"] + if isempty(uses(trap)) + erase!(trap) + changed = true + end + end + + end + return changed +end + +function lower_unreachable_control_flow!(f::LLVM.Function) + changed = false + + # Pass 1: strip every `llvm.trap` call, regardless of shape. + for bb in blocks(f), inst in collect(instructions(bb)) + if isa(inst, LLVM.CallInst) + callee = called_operand(inst) + if isa(callee, LLVM.Function) && name(callee) == "llvm.trap" + erase!(inst) + changed = true + end + end + end + + # Pass 2: lower every `unreachable` terminator to a branch to a return + # block. this also covers `unreachable` not preceded by a trap. + unreachables = Instruction[] + exit_blocks = BasicBlock[] + for bb in blocks(f), inst in instructions(bb) + if isa(inst, LLVM.UnreachableInst) + push!(unreachables, inst) + end + if isa(inst, LLVM.RetInst) + push!(exit_blocks, bb) + end + end + isempty(unreachables) && return changed + + @dispose builder=IRBuilder() begin + local return_block + if isempty(exit_blocks) + # the function has no normal return (e.g. a kernel whose only path is a `throw`). + # synthesize a return block so we can turn the `unreachable` into a clean return. + return_block = BasicBlock(f, "ret") + position!(builder, return_block) + rt = return_type(function_type(f)) + if rt == LLVM.VoidType() + ret!(builder) + else + ret!(builder, UndefValue(rt)) + end + else + # if we have multiple exit blocks, take the last one, which is hopefully the least + # divergent (assuming divergent control flow is the root of the problem here). + exit_block = last(exit_blocks) + ret = terminator(exit_block) + + # create a return block with only the return instruction, so that we only have to + # care about any values returned, and not about any other SSA value in the block. + if first(instructions(exit_block)) == ret + # we can reuse the exit block if it only contains the return + return_block = exit_block + else + # split the exit block right before the ret + return_block = BasicBlock(f, "ret") + move_after(return_block, exit_block) + + # emit a branch + position!(builder, ret) + br!(builder, return_block) + + # move the return + remove!(ret) + position!(builder, return_block) + insert!(builder, ret) + end + + # when returning a value, add a phi node to the return block, so that we can later + # add incoming undef values when branching from `unreachable` blocks + if !isempty(operands(ret)) + position!(builder, ret) + # XXX: support aggregate returns? + val = only(operands(ret)) + phi = phi!(builder, value_type(val)) + for pred in predecessors(return_block) + push!(incoming(phi), (val, pred)) + end + operands(ret)[1] = phi + end + end + + # replace the unreachable with a branch to the return block + for unreachable in unreachables + bb = LLVM.parent(unreachable) + + position!(builder, unreachable) + br!(builder, return_block) + erase!(unreachable) + + # patch up any phi nodes in the return block + for inst in instructions(return_block) + if isa(inst, LLVM.PHIInst) + undef = UndefValue(value_type(inst)) + vals = incoming(inst) + push!(vals, (undef, bb)) + end + end + end + end + + return true +end + + ## kernel promotion @enum ArgumentCC begin diff --git a/src/metal.jl b/src/metal.jl index 3452fe1c..7cf9115d 100644 --- a/src/metal.jl +++ b/src/metal.jl @@ -157,41 +157,6 @@ function validate_ir(job::CompilerJob{MetalCompilerTarget}, mod::LLVM.Module) errors end -# hide `noreturn` function attributes, which cause issues with the back-end compiler, -# probably because of thread-divergent control flow as we've encountered with CUDA. -# note that it isn't enough to remove the function attribute, because the Metal LLVM -# compiler re-optimizes and will rediscover the property. to avoid this, we inline -# all functions that are marked noreturn, i.e., until LLVM cannot rediscover it. -function hide_noreturn!(job::CompilerJob, mod::LLVM.Module) - noreturn_attr = EnumAttribute("noreturn", 0) - noinline_attr = EnumAttribute("noinline", 0) - alwaysinline_attr = EnumAttribute("alwaysinline", 0) - - any_noreturn = false - for f in functions(mod) - attrs = function_attributes(f) - if noreturn_attr in collect(attrs) - delete!(attrs, noreturn_attr) - delete!(attrs, noinline_attr) - push!(attrs, alwaysinline_attr) - any_noreturn = true - end - end - any_noreturn || return false - - @dispose pb=NewPMPassBuilder() begin - LLVM.target_transform_info!(pb, MetalTTI()) - add!(pb, AlwaysInlinerPass()) - add!(pb, NewPMFunctionPassManager()) do fpm - add!(fpm, SimplifyCFGPass()) - add!(fpm, instcombine_pass(job)) - end - run!(pb, mod) - end - - return true -end - function finish_ir!(@nospecialize(job::CompilerJob{MetalCompilerTarget}), mod::LLVM.Module, entry::LLVM.Function) entry_fn = LLVM.name(entry) @@ -227,20 +192,20 @@ function finish_ir!(@nospecialize(job::CompilerJob{MetalCompilerTarget}), mod::L add_module_metadata!(job, mod) end - # JuliaGPU/Metal.jl#113 - hide_noreturn!(job, mod) - - # rewrite unreachable control flow into clean returns. two Apple-specific reasons: - # - JuliaGPU/Metal.jl#370: divergent `unreachable` crashes the back-end (pre-macOS 15). - # - JuliaGPU/Metal.jl#433: device-side exceptions lower to `llvm.trap`, but a compute - # trap wedges the whole Apple GPU (no compute watchdog; only a reboot clears it). - # `replace_unreachable!` strips the trap and turns the throw into a return instead. + # strip device-side `trap`s and rewrite `unreachable` into clean returns (#433, #370). this + # runs post-`optimize!`, after the trap has finished serving as the optimizer guard; the pass + # force-inlines throwing functions into the kernel first so the rewrite is sound, then scrubs + # every `noreturn` attribute. # - # `hide_noreturn!` above must still run first: it drops the `noreturn` attribute (which - # the back-end would otherwise rediscover, #113) and inlines such functions. - for f in functions(mod) - replace_unreachable!(job, f) - end + # this also subsumes the old `hide_noreturn!` workaround for #113 (kernel hangs from divergent + # `noreturn` control flow on older macOS). that bug reduced to a `noinline` helper of the shape + # `trap; unreachable` called divergently, and `hide_noreturn!` worked by force-inlining it; + # this pass inlines the same helper (keying on the `trap`/`unreachable` it contains, not the + # attribute), rewrites its `unreachable` into a clean branch-to-`ret`, and drops the `noreturn`, + # leaving nothing divergent for the back-end to choke on. (the only `noreturn` shape it doesn't + # inline is a genuine infinite loop — but inlining can't make that return either, so + # `hide_noreturn!` never fixed that case to begin with.) + lower_unreachable_control_flow!(job, mod) # lower LLVM intrinsics that AIR doesn't support changed = false @@ -1114,119 +1079,3 @@ function annotate_air_intrinsics!(@nospecialize(job::CompilerJob), mod::LLVM.Mod return changed end - -# replace unreachable control flow (and the trap that precedes it) with a return. -# -# two reasons: -# - before macOS 15, code generated by Julia 1.11 causes compilation failures in the -# back-end: the reduced example contains unreachable control flow executed divergently, -# similar to what we hit on NVIDIA, but causing crashes instead of miscompiles (#370). -# - device-side exceptions lower to a `llvm.trap` followed by `unreachable`, but a compute -# trap wedges the whole Apple GPU (no watchdog; reboot to clear, JuliaGPU/Metal.jl#433). -# -# so we replace `unreachable` (and any immediately preceding `llvm.trap`) by a branch to a -# return block — reusing the function's existing `ret`, or synthesizing one (`ret void`, or -# `ret undef` for value-returning functions) when the function _only_ contains `unreachable`. -# -# this returns from *this function* only (returning undef to the caller), not the whole -# kernel; it is not a true abort. a `threadgroup_barrier` between the throw and the return -# is still skipped by the faulting lane and will deadlock — but that already wedges today -# via the trap, so this is no worse, and it fixes the common (barrier-free) case. swallowed -# exceptions should be surfaced separately via a `signal_exception` host-visible flag. -function replace_unreachable!(@nospecialize(job::CompilerJob), f::LLVM.Function) - # find unreachable instructions and exit blocks - unreachables = Instruction[] - exit_blocks = BasicBlock[] - for bb in blocks(f), inst in instructions(bb) - if isa(inst, LLVM.UnreachableInst) - push!(unreachables, inst) - end - if isa(inst, LLVM.RetInst) - push!(exit_blocks, bb) - end - end - isempty(unreachables) && return false - - @dispose builder=IRBuilder() begin - local return_block - if isempty(exit_blocks) - # the function has no normal return (e.g. a kernel whose only path - # is a `throw`, which lowers to trap + unreachable). synthesize a - # return block so we can still strip the trap and turn the - # `unreachable` into a clean return. - return_block = BasicBlock(f, "ret") - position!(builder, return_block) - rt = return_type(function_type(f)) - if rt == LLVM.VoidType() - ret!(builder) - else - ret!(builder, UndefValue(rt)) - end - else - # if we have multiple exit blocks, take the last one, which is hopefully the least - # divergent (assuming divergent control flow is the root of the problem here). - exit_block = last(exit_blocks) - ret = terminator(exit_block) - - # create a return block with only the return instruction, so that we only have to - # care about any values returned, and not about any other SSA value in the block. - if first(instructions(exit_block)) == ret - # we can reuse the exit block if it only contains the return - return_block = exit_block - else - # split the exit block right before the ret - return_block = BasicBlock(f, "ret") - move_after(return_block, exit_block) - - # emit a branch - position!(builder, ret) - br!(builder, return_block) - - # move the return - remove!(ret) - position!(builder, return_block) - insert!(builder, ret) - end - - # when returning a value, add a phi node to the return block, so that we can later - # add incoming undef values when branching from `unreachable` blocks - if !isempty(operands(ret)) - position!(builder, ret) - # XXX: support aggregate returns? - val = only(operands(ret)) - phi = phi!(builder, value_type(val)) - for pred in predecessors(return_block) - push!(incoming(phi), (val, pred)) - end - operands(ret)[1] = phi - end - end - - # replace the unreachable with a branch to the return block - for unreachable in unreachables - bb = LLVM.parent(unreachable) - - # remove preceding traps to avoid reconstructing unreachable control flow - prev = previnst(unreachable) - if isa(prev, LLVM.CallInst) && name(called_operand(prev)) == "llvm.trap" - erase!(prev) - end - - # replace the unreachable with a branch to the return block - position!(builder, unreachable) - br!(builder, return_block) - erase!(unreachable) - - # patch up any phi nodes in the return block - for inst in instructions(return_block) - if isa(inst, LLVM.PHIInst) - undef = UndefValue(value_type(inst)) - vals = incoming(inst) - push!(vals, (undef, bb)) - end - end - end - end - - return true -end diff --git a/src/spirv.jl b/src/spirv.jl index ef96f4c4..de1a76e8 100644 --- a/src/spirv.jl +++ b/src/spirv.jl @@ -96,6 +96,13 @@ end function finish_ir!(job::CompilerJob{SPIRVCompilerTarget}, mod::LLVM.Module, entry::LLVM.Function) + # SPIR-V has no `trap` and no mechanism to abort a compute kernel (OpKill is fragment-only), + # so strip the device-side `trap` and lower `unreachable` to a clean `ret`. running this here + # (post-`optimize!`) is the correct spot: the trap has finished serving as the optimizer + # guard (see `emit_exception!`), and turning `unreachable` into `ret` also avoids emitting + # OpUnreachable (UB if reached), which PoCL and friends handle poorly. + lower_unreachable_control_flow!(job, mod) + # convert the kernel state argument to a byval reference if job.config.kernel state = kernel_state_type(job) @@ -135,10 +142,6 @@ end # The SPIRV Tools don't handle Julia's debug info, rejecting DW_LANG_Julia... strip_debuginfo!(mod) - # SPIR-V does not support trap, and has no mechanism to abort compute kernels - # (OpKill is only available in fragment execution mode) - rm_trap!(mod) - # the LLVM to SPIR-V translator does not support the freeze instruction # (SPIRV-LLVM-Translator#1140) rm_freeze!(mod) @@ -244,31 +247,6 @@ end ## LLVM passes -# remove llvm.trap and its uses from a module -function rm_trap!(mod::LLVM.Module) - job = current_job::CompilerJob - changed = false - @tracepoint "remove trap" begin - - if haskey(functions(mod), "llvm.trap") - trap = functions(mod)["llvm.trap"] - - for use in uses(trap) - val = user(use) - if isa(val, LLVM.CallInst) - erase!(val) - changed = true - end - end - - @compiler_assert isempty(uses(trap)) job - erase!(trap) - end - - end - return changed -end - # remove freeze and replace uses by the original value # (KhronosGroup/SPIRV-LLVM-Translator#1140) function rm_freeze!(mod::LLVM.Module) diff --git a/test/spirv.jl b/test/spirv.jl index f5ef53fa..b5006194 100644 --- a/test/spirv.jl +++ b/test/spirv.jl @@ -121,12 +121,51 @@ end end end + # at the IR level, `lower_unreachable_control_flow!` must have stripped the device-side + # `llvm.trap` and lowered the throw's `unreachable` into a clean `ret`. + @test @filecheck begin + @check_label "define spir_kernel void @_Z6kernel" + @check_not "llvm.trap" + @check_not "unreachable" + @check "ret void" + SPIRV.code_llvm(mod.kernel, Tuple{Bool}; backend, kernel=true) + end + + # and at the SPIR-V level, no `OpUnreachable` (UB if reached) should survive. @test @filecheck begin @check "%_Z6kernel4Bool = OpFunction %void None" + @check_not "OpUnreachable" SPIRV.code_native(mod.kernel, Tuple{Bool}; backend, kernel=true) end end +@testset "inlining of throwing callees" begin + mod = @eval module $(gensym()) + @noinline function guard(x) + x || error() + return + end + function kernel(x) + guard(x) + return + end + end + + # `guard` throws on one path and returns on the other; rewriting its `unreachable` into a + # `ret` is only sound if `guard` is inlined into the kernel first (otherwise the kernel would + # resume after the call on the throwing path). even though `guard` is `@noinline`, the lowering + # must have force-inlined it: the throw's `signal_exception` now lives in the kernel's own body + # (it would sit in `guard` had it stayed out-of-line), with the trap/unreachable lowered away. + @test @filecheck begin + @check_label "define spir_kernel void @_Z6kernel" + @check "gpu_signal_exception" + @check_not "llvm.trap" + @check_not "unreachable" + @check "ret void" + SPIRV.code_llvm(mod.kernel, Tuple{Bool}; backend, kernel=true) + end +end + end @testset "replace i128 allocas" begin