diff --git a/README.md b/README.md index 52dc333..92bb9ee 100644 --- a/README.md +++ b/README.md @@ -416,11 +416,17 @@ b = ct.load(B, (expert_id, k, bid_n), (1, TILE_K, TILE_N)) ## Differences from Julia -### Float-to-integer conversion truncates +### Some operations are non-throwing -Inside cuTile kernels, `Int32(x::Float32)` and similar float-to-integer constructors -truncate toward zero (like C-style casts), rather than throwing `InexactError` as in -standard Julia. This matches the behavior of GPU hardware and cuTile Python's `ct.astype`. +cuTile kernels cannot throw Julia exceptions. Operations that would throw in +standard Julia silently produce truncated or wrapped results instead: + +- **Float-to-integer conversions:** `Int32(x)`, `trunc(Int32, x)`, and + `round(Int32, x, RoundToZero)` silently truncate toward zero rather than + throwing `InexactError` for non-integer or out-of-range values. Use + `unsafe_trunc` for the explicit non-throwing primitive. + +Assertions may be added in the future for testing purposes. ## Limitations diff --git a/src/language/broadcast.jl b/src/language/broadcast.jl index ab66afd..93cf8d8 100644 --- a/src/language/broadcast.jl +++ b/src/language/broadcast.jl @@ -1,7 +1,7 @@ # Broadcasting Infrastructure for Tiles # # Defines the broadcast style and shape computation for Tile types. -# All broadcasted operations are materialized via copy → map. +# All broadcasted operations are materialized via copy. import Base.Broadcast: BroadcastStyle, Broadcasted, broadcastable, broadcast_shape @@ -24,7 +24,19 @@ Base.Broadcast.broadcastable(t::Tile) = t #============================================================================= - Broadcast materialization via copy + map + Ghost wrapper for Type values in broadcasting +=============================================================================# + +# Replaces Julia's RefValue{Type{T}} wrapping which the cuTile compiler can't construct. +# The value is encoded in the type parameter — no runtime representation needed. +struct TypeRef{T} end + +Base.Broadcast.BroadcastStyle(::Type{<:TypeRef}) = Base.Broadcast.DefaultArrayStyle{0}() +Base.Broadcast.broadcastable(a::TypeRef) = a + + +#============================================================================= + Broadcast materialization via copy =============================================================================# # Tile is a ghost type with no storage, so axes/size are meaningless. @@ -32,7 +44,7 @@ Base.Broadcast.broadcastable(t::Tile) = t @inline Base.Broadcast.instantiate(bc::Broadcasted{TileStyle}) = bc # Recursively materialize nested Broadcasted nodes, -# promote scalars to Tiles, broadcast to a common shape, then apply via map. +# promote scalars to Tiles, broadcast to a common shape, then apply f. # This handles all element-wise operations: scalar @overlay methods provide # the implementation for overlaid ops, while Julia's native scalar functions # (compiled to Core intrinsics) handle the rest. Mixed-type and type-changing @@ -40,10 +52,10 @@ Base.Broadcast.broadcastable(t::Tile) = t # in operations.jl. @inline function Base.copy(bc::Broadcasted{TileStyle}) args = _materialize_args(bc.args) - tiles = _promote_to_tiles(args...) - S = _broadcast_shapes(tiles...) - broadcasted = _broadcast_all(S, tiles...) - map(bc.f, broadcasted...) + promoted = _promote_to_tiles(args...) + S = _broadcast_shapes(promoted...) + broadcasted = _broadcast_all(S, promoted...) + _apply_broadcast(bc.f, broadcasted...) end # Recursively materialize nested Broadcasted nodes into concrete Tiles. @@ -63,19 +75,46 @@ end # using its own type (e.g., 0.0f0 → Tile(Float32(0.0))), preserving the # type that Julia's broadcast promotion chose. This avoids the pitfall of # using the first Tile's eltype (which could be Bool for ifelse conditions). +# TypeRef arguments pass through unchanged — they carry no tile shape. @inline _promote_to_tiles() = () @inline _promote_to_tiles(a::Tile, rest...) = (a, _promote_to_tiles(rest...)...) @inline _promote_to_tiles(a::T, rest...) where {T <: Number} = (Tile(a), _promote_to_tiles(rest...)...) +@inline _promote_to_tiles(a::TypeRef, rest...) = (a, _promote_to_tiles(rest...)...) # Compute combined broadcast shape across all Tile arguments via tuple peeling. # Shape is always a tuple TYPE (e.g., Tuple{16, 32}). Convert to value for broadcast_shape. +# TypeRef arguments are skipped — they have no shape. @inline _tile_shape(t::Tile) = size(t) @inline _broadcast_shapes(t::Tile) = _tile_shape(t) -@inline _broadcast_shapes(t::Tile, rest::Tile...) = +@inline _broadcast_shapes(t::Tile, rest...) = broadcast_shape(_tile_shape(t), _broadcast_shapes(rest...)) +@inline _broadcast_shapes(::TypeRef, rest...) = _broadcast_shapes(rest...) +@inline _broadcast_shapes(::TypeRef) = () # Broadcast all tiles to shape S via tuple peeling. +# TypeRef arguments pass through unchanged. @inline _broadcast_all(S::Tuple) = () -@inline _broadcast_all(S::Tuple, a::Tile, rest::Tile...) = +@inline _broadcast_all(S::Tuple, a::Tile, rest...) = (broadcast_to(a, S), _broadcast_all(S, rest...)...) +@inline _broadcast_all(S::Tuple, a::TypeRef, rest...) = + (a, _broadcast_all(S, rest...)...) + +# Convert args to scalars, apply f, wrap result back into a Tile. +@inline function _apply_broadcast(f, args...) + scalar_args, S = _to_scalars(args...) + Intrinsics.from_scalar(f(scalar_args...), S) +end + +# Reinterpret Tile arguments as scalars for broadcast application. +# Skip and extract TypeRef arguments. +# Returns (scalar_args_tuple, S) where S is the shape from the first Tile. +@inline _to_scalars(t::Tile{<:Any,S}) where S = ((Intrinsics.to_scalar(t),), S) +@inline function _to_scalars(t::Tile{<:Any,S}, rest...) where S + rest_scalars, _ = _to_scalars(rest...) + ((Intrinsics.to_scalar(t), rest_scalars...), S) +end +@inline function _to_scalars(::TypeRef{T}, rest...) where T + rest_scalars, S = _to_scalars(rest...) + ((T, rest_scalars...), S) +end diff --git a/src/language/overlays.jl b/src/language/overlays.jl index 7094634..73debc8 100644 --- a/src/language/overlays.jl +++ b/src/language/overlays.jl @@ -8,6 +8,14 @@ macro overlay(ex) end +#============================================================================= + Broadcasting +=============================================================================# + +# Route Type values through TypeRef instead of RefValue (which can't be constructed in Tile IR). +@overlay Base.Broadcast.broadcastable(::Type{T}) where T = TypeRef{T}() + + #============================================================================= Type Conversions =============================================================================# @@ -51,13 +59,13 @@ end sizeof(S) > sizeof(T) ? Intrinsics.exti(x, S, SignednessUnsigned) : sizeof(S) < sizeof(T) ? Intrinsics.trunci(x, S) : x -# Float to float (specific type pairs) +# Float to float for T in Floats, S in Floats T === S && continue @eval @overlay $T(x::$S) = Intrinsics.ftof(x, $T) end -# Integer to float (specific type pairs) +# Integer to float for F in Floats for I in SignedInts @eval @overlay $F(x::$I) = Intrinsics.itof(x, $F, SignednessSigned) @@ -78,12 +86,26 @@ for F in Floats end end -# Float to integer (direct constructor - truncates like C-style cast) +# Float to integer (round with RoundToZero) +for F in Floats, I in (SignedInts..., UnsignedInts...) + @eval @overlay function Base.round(::Type{$I}, x::$F, ::Base.Rounding.RoundingMode{:ToZero}) + # TODO: assert that x is within bounds etc + unsafe_trunc($I, x) + end +end + +# Float to integer (direct constructor) for F in Floats for I in SignedInts - @eval @overlay $I(x::$F) = Intrinsics.ftoi(x, $I, SignednessSigned) + @eval @overlay function $I(x::$F) + # TODO: assert that x is within bounds etc + unsafe_trunc($I, x) + end end for I in UnsignedInts - @eval @overlay $I(x::$F) = Intrinsics.ftoi(x, $I, SignednessUnsigned) + @eval @overlay function $I(x::$F) + # TODO: assert that x is within bounds etc + unsafe_trunc($I, x) + end end end diff --git a/test/codegen/operations.jl b/test/codegen/operations.jl index bd822f2..8da55a9 100644 --- a/test/codegen/operations.jl +++ b/test/codegen/operations.jl @@ -849,6 +849,45 @@ end end end + + @testset "Type broadcasting" begin + # convert.(Float16, tile) — Type arg via TypeRef + @test @filecheck begin + @check_label "entry" + code_tiled(Tuple{ct.TileArray{Float32,1,spec1d}, ct.TileArray{Float16,1,spec1d}}) do a, b + pid = ct.bid(1) + tile = ct.load(a, pid, (16,)) + @check "ftof" + ct.store(b, pid, convert.(Float16, tile)) + return + end + end + + # convert.(Float32, float16_tile) — upcast via Type arg + @test @filecheck begin + @check_label "entry" + code_tiled(Tuple{ct.TileArray{Float16,1,spec1d}, ct.TileArray{Float32,1,spec1d}}) do a, b + pid = ct.bid(1) + tile = ct.load(a, pid, (16,)) + @check "ftof" + ct.store(b, pid, convert.(Float32, tile)) + return + end + end + + # unsafe_trunc.(Int32, float32_tile) — ftoi via Type arg + @test @filecheck begin + @check_label "entry" + code_tiled(Tuple{ct.TileArray{Float32,1,spec1d}, ct.TileArray{Int32,1,spec1d}}) do a, b + pid = ct.bid(1) + tile = ct.load(a, pid, (16,)) + @check "ftoi" + ct.store(b, pid, unsafe_trunc.(Int32, tile)) + return + end + end + + end end #========================================================================= diff --git a/test/execution/broadcast.jl b/test/execution/broadcast.jl index 6a63d5f..94dbbf8 100644 --- a/test/execution/broadcast.jl +++ b/test/execution/broadcast.jl @@ -685,6 +685,61 @@ end end # fma broadcasting +@testset "type argument broadcasting" begin + +@testset "convert.(Float16, tile)" begin + function convert_f16_kernel(a::ct.TileArray{Float32,1}, b::ct.TileArray{Float16,1}) + pid = ct.bid(1) + tile = ct.load(a, pid, (16,)) + ct.store(b, pid, convert.(Float16, tile)) + return + end + + n = 1024 + a = CUDA.rand(Float32, n) + b = CUDA.zeros(Float16, n) + + ct.launch(convert_f16_kernel, cld(n, 16), a, b) + + @test Array(b) == Float16.(Array(a)) +end + +@testset "convert.(Float32, float16_tile)" begin + function convert_f32_kernel(a::ct.TileArray{Float16,1}, b::ct.TileArray{Float32,1}) + pid = ct.bid(1) + tile = ct.load(a, pid, (16,)) + ct.store(b, pid, convert.(Float32, tile)) + return + end + + n = 1024 + a = CUDA.rand(Float16, n) + b = CUDA.zeros(Float32, n) + + ct.launch(convert_f32_kernel, cld(n, 16), a, b) + + @test Array(b) == Float32.(Array(a)) +end + +@testset "unsafe_trunc.(Int32, float_tile)" begin + function unsafe_trunc_i32_kernel(a::ct.TileArray{Float32,1}, b::ct.TileArray{Int32,1}) + pid = ct.bid(1) + tile = ct.load(a, pid, (16,)) + ct.store(b, pid, unsafe_trunc.(Int32, tile)) + return + end + + n = 1024 + a = CuArray(Float32.(rand(-100:100, n)) .+ 0.7f0) + b = CUDA.zeros(Int32, n) + + ct.launch(unsafe_trunc_i32_kernel, cld(n, 16), a, b) + + @test Array(b) == unsafe_trunc.(Int32, Array(a)) +end + +end # type argument broadcasting + @testset "multi-arg map" begin @testset "binary map(+, ...)" begin function map_add_kernel(a::ct.TileArray{Float32,1}, b::ct.TileArray{Float32,1},