Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -416,11 +416,17 @@ b = ct.load(B, (expert_id, k, bid_n), (1, TILE_K, TILE_N))

## Differences from Julia

### Float-to-integer conversion truncates
### Some operations are non-throwing

Inside cuTile kernels, `Int32(x::Float32)` and similar float-to-integer constructors
truncate toward zero (like C-style casts), rather than throwing `InexactError` as in
standard Julia. This matches the behavior of GPU hardware and cuTile Python's `ct.astype`.
cuTile kernels cannot throw Julia exceptions. Operations that would throw in
standard Julia silently produce truncated or wrapped results instead:

- **Float-to-integer conversions:** `Int32(x)`, `trunc(Int32, x)`, and
`round(Int32, x, RoundToZero)` silently truncate toward zero rather than
throwing `InexactError` for non-integer or out-of-range values. Use
`unsafe_trunc` for the explicit non-throwing primitive.

Assertions may be added in the future for testing purposes.


## Limitations
Expand Down
57 changes: 48 additions & 9 deletions src/language/broadcast.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Broadcasting Infrastructure for Tiles
#
# Defines the broadcast style and shape computation for Tile types.
# All broadcasted operations are materialized via copy → map.
# All broadcasted operations are materialized via copy.

import Base.Broadcast: BroadcastStyle, Broadcasted, broadcastable, broadcast_shape

Expand All @@ -24,26 +24,38 @@ Base.Broadcast.broadcastable(t::Tile) = t


#=============================================================================
Broadcast materialization via copy + map
Ghost wrapper for Type values in broadcasting
=============================================================================#

# Replaces Julia's RefValue{Type{T}} wrapping which the cuTile compiler can't construct.
# The value is encoded in the type parameter — no runtime representation needed.
struct TypeRef{T} end

Base.Broadcast.BroadcastStyle(::Type{<:TypeRef}) = Base.Broadcast.DefaultArrayStyle{0}()
Base.Broadcast.broadcastable(a::TypeRef) = a


#=============================================================================
Broadcast materialization via copy
=============================================================================#

# Tile is a ghost type with no storage, so axes/size are meaningless.
# Skip instantiate (which calls axes) by returning the Broadcasted as-is.
@inline Base.Broadcast.instantiate(bc::Broadcasted{TileStyle}) = bc

# Recursively materialize nested Broadcasted nodes,
# promote scalars to Tiles, broadcast to a common shape, then apply via map.
# promote scalars to Tiles, broadcast to a common shape, then apply f.
# This handles all element-wise operations: scalar @overlay methods provide
# the implementation for overlaid ops, while Julia's native scalar functions
# (compiled to Core intrinsics) handle the rest. Mixed-type and type-changing
# operations (comparisons, ifelse) are supported by the mixed-type map methods
# in operations.jl.
@inline function Base.copy(bc::Broadcasted{TileStyle})
args = _materialize_args(bc.args)
tiles = _promote_to_tiles(args...)
S = _broadcast_shapes(tiles...)
broadcasted = _broadcast_all(S, tiles...)
map(bc.f, broadcasted...)
promoted = _promote_to_tiles(args...)
S = _broadcast_shapes(promoted...)
broadcasted = _broadcast_all(S, promoted...)
_apply_broadcast(bc.f, broadcasted...)
end

# Recursively materialize nested Broadcasted nodes into concrete Tiles.
Expand All @@ -63,19 +75,46 @@ end
# using its own type (e.g., 0.0f0 → Tile(Float32(0.0))), preserving the
# type that Julia's broadcast promotion chose. This avoids the pitfall of
# using the first Tile's eltype (which could be Bool for ifelse conditions).
# TypeRef arguments pass through unchanged — they carry no tile shape.
@inline _promote_to_tiles() = ()
@inline _promote_to_tiles(a::Tile, rest...) = (a, _promote_to_tiles(rest...)...)
@inline _promote_to_tiles(a::T, rest...) where {T <: Number} =
(Tile(a), _promote_to_tiles(rest...)...)
@inline _promote_to_tiles(a::TypeRef, rest...) = (a, _promote_to_tiles(rest...)...)

