Skip to content

Commit eb007bb

Browse files
authored
Fix GEMM dispatch for complex-real matmul (#1520)
This should fix #1519. The issue on master is that we have specialized dispatches for arrays of the same eltype, and the complex-real matmul ends up in `generic_matmatmul!`. This adds an extra method to ensure that the complex-real case also reaches BLAS. ```julia julia> A = ones(ComplexF64, 400, 400); B = ones(size(A)); C = similar(A); julia> @Btime mul!($C, $A, $B); 57.043 ms (0 allocations: 0 bytes) # master 1.608 ms (0 allocations: 0 bytes) # this PR ```
1 parent f41bf65 commit eb007bb

File tree

1 file changed

+8
-10
lines changed

1 file changed

+8
-10
lines changed

src/matmul.jl

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,7 @@ end
317317
BlasFlag.SYRK
318318
elseif (tA_uc == 'C' && tB_uc == 'N') || (tA_uc == 'N' && tB_uc == 'C')
319319
BlasFlag.HERK
320-
else isntc
320+
else
321321
BlasFlag.GEMM
322322
end
323323
else
@@ -499,7 +499,7 @@ function matmul2x2or3x3_nonzeroalpha!(C, tA, tB, A, B, α::Bool, β)
499499
return false
500500
end
501501

502-
# THE one big BLAS dispatch. This is split into two methods to improve latency
502+
# THE one big BLAS dispatch. This is split into syrk/herk/gemm and symm/hemm/none methods to improve latency
503503
Base.@constprop :aggressive function generic_matmatmul_wrapper!(C::StridedMatrix{T}, tA, tB, A::StridedVecOrMat{T}, B::StridedVecOrMat{T},
504504
α::Number, β::Number, val::BlasFlag.SyrkHerkGemm) where {T<:Number}
505505
mA, nA = lapack_size(tA, A)
@@ -511,6 +511,12 @@ Base.@constprop :aggressive function generic_matmatmul_wrapper!(C::StridedMatrix
511511
_syrk_herk_gemm_wrapper!(C, tA, tB, A, B, α, β, val)
512512
return C
513513
end
514+
515+
function generic_matmatmul_wrapper!(C::StridedVecOrMat{Complex{T}}, tA, tB, A::StridedVecOrMat{Complex{T}}, B::StridedVecOrMat{T},
516+
α::Number, β::Number, ::Val{BlasFlag.GEMM}) where {T<:BlasReal}
517+
gemm_wrapper!(C, tA, tB, A, B, α, β)
518+
end
519+
514520
Base.@constprop :aggressive function _syrk_herk_gemm_wrapper!(C, tA, tB, A, B, α, β, ::Val{BlasFlag.SYRK})
515521
if A === B
516522
tA_uc = uppercase(tA) # potentially strip a WrapperChar
@@ -657,14 +663,6 @@ Base.@constprop :aggressive generic_matmatmul!(C::StridedMatrix{T}, tA, tB, A::S
657663
_add::MulAddMul = MulAddMul()) where {T<:BlasFloat} =
658664
generic_matmatmul!(C, tA, tB, A, B, _add.alpha, _add.beta)
659665

660-
function generic_matmatmul_wrapper!(C::StridedVecOrMat{Complex{T}}, tA, tB, A::StridedVecOrMat{Complex{T}}, B::StridedVecOrMat{T},
661-
α::Number, β::Number, ::Val{true}) where {T<:BlasReal}
662-
gemm_wrapper!(C, tA, tB, A, B, α, β)
663-
end
664-
Base.@constprop :aggressive function generic_matmatmul_wrapper!(C::StridedVecOrMat{Complex{T}}, tA, tB, A::StridedVecOrMat{Complex{T}}, B::StridedVecOrMat{T},
665-
alpha::Number, beta::Number, ::Val{false}) where {T<:BlasReal}
666-
_generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), alpha, beta)
667-
end
668666
# legacy method
669667
Base.@constprop :aggressive generic_matmatmul!(C::StridedVecOrMat{Complex{T}}, tA, tB, A::StridedVecOrMat{Complex{T}}, B::StridedVecOrMat{T},
670668
_add::MulAddMul = MulAddMul()) where {T<:BlasReal} =

0 commit comments

Comments
 (0)