Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 3 additions & 7 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,27 +8,23 @@ projects = ["test", "examples"]

[deps]
BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b"
CompilerCaching = "9db33cc3-5358-4881-8759-fa4194144afd"
CUDA_Compiler_jll = "d1e2174e-dfdc-576e-b43e-73b79eb1aca8"
CUDA_Tile_jll = "2068806d-a867-5dbd-af0e-42c2eb5d895d"
CompilerCaching = "9db33cc3-5358-4881-8759-fa4194144afd"
IRStructurizer = "93e32bba-5bb8-402b-805d-ffb066edee93"

[weakdeps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
DLFP8Types = "f4c16678-4a16-415b-82ef-ed337c5d6c7c"

[sources]
CompilerCaching = {url = "https://github.com/maleadt/CompilerCaching.jl", rev="main"}
IRStructurizer = {url = "https://github.com/maleadt/IRStructurizer.jl", rev = "main"}

[extensions]
CUDAExt = "CUDA"
DLFP8TypesExt = "DLFP8Types"

[compat]
julia = "1.11"
BFloat16s = "0.6"
CompilerCaching = "0.1"
CUDA_Compiler_jll = "0.4"
CUDA_Tile_jll = "13.1"
CompilerCaching = "0.1"
IRStructurizer = "0.1"
julia = "1.11"
8 changes: 4 additions & 4 deletions src/compiler/codegen/expressions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ function emit_expr!(ctx::CGCtx, expr::Expr, @nospecialize(result_type))
elseif expr.head === :foreigncall
throw(IRError("Foreign calls not supported in Tile IR"))
elseif expr.head === :boundscheck
return nothing
# Bounds checking is always disabled in Tile IR kernels.
# Emit false so IfOps referencing this SSA can resolve the condition.
return emit_constant!(ctx, false, Bool)
else
@warn "Unhandled expression head" expr.head expr
return nothing
Expand Down Expand Up @@ -79,9 +81,7 @@ function emit_call!(ctx::CGCtx, expr::Expr, @nospecialize(result_type))
func = get_constant(ctx, args[1])
call_args = args[2:end]

# TODO: This is normally dynamic dispatch, which we should allow.
# However, we currently trigger this when emitting Julia intrinsics.
# We should switch to our own intrinsics entirely, which are only invoked.
# We enter here for dynamic dispatch, but also for all intrinsic functions.

@static if isdefined(Core, :throw_methoderror)
if func === Core.throw_methoderror
Expand Down
9 changes: 7 additions & 2 deletions src/compiler/codegen/statements.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,14 @@ function emit_statement!(ctx::CGCtx, @nospecialize(stmt), ssa_idx::Int, @nospeci
# PiNode is a type narrowing assertion - store the resolved value
tv = emit_value!(ctx, stmt)
elseif stmt === nothing
# No-op
# Dead code elimination artifact — no value to register
else
@warn "Unhandled statement type" typeof(stmt) stmt
# Literal values from constant folding or concrete eval.
# Try emit_constant! first (numbers/ghost types), fall back to emit_value!.
tv = emit_constant!(ctx, stmt, result_type)
if tv === nothing
tv = emit_value!(ctx, stmt)
end
end

# Store result by original Julia SSA index
Expand Down
55 changes: 38 additions & 17 deletions src/compiler/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,20 @@ CC.may_compress(::cuTileInterpreter) = true
CC.may_discard_trees(::cuTileInterpreter) = false

#=============================================================================
Custom return-type inference (tfuncs) for intrinsics
Custom inference for intrinsics
=============================================================================#

# Per-intrinsic return type overrides using multiple dispatch.
# Per-intrinsic return type overrides.
# Returns nothing when no override applies (fallback).
# Concrete per-intrinsic methods are defined in intrinsics/ (after the
# Intrinsics module exists).
tfunc(@nospecialize(f), argtypes::Vector{Any}) = nothing
tfunc(𝕃, @nospecialize(f), @nospecialize args...) = nothing

# Per-intrinsic effect overrides.
# Returns nothing when no override applies (fallback).
efunc(@nospecialize(f), effects::CC.Effects) = nothing

# Predicate for functions defined in the Intrinsics module.
# These get NoCallInfo() so they stay as Expr(:call) rather than Expr(:invoke).
isintrinsic(@nospecialize(f)) = isa(f, Function) && parentmodule(f) === Intrinsics

#=============================================================================
Subprogram inference for reduce/scan
Expand Down Expand Up @@ -172,18 +178,23 @@ end
result = @invoke CC.abstract_call_known(interp::CC.AbstractInterpreter, f::Any,
arginfo::CC.ArgInfo, si::CC.StmtInfo, vtypes::Union{CC.VarTable,Nothing},
sv::CC.InferenceState, max_methods::Int)
rt_override = tfunc(f, arginfo.argtypes)
is_intr = isintrinsic(f)
𝕃 = CC.typeinf_lattice(interp)
rt_override = tfunc(𝕃, f, arginfo.argtypes[2:end]...)
subprog = _infer_subprogram(interp, f, arginfo, si, vtypes, sv)
rt_override === nothing && subprog === nothing && return result
!is_intr && rt_override === nothing && subprog === nothing && return result
wrapped = CC.Future{CC.CallMeta}()
push!(sv.tasks, function (interp′, sv′)
isready(result) || return false
subprog !== nothing && !isready(subprog) && return false
cm = result[]
sp = subprog !== nothing ? subprog[] : nothing
rt = rt_override !== nothing ? rt_override : cm.rt
info = sp !== nothing ? SubprogramCallInfo(cm.info, sp.info) : cm.info
wrapped[] = CC.CallMeta(rt, cm.exct, cm.effects, info, cm.refinements)
efunc_override = is_intr ? efunc(f, cm.effects) : nothing
effects = efunc_override !== nothing ? efunc_override : cm.effects
info = is_intr ? CC.NoCallInfo() : cm.info
info = sp !== nothing ? SubprogramCallInfo(info, sp.info) : info
wrapped[] = CC.CallMeta(rt, cm.exct, effects, info, cm.refinements)
return true
end)
return wrapped
Expand All @@ -195,18 +206,23 @@ elseif isdefined(CC, :Future) # 1.12–1.13
result = @invoke CC.abstract_call_known(interp::CC.AbstractInterpreter, f::Any,
arginfo::CC.ArgInfo, si::CC.StmtInfo,
sv::CC.InferenceState, max_methods::Int)
rt_override = tfunc(f, arginfo.argtypes)
is_intr = isintrinsic(f)
𝕃 = CC.typeinf_lattice(interp)
rt_override = tfunc(𝕃, f, arginfo.argtypes[2:end]...)
subprog = _infer_subprogram(interp, f, arginfo, si, nothing, sv)
rt_override === nothing && subprog === nothing && return result
!is_intr && rt_override === nothing && subprog === nothing && return result
wrapped = CC.Future{CC.CallMeta}()
push!(sv.tasks, function (interp′, sv′)
isready(result) || return false
subprog !== nothing && !isready(subprog) && return false
cm = result[]
sp = subprog !== nothing ? subprog[] : nothing
rt = rt_override !== nothing ? rt_override : cm.rt
info = sp !== nothing ? SubprogramCallInfo(cm.info, sp.info) : cm.info
wrapped[] = CC.CallMeta(rt, cm.exct, cm.effects, info, cm.refinements)
efunc_override = is_intr ? efunc(f, cm.effects) : nothing
effects = efunc_override !== nothing ? efunc_override : cm.effects
info = is_intr ? CC.NoCallInfo() : cm.info
info = sp !== nothing ? SubprogramCallInfo(info, sp.info) : info
wrapped[] = CC.CallMeta(rt, cm.exct, effects, info, cm.refinements)
return true
end)
return wrapped
Expand All @@ -219,10 +235,15 @@ else # 1.11: synchronous, edges auto-tracked via stmt_edges
arginfo::CC.ArgInfo, si::CC.StmtInfo,
sv::CC.AbsIntState, max_methods::Int)
_infer_subprogram(interp, f, arginfo, si, nothing, sv) # side-effect only
rt_override = tfunc(f, arginfo.argtypes)
if rt_override !== nothing
return CC.CallMeta(rt_override, result.exct, result.effects,
result.info)
is_intr = isintrinsic(f)
𝕃 = CC.typeinf_lattice(interp)
rt_override = tfunc(𝕃, f, arginfo.argtypes[2:end]...)
rt = rt_override !== nothing ? rt_override : result.rt
efunc_override = is_intr ? efunc(f, result.effects) : nothing
effects = efunc_override !== nothing ? efunc_override : result.effects
info = is_intr ? CC.NoCallInfo() : result.info
if is_intr || rt_override !== nothing
return CC.CallMeta(rt, result.exct, effects, info)
end
return result
end
Expand Down
51 changes: 33 additions & 18 deletions src/compiler/intrinsics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,32 +4,45 @@

module Intrinsics

using Base: compilerbarrier, donotdelete
using Base: compilerbarrier, inferencebarrier
using ..cuTile: Tile, TileArray, Constant, TensorView, PartitionView
using ..cuTile: Signedness, SignednessSigned, SignednessUnsigned
using ..cuTile: ComparisonPredicate, CmpLessThan, CmpLessThanOrEqual, CmpGreaterThan, CmpGreaterThanOrEqual, CmpEqual, CmpNotEqual
using ..cuTile: IdentityVal, FloatIdentityVal, IntegerIdentityVal

end

# NOTE: Due to JuliaLang/julia#60583, intrinsics may be called during constant evaluation.
# Because of that, such intrinsics (such as basic arithmetic) need to provide an
# implementation that actually computes a valid result using Julia intrinsics.
#
# Sometimes that's not possible, e.g., because the functionality required for that is
# overlayed by methods calling back into the intrinsic (e.g. `sin`), so for those
# intrinsics we disable constant folding using a `compilerbarrier(:const)`
#
# NOTE: Side-effectful intrinsics (stores, atomics) use `donotdelete(args...)` in their
# bodies to prevent the optimizer from DCE'ing calls. `donotdelete` is a Julia builtin
# with `effect_free=ALWAYS_FALSE`, which inference propagates through the function body.
# `@assume_effects !:effect_free` does NOT work — `override_effects` can only strengthen
# effects (set ALWAYS_TRUE), not weaken them. Spoofing `ipo_effects` via a custom
# `CC.finish!` override is possible but fragile (must race against `finishinfer!` setting
# `use_const_api` based on pre-override effects). `donotdelete` is the simplest correct
# approach.
"""
@intrinsic signature

emit_intrinsic!(ctx::CGCtx, @nospecialize(func), args) = missing
Define a Tile IR intrinsic in the `Intrinsics` module. These intrinsics are
defined to return `Any`, so need additional `tfunc` and `efunc` definitions
to specify their behavior.
"""
macro intrinsic(ex)
body = quote
compilerbarrier(:type, nothing)
end
funcdef = Expr(:function, ex, body)
funcdef = Expr(:macrocall, Symbol("@noinline"), nothing, funcdef)
return esc(:(Core.eval(Intrinsics, $(QuoteNode(funcdef)))))
end

"""
instanceof_tfunc(lat) -> Type or nothing

Extract `T` from a lattice element representing `Type{T}`.
Simplified version of `Base.Compiler.instanceof_tfunc` that handles `Const(T)`
and `Type{T}` lattice elements. Returns `nothing` when `T` cannot be determined.
"""
function instanceof_tfunc(@nospecialize(lat))
if isa(lat, CC.Const)
val = lat.val
return val isa Type ? val : nothing
end
tgt = CC.widenconst(lat)
return tgt isa DataType && tgt <: Type && !isempty(tgt.parameters) ? tgt.parameters[1] : nothing
end

# Shared helper for creating load/store optimization hints
function create_optimization_hints(ctx::CGCtx, latency::Union{Int, Nothing}, allow_tma::Bool=true)
Expand All @@ -39,6 +52,8 @@ function create_optimization_hints(ctx::CGCtx, latency::Union{Int, Nothing}, all
return make_load_store_hints(ctx.sm_arch, hints)
end

emit_intrinsic!(ctx::CGCtx, @nospecialize(func), args) = missing

include("intrinsics/core.jl")
include("intrinsics/conversions.jl")
include("intrinsics/arithmetic.jl")
Expand Down
Loading