|
512 | 512 | Sum reduction along 0-indexed axis. |
513 | 513 | Compiled to cuda_tile.reduce with ADD. |
514 | 514 | """ |
515 | | - @noinline function reduce_sum(tile::Tile{T, S}, ::Val{axis}) where {T <: AbstractFloat, S, axis} |
| 515 | + @noinline function reduce_sum(tile::Tile{T, S}, ::Val{axis}) where {T, S, axis} |
516 | 516 | reduced_shape = ntuple(i -> S[i < axis + 1 ? i : i + 1], length(S) - 1) |
517 | 517 | Tile{T, reduced_shape}() |
518 | 518 | end |
|
523 | 523 | Maximum reduction along 0-indexed axis. |
524 | 524 | Compiled to cuda_tile.reduce with MAX. |
525 | 525 | """ |
526 | | - @noinline function reduce_max(tile::Tile{T, S}, ::Val{axis}) where {T <: AbstractFloat, S, axis} |
| 526 | + @noinline function reduce_max(tile::Tile{T, S}, ::Val{axis}) where {T, S, axis} |
527 | 527 | reduced_shape = ntuple(i -> S[i < axis + 1 ? i : i + 1], length(S) - 1) |
528 | 528 | Tile{T, reduced_shape}() |
529 | 529 | end |
@@ -562,28 +562,76 @@ function emit_reduce!(ctx::CGCtx, args, reduce_fn::Symbol) |
562 | 562 | # Scalar type for reduction body (0D tile) |
563 | 563 | scalar_tile_type = tile_type!(tt, dtype, Int[]) |
564 | 564 |
|
565 | | - # Create identity value - use simple dtype (f32), not tile type |
566 | | - identity_val = reduce_fn == :add ? -0.0 : (reduce_fn == :max ? -Inf : 0.0) |
567 | | - identity = FloatIdentity(identity_val, dtype, elem_type) |
| 565 | + # Create identity value via dispatch on reduction function and element type |
| 566 | + identity = operation_identity(Val(reduce_fn), dtype, elem_type) |
568 | 567 |
|
569 | 568 | # Emit ReduceOp |
570 | 569 | results = encode_ReduceOp!(cb, [output_tile_type], [input_tv.v], axis, [identity], [scalar_tile_type]) do block_args |
571 | 570 | acc, elem = block_args[1], block_args[2] |
572 | 571 |
|
573 | | - if reduce_fn == :add |
574 | | - res = encode_AddFOp!(cb, scalar_tile_type, acc, elem) |
575 | | - elseif reduce_fn == :max |
576 | | - res = encode_MaxFOp!(cb, scalar_tile_type, acc, elem) |
577 | | - else |
578 | | - error("Unsupported reduction function: $reduce_fn") |
579 | | - end |
580 | | - |
| 572 | + res = encode_reduce_body(cb, scalar_tile_type, acc, elem, reduce_fn, elem_type) |
581 | 573 | encode_YieldOp!(cb, [res]) |
582 | 574 | end |
583 | 575 |
|
584 | 576 | CGVal(results[1], output_tile_type, Tile{elem_type, Tuple(output_shape)}, output_shape) |
585 | 577 | end |
586 | 578 |
|
| 579 | +#=============================================================================# |
| 580 | +# Reduce Identity Values via Dispatch |
| 581 | +#=============================================================================# |
| 582 | + |
| 583 | +""" |
| 584 | + to_uint128(value) |
| 585 | +
|
| 586 | +Convert an integer value to UInt128 for storage in IntegerIdentityVal. |
| 587 | +For signed types, this returns the two's complement bit representation. |
| 588 | +""" |
| 589 | +to_uint128(value::T) where T <: Unsigned = UInt128(value) |
| 590 | +to_uint128(value::T) where T <: Signed = UInt128(reinterpret(unsigned(T), value)) |
| 591 | + |
| 592 | +""" |
| 593 | + operation_identity(fn, dtype, elem_type) -> IdentityVal |
| 594 | +
|
| 595 | +Return the identity value for a binary operation (reduce, scan, etc.). |
| 596 | +Identity must satisfy: identity ⊕ x = x for the operation. |
| 597 | +""" |
| 598 | + |
| 599 | +# Addition identity: 0 + x = x |
| 600 | +operation_identity(::Val{:add}, dtype, ::Type{T}) where T <: AbstractFloat = |
| 601 | + FloatIdentityVal(zero(T), dtype, T) |
| 602 | +operation_identity(::Val{:add}, dtype, ::Type{T}) where T <: Integer = |
| 603 | + IntegerIdentityVal(to_uint128(zero(T)), dtype, T) |
| 604 | + |
| 605 | +# Maximum identity: max(typemin(T), x) = x |
| 606 | +operation_identity(::Val{:max}, dtype, ::Type{T}) where T <: AbstractFloat = |
| 607 | + FloatIdentityVal(typemin(T), dtype, T) |
| 608 | +operation_identity(::Val{:max}, dtype, ::Type{T}) where T <: Integer = |
| 609 | + IntegerIdentityVal(to_uint128(typemin(T)), dtype, T) |
| 610 | + |
| 611 | +#=============================================================================# |
| 612 | +# Reduce Body Operations |
| 613 | +#=============================================================================# |
| 614 | +function encode_reduce_body(cb, type, acc, elem, op::Symbol, ::Type{T}) where T |
| 615 | + if T <: AbstractFloat |
| 616 | + if op == :add |
| 617 | + encode_AddFOp!(cb, type, acc, elem) |
| 618 | + elseif op == :max |
| 619 | + encode_MaxFOp!(cb, type, acc, elem) |
| 620 | + else |
| 621 | + error("Unsupported float reduction operation: $op") |
| 622 | + end |
| 623 | + else # Integer |
| 624 | + signedness = T <: Signed ? SignednessSigned : SignednessUnsigned |
| 625 | + if op == :add |
| 626 | + encode_AddIOp!(cb, type, acc, elem) |
| 627 | + elseif op == :max |
| 628 | + encode_MaxIOp!(cb, type, acc, elem; signedness) |
| 629 | + else |
| 630 | + error("Unsupported integer reduction operation: $op") |
| 631 | + end |
| 632 | + end |
| 633 | +end |
| 634 | + |
587 | 635 |
|
588 | 636 | # cuda_tile.reshape |
589 | 637 | @eval Intrinsics begin |
|
0 commit comments