Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions src/optim.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,34 @@
# LLVM IR optimization

"""
apply_fastmath!(mod::LLVM.Module)

Apply fast-math semantics to every function definition in `mod` — as if every
floating-point operation were wrapped in `@fastmath`. Sets `unsafe-fp-math="true"`
as a function attribute and turns on all fast-math flags on eligible FP
instructions.

Back-ends should call this from `finish_linked_module!` when their target has
fast math enabled. Set both the function attribute and the per-instruction
flags: not every codegen hook reads both (e.g. LLVM 18's NVPTX
`usePrecSqrtF32` only consults `TargetMachine.Options.UnsafeFPMath`, which
isn't reachable through LLVM.jl, so flagging the instructions is the
portable path), and flagging both leaves no path that silently keeps the
slow lowering.
"""
function apply_fastmath!(mod::LLVM.Module)
for f in functions(mod)
isdeclaration(f) && continue
push!(function_attributes(f), StringAttribute("unsafe-fp-math", "true"))
for bb in blocks(f), inst in instructions(bb)
if Bool(LLVM.API.LLVMCanValueUseFastMathFlags(inst))
fast_math!(inst; all=true)
end
end
end
return
end

# Pick the peephole pass according to `optimization_options(job).instcombine`. Defaults to
# `InstCombinePass` to match LLVM's standard pipeline; `InstSimplifyPass` is the fallback
# for back-ends that need only the simplification subset.
Expand Down
20 changes: 20 additions & 0 deletions src/ptx.jl
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,26 @@ function finish_module!(@nospecialize(job::CompilerJob{PTXCompilerTarget}),
return entry
end

function finish_linked_module!(@nospecialize(job::CompilerJob{PTXCompilerTarget}),
mod::LLVM.Module)
# propagate `target.fastmath` as `@fastmath`-everywhere semantics
# (mirrors nvcc's `--use_fast_math`). post-link so that bodies pulled in
# from libdevice and the runtime also get the flags.
if job.config.target.fastmath
apply_fastmath!(mod)
# additionally request FTZ on f32: NVPTX' `useF32FTZ` reads
# `denormal-fp-math-f32` to pick the FTZ variants for
# fdiv/fsqrt/etc.
for f in functions(mod)
isdeclaration(f) && continue
push!(function_attributes(f),
StringAttribute("denormal-fp-math-f32",
"preserve-sign,preserve-sign"))
end
end
return
end

function optimize_module!(@nospecialize(job::CompilerJob{PTXCompilerTarget}),
mod::LLVM.Module)
tm = llvm_machine(job.config.target)
Expand Down
4 changes: 3 additions & 1 deletion test/helpers/ptx.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,13 @@ function create_job(@nospecialize(func), @nospecialize(types);
cap=v"7.0", ptx=v"6.0", feature_set=:baseline,
minthreads=nothing, maxthreads=nothing,
blocks_per_sm=nothing, maxregs=nothing,
fastmath=false,
kwargs...)
config_kwargs, kwargs = split_kwargs(kwargs, GPUCompiler.CONFIG_KWARGS)
source = methodinstance(typeof(func), Base.to_tuple_type(types), Base.get_world_counter())
target = PTXCompilerTarget(; cap, ptx, feature_set,
minthreads, maxthreads, blocks_per_sm, maxregs)
minthreads, maxthreads, blocks_per_sm, maxregs,
fastmath)
params = CompilerParams()
config = CompilerConfig(target, params; kernel=false, config_kwargs...)
CompilerJob(source, config), kwargs
Expand Down
40 changes: 40 additions & 0 deletions test/ptx.jl
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,46 @@ end
PTX.code_native(devnull, mod.kernel, Tuple{Float32,Ptr{Float32}})
end

@testset "fastmath" begin
# `fastmath=true` on the target should call `apply_fastmath!` from
# `finish_linked_module!`, stamping `unsafe-fp-math` + fast-math flags on
# every FP op, and additionally setting `denormal-fp-math-f32` so NVPTX
# picks the FTZ variants. Verify both pieces — IR-level attributes and
# PTX-level instruction selection — with and without the flag.
mod = @eval module $(gensym())
kernel(x, out) = (unsafe_store!(out, sqrt(unsafe_load(x))); return)
end

# without fastmath, no unsafe-fp-math / f32-FTZ, and sqrt stays precise
@test @filecheck begin
@check_label "define void @{{(julia|j)_kernel_[0-9]+}}"
@check_not "unsafe-fp-math"
@check_not "denormal-fp-math-f32"
@check "call float @llvm.sqrt.f32"
PTX.code_llvm(mod.kernel, Tuple{Ptr{Float32},Ptr{Float32}}; dump_module=true)
end
@test @filecheck begin
@check "sqrt.rn.f32"
@check_not "sqrt.approx"
PTX.code_native(mod.kernel, Tuple{Ptr{Float32},Ptr{Float32}})
end

# with fastmath, the entry function carries the attributes, the sqrt call
# picks up fast-math flags, and PTX selects the approx+ftz variant.
@test @filecheck begin
@check_label "define void @{{(julia|j)_kernel_[0-9]+}}"
@check "call fast float @llvm.sqrt.f32"
@check "\"denormal-fp-math-f32\"=\"preserve-sign,preserve-sign\""
@check "\"unsafe-fp-math\"=\"true\""
PTX.code_llvm(mod.kernel, Tuple{Ptr{Float32},Ptr{Float32}};
dump_module=true, fastmath=true)
end
@test @filecheck begin
@check "sqrt.approx.ftz.f32"
PTX.code_native(mod.kernel, Tuple{Ptr{Float32},Ptr{Float32}}; fastmath=true)
end
end

@testset "feature_set" begin
# PTXCompilerTarget.feature_set controls the suffix on the LLVM CPU name, which is
# what the NVPTX backend uses to flip `hasArchAccelFeatures()`. Verify it makes its
Expand Down