From c2593d509ced58f5468a6254bd2ebe8d110f22f0 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Wed, 20 May 2026 09:49:44 +0200 Subject: [PATCH 1/2] Add a pass to apply fastmath attributes. --- src/optim.jl | 29 +++++++++++++++++++++++++++++ src/ptx.jl | 20 ++++++++++++++++++++ 2 files changed, 49 insertions(+) diff --git a/src/optim.jl b/src/optim.jl index 8bdb110e..8ecf4dc0 100644 --- a/src/optim.jl +++ b/src/optim.jl @@ -1,5 +1,34 @@ # LLVM IR optimization +""" + apply_fastmath!(mod::LLVM.Module) + +Apply fast-math semantics to every function definition in `mod` — as if every +floating-point operation were wrapped in `@fastmath`. Sets `unsafe-fp-math="true"` +as a function attribute and turns on all fast-math flags on eligible FP +instructions. + +Back-ends should call this from `finish_linked_module!` when their target has +fast math enabled. Set both the function attribute and the per-instruction +flags: not every codegen hook reads both (e.g. LLVM 18's NVPTX +`usePrecSqrtF32` only consults `TargetMachine.Options.UnsafeFPMath`, which +isn't reachable through LLVM.jl, so flagging the instructions is the +portable path), and flagging both leaves no path that silently keeps the +slow lowering. +""" +function apply_fastmath!(mod::LLVM.Module) + for f in functions(mod) + isdeclaration(f) && continue + push!(function_attributes(f), StringAttribute("unsafe-fp-math", "true")) + for bb in blocks(f), inst in instructions(bb) + if Bool(LLVM.API.LLVMCanValueUseFastMathFlags(inst)) + fast_math!(inst; all=true) + end + end + end + return +end + # Pick the peephole pass according to `optimization_options(job).instcombine`. Defaults to # `InstCombinePass` to match LLVM's standard pipeline; `InstSimplifyPass` is the fallback # for back-ends that need only the simplification subset. diff --git a/src/ptx.jl b/src/ptx.jl index aac856b7..e5763e8d 100644 --- a/src/ptx.jl +++ b/src/ptx.jl @@ -236,6 +236,26 @@ function finish_module!(@nospecialize(job::CompilerJob{PTXCompilerTarget}), return entry end +function finish_linked_module!(@nospecialize(job::CompilerJob{PTXCompilerTarget}), + mod::LLVM.Module) + # propagate `target.fastmath` as `@fastmath`-everywhere semantics + # (mirrors nvcc's `--use_fast_math`). post-link so that bodies pulled in + # from libdevice and the runtime also get the flags. + if job.config.target.fastmath + apply_fastmath!(mod) + # additionally request FTZ on f32: NVPTX' `useF32FTZ` reads + # `denormal-fp-math-f32` to pick the FTZ variants for + # fdiv/fsqrt/etc. + for f in functions(mod) + isdeclaration(f) && continue + push!(function_attributes(f), + StringAttribute("denormal-fp-math-f32", + "preserve-sign,preserve-sign")) + end + end + return +end + function optimize_module!(@nospecialize(job::CompilerJob{PTXCompilerTarget}), mod::LLVM.Module) tm = llvm_machine(job.config.target) From af1dcc3dbd1042b9851909970b9fde0fcd25659d Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Wed, 20 May 2026 10:04:48 +0200 Subject: [PATCH 2/2] Add a test. --- test/helpers/ptx.jl | 4 +++- test/ptx.jl | 40 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 43 insertions(+), 1 deletion(-) diff --git a/test/helpers/ptx.jl b/test/helpers/ptx.jl index 634b59de..4b00f707 100644 --- a/test/helpers/ptx.jl +++ b/test/helpers/ptx.jl @@ -39,11 +39,13 @@ function create_job(@nospecialize(func), @nospecialize(types); cap=v"7.0", ptx=v"6.0", feature_set=:baseline, minthreads=nothing, maxthreads=nothing, blocks_per_sm=nothing, maxregs=nothing, + fastmath=false, kwargs...) config_kwargs, kwargs = split_kwargs(kwargs, GPUCompiler.CONFIG_KWARGS) source = methodinstance(typeof(func), Base.to_tuple_type(types), Base.get_world_counter()) target = PTXCompilerTarget(; cap, ptx, feature_set, - minthreads, maxthreads, blocks_per_sm, maxregs) + minthreads, maxthreads, blocks_per_sm, maxregs, + fastmath) params = CompilerParams() config = CompilerConfig(target, params; kernel=false, config_kwargs...) CompilerJob(source, config), kwargs diff --git a/test/ptx.jl b/test/ptx.jl index dd89d61a..2ae4a150 100644 --- a/test/ptx.jl +++ b/test/ptx.jl @@ -422,6 +422,46 @@ end PTX.code_native(devnull, mod.kernel, Tuple{Float32,Ptr{Float32}}) end +@testset "fastmath" begin + # `fastmath=true` on the target should call `apply_fastmath!` from + # `finish_linked_module!`, stamping `unsafe-fp-math` + fast-math flags on + # every FP op, and additionally setting `denormal-fp-math-f32` so NVPTX + # picks the FTZ variants. Verify both pieces — IR-level attributes and + # PTX-level instruction selection — with and without the flag. + mod = @eval module $(gensym()) + kernel(x, out) = (unsafe_store!(out, sqrt(unsafe_load(x))); return) + end + + # without fastmath, no unsafe-fp-math / f32-FTZ, and sqrt stays precise + @test @filecheck begin + @check_label "define void @{{(julia|j)_kernel_[0-9]+}}" + @check_not "unsafe-fp-math" + @check_not "denormal-fp-math-f32" + @check "call float @llvm.sqrt.f32" + PTX.code_llvm(mod.kernel, Tuple{Ptr{Float32},Ptr{Float32}}; dump_module=true) + end + @test @filecheck begin + @check "sqrt.rn.f32" + @check_not "sqrt.approx" + PTX.code_native(mod.kernel, Tuple{Ptr{Float32},Ptr{Float32}}) + end + + # with fastmath, the entry function carries the attributes, the sqrt call + # picks up fast-math flags, and PTX selects the approx+ftz variant. + @test @filecheck begin + @check_label "define void @{{(julia|j)_kernel_[0-9]+}}" + @check "call fast float @llvm.sqrt.f32" + @check "\"denormal-fp-math-f32\"=\"preserve-sign,preserve-sign\"" + @check "\"unsafe-fp-math\"=\"true\"" + PTX.code_llvm(mod.kernel, Tuple{Ptr{Float32},Ptr{Float32}}; + dump_module=true, fastmath=true) + end + @test @filecheck begin + @check "sqrt.approx.ftz.f32" + PTX.code_native(mod.kernel, Tuple{Ptr{Float32},Ptr{Float32}}; fastmath=true) + end +end + @testset "feature_set" begin # PTXCompilerTarget.feature_set controls the suffix on the LLVM CPU name, which is # what the NVPTX backend uses to flip `hasArchAccelFeatures()`. Verify it makes its