Skip to content

Commit fb5ae40

Browse files
committed
one more round of eig changes
1 parent dfb97cd commit fb5ae40

6 files changed

Lines changed: 28 additions & 26 deletions

File tree

ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {T <
2222
return QRIteration(; kwargs...)
2323
end
2424
function MatrixAlgebraKit.default_eig_algorithm(::Type{T}; kwargs...) where {T <: StridedCuVecOrMat{<:BlasFloat}}
25-
return QRIteration(; balanced = false, kwargs...)
25+
return QRIteration(; kwargs...)
2626
end
2727
function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {T <: StridedCuVecOrMat{<:BlasFloat}}
2828
return DivideAndConquer(; kwargs...)

ext/MatrixAlgebraKitGenericSchurExt.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,9 @@ using GenericSchur
99
const GSFloat = Union{Float16, ComplexF16, BigFloat, Complex{BigFloat}}
1010

1111
function MatrixAlgebraKit.default_eig_algorithm(
12-
::Type{T};
13-
balanced::Bool = false, driver::Driver = GS(), kwargs...
12+
::Type{T}; driver::Driver = GS(), kwargs...
1413
) where {T <: StridedMatrix{<:GSFloat}}
15-
return QRIteration(; driver, balanced, kwargs...)
14+
return QRIteration(; driver, kwargs...)
1615
end
1716

1817
function geev!(::GS, A::AbstractMatrix, Dd::AbstractVector, V::AbstractMatrix; kwargs...)

src/implementations/eig.jl

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -91,18 +91,16 @@ end
9191
# IMPLEMENTATIONS
9292
# ==========================
9393

94-
geev!(driver::Driver, args...; kwargs...) = throw(ArgumentError("$driver does not provide $f!"))
94+
geev!(driver::Driver, args...; kwargs...) = throw(ArgumentError("$driver does not provide `geev!`"))
9595
function geevx!(driver::Driver, A, Dd, V; kwargs...)
9696
@warn "$driver does not provide `geevx!`, falling back to `geev!`" maxlog = 1
97-
return geev!(driver, A, Dd, V; kwargs...)
97+
return geev!(driver, A, Dd, V)
9898
end
99-
_has_geevx!(::Driver) = false
10099

101100
# LAPACK implementations
102101
for f! in (:geev!, :geevx!)
103102
@eval $f!(::LAPACK, args...; kwargs...) = YALAPACK.$f!(args...; kwargs...)
104103
end
105-
_has_geevx!(::LAPACK) = true
106104

107105
# driver dispatch
108106
@inline qr_iteration_eig_full!(A, Dd, V; driver::Driver = DefaultDriver(), kwargs...) =
@@ -118,17 +116,17 @@ _has_geevx!(::LAPACK) = true
118116
# Implementation
119117
function qr_iteration_eig_full!(
120118
driver::Driver, A, Dd, V;
121-
fixgauge::Bool = default_fixgauge(), balanced::Bool = _has_geevx!(driver), kwargs...
119+
fixgauge::Bool = default_fixgauge(), scale::Bool = true, permute::Bool = true
122120
)
123-
(balanced ? geevx! : geev!)(driver, A, Dd, V; kwargs...)
121+
(scale & permute) ? geev!(driver, A, Dd, V) : geevx!(driver, A, Dd, V; scale, permute)
124122
fixgauge && gaugefix!(eig_full!, V)
125123
return Dd, V
126124
end
127125
function qr_iteration_eig_vals!(
128126
driver::Driver, A, D, V;
129-
fixgauge::Bool = default_fixgauge(), balanced::Bool = _has_geevx!(driver), kwargs...
127+
fixgauge::Bool = default_fixgauge(), scale::Bool = true, permute::Bool = true
130128
)
131-
(balanced ? geevx! : geev!)(driver, A, D, V; kwargs...)
129+
(scale & permute) ? geev!(driver, A, D, V) : geevx!(driver, A, D, V; scale, permute)
132130
return D
133131
end
134132

@@ -188,15 +186,15 @@ end
188186

189187
# Deprecations
190188
# ------------
191-
for (lapack_algtype, balanced_val) in ((:LAPACK_Simple, false), (:LAPACK_Expert, true))
189+
for lapack_algtype in (:LAPACK_Simple, :LAPACK_Expert)
192190
@eval begin
193191
Base.@deprecate(
194192
eig_full!(A, DV, alg::$lapack_algtype),
195-
eig_full!(A, DV, QRIteration(; balanced = $balanced_val, alg.kwargs...))
193+
eig_full!(A, DV, QRIteration(; alg.kwargs...))
196194
)
197195
Base.@deprecate(
198196
eig_vals!(A, D, alg::$lapack_algtype),
199-
eig_vals!(A, D, QRIteration(; balanced = $balanced_val, alg.kwargs...))
197+
eig_vals!(A, D, QRIteration(; alg.kwargs...))
200198
)
201199
end
202200
end

src/implementations/schur.jl

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,10 @@ end
5353
# IMPLEMENTATIONS
5454
# ==========================
5555

56-
for f! in (:gees!, :geesx!)
57-
@eval $f!(driver::Driver, args...) = throw(ArgumentError("$driver does not provide $f!"))
56+
gees!(driver::Driver, args...) = throw(ArgumentError("$driver does not provide `gees!`"))
57+
function geesx!(driver::Driver, A, Dd, V; kwargs...)
58+
@warn "$driver does not provide `geesx!`, falling back to `gees!`" maxlog = 1
59+
return gees!(driver, A, Dd, V)
5860
end
5961

6062
# LAPACK implementations
@@ -74,13 +76,13 @@ end
7476
qr_iteration_schur_vals!(default_driver(QRIteration, A), A, Z, vals; kwargs...)
7577

7678
# Implementation
77-
function qr_iteration_schur_full!(driver::Driver, A, T, Z, vals; balanced::Bool = false)
78-
(balanced ? geesx! : gees!)(driver, A, Z, vals)
79+
function qr_iteration_schur_full!(driver::Driver, A, T, Z, vals; expert::Bool = false)
80+
expert ? geesx!(driver, A, Z, vals) : gees!(driver, A, Z, vals)
7981
T === A || copy!(T, A)
8082
return T, Z, vals
8183
end
82-
function qr_iteration_schur_vals!(driver::Driver, A, Z, vals; balanced::Bool = false)
83-
(balanced ? geesx! : gees!)(driver, A, Z, vals)
84+
function qr_iteration_schur_vals!(driver::Driver, A, Z, vals; expert::Bool = false)
85+
expert ? geesx!(driver, A, Z, vals) : gees!(driver, A, Z, vals)
8486
return vals
8587
end
8688

@@ -100,15 +102,15 @@ end
100102

101103
# Deprecations
102104
# ------------
103-
for (lapack_algtype, balanced_val) in ((:LAPACK_Simple, false), (:LAPACK_Expert, true))
105+
for (lapack_algtype, expert_val) in ((:LAPACK_Simple, false), (:LAPACK_Expert, true))
104106
@eval begin
105107
Base.@deprecate(
106108
schur_full!(A, TZv, alg::$lapack_algtype),
107-
schur_full!(A, TZv, QRIteration(; balanced = $balanced_val, alg.kwargs...))
109+
schur_full!(A, TZv, QRIteration(; expert = $expert_val, alg.kwargs...))
108110
)
109111
Base.@deprecate(
110112
schur_vals!(A, vals, alg::$lapack_algtype),
111-
schur_vals!(A, vals, QRIteration(; balanced = $balanced_val, alg.kwargs...))
113+
schur_vals!(A, vals, QRIteration(; expert = $expert_val, alg.kwargs...))
112114
)
113115
end
114116
end

test/eig.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,10 @@ for T in (BLASFloats..., GenericFloats...)
3535
if !is_buildkite
3636
TestSuite.test_eig(T, (m, m))
3737
if T BLASFloats
38-
LAPACK_EIG_ALGS = (QRIteration(), QRIteration(balanced = true))
38+
LAPACK_EIG_ALGS = (
39+
QRIteration(),
40+
QRIteration(scale = false), # to trigger geevx!
41+
)
3942
TestSuite.test_eig_algs(T, (m, m), LAPACK_EIG_ALGS)
4043
elseif T GenericFloats
4144
GS_EIG_ALGS = (QRIteration(; driver = GS()),)

test/schur.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ for T in (BLASFloats..., GenericFloats...)
2828
if !is_buildkite
2929
TestSuite.test_schur(T, (m, m))
3030
if T BLASFloats
31-
LAPACK_SCHUR_ALGS = (QRIteration(), QRIteration(balanced = true))
31+
LAPACK_SCHUR_ALGS = (QRIteration(), QRIteration(expert = true))
3232
TestSuite.test_schur_algs(T, (m, m), LAPACK_SCHUR_ALGS)
3333
end
3434
#AT = Diagonal{T, Vector{T}}

0 commit comments

Comments
 (0)