From 579fe7fe87e42546d9907f04c1b2eaf6b0b62073 Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Tue, 19 May 2026 15:19:30 +0200 Subject: [PATCH 1/2] Add PTXFDivFastPass to lower fdiv fast to NVPTX approximate division MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The LLVM NVPTX backend handles fdiv fast for Float32 (→ div.approx.ftz.f32) but has no fast path for Float64. This IR-level pass covers both: - Float32: replaces fdiv with __nv_fast_fdividef (libdevice) - Float64: replaces fdiv with rcp.approx.ftz.d + Newton refinement, matching CUDA.jl's inv_fast(::Float64) algorithm The pass fires when the instruction carries the afn fast-math flag (set by @fastmath) or when target.fastmath=true. It follows the NVVMReflectPass pattern already in ptx.jl. Co-Authored-By: Claude Sonnet 4.6 --- src/ptx.jl | 59 +++++++++++++++++++++++++++++++++++++++++++++++++++++ test/ptx.jl | 35 +++++++++++++++++++++++++++++++ 2 files changed, 94 insertions(+) diff --git a/src/ptx.jl b/src/ptx.jl index e5763e8d..913f33a8 100644 --- a/src/ptx.jl +++ b/src/ptx.jl @@ -262,8 +262,10 @@ function optimize_module!(@nospecialize(job::CompilerJob{PTXCompilerTarget}), # TODO: Use the registered target passes (JuliaGPU/GPUCompiler.jl#450) @dispose pb=NewPMPassBuilder() begin register!(pb, NVVMReflectPass()) + register!(pb, PTXFDivF64FastPass()) add!(pb, NVVMReflectPass()) + add!(pb, PTXFDivF64FastPass()) add!(pb, NewPMFunctionPassManager()) do fpm # needed by GemmKernels.jl-like code @@ -555,3 +557,60 @@ function nvvm_reflect!(mod::LLVM.Module) return changed end NVVMReflectPass() = NewPMModulePass("custom-nvvm-reflect", nvvm_reflect!) + +# Rewrite `afn`-flagged f64 `fdiv` to `rcp.approx.ftz.d` + one-step Newton +# refinement, matching CUDA.jl's `FastMath.inv_fast(::Float64)`. NVPTX has no +# fast f64 fdiv lowering of its own; f32 is left to the backend, which picks +# `div.approx.ftz.f32` for `afn`-flagged f32 fdivs. Job-wide `fastmath=true` +# reaches this pass through the per-instruction flags that `apply_fastmath!` +# already stamped on every FP op in `finish_linked_module!`. +function ptx_fdiv_f64_fast!(mod::LLVM.Module) + changed = false + @tracepoint "ptx-fdiv-f64-fast" begin + + f64 = LLVM.DoubleType() + + # collect first to avoid mutation-during-iteration + to_replace = LLVM.FDivInst[] + for f in functions(mod), bb in blocks(f), inst in instructions(bb) + inst isa LLVM.FDivInst || continue + LLVM.value_type(inst) == f64 || continue + LLVM.fast_math(inst).afn || continue + push!(to_replace, inst) + end + isempty(to_replace) && return false + + # declare rcp by name so LLVM keeps the exact (non-overloaded) intrinsic name; + # LLVM.Intrinsic + type params would mangle to *.f64, unrecognized by NVPTX. + fns = functions(mod) + rcp_ft = LLVM.FunctionType(f64, [f64]) + rcp_fn = haskey(fns, "llvm.nvvm.rcp.approx.ftz.d") ? + fns["llvm.nvvm.rcp.approx.ftz.d"] : LLVM.Function(mod, "llvm.nvvm.rcp.approx.ftz.d", rcp_ft) + fma_ft = LLVM.FunctionType(f64, [f64, f64, f64]) + fma_fn = haskey(fns, "llvm.fma.f64") ? + fns["llvm.fma.f64"] : LLVM.Function(mod, "llvm.fma.f64", fma_ft) + one_f64 = ConstantFP(f64, 1.0) + + @dispose builder=IRBuilder() begin + for inst in to_replace + lhs, rhs = operands(inst)[1], operands(inst)[2] + position!(builder, inst) + + inv_y = call!(builder, rcp_ft, rcp_fn, [rhs]) + neg_rhs = fneg!(builder, rhs) + # Newton refinement matching CUDA.jl's inv_fast(::Float64) + e = call!(builder, fma_ft, fma_fn, [inv_y, neg_rhs, one_f64]) + e = call!(builder, fma_ft, fma_fn, [e, e, e]) + inv_ref = call!(builder, fma_ft, fma_fn, [e, inv_y, inv_y]) + replacement = fmul!(builder, lhs, inv_ref) + + replace_uses!(inst, replacement) + erase!(inst) + changed = true + end + end + + end # @tracepoint + return changed +end +PTXFDivF64FastPass() = NewPMModulePass("ptx-fdiv-f64-fast", ptx_fdiv_f64_fast!) diff --git a/test/ptx.jl b/test/ptx.jl index 2ae4a150..4f512e56 100644 --- a/test/ptx.jl +++ b/test/ptx.jl @@ -462,6 +462,41 @@ end end end +@testset "fastmath division" begin + # NVPTX has no fast f64 fdiv lowering, so we rewrite `afn`-flagged f64 + # divs to `rcp.approx.ftz.d` + Newton refinement. f32 is left to the + # backend (`apply_fastmath!` already stamps `afn`, and NVPTX' own + # `getDivF32Level` then emits `div.approx.ftz.f32`). + mod_fast = @eval module $(gensym()) + kernel_f32(x::Float32, y::Float32) = @fastmath x / y + kernel_f64(x::Float64, y::Float64) = @fastmath x / y + end + mod_precise = @eval module $(gensym()) + kernel_f32(x::Float32, y::Float32) = x / y + kernel_f64(x::Float64, y::Float64) = x / y + end + + # f32 fast: backend handles `afn`-flagged fdiv natively. + @test @filecheck begin + @check "div.approx.ftz.f32" + PTX.code_native(mod_fast.kernel_f32, Tuple{Float32, Float32}) + end + @test @filecheck begin + @check_not "div.approx" + PTX.code_native(mod_precise.kernel_f32, Tuple{Float32, Float32}) + end + + # f64 fast: pass rewrites to rcp.approx.ftz + Newton refinement. + @test @filecheck begin + @check "rcp.approx.ftz.f64" + PTX.code_native(mod_fast.kernel_f64, Tuple{Float64, Float64}) + end + @test @filecheck begin + @check_not "rcp.approx" + PTX.code_native(mod_precise.kernel_f64, Tuple{Float64, Float64}) + end +end + @testset "feature_set" begin # PTXCompilerTarget.feature_set controls the suffix on the LLVM CPU name, which is # what the NVPTX backend uses to flip `hasArchAccelFeatures()`. Verify it makes its From 8507a1c600f7d02b5fe9d4262c84818b1d1455dc Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Wed, 20 May 2026 11:38:36 +0200 Subject: [PATCH 2/2] PTX: simplify fdiv-fast pass. Drop the `target.fastmath` check (`apply_fastmath!` stamps `afn` already), and emit NVPTX intrinsics directly so the f32 rewrite doesn't depend on libdevice being linked. f32 picks the FTZ variant from the function's `denormal-fp-math-f32` attribute. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/ptx.jl | 80 +++++++++++++++++++++++++++++++++++------------------ test/ptx.jl | 14 ++++------ 2 files changed, 59 insertions(+), 35 deletions(-) diff --git a/src/ptx.jl b/src/ptx.jl index 913f33a8..ebfce4f7 100644 --- a/src/ptx.jl +++ b/src/ptx.jl @@ -262,10 +262,10 @@ function optimize_module!(@nospecialize(job::CompilerJob{PTXCompilerTarget}), # TODO: Use the registered target passes (JuliaGPU/GPUCompiler.jl#450) @dispose pb=NewPMPassBuilder() begin register!(pb, NVVMReflectPass()) - register!(pb, PTXFDivF64FastPass()) + register!(pb, PTXFDivFastPass()) add!(pb, NVVMReflectPass()) - add!(pb, PTXFDivF64FastPass()) + add!(pb, PTXFDivFastPass()) add!(pb, NewPMFunctionPassManager()) do fpm # needed by GemmKernels.jl-like code @@ -558,51 +558,77 @@ function nvvm_reflect!(mod::LLVM.Module) end NVVMReflectPass() = NewPMModulePass("custom-nvvm-reflect", nvvm_reflect!) -# Rewrite `afn`-flagged f64 `fdiv` to `rcp.approx.ftz.d` + one-step Newton -# refinement, matching CUDA.jl's `FastMath.inv_fast(::Float64)`. NVPTX has no -# fast f64 fdiv lowering of its own; f32 is left to the backend, which picks -# `div.approx.ftz.f32` for `afn`-flagged f32 fdivs. Job-wide `fastmath=true` -# reaches this pass through the per-instruction flags that `apply_fastmath!` -# already stamped on every FP op in `finish_linked_module!`. -function ptx_fdiv_f64_fast!(mod::LLVM.Module) +# Same source NVPTX' `useF32FTZ` reads — `apply_fastmath!` sets it when +# `target.fastmath=true`. Used here to pick FTZ variants for the f32 rewrite. +function f32_ftz(f::LLVM.Function) + for attr in collect(LLVM.function_attributes(f)) + attr isa LLVM.StringAttribute || continue + LLVM.kind(attr) == "denormal-fp-math-f32" || continue + return startswith(LLVM.value(attr), "preserve-sign") + end + return false +end + +# Rewrite `afn`-flagged `fdiv` to NVPTX' fast lowerings. `apply_fastmath!` +# propagates job-wide `target.fastmath=true` as per-instruction `afn`, so the +# single flag check covers both per-call `@fastmath` and the job toggle. We +# emit NVPTX intrinsics directly (rather than libdevice `__nv_*`) so this +# doesn't depend on which libdevice symbols got linked in. +# +# - f32 → `llvm.nvvm.div.approx{,.ftz}.f`. Redundant on LLVM 21+, where +# `getDivF32Level` honors `afn`; LLVM 18 only consults +# `TargetMachine.Options.UnsafeFPMath`, which is unreachable through LLVM.jl. +# - f64 → `rcp.approx.ftz.d` + one Newton step (NVPTX has no fast f64 fdiv). +function ptx_fdiv_fast!(mod::LLVM.Module) changed = false - @tracepoint "ptx-fdiv-f64-fast" begin + @tracepoint "ptx-fdiv-fast" begin + f32 = LLVM.FloatType() f64 = LLVM.DoubleType() # collect first to avoid mutation-during-iteration - to_replace = LLVM.FDivInst[] + to_replace = Tuple{LLVM.FDivInst, Bool}[] for f in functions(mod), bb in blocks(f), inst in instructions(bb) inst isa LLVM.FDivInst || continue - LLVM.value_type(inst) == f64 || continue + is_f32 = LLVM.value_type(inst) == f32 + is_f64 = LLVM.value_type(inst) == f64 + (is_f32 || is_f64) || continue LLVM.fast_math(inst).afn || continue - push!(to_replace, inst) + push!(to_replace, (inst, is_f32)) end isempty(to_replace) && return false - # declare rcp by name so LLVM keeps the exact (non-overloaded) intrinsic name; + # declare intrinsics by name so LLVM keeps the exact non-overloaded names; # LLVM.Intrinsic + type params would mangle to *.f64, unrecognized by NVPTX. fns = functions(mod) + declare(name, ft) = haskey(fns, name) ? fns[name] : LLVM.Function(mod, name, ft) + f32_ft = LLVM.FunctionType(f32, [f32, f32]) + div_f32 = declare("llvm.nvvm.div.approx.f", f32_ft) + div_f32_ftz = declare("llvm.nvvm.div.approx.ftz.f", f32_ft) rcp_ft = LLVM.FunctionType(f64, [f64]) - rcp_fn = haskey(fns, "llvm.nvvm.rcp.approx.ftz.d") ? - fns["llvm.nvvm.rcp.approx.ftz.d"] : LLVM.Function(mod, "llvm.nvvm.rcp.approx.ftz.d", rcp_ft) + rcp_f64 = declare("llvm.nvvm.rcp.approx.ftz.d", rcp_ft) fma_ft = LLVM.FunctionType(f64, [f64, f64, f64]) - fma_fn = haskey(fns, "llvm.fma.f64") ? - fns["llvm.fma.f64"] : LLVM.Function(mod, "llvm.fma.f64", fma_ft) + fma_f64 = declare("llvm.fma.f64", fma_ft) one_f64 = ConstantFP(f64, 1.0) @dispose builder=IRBuilder() begin - for inst in to_replace + for (inst, is_f32) in to_replace lhs, rhs = operands(inst)[1], operands(inst)[2] position!(builder, inst) - inv_y = call!(builder, rcp_ft, rcp_fn, [rhs]) - neg_rhs = fneg!(builder, rhs) - # Newton refinement matching CUDA.jl's inv_fast(::Float64) - e = call!(builder, fma_ft, fma_fn, [inv_y, neg_rhs, one_f64]) - e = call!(builder, fma_ft, fma_fn, [e, e, e]) - inv_ref = call!(builder, fma_ft, fma_fn, [e, inv_y, inv_y]) - replacement = fmul!(builder, lhs, inv_ref) + replacement = if is_f32 + # TODO: drop f32 path once we require LLVM 21+. + f = LLVM.parent(LLVM.parent(inst)) + call!(builder, f32_ft, f32_ftz(f) ? div_f32_ftz : div_f32, [lhs, rhs]) + else + inv_y = call!(builder, rcp_ft, rcp_f64, [rhs]) + neg_rhs = fneg!(builder, rhs) + # Newton refinement, matching CUDA.jl's `FastMath.inv_fast(::Float64)` + e = call!(builder, fma_ft, fma_f64, [inv_y, neg_rhs, one_f64]) + e = call!(builder, fma_ft, fma_f64, [e, e, e]) + inv_ref = call!(builder, fma_ft, fma_f64, [e, inv_y, inv_y]) + fmul!(builder, lhs, inv_ref) + end replace_uses!(inst, replacement) erase!(inst) @@ -613,4 +639,4 @@ function ptx_fdiv_f64_fast!(mod::LLVM.Module) end # @tracepoint return changed end -PTXFDivF64FastPass() = NewPMModulePass("ptx-fdiv-f64-fast", ptx_fdiv_f64_fast!) +PTXFDivFastPass() = NewPMModulePass("ptx-fdiv-fast", ptx_fdiv_fast!) diff --git a/test/ptx.jl b/test/ptx.jl index 4f512e56..34466ef3 100644 --- a/test/ptx.jl +++ b/test/ptx.jl @@ -463,10 +463,11 @@ end end @testset "fastmath division" begin - # NVPTX has no fast f64 fdiv lowering, so we rewrite `afn`-flagged f64 - # divs to `rcp.approx.ftz.d` + Newton refinement. f32 is left to the - # backend (`apply_fastmath!` already stamps `afn`, and NVPTX' own - # `getDivF32Level` then emits `div.approx.ftz.f32`). + # `PTXFDivFastPass` rewrites `afn`-flagged fdiv. f32 → `div.approx{,.ftz}.f32` + # (filling in for LLVM 18, whose `getDivF32Level` doesn't honor per-call + # `afn`); f64 → `rcp.approx.ftz.d` + Newton refinement (NVPTX has no fast + # f64 fdiv lowering). Job-wide `fastmath=true` reaches this through the + # per-instruction flags `apply_fastmath!` stamps in `finish_linked_module!`. mod_fast = @eval module $(gensym()) kernel_f32(x::Float32, y::Float32) = @fastmath x / y kernel_f64(x::Float64, y::Float64) = @fastmath x / y @@ -476,17 +477,14 @@ end kernel_f64(x::Float64, y::Float64) = x / y end - # f32 fast: backend handles `afn`-flagged fdiv natively. @test @filecheck begin - @check "div.approx.ftz.f32" + @check "div.approx.f32" PTX.code_native(mod_fast.kernel_f32, Tuple{Float32, Float32}) end @test @filecheck begin @check_not "div.approx" PTX.code_native(mod_precise.kernel_f32, Tuple{Float32, Float32}) end - - # f64 fast: pass rewrites to rcp.approx.ftz + Newton refinement. @test @filecheck begin @check "rcp.approx.ftz.f64" PTX.code_native(mod_fast.kernel_f64, Tuple{Float64, Float64})