# Compute combined broadcast shape across all Tile arguments via tuple peeling.
# Shape is always a tuple TYPE (e.g., Tuple{16, 32}). Convert to value for broadcast_shape.
# TypeRef arguments are skipped — they have no shape.
@inline _tile_shape(t::Tile) = size(t)
@inline _broadcast_shapes(t::Tile) = _tile_shape(t)
@inline _broadcast_shapes(t::Tile, rest::Tile...) =
@inline _broadcast_shapes(t::Tile, rest...) =
broadcast_shape(_tile_shape(t), _broadcast_shapes(rest...))
@inline _broadcast_shapes(::TypeRef, rest...) = _broadcast_shapes(rest...)
@inline _broadcast_shapes(::TypeRef) = ()

# Broadcast all tiles to shape S via tuple peeling.
# TypeRef arguments pass through unchanged.
@inline _broadcast_all(S::Tuple) = ()
@inline _broadcast_all(S::Tuple, a::Tile, rest::Tile...) =
@inline _broadcast_all(S::Tuple, a::Tile, rest...) =
(broadcast_to(a, S), _broadcast_all(S, rest...)...)
@inline _broadcast_all(S::Tuple, a::TypeRef, rest...) =
(a, _broadcast_all(S, rest...)...)

# Convert args to scalars, apply f, wrap result back into a Tile.
@inline function _apply_broadcast(f, args...)
scalar_args, S = _to_scalars(args...)
Intrinsics.from_scalar(f(scalar_args...), S)
end

# Reinterpret Tile arguments as scalars for broadcast application.
# Skip and extract TypeRef arguments.
# Returns (scalar_args_tuple, S) where S is the shape from the first Tile.
@inline _to_scalars(t::Tile{<:Any,S}) where S = ((Intrinsics.to_scalar(t),), S)
@inline function _to_scalars(t::Tile{<:Any,S}, rest...) where S
rest_scalars, _ = _to_scalars(rest...)
((Intrinsics.to_scalar(t), rest_scalars...), S)
end
@inline function _to_scalars(::TypeRef{T}, rest...) where T
rest_scalars, S = _to_scalars(rest...)
((T, rest_scalars...), S)
end
32 changes: 27 additions & 5 deletions src/language/overlays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,14 @@ macro overlay(ex)
end


#=============================================================================
Broadcasting
=============================================================================#

# Route Type values through TypeRef instead of RefValue (which can't be constructed in Tile IR).
@overlay Base.Broadcast.broadcastable(::Type{T}) where T = TypeRef{T}()


#=============================================================================
Type Conversions
=============================================================================#
Expand Down Expand Up @@ -51,13 +59,13 @@ end
sizeof(S) > sizeof(T) ? Intrinsics.exti(x, S, SignednessUnsigned) :
sizeof(S) < sizeof(T) ? Intrinsics.trunci(x, S) : x

# Float to float (specific type pairs)
# Float to float
for T in Floats, S in Floats
T === S && continue
@eval @overlay $T(x::$S) = Intrinsics.ftof(x, $T)
end

# Integer to float (specific type pairs)
# Integer to float
for F in Floats
for I in SignedInts
@eval @overlay $F(x::$I) = Intrinsics.itof(x, $F, SignednessSigned)
Expand All @@ -78,12 +86,26 @@ for F in Floats
end
end

# Float to integer (direct constructor - truncates like C-style cast)
# Float to integer (round with RoundToZero)
for F in Floats, I in (SignedInts..., UnsignedInts...)
@eval @overlay function Base.round(::Type{$I}, x::$F, ::Base.Rounding.RoundingMode{:ToZero})
# TODO: assert that x is within bounds etc
unsafe_trunc($I, x)
end
end

