Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 85 additions & 0 deletions src/ptx.jl
Original file line number Diff line number Diff line change
Expand Up @@ -262,8 +262,10 @@ function optimize_module!(@nospecialize(job::CompilerJob{PTXCompilerTarget}),
# TODO: Use the registered target passes (JuliaGPU/GPUCompiler.jl#450)
@dispose pb=NewPMPassBuilder() begin
register!(pb, NVVMReflectPass())
register!(pb, PTXFDivFastPass())

add!(pb, NVVMReflectPass())
add!(pb, PTXFDivFastPass())

add!(pb, NewPMFunctionPassManager()) do fpm
# needed by GemmKernels.jl-like code
Expand Down Expand Up @@ -555,3 +557,86 @@ function nvvm_reflect!(mod::LLVM.Module)
return changed
end
NVVMReflectPass() = NewPMModulePass("custom-nvvm-reflect", nvvm_reflect!)

# Same source NVPTX' `useF32FTZ` reads — `apply_fastmath!` sets it when
# `target.fastmath=true`. Used here to pick FTZ variants for the f32 rewrite.
function f32_ftz(f::LLVM.Function)
for attr in collect(LLVM.function_attributes(f))
attr isa LLVM.StringAttribute || continue
LLVM.kind(attr) == "denormal-fp-math-f32" || continue
return startswith(LLVM.value(attr), "preserve-sign")
end
return false
end

# Rewrite `afn`-flagged `fdiv` to NVPTX' fast lowerings. `apply_fastmath!`
# propagates job-wide `target.fastmath=true` as per-instruction `afn`, so the
# single flag check covers both per-call `@fastmath` and the job toggle. We
# emit NVPTX intrinsics directly (rather than libdevice `__nv_*`) so this
# doesn't depend on which libdevice symbols got linked in.
#
# - f32 → `llvm.nvvm.div.approx{,.ftz}.f`. Redundant on LLVM 21+, where
# `getDivF32Level` honors `afn`; LLVM 18 only consults
# `TargetMachine.Options.UnsafeFPMath`, which is unreachable through LLVM.jl.
# - f64 → `rcp.approx.ftz.d` + one Newton step (NVPTX has no fast f64 fdiv).
function ptx_fdiv_fast!(mod::LLVM.Module)
changed = false
@tracepoint "ptx-fdiv-fast" begin

f32 = LLVM.FloatType()
f64 = LLVM.DoubleType()

# collect first to avoid mutation-during-iteration
to_replace = Tuple{LLVM.FDivInst, Bool}[]
for f in functions(mod), bb in blocks(f), inst in instructions(bb)
inst isa LLVM.FDivInst || continue
is_f32 = LLVM.value_type(inst) == f32
is_f64 = LLVM.value_type(inst) == f64
(is_f32 || is_f64) || continue
LLVM.fast_math(inst).afn || continue
push!(to_replace, (inst, is_f32))
end
isempty(to_replace) && return false

# declare intrinsics by name so LLVM keeps the exact non-overloaded names;
# LLVM.Intrinsic + type params would mangle to *.f64, unrecognized by NVPTX.
fns = functions(mod)
declare(name, ft) = haskey(fns, name) ? fns[name] : LLVM.Function(mod, name, ft)
f32_ft = LLVM.FunctionType(f32, [f32, f32])
div_f32 = declare("llvm.nvvm.div.approx.f", f32_ft)
div_f32_ftz = declare("llvm.nvvm.div.approx.ftz.f", f32_ft)
rcp_ft = LLVM.FunctionType(f64, [f64])
rcp_f64 = declare("llvm.nvvm.rcp.approx.ftz.d", rcp_ft)
fma_ft = LLVM.FunctionType(f64, [f64, f64, f64])
fma_f64 = declare("llvm.fma.f64", fma_ft)
one_f64 = ConstantFP(f64, 1.0)

@dispose builder=IRBuilder() begin
for (inst, is_f32) in to_replace
lhs, rhs = operands(inst)[1], operands(inst)[2]
position!(builder, inst)

replacement = if is_f32
# TODO: drop f32 path once we require LLVM 21+.
f = LLVM.parent(LLVM.parent(inst))
call!(builder, f32_ft, f32_ftz(f) ? div_f32_ftz : div_f32, [lhs, rhs])
else
inv_y = call!(builder, rcp_ft, rcp_f64, [rhs])
neg_rhs = fneg!(builder, rhs)
# Newton refinement, matching CUDA.jl's `FastMath.inv_fast(::Float64)`
e = call!(builder, fma_ft, fma_f64, [inv_y, neg_rhs, one_f64])
e = call!(builder, fma_ft, fma_f64, [e, e, e])
inv_ref = call!(builder, fma_ft, fma_f64, [e, inv_y, inv_y])
fmul!(builder, lhs, inv_ref)
end

replace_uses!(inst, replacement)
erase!(inst)
changed = true
end
end

end # @tracepoint
return changed
end
PTXFDivFastPass() = NewPMModulePass("ptx-fdiv-fast", ptx_fdiv_fast!)
33 changes: 33 additions & 0 deletions test/ptx.jl
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,39 @@ end
end
end

@testset "fastmath division" begin
# `PTXFDivFastPass` rewrites `afn`-flagged fdiv. f32 → `div.approx{,.ftz}.f32`
# (filling in for LLVM 18, whose `getDivF32Level` doesn't honor per-call
# `afn`); f64 → `rcp.approx.ftz.d` + Newton refinement (NVPTX has no fast
# f64 fdiv lowering). Job-wide `fastmath=true` reaches this through the
# per-instruction flags `apply_fastmath!` stamps in `finish_linked_module!`.
mod_fast = @eval module $(gensym())
kernel_f32(x::Float32, y::Float32) = @fastmath x / y
kernel_f64(x::Float64, y::Float64) = @fastmath x / y
end
mod_precise = @eval module $(gensym())
kernel_f32(x::Float32, y::Float32) = x / y
kernel_f64(x::Float64, y::Float64) = x / y
end

@test @filecheck begin
@check "div.approx.f32"
PTX.code_native(mod_fast.kernel_f32, Tuple{Float32, Float32})
end
@test @filecheck begin
@check_not "div.approx"
PTX.code_native(mod_precise.kernel_f32, Tuple{Float32, Float32})
end
@test @filecheck begin
@check "rcp.approx.ftz.f64"
PTX.code_native(mod_fast.kernel_f64, Tuple{Float64, Float64})
end
@test @filecheck begin
@check_not "rcp.approx"
PTX.code_native(mod_precise.kernel_f64, Tuple{Float64, Float64})
end
end

@testset "feature_set" begin
# PTXCompilerTarget.feature_set controls the suffix on the LLVM CPU name, which is
# what the NVPTX backend uses to flip `hasArchAccelFeatures()`. Verify it makes its
Expand Down
Loading