diff --git a/Project.toml b/Project.toml index 9147b6f7..c64b6df0 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "GPUCompiler" uuid = "61eb1bfa-7361-4325-ad38-22787b887f55" -version = "1.13.1" +version = "1.13.2" authors = ["Tim Besard "] [workspace] diff --git a/src/ptx.jl b/src/ptx.jl index d2c97b21..4a977371 100644 --- a/src/ptx.jl +++ b/src/ptx.jl @@ -262,10 +262,12 @@ function optimize_module!(@nospecialize(job::CompilerJob{PTXCompilerTarget}), # TODO: Use the registered target passes (JuliaGPU/GPUCompiler.jl#450) @dispose pb=NewPMPassBuilder() begin register!(pb, NVVMReflectPass()) + register!(pb, PTXRSqrtFastPass()) register!(pb, PTXFDivFastPass()) register!(pb, PTXFSqrtFastPass()) add!(pb, NVVMReflectPass()) + add!(pb, PTXRSqrtFastPass()) add!(pb, PTXFDivFastPass()) add!(pb, PTXFSqrtFastPass()) @@ -571,23 +573,117 @@ function f32_ftz(f::LLVM.Function) return false end -# Both passes below rewrite `afn`-flagged ops to NVPTX' fast lowerings. +# All three passes below rewrite `afn`-flagged ops to NVPTX' fast lowerings. # `apply_fastmath!` propagates job-wide `target.fastmath=true` as per- # instruction `afn`, so a single flag check covers both per-call `@fastmath` # and the job toggle. We emit NVPTX intrinsics by name (rather than libdevice # `__nv_*`) so this doesn't depend on which libdevice symbols got linked in. # -# Both passes are temporary backports for LLVM 18: -# - `PTXFSqrtFastPass` is fully redundant on LLVM 21+: `usePrecSqrtF32` then -# honors the per-instruction `afn` and the function `unsafe-fp-math` -# attribute, so `DAGCombiner::visitFSQRT` → `NVPTXTargetLowering::getSqrtEstimate` -# emits the f32 `sqrt.approx{,.ftz}` and f64 `rcp(rsqrt(x))` sequences -# itself. LLVM 18's `usePrecSqrtF32` only consults `TargetMachine.Options.UnsafeFPMath`, -# which is unreachable through LLVM.jl. -# - `PTXFDivFastPass`'s f32 path is similarly redundant on LLVM 21+; +# `PTXRSqrtFastPass` runs first: the rsqrt pattern (`fdiv afn 1.0, sqrt afn x`) +# spans an fdiv and a sqrt, so it has to claim both operands before the per-op +# passes below eat them. NVPTX has native `rsqrt.approx.{f,d}` for both f32 and +# f64, so this is a single-instruction lowering for both types. +# +# `PTXFDivFastPass` / `PTXFSqrtFastPass` are temporary backports for LLVM 18: +# - `PTXFSqrtFastPass`: on LLVM 21+ `usePrecSqrtF32` honors per-instruction +# `afn` + the `unsafe-fp-math` attribute, so `DAGCombiner::visitFSQRT` → +# `NVPTXTargetLowering::getSqrtEstimate` emits the f32 `sqrt.approx{,.ftz}` +# and f64 `rcp(rsqrt(x))` sequences itself. LLVM 18's `usePrecSqrtF32` only +# consults `TargetMachine.Options.UnsafeFPMath`, which is unreachable +# through LLVM.jl. +# - `PTXFDivFastPass`'s f32 path is similarly fixed on LLVM 21+; # `getDivF32Level` there honors `afn` + the function attribute. The f64 # path stays needed until NVPTX gains a `getRecipEstimate` hook (filed # upstream). +# On LLVM 21+ both passes (and the f32 path of `PTXRSqrtFastPass`) can be +# dropped together — they have to leave the pipeline as a set, because as +# long as `PTXFDivFastPass` runs and rewrites `fdiv afn 1.0, sqrt(x)` into +# `nvvm.div.approx.f(1.0, ...)`, the rsqrt tablegen pattern can't match. + +# Rewrite `fdiv afn 1.0, sqrt afn(x)` to `nvvm.rsqrt.approx.{f,d}(x)`. Must run +# before `PTXFDivFastPass` (which would rewrite the fdiv to `nvvm.div.approx.f`, +# defeating ISel pattern-matching) and `PTXFSqrtFastPass` (which for f64 +# expands sqrt into `rcp(rsqrt(...))`, hiding the pattern entirely). +# +# Why we can't rely on LLVM upstream: +# - f32: NVPTX has tablegen patterns (`NVPTXIntrinsics.td`, `doRsqrtOpt`) that +# match `fdiv 1.0, sqrt_approx(x)` → `rsqrt.approx.f32` — but they landed in +# LLVM 19 (so absent on our LLVM 18 floor), and even on LLVM 21+ they only +# fire if the fdiv is still a generic `fdiv`. PTXFDivFastPass kills that. +# - f64: no upstream fold exists at all. NVPTX doesn't override +# `getRecipEstimateSqrtEnabled`, so the DAGCombiner's generic rsqrt path is +# disabled, and there's no f64 equivalent of the f32 tablegen patterns. +# `rsqrt.approx.f64` is a real instruction; it just isn't selected for +# `1/sqrt(x)` upstream. +function ptx_rsqrt_fast!(mod::LLVM.Module) + changed = false + @tracepoint "ptx-rsqrt-fast" begin + + f32 = LLVM.FloatType() + f64 = LLVM.DoubleType() + + # collect first to avoid mutation-during-iteration + to_replace = Tuple{LLVM.FDivInst, LLVM.CallInst, Bool}[] + for f in functions(mod), bb in blocks(f), inst in instructions(bb) + inst isa LLVM.FDivInst || continue + is_f32 = LLVM.value_type(inst) == f32 + is_f64 = LLVM.value_type(inst) == f64 + (is_f32 || is_f64) || continue + LLVM.fast_math(inst).afn || continue + + # numerator must be the constant 1.0 + lhs = operands(inst)[1] + lhs isa LLVM.ConstantFP || continue + convert(Float64, lhs) == 1.0 || continue + + # denominator must be an `afn`-flagged `llvm.sqrt.f{32,64}` call + rhs = operands(inst)[2] + rhs isa LLVM.CallInst || continue + callee = LLVM.called_operand(rhs) + callee isa LLVM.Function || continue + name = LLVM.name(callee) + expected = is_f32 ? "llvm.sqrt.f32" : "llvm.sqrt.f64" + name == expected || continue + LLVM.fast_math(rhs).afn || continue + + push!(to_replace, (inst, rhs, is_f32)) + end + isempty(to_replace) && return false + + fns = functions(mod) + declare(name, ft) = haskey(fns, name) ? fns[name] : LLVM.Function(mod, name, ft) + f32_ft = LLVM.FunctionType(f32, [f32]) + rsqrt_f32 = declare("llvm.nvvm.rsqrt.approx.f", f32_ft) + rsqrt_f32_ftz = declare("llvm.nvvm.rsqrt.approx.ftz.f", f32_ft) + f64_ft = LLVM.FunctionType(f64, [f64]) + rsqrt_f64 = declare("llvm.nvvm.rsqrt.approx.d", f64_ft) + + @dispose builder=IRBuilder() begin + for (fdiv, sqrt_call, is_f32) in to_replace + x = operands(sqrt_call)[1] + position!(builder, fdiv) + + replacement = if is_f32 + f = LLVM.parent(LLVM.parent(fdiv)) + call!(builder, f32_ft, f32_ftz(f) ? rsqrt_f32_ftz : rsqrt_f32, [x]) + else + call!(builder, f64_ft, rsqrt_f64, [x]) + end + + replace_uses!(fdiv, replacement) + erase!(fdiv) + # sqrt may still be used elsewhere; only clean it up if dead now. + if isempty(uses(sqrt_call)) + erase!(sqrt_call) + end + changed = true + end + end + + end # @tracepoint + return changed +end +PTXRSqrtFastPass() = NewPMModulePass("ptx-rsqrt-fast", ptx_rsqrt_fast!) # Rewrite `afn`-flagged `fdiv`: # - f32 → `llvm.nvvm.div.approx{,.ftz}.f`. diff --git a/test/ptx.jl b/test/ptx.jl index b411e533..d6c38183 100644 --- a/test/ptx.jl +++ b/test/ptx.jl @@ -521,6 +521,53 @@ end end end +@testset "fastmath rsqrt" begin + # `PTXRSqrtFastPass` pattern-matches `fdiv afn 1.0, sqrt afn(x)` and folds + # it to `nvvm.rsqrt.approx.{f,d}`, so high-level `@fastmath 1/sqrt(x)` and + # `@fastmath inv(sqrt(x))` lower to a single `rsqrt.approx` instruction + # rather than `sqrt.approx + div.approx` (f32) or expansion into + # `rcp(rsqrt(...))` and a Newton step (f64). Without afn on both operands, + # the pattern doesn't fire — folding would change precision. + mod = @eval module $(gensym()) + rsqrt32_fast(x::Float32) = @fastmath 1f0 / sqrt(x) + rsqrt64_fast(x::Float64) = @fastmath 1.0 / sqrt(x) + rsqrt32(x::Float32) = 1f0 / sqrt(x) + rsqrt64(x::Float64) = 1.0 / sqrt(x) + end + + @test @filecheck begin + @check "rsqrt.approx.f32" + @check_not "sqrt.approx" + @check_not "div.approx" + PTX.code_native(mod.rsqrt32_fast, Tuple{Float32}) + end + @test @filecheck begin + @check "rsqrt.approx.ftz.f32" + @check_not "sqrt.approx" + @check_not "div.approx" + PTX.code_native(mod.rsqrt32_fast, Tuple{Float32}; fastmath=true) + end + @test @filecheck begin + @check "rsqrt.approx.f64" + @check_not "rcp.approx" + PTX.code_native(mod.rsqrt64_fast, Tuple{Float64}) + end + @test @filecheck begin + # job-wide fastmath stamps afn on all FP ops, so the pattern still fires + @check "rsqrt.approx.f64" + @check_not "rcp.approx" + PTX.code_native(mod.rsqrt64, Tuple{Float64}; fastmath=true) + end + + # Without afn, plain `1/sqrt(x)` must NOT fold to rsqrt: it would change + # precision. The non-fast f64 emits `sqrt.rn.f64 + div.rn.f64`. + @test @filecheck begin + @check "sqrt.rn.f64" + @check_not "rsqrt.approx" + PTX.code_native(mod.rsqrt64, Tuple{Float64}) + end +end + @testset "feature_set" begin # PTXCompilerTarget.feature_set controls the suffix on the LLVM CPU name, which is # what the NVPTX backend uses to flip `hasArchAccelFeatures()`. Verify it makes its