From 94f20059e57e22b8810938d8415527eead8cecb9 Mon Sep 17 00:00:00 2001 From: lkdvos Date: Thu, 18 Jun 2026 17:42:40 -0400 Subject: [PATCH 1/2] Add @nospecialize to mapreduce dispatch chain to cut compile time The strided mapreduce machinery (map/map!/mapreducedim!/_mapreduce and the inner bookkeeping + threading helpers) was specialized on the map/reduce function types f/op/initop. Combined with the (M, N, eltype) axes this caused a combinatorial explosion of specializations for downstream packages such as TensorOperations, which generate many distinct closures. Annotate the outer entry points and the deeper bookkeeping/threading helpers with @nospecialize so they compile once per (M, N, eltype) regardless of the function types. The expensive @generated `_mapreduce_kernel!` stays fully specialized and is reached via a function barrier (one dynamic dispatch per coarse call), so steady-state runtime is unchanged. Also split the body of the @generated `_mapreduce_kernel!` into a sibling plain function `_mapreduce_kernel_expr(f, op, initop, N, M)` that returns the Expr, for clarity; the generated kernel itself is otherwise unchanged. `_mapreduce_block!` is preserved as the GPU extension override point. Co-Authored-By: Claude Opus 4.8 (1M context) --- src/mapreduce.jl | 30 +++++++++++++++++++----------- 1 file changed, 19 insertions(+), 11 deletions(-) diff --git a/src/mapreduce.jl b/src/mapreduce.jl index 9182ee8..4f5b8cd 100644 --- a/src/mapreduce.jl +++ b/src/mapreduce.jl @@ -33,17 +33,17 @@ function Base._mapreduce_dim(f, op, ::NamedTuple{()}, A::StridedView, dims) end function Base.map( - f::F, a1::StridedView{<:Any, N}, + @nospecialize(f), a1::StridedView{<:Any, N}, A::Vararg{StridedView{<:Any, N}} - ) where {F, N} + ) where {N} T = Base.promote_eltype(a1, A...) return map!(f, similar(a1, T), a1, A...) end function Base.map!( - f::F, b::StridedView{<:Any, N}, a1::StridedView{<:Any, N}, + @nospecialize(f), b::StridedView{<:Any, N}, a1::StridedView{<:Any, N}, A::Vararg{StridedView{<:Any, N}} - ) where {F, N} + ) where {N} dims = size(b) # Check dimesions @@ -59,7 +59,7 @@ function Base.map!( return b end -function _mapreduce(f, op, A::StridedView{T}, nt = nothing) where {T} +function _mapreduce(@nospecialize(f), @nospecialize(op), A::StridedView{T}, nt = nothing) where {T} if isempty(A) b = Base.mapreduce_empty(f, op, T) return nt === nothing ? b : op(b, nt.init) @@ -79,7 +79,7 @@ function _mapreduce(f, op, A::StridedView{T}, nt = nothing) where {T} end function Base.mapreducedim!( - f, op, b::StridedView{<:Any, N}, + @nospecialize(f), @nospecialize(op), b::StridedView{<:Any, N}, a1::StridedView{<:Any, N}, A::Vararg{StridedView{<:Any, N}} ) where {N} @@ -93,7 +93,7 @@ function Base.mapreducedim!( end function _mapreducedim!( - (f), (op), (initop), + @nospecialize(f), @nospecialize(op), @nospecialize(initop), dims::Dims, arrays::Tuple{Vararg{StridedView}} ) if any(isequal(0), dims) @@ -107,7 +107,7 @@ function _mapreducedim!( end function _mapreduce_fuse!( - (f), (op), (initop), + @nospecialize(f), @nospecialize(op), @nospecialize(initop), dims::Dims, arrays::Tuple{Vararg{StridedView}} ) # Fuse dimensions if possible: assume that at least one array, e.g. the output array in @@ -130,7 +130,7 @@ function _mapreduce_fuse!( end function _mapreduce_order!( - (f), (op), (initop), + @nospecialize(f), @nospecialize(op), @nospecialize(initop), dims, strides, arrays ) M = length(arrays) @@ -155,7 +155,7 @@ end const MINTHREADLENGTH = 1 << 15 # minimal length before any kind of threading is applied function _mapreduce_block!( - (f), (op), (initop), + @nospecialize(f), @nospecialize(op), @nospecialize(initop), dims, strides, offsets, costs, arrays ) bytestrides = map((s, stride) -> s .* stride, sizeof.(eltype.(arrays)), strides) @@ -214,7 +214,7 @@ end # nthreads: number of threads spacing: extra addition to offset of array 1, to account for # reduction function _mapreduce_threaded!( - (f), (op), (initop), + @nospecialize(f), @nospecialize(op), @nospecialize(initop), dims, blocks, strides, offsets, costs, arrays, nthreads, spacing, taskindex ) @@ -261,6 +261,14 @@ end strides::NTuple{M, NTuple{N, Int}}, offsets::NTuple{M, Int} ) where {N, M} + return _mapreduce_kernel_expr(f, op, initop, N, M) +end + +# Build the body of `_mapreduce_kernel!` as an `Expr`. Split out from the +# `@generated` function so the generation logic is a plain function. `f`, `op`, +# `initop` are the *types* of the corresponding arguments (as seen inside +# `@generated`); `N`/`M` are the ndims / number of arrays. +function _mapreduce_kernel_expr(f, op, initop, N::Int, M::Int) # many variables blockloopvars = Array{Symbol}(undef, N) From 25b35b43157392ad1815197b439c8cc6f16bdcad Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 18 Jun 2026 17:46:05 -0400 Subject: [PATCH 2/2] Apply suggestion from @lkdvos --- src/mapreduce.jl | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/mapreduce.jl b/src/mapreduce.jl index 4f5b8cd..ff8180a 100644 --- a/src/mapreduce.jl +++ b/src/mapreduce.jl @@ -264,10 +264,6 @@ end return _mapreduce_kernel_expr(f, op, initop, N, M) end -# Build the body of `_mapreduce_kernel!` as an `Expr`. Split out from the -# `@generated` function so the generation logic is a plain function. `f`, `op`, -# `initop` are the *types* of the corresponding arguments (as seen inside -# `@generated`); `N`/`M` are the ndims / number of arrays. function _mapreduce_kernel_expr(f, op, initop, N::Int, M::Int) # many variables