diff --git a/src/ptx.jl b/src/ptx.jl index e5763e8d..ebfce4f7 100644 --- a/src/ptx.jl +++ b/src/ptx.jl @@ -262,8 +262,10 @@ function optimize_module!(@nospecialize(job::CompilerJob{PTXCompilerTarget}), # TODO: Use the registered target passes (JuliaGPU/GPUCompiler.jl#450) @dispose pb=NewPMPassBuilder() begin register!(pb, NVVMReflectPass()) + register!(pb, PTXFDivFastPass()) add!(pb, NVVMReflectPass()) + add!(pb, PTXFDivFastPass()) add!(pb, NewPMFunctionPassManager()) do fpm # needed by GemmKernels.jl-like code @@ -555,3 +557,86 @@ function nvvm_reflect!(mod::LLVM.Module) return changed end NVVMReflectPass() = NewPMModulePass("custom-nvvm-reflect", nvvm_reflect!) + +# Same source NVPTX' `useF32FTZ` reads — `apply_fastmath!` sets it when +# `target.fastmath=true`. Used here to pick FTZ variants for the f32 rewrite. +function f32_ftz(f::LLVM.Function) + for attr in collect(LLVM.function_attributes(f)) + attr isa LLVM.StringAttribute || continue + LLVM.kind(attr) == "denormal-fp-math-f32" || continue + return startswith(LLVM.value(attr), "preserve-sign") + end + return false +end + +# Rewrite `afn`-flagged `fdiv` to NVPTX' fast lowerings. `apply_fastmath!` +# propagates job-wide `target.fastmath=true` as per-instruction `afn`, so the +# single flag check covers both per-call `@fastmath` and the job toggle. We +# emit NVPTX intrinsics directly (rather than libdevice `__nv_*`) so this +# doesn't depend on which libdevice symbols got linked in. +# +# - f32 → `llvm.nvvm.div.approx{,.ftz}.f`. Redundant on LLVM 21+, where +# `getDivF32Level` honors `afn`; LLVM 18 only consults +# `TargetMachine.Options.UnsafeFPMath`, which is unreachable through LLVM.jl. +# - f64 → `rcp.approx.ftz.d` + one Newton step (NVPTX has no fast f64 fdiv). +function ptx_fdiv_fast!(mod::LLVM.Module) + changed = false + @tracepoint "ptx-fdiv-fast" begin + + f32 = LLVM.FloatType() + f64 = LLVM.DoubleType() + + # collect first to avoid mutation-during-iteration + to_replace = Tuple{LLVM.FDivInst, Bool}[] + for f in functions(mod), bb in blocks(f), inst in instructions(bb) + inst isa LLVM.FDivInst || continue + is_f32 = LLVM.value_type(inst) == f32 + is_f64 = LLVM.value_type(inst) == f64 + (is_f32 || is_f64) || continue + LLVM.fast_math(inst).afn || continue + push!(to_replace, (inst, is_f32)) + end + isempty(to_replace) && return false + + # declare intrinsics by name so LLVM keeps the exact non-overloaded names; + # LLVM.Intrinsic + type params would mangle to *.f64, unrecognized by NVPTX. + fns = functions(mod) + declare(name, ft) = haskey(fns, name) ? fns[name] : LLVM.Function(mod, name, ft) + f32_ft = LLVM.FunctionType(f32, [f32, f32]) + div_f32 = declare("llvm.nvvm.div.approx.f", f32_ft) + div_f32_ftz = declare("llvm.nvvm.div.approx.ftz.f", f32_ft) + rcp_ft = LLVM.FunctionType(f64, [f64]) + rcp_f64 = declare("llvm.nvvm.rcp.approx.ftz.d", rcp_ft) + fma_ft = LLVM.FunctionType(f64, [f64, f64, f64]) + fma_f64 = declare("llvm.fma.f64", fma_ft) + one_f64 = ConstantFP(f64, 1.0) + + @dispose builder=IRBuilder() begin + for (inst, is_f32) in to_replace + lhs, rhs = operands(inst)[1], operands(inst)[2] + position!(builder, inst) + + replacement = if is_f32 + # TODO: drop f32 path once we require LLVM 21+. + f = LLVM.parent(LLVM.parent(inst)) + call!(builder, f32_ft, f32_ftz(f) ? div_f32_ftz : div_f32, [lhs, rhs]) + else + inv_y = call!(builder, rcp_ft, rcp_f64, [rhs]) + neg_rhs = fneg!(builder, rhs) + # Newton refinement, matching CUDA.jl's `FastMath.inv_fast(::Float64)` + e = call!(builder, fma_ft, fma_f64, [inv_y, neg_rhs, one_f64]) + e = call!(builder, fma_ft, fma_f64, [e, e, e]) + inv_ref = call!(builder, fma_ft, fma_f64, [e, inv_y, inv_y]) + fmul!(builder, lhs, inv_ref) + end + + replace_uses!(inst, replacement) + erase!(inst) + changed = true + end + end + + end # @tracepoint + return changed +end +PTXFDivFastPass() = NewPMModulePass("ptx-fdiv-fast", ptx_fdiv_fast!) diff --git a/test/ptx.jl b/test/ptx.jl index 2ae4a150..34466ef3 100644 --- a/test/ptx.jl +++ b/test/ptx.jl @@ -462,6 +462,39 @@ end end end +@testset "fastmath division" begin + # `PTXFDivFastPass` rewrites `afn`-flagged fdiv. f32 → `div.approx{,.ftz}.f32` + # (filling in for LLVM 18, whose `getDivF32Level` doesn't honor per-call + # `afn`); f64 → `rcp.approx.ftz.d` + Newton refinement (NVPTX has no fast + # f64 fdiv lowering). Job-wide `fastmath=true` reaches this through the + # per-instruction flags `apply_fastmath!` stamps in `finish_linked_module!`. + mod_fast = @eval module $(gensym()) + kernel_f32(x::Float32, y::Float32) = @fastmath x / y + kernel_f64(x::Float64, y::Float64) = @fastmath x / y + end + mod_precise = @eval module $(gensym()) + kernel_f32(x::Float32, y::Float32) = x / y + kernel_f64(x::Float64, y::Float64) = x / y + end + + @test @filecheck begin + @check "div.approx.f32" + PTX.code_native(mod_fast.kernel_f32, Tuple{Float32, Float32}) + end + @test @filecheck begin + @check_not "div.approx" + PTX.code_native(mod_precise.kernel_f32, Tuple{Float32, Float32}) + end + @test @filecheck begin + @check "rcp.approx.ftz.f64" + PTX.code_native(mod_fast.kernel_f64, Tuple{Float64, Float64}) + end + @test @filecheck begin + @check_not "rcp.approx" + PTX.code_native(mod_precise.kernel_f64, Tuple{Float64, Float64}) + end +end + @testset "feature_set" begin # PTXCompilerTarget.feature_set controls the suffix on the LLVM CPU name, which is # what the NVPTX backend uses to flip `hasArchAccelFeatures()`. Verify it makes its