Skip to content

Commit dd3ea72

Browse files
authored
Backports to release 1.12.4 (#1524)
2 parents 997c4b7 + 7923683 commit dd3ea72

File tree

2 files changed

+16
-15
lines changed

2 files changed

+16
-15
lines changed

src/matmul.jl

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,7 @@ end
312312
BlasFlag.SYRK
313313
elseif (tA_uc == 'C' && tB_uc == 'N') || (tA_uc == 'N' && tB_uc == 'C')
314314
BlasFlag.HERK
315-
else isntc
315+
else
316316
BlasFlag.GEMM
317317
end
318318
else
@@ -494,7 +494,7 @@ function matmul2x2or3x3_nonzeroalpha!(C, tA, tB, A, B, α::Bool, β)
494494
return false
495495
end
496496

497-
# THE one big BLAS dispatch. This is split into two methods to improve latency
497+
# THE one big BLAS dispatch. This is split into syrk/herk/gemm and symm/hemm/none methods to improve latency
498498
Base.@constprop :aggressive function generic_matmatmul_wrapper!(C::StridedMatrix{T}, tA, tB, A::StridedVecOrMat{T}, B::StridedVecOrMat{T},
499499
α::Number, β::Number, val::BlasFlag.SyrkHerkGemm) where {T<:BlasFloat}
500500
mA, nA = lapack_size(tA, A)
@@ -507,6 +507,12 @@ Base.@constprop :aggressive function generic_matmatmul_wrapper!(C::StridedMatrix
507507
_syrk_herk_gemm_wrapper!(C, tA, tB, A, B, α, β, val)
508508
return C
509509
end
510+
511+
function generic_matmatmul_wrapper!(C::StridedVecOrMat{Complex{T}}, tA, tB, A::StridedVecOrMat{Complex{T}}, B::StridedVecOrMat{T},
512+
α::Number, β::Number, ::Val{BlasFlag.GEMM}) where {T<:BlasReal}
513+
gemm_wrapper!(C, tA, tB, A, B, α, β)
514+
end
515+
510516
Base.@constprop :aggressive function _syrk_herk_gemm_wrapper!(C, tA, tB, A, B, α, β, ::Val{BlasFlag.SYRK})
511517
if A === B
512518
tA_uc = uppercase(tA) # potentially strip a WrapperChar
@@ -583,14 +589,6 @@ Base.@constprop :aggressive generic_matmatmul!(C::StridedMatrix{T}, tA, tB, A::S
583589
_add::MulAddMul = MulAddMul()) where {T<:BlasFloat} =
584590
generic_matmatmul!(C, tA, tB, A, B, _add.alpha, _add.beta)
585591

586-
function generic_matmatmul_wrapper!(C::StridedVecOrMat{Complex{T}}, tA, tB, A::StridedVecOrMat{Complex{T}}, B::StridedVecOrMat{T},
587-
α::Number, β::Number, ::Val{true}) where {T<:BlasReal}
588-
gemm_wrapper!(C, tA, tB, A, B, α, β)
589-
end
590-
Base.@constprop :aggressive function generic_matmatmul_wrapper!(C::StridedVecOrMat{Complex{T}}, tA, tB, A::StridedVecOrMat{Complex{T}}, B::StridedVecOrMat{T},
591-
alpha::Number, beta::Number, ::Val{false}) where {T<:BlasReal}
592-
_generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), alpha, beta)
593-
end
594592
# legacy method
595593
Base.@constprop :aggressive generic_matmatmul!(C::StridedVecOrMat{Complex{T}}, tA, tB, A::StridedVecOrMat{Complex{T}}, B::StridedVecOrMat{T},
596594
_add::MulAddMul = MulAddMul()) where {T<:BlasReal} =

test/blas.jl

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -721,15 +721,18 @@ end
721721
@test BLAS.her!('L', real(elty(2)), x, A) isa WrappedArray{elty,2}
722722
@test A == WrappedArray(elty[5 2+2im; 11+3im 20])
723723
# Level 3
724-
A = WrappedArray(elty[1+im 2+2im; 3+3im 4+4im])
724+
# Hermitian matrices require real diagonal elements
725+
A = WrappedArray(elty[1 2+2im; 2-2im 4])
725726
B = WrappedArray(elty[1+im 2+2im; 3+3im 4+4im])
726-
C = WrappedArray(elty[1+im 2+2im; 3+3im 4+4im])
727+
C = WrappedArray(elty[1 2+2im; 2-2im 4])
727728
@test BLAS.hemm!('L', 'U', elty(2), A, B, elty(1), C) isa WrappedArray{elty,2}
728-
@test C == WrappedArray([3+27im 6+38im; 35+27im 52+36im])
729+
@test C == WrappedArray([3+26im 6+38im; 34+22im 52+32im])
730+
C = WrappedArray(elty[1 2+2im; 2-2im 4]) # reset C to Hermitian
729731
@test BLAS.herk!('U', 'N', real(elty(2)), A, real(elty(1)), C) isa WrappedArray{elty,2}
730-
@test C == WrappedArray([23 50+38im; 35+27im 152])
732+
@test C == WrappedArray([19 22+22im; 2-2im 52])
733+
C = WrappedArray(elty[1 2+2im; 2-2im 4]) # reset C to Hermitian
731734
@test BLAS.her2k!('U', 'N', elty(2), A, B, real(elty(1)), C) isa WrappedArray{elty,2}
732-
@test C == WrappedArray([63 138+38im; 35+27im 352])
735+
@test C == WrappedArray([37 56+20im; 2-2im 68])
733736
end
734737
end
735738

0 commit comments

Comments
 (0)