You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This should fix#1519. The issue on master is that we have specialized
dispatches for arrays of the same eltype, and the complex-real matmul
ends up in `generic_matmatmul!`. This adds an extra method to ensure
that the complex-real case also reaches BLAS.
```julia
julia> A = ones(ComplexF64, 400, 400); B = ones(size(A)); C = similar(A);
julia> @Btime mul!($C, $A, $B);
57.043 ms (0 allocations: 0 bytes) # master
1.608 ms (0 allocations: 0 bytes) # this PR
```
@@ -499,7 +499,7 @@ function matmul2x2or3x3_nonzeroalpha!(C, tA, tB, A, B, α::Bool, β)
499
499
returnfalse
500
500
end
501
501
502
-
# THE one big BLAS dispatch. This is split into two methods to improve latency
502
+
# THE one big BLAS dispatch. This is split into syrk/herk/gemm and symm/hemm/none methods to improve latency
503
503
Base.@constprop:aggressivefunctiongeneric_matmatmul_wrapper!(C::StridedMatrix{T}, tA, tB, A::StridedVecOrMat{T}, B::StridedVecOrMat{T},
504
504
α::Number, β::Number, val::BlasFlag.SyrkHerkGemm) where {T<:Number}
505
505
mA, nA =lapack_size(tA, A)
@@ -511,6 +511,12 @@ Base.@constprop :aggressive function generic_matmatmul_wrapper!(C::StridedMatrix
511
511
_syrk_herk_gemm_wrapper!(C, tA, tB, A, B, α, β, val)
512
512
return C
513
513
end
514
+
515
+
functiongeneric_matmatmul_wrapper!(C::StridedVecOrMat{Complex{T}}, tA, tB, A::StridedVecOrMat{Complex{T}}, B::StridedVecOrMat{T},
516
+
α::Number, β::Number, ::Val{BlasFlag.GEMM}) where {T<:BlasReal}
517
+
gemm_wrapper!(C, tA, tB, A, B, α, β)
518
+
end
519
+
514
520
Base.@constprop:aggressivefunction_syrk_herk_gemm_wrapper!(C, tA, tB, A, B, α, β, ::Val{BlasFlag.SYRK})
515
521
if A === B
516
522
tA_uc =uppercase(tA) # potentially strip a WrapperChar
@@ -657,14 +663,6 @@ Base.@constprop :aggressive generic_matmatmul!(C::StridedMatrix{T}, tA, tB, A::S
657
663
_add::MulAddMul=MulAddMul()) where {T<:BlasFloat} =
658
664
generic_matmatmul!(C, tA, tB, A, B, _add.alpha, _add.beta)
659
665
660
-
functiongeneric_matmatmul_wrapper!(C::StridedVecOrMat{Complex{T}}, tA, tB, A::StridedVecOrMat{Complex{T}}, B::StridedVecOrMat{T},
661
-
α::Number, β::Number, ::Val{true}) where {T<:BlasReal}
662
-
gemm_wrapper!(C, tA, tB, A, B, α, β)
663
-
end
664
-
Base.@constprop:aggressivefunctiongeneric_matmatmul_wrapper!(C::StridedVecOrMat{Complex{T}}, tA, tB, A::StridedVecOrMat{Complex{T}}, B::StridedVecOrMat{T},
665
-
alpha::Number, beta::Number, ::Val{false}) where {T<:BlasReal}
0 commit comments