Skip to content

Commit f7170cf

Browse files
arhikmaleadt
andauthored
Add integer reduction support (#37)
Co-authored-by: Tim Besard <tim.besard@gmail.com>
1 parent 532bcc9 commit f7170cf

7 files changed

Lines changed: 290 additions & 30 deletions

File tree

src/bytecode/encodings.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1291,7 +1291,7 @@ function encode_ReduceOp!(body::Function, cb::CodeBuilder,
12911291
result_types::Vector{TypeId},
12921292
operands::Vector{Value},
12931293
dim::Int,
1294-
identities::Vector{<:ReduceIdentity},
1294+
identities::Vector{<:IdentityVal},
12951295
body_scalar_types::Vector{TypeId})
12961296
encode_varint!(cb.buf, Opcode.ReduceOp)
12971297

src/bytecode/writer.jl

Lines changed: 66 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -234,30 +234,41 @@ end
234234
=============================================================================#
235235

236236
"""
237-
ReduceIdentity
237+
IdentityVal
238238
239-
Abstract type for reduce identity attributes.
239+
Abstract type for binary operation identity attributes (reduce, scan, etc.).
240240
"""
241-
abstract type ReduceIdentity end
241+
abstract type IdentityVal end
242242

243243
"""
244-
FloatIdentity(value, type_id, dtype)
244+
FloatIdentityVal(value, type_id, dtype)
245245
246-
Float identity value for reduce operations.
246+
Float identity value for binary operations.
247247
"""
248-
struct FloatIdentity <: ReduceIdentity
248+
struct FloatIdentityVal <: IdentityVal
249249
value::Float64
250250
type_id::TypeId
251251
dtype::Type # Float16, Float32, Float64, etc.
252252
end
253253

254254
"""
255-
encode_tagged_float!(cb, identity::FloatIdentity)
255+
IntegerIdentityVal(value, type_id, dtype)
256+
257+
Integer identity value for binary operations.
258+
"""
259+
struct IntegerIdentityVal <: IdentityVal
260+
value::UInt128 # Store as UInt128 to handle all unsigned values up to 64 bits
261+
type_id::TypeId
262+
dtype::Type # Int8, Int16, Int32, Int64, UInt8, etc. (signedness inferred from dtype)
263+
end
264+
265+
"""
266+
encode_tagged_float!(cb, identity::FloatIdentityVal)
256267
257268
Encode a tagged float attribute for reduce identity.
258269
Format: tag(Float=0x02) + typeid + ap_int(value_bits)
259270
"""
260-
function encode_tagged_float!(cb::CodeBuilder, identity::FloatIdentity)
271+
function encode_tagged_float!(cb::CodeBuilder, identity::FloatIdentityVal)
261272
# Tag for Float attribute
262273
push!(cb.buf, 0x02)
263274
# Type ID
@@ -267,6 +278,41 @@ function encode_tagged_float!(cb::CodeBuilder, identity::FloatIdentity)
267278
encode_varint!(cb.buf, bits)
268279
end
269280

281+
"""
282+
encode_tagged_int!(cb, identity::IntegerIdentityVal)
283+
284+
Encode a tagged integer identity attribute.
285+
Format: tag(Int=0x01) + typeid + ap_int(value)
286+
"""
287+
function encode_tagged_int!(cb::CodeBuilder, identity::IntegerIdentityVal)
288+
# Tag for Int attribute
289+
push!(cb.buf, 0x01)
290+
# Type ID
291+
encode_typeid!(cb.buf, identity.type_id)
292+
# Mask value to correct bit width and apply zigzag encoding for signed types
293+
masked_value = mask_to_width(identity.value, identity.dtype)
294+
encode_varint!(cb.buf, masked_value) # masked_value are already zigzag encoded
295+
end
296+
297+
"""
298+
mask_to_width(value, dtype)
299+
300+
Mask a UInt128 value to the correct bit width for the given type.
301+
Applies zigzag encoding for signed types.
302+
"""
303+
function mask_to_width(value::UInt128, ::Type{T}) where T <: Integer
304+
bits = sizeof(T) * 8
305+
mask = (UInt128(1) << bits) - 1
306+
masked = value & mask
307+
U = unsigned(T)
308+
unsigned_masked = U(masked)
309+
if T <: Signed # do zig-zag encoding
310+
U((unsigned_masked << 1) (unsigned_masked >>> (bits - 1)))
311+
else
312+
unsigned_masked
313+
end
314+
end
315+
270316
"""
271317
float_to_bits(value, dtype)
272318
@@ -297,15 +343,24 @@ end
297343
"""
298344
encode_identity_array!(cb, identities)
299345
300-
Encode an array of reduce identity attributes.
346+
Encode an array of binary operation identity attributes.
347+
Dispatches on identity type to encode correctly.
301348
"""
302-
function encode_identity_array!(cb::CodeBuilder, identities::Vector{<:ReduceIdentity})
349+
function encode_identity_array!(cb::CodeBuilder, identities::Vector{<:IdentityVal})
303350
encode_varint!(cb.buf, length(identities))
304351
for identity in identities
305-
encode_tagged_float!(cb, identity)
352+
encode_identity!(cb, identity)
306353
end
307354
end
308355

356+
"""
357+
encode_identity!(cb, identity)
358+
359+
Encode a single identity attribute, dispatching on type.
360+
"""
361+
encode_identity!(cb::CodeBuilder, identity::FloatIdentityVal) = encode_tagged_float!(cb, identity)
362+
encode_identity!(cb::CodeBuilder, identity::IntegerIdentityVal) = encode_tagged_int!(cb, identity)
363+
309364
"""
310365
BytecodeWriter
311366

src/compiler/intrinsics.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ using Base: compilerbarrier, donotdelete
88
using ..cuTile: Tile, TileArray, Constant, TensorView, PartitionView
99
using ..cuTile: Signedness, SignednessSigned, SignednessUnsigned
1010
using ..cuTile: ComparisonPredicate, CmpLessThan, CmpLessThanOrEqual, CmpGreaterThan, CmpGreaterThanOrEqual, CmpEqual, CmpNotEqual
11+
using ..cuTile: IdentityVal, FloatIdentityVal, IntegerIdentityVal
1112

1213
end
1314

src/compiler/intrinsics/core.jl

Lines changed: 61 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -512,7 +512,7 @@ end
512512
Sum reduction along 0-indexed axis.
513513
Compiled to cuda_tile.reduce with ADD.
514514
"""
515-
@noinline function reduce_sum(tile::Tile{T, S}, ::Val{axis}) where {T <: AbstractFloat, S, axis}
515+
@noinline function reduce_sum(tile::Tile{T, S}, ::Val{axis}) where {T, S, axis}
516516
reduced_shape = ntuple(i -> S[i < axis + 1 ? i : i + 1], length(S) - 1)
517517
Tile{T, reduced_shape}()
518518
end
@@ -523,7 +523,7 @@ end
523523
Maximum reduction along 0-indexed axis.
524524
Compiled to cuda_tile.reduce with MAX.
525525
"""
526-
@noinline function reduce_max(tile::Tile{T, S}, ::Val{axis}) where {T <: AbstractFloat, S, axis}
526+
@noinline function reduce_max(tile::Tile{T, S}, ::Val{axis}) where {T, S, axis}
527527
reduced_shape = ntuple(i -> S[i < axis + 1 ? i : i + 1], length(S) - 1)
528528
Tile{T, reduced_shape}()
529529
end
@@ -562,28 +562,76 @@ function emit_reduce!(ctx::CGCtx, args, reduce_fn::Symbol)
562562
# Scalar type for reduction body (0D tile)
563563
scalar_tile_type = tile_type!(tt, dtype, Int[])
564564

565-
# Create identity value - use simple dtype (f32), not tile type
566-
identity_val = reduce_fn == :add ? -0.0 : (reduce_fn == :max ? -Inf : 0.0)
567-
identity = FloatIdentity(identity_val, dtype, elem_type)
565+
# Create identity value via dispatch on reduction function and element type
566+
identity = operation_identity(Val(reduce_fn), dtype, elem_type)
568567

569568
# Emit ReduceOp
570569
results = encode_ReduceOp!(cb, [output_tile_type], [input_tv.v], axis, [identity], [scalar_tile_type]) do block_args
571570
acc, elem = block_args[1], block_args[2]
572571

573-
if reduce_fn == :add
574-
res = encode_AddFOp!(cb, scalar_tile_type, acc, elem)
575-
elseif reduce_fn == :max
576-
res = encode_MaxFOp!(cb, scalar_tile_type, acc, elem)
577-
else
578-
error("Unsupported reduction function: $reduce_fn")
579-
end
580-
572+
res = encode_reduce_body(cb, scalar_tile_type, acc, elem, reduce_fn, elem_type)
581573
encode_YieldOp!(cb, [res])
582574
end
583575

584576
CGVal(results[1], output_tile_type, Tile{elem_type, Tuple(output_shape)}, output_shape)
585577
end
586578

579+
#=============================================================================#
580+
# Reduce Identity Values via Dispatch
581+
#=============================================================================#
582+
583+
"""
584+
to_uint128(value)
585+
586+
Convert an integer value to UInt128 for storage in IntegerIdentityVal.
587+
For signed types, this returns the two's complement bit representation.
588+
"""
589+
to_uint128(value::T) where T <: Unsigned = UInt128(value)
590+
to_uint128(value::T) where T <: Signed = UInt128(reinterpret(unsigned(T), value))
591+
592+
"""
593+
operation_identity(fn, dtype, elem_type) -> IdentityVal
594+
595+
Return the identity value for a binary operation (reduce, scan, etc.).
596+
Identity must satisfy: identity ⊕ x = x for the operation.
597+
"""
598+
599+
# Addition identity: 0 + x = x
600+
operation_identity(::Val{:add}, dtype, ::Type{T}) where T <: AbstractFloat =
601+
FloatIdentityVal(zero(T), dtype, T)
602+
operation_identity(::Val{:add}, dtype, ::Type{T}) where T <: Integer =
603+
IntegerIdentityVal(to_uint128(zero(T)), dtype, T)
604+
605+
# Maximum identity: max(typemin(T), x) = x
606+
operation_identity(::Val{:max}, dtype, ::Type{T}) where T <: AbstractFloat =
607+
FloatIdentityVal(typemin(T), dtype, T)
608+
operation_identity(::Val{:max}, dtype, ::Type{T}) where T <: Integer =
609+
IntegerIdentityVal(to_uint128(typemin(T)), dtype, T)
610+
611+
#=============================================================================#
612+
# Reduce Body Operations
613+
#=============================================================================#
614+
function encode_reduce_body(cb, type, acc, elem, op::Symbol, ::Type{T}) where T
615+
if T <: AbstractFloat
616+
if op == :add
617+
encode_AddFOp!(cb, type, acc, elem)
618+
elseif op == :max
619+
encode_MaxFOp!(cb, type, acc, elem)
620+
else
621+
error("Unsupported float reduction operation: $op")
622+
end
623+
else # Integer
624+
signedness = T <: Signed ? SignednessSigned : SignednessUnsigned
625+
if op == :add
626+
encode_AddIOp!(cb, type, acc, elem)
627+
elseif op == :max
628+
encode_MaxIOp!(cb, type, acc, elem; signedness)
629+
else
630+
error("Unsupported integer reduction operation: $op")
631+
end
632+
end
633+
end
634+
587635

588636
# cuda_tile.reshape
589637
@eval Intrinsics begin

src/language/operations.jl

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -529,10 +529,10 @@ Returns a tile with the specified dimension removed.
529529
sums = ct.reduce_sum(tile, 2) # Returns (128,) tile
530530
```
531531
"""
532-
@inline function reduce_sum(tile::Tile{T, S}, axis::Integer) where {T <: AbstractFloat, S}
532+
@inline function reduce_sum(tile::Tile{T, S}, axis::Integer) where {T <: Number, S}
533533
Intrinsics.reduce_sum(tile, Val(axis - 1))
534534
end
535-
@inline function reduce_sum(tile::Tile{T, S}, ::Val{axis}) where {T <: AbstractFloat, S, axis}
535+
@inline function reduce_sum(tile::Tile{T, S}, ::Val{axis}) where {T <: Number, S, axis}
536536
Intrinsics.reduce_sum(tile, Val(axis - 1))
537537
end
538538

@@ -546,10 +546,10 @@ Maximum reduction along the specified axis (1-indexed).
546546
maxes = ct.reduce_max(tile, 2) # Max along axis 2
547547
```
548548
"""
549-
@inline function reduce_max(tile::Tile{T, S}, axis::Integer) where {T <: AbstractFloat, S}
549+
@inline function reduce_max(tile::Tile{T, S}, axis::Integer) where {T <: Number, S}
550550
Intrinsics.reduce_max(tile, Val(axis - 1))
551551
end
552-
@inline function reduce_max(tile::Tile{T, S}, ::Val{axis}) where {T <: AbstractFloat, S, axis}
552+
@inline function reduce_max(tile::Tile{T, S}, ::Val{axis}) where {T <: Number, S, axis}
553553
Intrinsics.reduce_max(tile, Val(axis - 1))
554554
end
555555

@@ -649,4 +649,3 @@ br = ct.extract(tile, (2, 2), (4, 4)) # Bottom-right (rows 5-8, cols 5-8)
649649
Intrinsics.extract(tile, Val(map(i -> i - 1, index)), Val(shape))
650650
@inline extract(tile::Tile{T}, ::Val{Index}, ::Val{Shape}) where {T, Index, Shape} =
651651
Intrinsics.extract(tile, Val(map(i -> i - 1, Index)), Val(Shape))
652-

test/codegen.jl

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,62 @@
387387
end
388388
end
389389

390+
# Integer reduce_sum (Int32)
391+
@test @filecheck begin
392+
@check_label "entry"
393+
code_tiled(Tuple{ct.TileArray{Int32,2,spec2d}, ct.TileArray{Int32,1,spec1d}}) do a, b
394+
pid = ct.bid(1)
395+
tile = ct.load(a, pid, (4, 16))
396+
@check "reduce"
397+
@check "addi"
398+
sums = ct.reduce_sum(tile, 2)
399+
ct.store(b, pid, sums)
400+
return
401+
end
402+
end
403+
404+
# Integer reduce_max (Int32)
405+
@test @filecheck begin
406+
@check_label "entry"
407+
code_tiled(Tuple{ct.TileArray{Int32,2,spec2d}, ct.TileArray{Int32,1,spec1d}}) do a, b
408+
pid = ct.bid(1)
409+
tile = ct.load(a, pid, (4, 16))
410+
@check "reduce"
411+
@check "maxi"
412+
maxes = ct.reduce_max(tile, 2)
413+
ct.store(b, pid, maxes)
414+
return
415+
end
416+
end
417+
418+
# Unsigned reduce_sum (UInt32)
419+
@test @filecheck begin
420+
@check_label "entry"
421+
code_tiled(Tuple{ct.TileArray{UInt32,2,spec2d}, ct.TileArray{UInt32,1,spec1d}}) do a, b
422+
pid = ct.bid(1)
423+
tile = ct.load(a, pid, (4, 16))
424+
@check "reduce"
425+
@check "addi"
426+
sums = ct.reduce_sum(tile, 2)
427+
ct.store(b, pid, sums)
428+
return
429+
end
430+
end
431+
432+
# Unsigned reduce_max (UInt32)
433+
@test @filecheck begin
434+
@check_label "entry"
435+
code_tiled(Tuple{ct.TileArray{UInt32,2,spec2d}, ct.TileArray{UInt32,1,spec1d}}) do a, b
436+
pid = ct.bid(1)
437+
tile = ct.load(a, pid, (4, 16))
438+
@check "reduce"
439+
@check "maxi"
440+
maxes = ct.reduce_max(tile, 2)
441+
ct.store(b, pid, maxes)
442+
return
443+
end
444+
end
445+
390446
@testset "select" begin
391447
@test @filecheck begin
392448
@check_label "entry"

0 commit comments

Comments
 (0)