Skip to content

Commit 2b39359

Browse files
committed
rework type stability
1 parent aa6870e commit 2b39359

File tree

7 files changed

+21
-19
lines changed

7 files changed

+21
-19
lines changed

ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ using LinearAlgebra: BlasFloat
1414

1515
include("yarocsolver.jl")
1616

17-
MatrixAlgebraKit.default_householder_driver(::StridedROCMatrix{<:BlasFloat}) = ROCSOLVER()
17+
MatrixAlgebraKit.default_householder_driver(::Type{A}) where {A <: StridedROCMatrix{<:BlasFloat}} = ROCSOLVER()
1818
function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {T <: StridedROCMatrix}
1919
return ROCSOLVER_QRIteration(; kwargs...)
2020
end

ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ using LinearAlgebra: BlasFloat
1515

1616
include("yacusolver.jl")
1717

18-
MatrixAlgebraKit.default_householder_driver(::StridedCuVecOrMat{<:BlasFloat}) = CUSOLVER()
18+
MatrixAlgebraKit.default_householder_driver(::Type{A}) where {A <: StridedCuVecOrMat{<:BlasFloat}} = CUSOLVER()
1919
function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {TT <: BlasFloat, T <: StridedCuVecOrMat{TT}}
2020
return CUSOLVER_QRIteration(; kwargs...)
2121
end

src/algorithms.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ Finally, the same behavior is obtained when the keyword arguments are
8888
passed as the third positional argument in the form of a `NamedTuple`.
8989
""" select_algorithm
9090

91-
Base.@assume_effects :foldable function select_algorithm(f::F, A, alg::Alg = nothing; kwargs...) where {F, Alg}
91+
@inline function select_algorithm(f::F, A, alg::Alg = nothing; kwargs...) where {F, Alg}
9292
if isnothing(alg)
9393
return default_algorithm(f, A; kwargs...)
9494
elseif alg isa Symbol

src/interface/decompositions.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,13 +80,17 @@ Depending on the driver, various other keywords may be (un)available to customiz
8080
"""
8181
@algdef Householder
8282

83-
default_householder_driver(A) = Native()
83+
default_householder_driver(A) = default_householder_driver(typeof(A))
84+
default_householder_driver(::Type) = Native()
8485

85-
default_householder_driver(::YALAPACK.MaybeBlasMat) = LAPACK()
86+
default_householder_driver(::Type{A}) where {A <: YALAPACK.MaybeBlasMat} = LAPACK()
8687

8788
# note: StridedVector fallback is needed for handling reshaped parent types
88-
default_householder_driver(::StridedVector{<:BlasFloat}) = LAPACK()
89-
default_householder_driver(A::Union{SubArray, Base.ReshapedArray}) = default_householder_driver(parent(A))
89+
default_householder_driver(::Type{A}) where {A <: StridedVector{<:BlasFloat}} = LAPACK()
90+
default_householder_driver(::Type{<:SubArray{T, N, A}}) where {T, N, A} =
91+
default_householder_driver(A)
92+
default_householder_driver(::Type{<:Base.ReshapedArray{T, N, A}}) where {T, N, A} =
93+
default_householder_driver(A)
9094

9195

9296
# General Eigenvalue Decomposition

src/interface/lq.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,10 +70,9 @@ See also [`qr_full(!)`](@ref lq_full) and [`qr_compact(!)`](@ref lq_compact).
7070
# -------------------
7171
default_lq_algorithm(A; kwargs...) = default_lq_algorithm(typeof(A); kwargs...)
7272

73-
default_lq_algorithm(T::Type; kwargs...) =
74-
throw(MethodError(default_lq_algorithm, (T,)))
75-
default_lq_algorithm(::Type{T}; kwargs...) where {T <: AbstractMatrix} =
76-
Householder(; kwargs...)
73+
default_lq_algorithm(T::Type; kwargs...) = throw(MethodError(default_lq_algorithm, (T,)))
74+
default_lq_algorithm(::Type{T}; driver = default_householder_driver(T), kwargs...) where {T <: AbstractMatrix} =
75+
Householder(; driver, kwargs...)
7776
default_lq_algorithm(::Type{T}; kwargs...) where {T <: Diagonal} =
7877
DiagonalAlgorithm(; kwargs...)
7978
default_lq_algorithm(::Type{<:Base.ReshapedArray{T, N, A}}) where {T, N, A} =

src/interface/qr.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,10 +70,9 @@ See also [`lq_full(!)`](@ref lq_full) and [`lq_compact(!)`](@ref lq_compact).
7070
# -------------------
7171
default_qr_algorithm(A; kwargs...) = default_qr_algorithm(typeof(A); kwargs...)
7272

73-
default_qr_algorithm(T::Type; kwargs...) =
74-
throw(MethodError(default_qr_algorithm, (T,)))
75-
default_qr_algorithm(::Type{T}; kwargs...) where {T <: AbstractMatrix} =
76-
Householder(; kwargs...)
73+
default_qr_algorithm(T::Type; kwargs...) = throw(MethodError(default_qr_algorithm, (T,)))
74+
default_qr_algorithm(::Type{T}; driver = default_householder_driver(T), kwargs...) where {T <: AbstractMatrix} =
75+
Householder(; driver, kwargs...)
7776
default_qr_algorithm(::Type{T}; kwargs...) where {T <: Diagonal} =
7877
DiagonalAlgorithm(; kwargs...)
7978
default_qr_algorithm(::Type{<:Base.ReshapedArray{T, N, A}}) where {T, N, A} =

test/algorithms.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ using MatrixAlgebraKit
22
using Test
33
using TestExtras
44
using MatrixAlgebraKit: LAPACK_SVDAlgorithm, PolarViaSVD, TruncatedAlgorithm,
5-
default_algorithm, select_algorithm, Householder
5+
default_algorithm, select_algorithm, Householder, LAPACK
66

77
@testset "default_algorithm" begin
88
A = randn(3, 3)
@@ -17,21 +17,21 @@ using MatrixAlgebraKit: LAPACK_SVDAlgorithm, PolarViaSVD, TruncatedAlgorithm,
1717
LAPACK_MultipleRelativelyRobustRepresentations()
1818
end
1919
for f in (lq_full!, lq_full, lq_compact!, lq_compact, lq_null!, lq_null)
20-
@test @constinferred(default_algorithm(f, A)) == Householder()
20+
@test @constinferred(default_algorithm(f, A)) == Householder(; driver = LAPACK())
2121
end
2222
for f in (left_polar!, left_polar, right_polar!, right_polar)
2323
@test @constinferred(default_algorithm(f, A)) ==
2424
PolarViaSVD(LAPACK_DivideAndConquer())
2525
end
2626
for f in (qr_full!, qr_full, qr_compact!, qr_compact, qr_null!, qr_null)
27-
@test @constinferred(default_algorithm(f, A)) == Householder()
27+
@test @constinferred(default_algorithm(f, A)) == Householder(; driver = LAPACK())
2828
end
2929
for f in (schur_full!, schur_full, schur_vals!, schur_vals)
3030
@test @constinferred(default_algorithm(f, A)) === LAPACK_Expert()
3131
end
3232

3333
@test @constinferred(default_algorithm(qr_compact!, A; blocksize = 2)) ==
34-
Householder(; blocksize = 2)
34+
Householder(; driver = LAPACK(), blocksize = 2)
3535
end
3636

3737
@testset "select_algorithm" begin

0 commit comments

Comments
 (0)