From 3de8daa5cf0bb88dd3f8c7c4582d00b364a5be5b Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Sat, 13 Dec 2025 13:35:42 +0530 Subject: [PATCH 1/2] Fix GEMM dispatch for complex-real matmul (#1520) This should fix #1519. The issue on master is that we have specialized dispatches for arrays of the same eltype, and the complex-real matmul ends up in `generic_matmatmul!`. This adds an extra method to ensure that the complex-real case also reaches BLAS. ```julia julia> A = ones(ComplexF64, 400, 400); B = ones(size(A)); C = similar(A); julia> @btime mul!($C, $A, $B); 57.043 ms (0 allocations: 0 bytes) # master 1.608 ms (0 allocations: 0 bytes) # this PR ``` --- src/matmul.jl | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/src/matmul.jl b/src/matmul.jl index e45c9efb..82ccc9f7 100644 --- a/src/matmul.jl +++ b/src/matmul.jl @@ -312,7 +312,7 @@ end BlasFlag.SYRK elseif (tA_uc == 'C' && tB_uc == 'N') || (tA_uc == 'N' && tB_uc == 'C') BlasFlag.HERK - else isntc + else BlasFlag.GEMM end else @@ -494,7 +494,7 @@ function matmul2x2or3x3_nonzeroalpha!(C, tA, tB, A, B, α::Bool, β) return false end -# THE one big BLAS dispatch. This is split into two methods to improve latency +# THE one big BLAS dispatch. This is split into syrk/herk/gemm and symm/hemm/none methods to improve latency Base.@constprop :aggressive function generic_matmatmul_wrapper!(C::StridedMatrix{T}, tA, tB, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}, α::Number, β::Number, val::BlasFlag.SyrkHerkGemm) where {T<:BlasFloat} mA, nA = lapack_size(tA, A) @@ -507,6 +507,12 @@ Base.@constprop :aggressive function generic_matmatmul_wrapper!(C::StridedMatrix _syrk_herk_gemm_wrapper!(C, tA, tB, A, B, α, β, val) return C end + +function generic_matmatmul_wrapper!(C::StridedVecOrMat{Complex{T}}, tA, tB, A::StridedVecOrMat{Complex{T}}, B::StridedVecOrMat{T}, + α::Number, β::Number, ::Val{BlasFlag.GEMM}) where {T<:BlasReal} + gemm_wrapper!(C, tA, tB, A, B, α, β) +end + Base.@constprop :aggressive function _syrk_herk_gemm_wrapper!(C, tA, tB, A, B, α, β, ::Val{BlasFlag.SYRK}) if A === B tA_uc = uppercase(tA) # potentially strip a WrapperChar @@ -583,14 +589,6 @@ Base.@constprop :aggressive generic_matmatmul!(C::StridedMatrix{T}, tA, tB, A::S _add::MulAddMul = MulAddMul()) where {T<:BlasFloat} = generic_matmatmul!(C, tA, tB, A, B, _add.alpha, _add.beta) -function generic_matmatmul_wrapper!(C::StridedVecOrMat{Complex{T}}, tA, tB, A::StridedVecOrMat{Complex{T}}, B::StridedVecOrMat{T}, - α::Number, β::Number, ::Val{true}) where {T<:BlasReal} - gemm_wrapper!(C, tA, tB, A, B, α, β) -end -Base.@constprop :aggressive function generic_matmatmul_wrapper!(C::StridedVecOrMat{Complex{T}}, tA, tB, A::StridedVecOrMat{Complex{T}}, B::StridedVecOrMat{T}, - alpha::Number, beta::Number, ::Val{false}) where {T<:BlasReal} - _generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), alpha, beta) -end # legacy method Base.@constprop :aggressive generic_matmatmul!(C::StridedVecOrMat{Complex{T}}, tA, tB, A::StridedVecOrMat{Complex{T}}, B::StridedVecOrMat{T}, _add::MulAddMul = MulAddMul()) where {T<:BlasReal} = From 79236831acfeb4a3b2b60ac375003658930bd05e Mon Sep 17 00:00:00 2001 From: "Viral B. Shah" Date: Sat, 13 Dec 2025 21:40:33 -0500 Subject: [PATCH 2/2] Make sure inputs to the hemm etc tests are Hermitian (#1522) Fix #1496 https://github.com/JuliaLinearAlgebra/AppleAccelerate.jl/issues/87 Co-authored-by: Viral B. Shah --- test/blas.jl | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/test/blas.jl b/test/blas.jl index b7f4a03a..6cda7fbe 100644 --- a/test/blas.jl +++ b/test/blas.jl @@ -721,15 +721,18 @@ end @test BLAS.her!('L', real(elty(2)), x, A) isa WrappedArray{elty,2} @test A == WrappedArray(elty[5 2+2im; 11+3im 20]) # Level 3 - A = WrappedArray(elty[1+im 2+2im; 3+3im 4+4im]) + # Hermitian matrices require real diagonal elements + A = WrappedArray(elty[1 2+2im; 2-2im 4]) B = WrappedArray(elty[1+im 2+2im; 3+3im 4+4im]) - C = WrappedArray(elty[1+im 2+2im; 3+3im 4+4im]) + C = WrappedArray(elty[1 2+2im; 2-2im 4]) @test BLAS.hemm!('L', 'U', elty(2), A, B, elty(1), C) isa WrappedArray{elty,2} - @test C == WrappedArray([3+27im 6+38im; 35+27im 52+36im]) + @test C == WrappedArray([3+26im 6+38im; 34+22im 52+32im]) + C = WrappedArray(elty[1 2+2im; 2-2im 4]) # reset C to Hermitian @test BLAS.herk!('U', 'N', real(elty(2)), A, real(elty(1)), C) isa WrappedArray{elty,2} - @test C == WrappedArray([23 50+38im; 35+27im 152]) + @test C == WrappedArray([19 22+22im; 2-2im 52]) + C = WrappedArray(elty[1 2+2im; 2-2im 4]) # reset C to Hermitian @test BLAS.her2k!('U', 'N', elty(2), A, B, real(elty(1)), C) isa WrappedArray{elty,2} - @test C == WrappedArray([63 138+38im; 35+27im 352]) + @test C == WrappedArray([37 56+20im; 2-2im 68]) end end