diff --git a/src/mapreduce.jl b/src/mapreduce.jl index ff8180a..d7c4659 100644 --- a/src/mapreduce.jl +++ b/src/mapreduce.jl @@ -340,6 +340,17 @@ function _mapreduce_kernel_expr(f, op, initop, N::Int, M::Int) outerreturnstrideex[i] = returnex end + # Unit-stride fast path for the innermost (vectorized) loop dimension. + # We special-case for contiguous loads/stores to avoid SIMD gather/scatter + # in favor of SIMD load/store, which streams memory more efficiently. + unitstep1ex = :($(Ivars[1]) += 1) + unitstep2ex = Expr(:block) + for j in 2:M + push!(unitstep2ex.args, :($(Ivars[j]) += 1)) + end + firststrides = Expr(:tuple, (stridevars[1, j] for j in 1:M)...) + unitstridecond = :(all(==(1), $firststrides)) + if op == Nothing ex = Expr(:(=), lhsex, fcallex) exa = Expr(:(=), :a, fcallex) @@ -358,6 +369,14 @@ function _mapreduce_kernel_expr(f, op, initop, N::Int, M::Int) end $lhsex = a $(returnstride2ex[i]) + elseif $unitstridecond + @simd for $(innerloopvars[i]) in Base.OneTo($(blockdimvars[i])) + $ex + $unitstep1ex + $unitstep2ex + end + $(returnstride1ex[i]) + $(returnstride2ex[i]) else @simd for $(innerloopvars[i]) in Base.OneTo($(blockdimvars[i])) $ex