# Float to integer (direct constructor)
for F in Floats
for I in SignedInts
@eval @overlay $I(x::$F) = Intrinsics.ftoi(x, $I, SignednessSigned)
@eval @overlay function $I(x::$F)
# TODO: assert that x is within bounds etc
unsafe_trunc($I, x)
end
end
for I in UnsignedInts
@eval @overlay $I(x::$F) = Intrinsics.ftoi(x, $I, SignednessUnsigned)
@eval @overlay function $I(x::$F)
# TODO: assert that x is within bounds etc
unsafe_trunc($I, x)
end
end
end
39 changes: 39 additions & 0 deletions test/codegen/operations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -849,6 +849,45 @@
end
end
end

@testset "Type broadcasting" begin
# convert.(Float16, tile) — Type arg via TypeRef
@test @filecheck begin
@check_label "entry"
code_tiled(Tuple{ct.TileArray{Float32,1,spec1d}, ct.TileArray{Float16,1,spec1d}}) do a, b
pid = ct.bid(1)
tile = ct.load(a, pid, (16,))
@check "ftof"
ct.store(b, pid, convert.(Float16, tile))
return
end
end

# convert.(Float32, float16_tile) — upcast via Type arg
@test @filecheck begin
@check_label "entry"
code_tiled(Tuple{ct.TileArray{Float16,1,spec1d}, ct.TileArray{Float32,1,spec1d}}) do a, b
pid = ct.bid(1)
tile = ct.load(a, pid, (16,))
@check "ftof"
ct.store(b, pid, convert.(Float32, tile))
return
end
end

# unsafe_trunc.(Int32, float32_tile) — ftoi via Type arg
@test @filecheck begin
@check_label "entry"
code_tiled(Tuple{ct.TileArray{Float32,1,spec1d}, ct.TileArray{Int32,1,spec1d}}) do a, b
pid = ct.bid(1)
tile = ct.load(a, pid, (16,))
@check "ftoi"
ct.store(b, pid, unsafe_trunc.(Int32, tile))
return
end
end

end
end

#=========================================================================
Expand Down
55 changes: 55 additions & 0 deletions test/execution/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -685,6 +685,61 @@ end

end # fma broadcasting

@testset "type argument broadcasting" begin

@testset "convert.(Float16, tile)" begin
function convert_f16_kernel(a::ct.TileArray{Float32,1}, b::ct.TileArray{Float16,1})
pid = ct.bid(1)
tile = ct.load(a, pid, (16,))
ct.store(b, pid, convert.(Float16, tile))
return
end

n = 1024
a = CUDA.rand(Float32, n)
b = CUDA.zeros(Float16, n)

ct.launch(convert_f16_kernel, cld(n, 16), a, b)

@test Array(b) == Float16.(Array(a))
end

@testset "convert.(Float32, float16_tile)" begin
function convert_f32_kernel(a::ct.TileArray{Float16,1}, b::ct.TileArray{Float32,1})
pid = ct.bid(1)
tile = ct.load(a, pid, (16,))
ct.store(b, pid, convert.(Float32, tile))
return
end

n = 1024
a = CUDA.rand(Float16, n)
b = CUDA.zeros(Float32, n)

ct.launch(convert_f32_kernel, cld(n, 16), a, b)

@test Array(b) == Float32.(Array(a))
end

@testset "unsafe_trunc.(Int32, float_tile)" begin
function unsafe_trunc_i32_kernel(a::ct.TileArray{Float32,1}, b::ct.TileArray{Int32,1})
pid = ct.bid(1)
tile = ct.load(a, pid, (16,))
ct.store(b, pid, unsafe_trunc.(Int32, tile))
return
end

n = 1024
a = CuArray(Float32.(rand(-100:100, n)) .+ 0.7f0)
b = CUDA.zeros(Int32, n)

ct.launch(unsafe_trunc_i32_kernel, cld(n, 16), a, b)

@test Array(b) == unsafe_trunc.(Int32, Array(a))
end

end # type argument broadcasting

@testset "multi-arg map" begin
@testset "binary map(+, ...)" begin
function map_add_kernel(a::ct.TileArray{Float32,1}, b::ct.TileArray{Float32,1},
Expand Down