Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "GPUCompiler"
uuid = "61eb1bfa-7361-4325-ad38-22787b887f55"
version = "1.13.1"
version = "1.13.2"
authors = ["Tim Besard <tim.besard@gmail.com>"]

[workspace]
Expand Down
114 changes: 105 additions & 9 deletions src/ptx.jl
Original file line number Diff line number Diff line change
Expand Up @@ -262,10 +262,12 @@ function optimize_module!(@nospecialize(job::CompilerJob{PTXCompilerTarget}),
# TODO: Use the registered target passes (JuliaGPU/GPUCompiler.jl#450)
@dispose pb=NewPMPassBuilder() begin
register!(pb, NVVMReflectPass())
register!(pb, PTXRSqrtFastPass())
register!(pb, PTXFDivFastPass())
register!(pb, PTXFSqrtFastPass())

add!(pb, NVVMReflectPass())
add!(pb, PTXRSqrtFastPass())
add!(pb, PTXFDivFastPass())
add!(pb, PTXFSqrtFastPass())

Expand Down Expand Up @@ -571,23 +573,117 @@ function f32_ftz(f::LLVM.Function)
return false
end

# Both passes below rewrite `afn`-flagged ops to NVPTX' fast lowerings.
# All three passes below rewrite `afn`-flagged ops to NVPTX' fast lowerings.
# `apply_fastmath!` propagates job-wide `target.fastmath=true` as per-
# instruction `afn`, so a single flag check covers both per-call `@fastmath`
# and the job toggle. We emit NVPTX intrinsics by name (rather than libdevice
# `__nv_*`) so this doesn't depend on which libdevice symbols got linked in.
#
# Both passes are temporary backports for LLVM 18:
# - `PTXFSqrtFastPass` is fully redundant on LLVM 21+: `usePrecSqrtF32` then
# honors the per-instruction `afn` and the function `unsafe-fp-math`
# attribute, so `DAGCombiner::visitFSQRT` → `NVPTXTargetLowering::getSqrtEstimate`
# emits the f32 `sqrt.approx{,.ftz}` and f64 `rcp(rsqrt(x))` sequences
# itself. LLVM 18's `usePrecSqrtF32` only consults `TargetMachine.Options.UnsafeFPMath`,
# which is unreachable through LLVM.jl.
# - `PTXFDivFastPass`'s f32 path is similarly redundant on LLVM 21+;
# `PTXRSqrtFastPass` runs first: the rsqrt pattern (`fdiv afn 1.0, sqrt afn x`)
# spans an fdiv and a sqrt, so it has to claim both operands before the per-op
# passes below eat them. NVPTX has native `rsqrt.approx.{f,d}` for both f32 and
# f64, so this is a single-instruction lowering for both types.
#
# `PTXFDivFastPass` / `PTXFSqrtFastPass` are temporary backports for LLVM 18:
# - `PTXFSqrtFastPass`: on LLVM 21+ `usePrecSqrtF32` honors per-instruction
# `afn` + the `unsafe-fp-math` attribute, so `DAGCombiner::visitFSQRT` →
# `NVPTXTargetLowering::getSqrtEstimate` emits the f32 `sqrt.approx{,.ftz}`
# and f64 `rcp(rsqrt(x))` sequences itself. LLVM 18's `usePrecSqrtF32` only
# consults `TargetMachine.Options.UnsafeFPMath`, which is unreachable
# through LLVM.jl.
# - `PTXFDivFastPass`'s f32 path is similarly fixed on LLVM 21+;
# `getDivF32Level` there honors `afn` + the function attribute. The f64
# path stays needed until NVPTX gains a `getRecipEstimate` hook (filed
# upstream).
# On LLVM 21+ both passes (and the f32 path of `PTXRSqrtFastPass`) can be
# dropped together — they have to leave the pipeline as a set, because as
# long as `PTXFDivFastPass` runs and rewrites `fdiv afn 1.0, sqrt(x)` into
# `nvvm.div.approx.f(1.0, ...)`, the rsqrt tablegen pattern can't match.

# Rewrite `fdiv afn 1.0, sqrt afn(x)` to `nvvm.rsqrt.approx.{f,d}(x)`. Must run
# before `PTXFDivFastPass` (which would rewrite the fdiv to `nvvm.div.approx.f`,
# defeating ISel pattern-matching) and `PTXFSqrtFastPass` (which for f64
# expands sqrt into `rcp(rsqrt(...))`, hiding the pattern entirely).
#
# Why we can't rely on LLVM upstream:
# - f32: NVPTX has tablegen patterns (`NVPTXIntrinsics.td`, `doRsqrtOpt`) that
# match `fdiv 1.0, sqrt_approx(x)` → `rsqrt.approx.f32` — but they landed in
# LLVM 19 (so absent on our LLVM 18 floor), and even on LLVM 21+ they only
# fire if the fdiv is still a generic `fdiv`. PTXFDivFastPass kills that.
# - f64: no upstream fold exists at all. NVPTX doesn't override
# `getRecipEstimateSqrtEnabled`, so the DAGCombiner's generic rsqrt path is
# disabled, and there's no f64 equivalent of the f32 tablegen patterns.
# `rsqrt.approx.f64` is a real instruction; it just isn't selected for
# `1/sqrt(x)` upstream.
function ptx_rsqrt_fast!(mod::LLVM.Module)
changed = false
@tracepoint "ptx-rsqrt-fast" begin

f32 = LLVM.FloatType()
f64 = LLVM.DoubleType()

# collect first to avoid mutation-during-iteration
to_replace = Tuple{LLVM.FDivInst, LLVM.CallInst, Bool}[]
for f in functions(mod), bb in blocks(f), inst in instructions(bb)
inst isa LLVM.FDivInst || continue
is_f32 = LLVM.value_type(inst) == f32
is_f64 = LLVM.value_type(inst) == f64
(is_f32 || is_f64) || continue
LLVM.fast_math(inst).afn || continue

# numerator must be the constant 1.0
lhs = operands(inst)[1]
lhs isa LLVM.ConstantFP || continue
convert(Float64, lhs) == 1.0 || continue

# denominator must be an `afn`-flagged `llvm.sqrt.f{32,64}` call
rhs = operands(inst)[2]
rhs isa LLVM.CallInst || continue
callee = LLVM.called_operand(rhs)
callee isa LLVM.Function || continue
name = LLVM.name(callee)
expected = is_f32 ? "llvm.sqrt.f32" : "llvm.sqrt.f64"
name == expected || continue
LLVM.fast_math(rhs).afn || continue

push!(to_replace, (inst, rhs, is_f32))
end
isempty(to_replace) && return false

fns = functions(mod)
declare(name, ft) = haskey(fns, name) ? fns[name] : LLVM.Function(mod, name, ft)
f32_ft = LLVM.FunctionType(f32, [f32])
rsqrt_f32 = declare("llvm.nvvm.rsqrt.approx.f", f32_ft)
rsqrt_f32_ftz = declare("llvm.nvvm.rsqrt.approx.ftz.f", f32_ft)
f64_ft = LLVM.FunctionType(f64, [f64])
rsqrt_f64 = declare("llvm.nvvm.rsqrt.approx.d", f64_ft)

@dispose builder=IRBuilder() begin
for (fdiv, sqrt_call, is_f32) in to_replace
x = operands(sqrt_call)[1]
position!(builder, fdiv)

replacement = if is_f32
f = LLVM.parent(LLVM.parent(fdiv))
call!(builder, f32_ft, f32_ftz(f) ? rsqrt_f32_ftz : rsqrt_f32, [x])
else
call!(builder, f64_ft, rsqrt_f64, [x])
end

replace_uses!(fdiv, replacement)
erase!(fdiv)
# sqrt may still be used elsewhere; only clean it up if dead now.
if isempty(uses(sqrt_call))
erase!(sqrt_call)
end
changed = true
end
end

end # @tracepoint
return changed
end
PTXRSqrtFastPass() = NewPMModulePass("ptx-rsqrt-fast", ptx_rsqrt_fast!)

# Rewrite `afn`-flagged `fdiv`:
# - f32 → `llvm.nvvm.div.approx{,.ftz}.f`.
Expand Down
47 changes: 47 additions & 0 deletions test/ptx.jl
Original file line number Diff line number Diff line change
Expand Up @@ -521,6 +521,53 @@ end
end
end

@testset "fastmath rsqrt" begin
# `PTXRSqrtFastPass` pattern-matches `fdiv afn 1.0, sqrt afn(x)` and folds
# it to `nvvm.rsqrt.approx.{f,d}`, so high-level `@fastmath 1/sqrt(x)` and
# `@fastmath inv(sqrt(x))` lower to a single `rsqrt.approx` instruction
# rather than `sqrt.approx + div.approx` (f32) or expansion into
# `rcp(rsqrt(...))` and a Newton step (f64). Without afn on both operands,
# the pattern doesn't fire — folding would change precision.
mod = @eval module $(gensym())
rsqrt32_fast(x::Float32) = @fastmath 1f0 / sqrt(x)
rsqrt64_fast(x::Float64) = @fastmath 1.0 / sqrt(x)
rsqrt32(x::Float32) = 1f0 / sqrt(x)
rsqrt64(x::Float64) = 1.0 / sqrt(x)
end

@test @filecheck begin
@check "rsqrt.approx.f32"
@check_not "sqrt.approx"
@check_not "div.approx"
PTX.code_native(mod.rsqrt32_fast, Tuple{Float32})
end
@test @filecheck begin
@check "rsqrt.approx.ftz.f32"
@check_not "sqrt.approx"
@check_not "div.approx"
PTX.code_native(mod.rsqrt32_fast, Tuple{Float32}; fastmath=true)
end
@test @filecheck begin
@check "rsqrt.approx.f64"
@check_not "rcp.approx"
PTX.code_native(mod.rsqrt64_fast, Tuple{Float64})
end
@test @filecheck begin
# job-wide fastmath stamps afn on all FP ops, so the pattern still fires
@check "rsqrt.approx.f64"
@check_not "rcp.approx"
PTX.code_native(mod.rsqrt64, Tuple{Float64}; fastmath=true)
end

# Without afn, plain `1/sqrt(x)` must NOT fold to rsqrt: it would change
# precision. The non-fast f64 emits `sqrt.rn.f64 + div.rn.f64`.
@test @filecheck begin
@check "sqrt.rn.f64"
@check_not "rsqrt.approx"
PTX.code_native(mod.rsqrt64, Tuple{Float64})
end
end

@testset "feature_set" begin
# PTXCompilerTarget.feature_set controls the suffix on the LLVM CPU name, which is
# what the NVPTX backend uses to flip `hasArchAccelFeatures()`. Verify it makes its
Expand Down