Skip to content

Commit bb27392

Browse files
lkdvosclaude
andauthored
Add TruncationUnion to support supplying minimal ranks (#183)
* Add `TruncationUnion` and `|` operator for truncation strategies `TruncationUnion` is the symmetric counterpart to `TruncationIntersection`, keeping values that are present in *any* component strategy. `NoTruncation` acts as the absorbing element (union with "keep all" = keep all). Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * Add `minrank` keyword argument to `TruncationStrategy` `minrank` provides a lower bound on the number of kept values, composing with other constraints via `TruncationUnion`. When no upper-bound constraints are active, `minrank` alone returns `truncrank(minrank)` directly. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * Add tests for `TruncationUnion` and `minrank` Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * Update truncations documentation for `TruncationUnion` and `minrank` Add examples for the `minrank` keyword argument and the `|` operator, and add `@docs` entries for `TruncationIntersection` and `TruncationUnion`. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * Add GPU overrides for `_ind_union` in CUDA and AMDGPU extensions Mirrors the existing `_ind_intersect` overrides: collect GPU index vectors to CPU before computing the union, since `union` does not work on GPU arrays. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * fix GPU support maybe * bold any/all * consistent punctuation --------- Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 570b221 commit bb27392

File tree

8 files changed

+209
-15
lines changed

8 files changed

+209
-15
lines changed

docs/src/user_interface/truncations.md

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@ all(>(2.9), diagview(Dtrunc))
5656
true
5757
```
5858

59+
Use `maxrank` together with a tolerance to keep at most `maxrank` values above the tolerance (intersection):
60+
5961
```jldoctest truncations; output=false
6062
Dtrunc, Vtrunc, ϵ = eigh_trunc(A; trunc = (maxrank = 2, atol = 2.9));
6163
size(Dtrunc, 1) <= 2 && all(>(2.9), diagview(Dtrunc))
@@ -64,6 +66,16 @@ size(Dtrunc, 1) <= 2 && all(>(2.9), diagview(Dtrunc))
6466
true
6567
```
6668

69+
Use `minrank` together with a tolerance to guarantee at least `minrank` values are kept (union):
70+
71+
```jldoctest truncations; output=false
72+
Dtrunc, Vtrunc, ϵ = eigh_trunc(A; trunc = (atol = 3.5, minrank = 2));
73+
size(Dtrunc, 1) >= 2
74+
75+
# output
76+
true
77+
```
78+
6779
In general, the keyword arguments that are supported can be found in the `TruncationStrategy` docstring:
6880

6981
```@docs; canonical = false
@@ -84,6 +96,8 @@ size(Dtrunc, 1) <= 2
8496
true
8597
```
8698

99+
Strategies can be combined with `&` (intersection: keep values satisfying **all** conditions):
100+
87101
```jldoctest truncations; output=false
88102
Dtrunc, Vtrunc, ϵ = eigh_trunc(A; trunc = truncrank(2) & trunctol(; atol = 2.9))
89103
size(Dtrunc, 1) <= 2 && all(>(2.9), diagview(Dtrunc))
@@ -92,6 +106,17 @@ size(Dtrunc, 1) <= 2 && all(>(2.9), diagview(Dtrunc))
92106
true
93107
```
94108

109+
Strategies can also be combined with `|` (union: keep values satisfying **any** condition).
110+
This is useful to set a lower bound on the number of kept values with `minrank`:
111+
112+
```jldoctest truncations; output=false
113+
Dtrunc, Vtrunc, ϵ = eigh_trunc(A; trunc = trunctol(; atol = 3.5) | truncrank(2))
114+
size(Dtrunc, 1) >= 2
115+
116+
# output
117+
true
118+
```
119+
95120
## Truncation Strategies
96121

97122
MatrixAlgebraKit provides several built-in truncation strategies:
@@ -104,11 +129,20 @@ truncfilter
104129
truncerror
105130
```
106131

107-
Truncation strategies can be combined using the `&` operator to create intersection-based truncation.
108-
When strategies are combined, only the values that satisfy all conditions are kept.
132+
Strategies can be composed using the `&` operator ([`TruncationIntersection`](@ref)) to keep only values satisfying **all** conditions,
133+
or the `|` operator ([`TruncationUnion`](@ref)) to keep values satisfying **any** condition.
134+
135+
```@docs; canonical=false
136+
TruncationIntersection
137+
TruncationUnion
138+
```
109139

110140
```julia
141+
# Keep at most 10 values, all above tolerance (intersection)
111142
combined_trunc = truncrank(10) & trunctol(; atol = 1e-6);
143+
144+
# Keep values above tolerance, but always at least 3 (union)
145+
combined_trunc = trunctol(; atol = 1e-6) | truncrank(3);
112146
```
113147

114148
## Truncation Error

ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,9 +161,15 @@ function MatrixAlgebraKit._mul_herm!(C::StridedROCMatrix{T}, A::StridedROCMatrix
161161
return C
162162
end
163163

164-
# TODO: intersect doesn't work on GPU
164+
# TODO: intersect/union don't work on GPU
165165
MatrixAlgebraKit._ind_intersect(A::ROCVector{Int}, B::ROCVector{Int}) =
166166
MatrixAlgebraKit._ind_intersect(collect(A), collect(B))
167+
MatrixAlgebraKit._ind_union(A::AbstractVector{<:Integer}, B::ROCVector{Int}) =
168+
MatrixAlgebraKit._ind_union(A, collect(B))
169+
MatrixAlgebraKit._ind_union(A::ROCVector{Int}, B::AbstractVector{<:Integer}) =
170+
MatrixAlgebraKit._ind_union(collect(A), B)
171+
MatrixAlgebraKit._ind_union(A::ROCVector{Int}, B::ROCVector{Int}) =
172+
MatrixAlgebraKit._ind_union(collect(A), collect(B))
167173

168174
function _sylvester(A::AnyROCMatrix, B::AnyROCMatrix, C::AnyROCMatrix)
169175
hX = sylvester(collect(A), collect(B), collect(C))

ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,9 +165,15 @@ function MatrixAlgebraKit._mul_herm!(C::StridedCuMatrix{T}, A::StridedCuMatrix{T
165165
return C
166166
end
167167

168-
# TODO: intersect doesn't work on GPU
168+
# TODO: intersect/union don't work on GPU
169169
MatrixAlgebraKit._ind_intersect(A::CuVector{Int}, B::CuVector{Int}) =
170170
MatrixAlgebraKit._ind_intersect(collect(A), collect(B))
171+
MatrixAlgebraKit._ind_union(A::AbstractVector{<:Integer}, B::CuVector{Int}) =
172+
MatrixAlgebraKit._ind_union(A, collect(B))
173+
MatrixAlgebraKit._ind_union(A::CuVector{Int}, B::AbstractVector{<:Integer}) =
174+
MatrixAlgebraKit._ind_union(collect(A), B)
175+
MatrixAlgebraKit._ind_union(A::CuVector{Int}, B::CuVector{Int}) =
176+
MatrixAlgebraKit._ind_union(collect(A), collect(B))
171177

172178
function _sylvester(A::AnyCuMatrix, B::AnyCuMatrix, C::AnyCuMatrix)
173179
# https://github.com/JuliaGPU/CUDA.jl/issues/3021

src/MatrixAlgebraKit.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ export notrunc, truncrank, trunctol, truncerror, truncfilter
5757
eval(
5858
Expr(
5959
:public, :TruncationByOrder, :TruncationByFilter, :TruncationByValue,
60-
:TruncationByError, :TruncationIntersection, :truncate
60+
:TruncationByError, :TruncationIntersection, :TruncationUnion, :truncate
6161
)
6262
)
6363
eval(

src/implementations/truncation.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,32 @@ _ind_intersect(A::AbstractUnitRange{Int}, B::AbstractVector{Int}) = _ind_interse
137137
# when all else fails, call intersect
138138
_ind_intersect(A, B) = intersect(A, B)
139139

140+
function findtruncated(values::AbstractVector, strategy::TruncationUnion)
141+
length(strategy.components) == 0 && return Base.OneTo(0)
142+
length(strategy.components) == 1 && return findtruncated(values, only(strategy.components))
143+
ind1 = findtruncated(values, strategy.components[1])
144+
ind2 = findtruncated(values, TruncationUnion(Base.tail(strategy.components)))
145+
return _ind_union(ind1, ind2)
146+
end
147+
function findtruncated_svd(values::AbstractVector, strategy::TruncationUnion)
148+
length(strategy.components) == 0 && return Base.OneTo(0)
149+
length(strategy.components) == 1 && return findtruncated_svd(values, only(strategy.components))
150+
ind1 = findtruncated_svd(values, strategy.components[1])
151+
ind2 = findtruncated_svd(values, TruncationUnion(Base.tail(strategy.components)))
152+
return _ind_union(ind1, ind2)
153+
end
154+
155+
_ind_union(A::AbstractVector{Bool}, B::AbstractVector{Bool}) = A .| B
156+
function _ind_union(A::AbstractVector{Bool}, B::AbstractVector)
157+
result = copy(A)
158+
result[B] .= true
159+
return result
160+
end
161+
_ind_union(A::AbstractVector, B::AbstractVector{Bool}) = _ind_union(B, A)
162+
_ind_union(A::Base.OneTo, B::Base.OneTo) = Base.OneTo(max(length(A), length(B)))
163+
_ind_union(A::AbstractUnitRange, B::AbstractUnitRange) = union(A, B)
164+
_ind_union(A, B) = union(A, B)
165+
140166
# Truncation error
141167
# ----------------
142168
truncation_error(values::AbstractVector, ind) = truncation_error!(copy(values), ind)

src/interface/truncation.jl

Lines changed: 56 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
const docs_truncation_kwargs = """
2-
* `atol::Real` : Absolute tolerance for the truncation
3-
* `rtol::Real` : Relative tolerance for the truncation
4-
* `maxrank::Real` : Maximal rank for the truncation
5-
* `maxerror::Real` : Maximal truncation error.
6-
* `filter` : Custom filter to select truncated values.
2+
* `atol::Real` : Absolute tolerance for the truncation
3+
* `rtol::Real` : Relative tolerance for the truncation
4+
* `maxrank::Integer` : Maximal rank for the truncation
5+
* `minrank::Integer` : Minimal rank for the truncation
6+
* `maxerror::Real` : Maximal truncation error
7+
* `filter` : Custom filter to select truncated values
78
"""
89

910
const docs_truncation_strategies = """
@@ -28,16 +29,18 @@ Select a truncation strategy based on the provided keyword arguments.
2829
## Keyword arguments
2930
The following keyword arguments are all optional, and their default value (`nothing`)
3031
will be ignored. It is also allowed to combine multiple of these, in which case the kept
31-
values will consist of the intersection of the different truncated strategies.
32+
values will consist of the intersection of the different truncated strategies (except
33+
`minrank`, which uses union semantics to guarantee a lower bound on the number of kept values).
3234
3335
$docs_truncation_kwargs
3436
"""
3537
function TruncationStrategy(;
3638
atol::Union{Real, Nothing} = nothing,
3739
rtol::Union{Real, Nothing} = nothing,
38-
maxrank::Union{Real, Nothing} = nothing,
40+
maxrank::Union{Integer, Nothing} = nothing,
41+
minrank::Union{Integer, Nothing} = nothing,
3942
maxerror::Union{Real, Nothing} = nothing,
40-
filter = nothing
43+
filter = nothing,
4144
)
4245
strategy = notrunc()
4346

@@ -51,6 +54,14 @@ function TruncationStrategy(;
5154
isnothing(maxerror) || (strategy &= truncerror(; atol = maxerror))
5255
isnothing(filter) || (strategy &= truncfilter(filter))
5356

57+
# union constraint: guarantee a lower bound on number of kept values
58+
# special-case NoTruncation: keeping everything already satisfies any minrank
59+
if !isnothing(minrank) && !(strategy isa NoTruncation)
60+
strategy |= truncrank(minrank)
61+
elseif !isnothing(minrank)
62+
strategy = truncrank(minrank)
63+
end
64+
5465
return strategy
5566
end
5667

@@ -222,6 +233,43 @@ Base.:&(::NoTruncation, ::NoTruncation) = notrunc()
222233
Base.:&(::NoTruncation, trunc::TruncationIntersection) = trunc
223234
Base.:&(trunc::TruncationIntersection, ::NoTruncation) = trunc
224235

236+
"""
237+
TruncationUnion(trunc::TruncationStrategy, truncs::TruncationStrategy...)
238+
239+
Truncation strategy that composes multiple truncation strategies, keeping values that are
240+
present in any of them.
241+
"""
242+
struct TruncationUnion{T <: Tuple{Vararg{TruncationStrategy}}} <: TruncationStrategy
243+
components::T
244+
end
245+
function TruncationUnion(trunc::TruncationStrategy, truncs::TruncationStrategy...)
246+
return TruncationUnion((trunc, truncs...))
247+
end
248+
249+
function Base.:|(trunc1::TruncationStrategy, trunc2::TruncationStrategy)
250+
return TruncationUnion((trunc1, trunc2))
251+
end
252+
253+
# flatten components
254+
function Base.:|(trunc1::TruncationUnion, trunc2::TruncationUnion)
255+
return TruncationUnion((trunc1.components..., trunc2.components...))
256+
end
257+
function Base.:|(trunc1::TruncationUnion, trunc2::TruncationStrategy)
258+
return TruncationUnion((trunc1.components..., trunc2))
259+
end
260+
function Base.:|(trunc1::TruncationStrategy, trunc2::TruncationUnion)
261+
return TruncationUnion((trunc1, trunc2.components...))
262+
end
263+
264+
# NoTruncation is the absorbing element for | (union with "keep all" = keep all)
265+
Base.:|(::NoTruncation, ::TruncationStrategy) = notrunc()
266+
Base.:|(::TruncationStrategy, ::NoTruncation) = notrunc()
267+
Base.:|(::NoTruncation, ::NoTruncation) = notrunc()
268+
269+
# disambiguate
270+
Base.:|(::NoTruncation, ::TruncationUnion) = notrunc()
271+
Base.:|(::TruncationUnion, ::NoTruncation) = notrunc()
272+
225273
@doc """
226274
truncation_error(values, ind)
227275
Compute the truncation error as the 2-norm of the values that are not kept by `ind`.

test/testsuite/decompositions/svd.jl

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,29 @@ function test_svd_trunc(
212212
@test diagview(S2) diagview(S)[1:2]
213213
end
214214
end
215+
@testset "mix minrank and tol" begin
216+
m4 = 4
217+
U = instantiate_unitary(T, A, m4)
218+
Sdiag = similar(A, real(eltype(T)), m4)
219+
copyto!(Sdiag, [0.9, 0.3, 0.1, 0.01])
220+
S = Diagonal(Sdiag)
221+
Vᴴ = instantiate_unitary(T, A, m4)
222+
A = U * S * Vᴴ
223+
for trunc_fun in (
224+
(rtol, minrank) -> (; rtol, minrank),
225+
(rtol, minrank) -> trunctol(; rtol) | truncrank(minrank),
226+
)
227+
# trunctol(rtol=0.5) keeps 1 value, truncrank(3) keeps 3, union keeps 3
228+
U1, S1, V1ᴴ, ϵ1 = svd_trunc(A; trunc = trunc_fun(0.5, 3))
229+
@test length(diagview(S1)) == 3
230+
@test diagview(S1) diagview(S)[1:3]
231+
232+
# trunctol(rtol=0.2) keeps 2 values, truncrank(1) keeps 1, union keeps 2
233+
U2, S2, V2ᴴ = svd_trunc_no_error(A; trunc = trunc_fun(0.2, 1))
234+
@test length(diagview(S2)) == 2
235+
@test diagview(S2) diagview(S)[1:2]
236+
end
237+
end
215238
@testset "specify truncation algorithm" begin
216239
atol = sqrt(eps(real(eltype(T))))
217240
m4 = 4
@@ -294,6 +317,29 @@ function test_svd_trunc_algs(
294317
@test collect(diagview(S2)) collect(diagview(S)[1:2])
295318
end
296319
end
320+
@testset "mix minrank and tol" begin
321+
m4 = 4
322+
U = instantiate_unitary(T, A, m4)
323+
Sdiag = similar(A, real(eltype(T)), m4)
324+
copyto!(Sdiag, real(eltype(T))[0.9, 0.3, 0.1, 0.01])
325+
S = Diagonal(Sdiag)
326+
Vᴴ = instantiate_unitary(T, A, m4)
327+
A = U * S * Vᴴ
328+
for trunc_fun in (
329+
(rtol, minrank) -> (; rtol, minrank),
330+
(rtol, minrank) -> trunctol(; rtol) | truncrank(minrank),
331+
)
332+
# trunctol(rtol=0.5) keeps 1 value, truncrank(3) keeps 3, union keeps 3
333+
U1, S1, V1ᴴ, ϵ1 = svd_trunc(A; trunc = trunc_fun(0.5, 3), alg)
334+
@test length(diagview(S1)) == 3
335+
@test collect(diagview(S1)) collect(diagview(S)[1:3])
336+
337+
# trunctol(rtol=0.2) keeps 2 values, truncrank(1) keeps 1, union keeps 2
338+
U2, S2, V2ᴴ, ϵ2 = svd_trunc(A; trunc = trunc_fun(0.2, 1), alg)
339+
@test length(diagview(S2)) == 2
340+
@test collect(diagview(S2)) collect(diagview(S)[1:2])
341+
end
342+
end
297343
@testset "specify truncation algorithm" begin
298344
atol = sqrt(eps(real(eltype(T))))
299345
m4 = 4

test/truncate.jl

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
using MatrixAlgebraKit
22
using Test
33
using TestExtras
4-
using MatrixAlgebraKit: NoTruncation, TruncationIntersection, TruncationByOrder,
5-
TruncationByValue, TruncationStrategy, findtruncated, findtruncated_svd
4+
using MatrixAlgebraKit: NoTruncation, TruncationIntersection, TruncationUnion,
5+
TruncationByOrder, TruncationByValue, TruncationStrategy, findtruncated, findtruncated_svd
66

77
@testset "truncate" begin
88
trunc = @constinferred TruncationStrategy()
@@ -65,4 +65,32 @@ using MatrixAlgebraKit: NoTruncation, TruncationIntersection, TruncationByOrder,
6565
@test issetequal(values[@constinferred(findtruncated(values, strategy))], values[2:5])
6666
vals_sorted = sort(values; by = abs, rev = true)
6767
@test vals_sorted[@constinferred(findtruncated_svd(vals_sorted, strategy))] == vals_sorted[1:4]
68+
69+
# TruncationUnion / minrank
70+
trunc = @constinferred TruncationStrategy(; minrank = 3)
71+
@test trunc isa TruncationByOrder
72+
@test trunc == truncrank(3)
73+
74+
trunc = @constinferred TruncationStrategy(; atol, minrank = 3)
75+
@test trunc isa TruncationUnion
76+
@test trunc == trunctol(; atol) | truncrank(3)
77+
78+
# | operator
79+
values2 = [1.0, 0.9, 0.5, 0.3, 0.01]
80+
# trunctol keeps 1:3 (above 0.4), truncrank(4) keeps 1:4, union keeps 1:4
81+
strategy = trunctol(; atol = 0.4) | truncrank(4)
82+
@test @constinferred(findtruncated_svd(values2, strategy)) == 1:4
83+
# trunctol keeps 1:3, truncrank(2) keeps 1:2, union keeps 1:3
84+
strategy = trunctol(; atol = 0.4) | truncrank(2)
85+
@test @constinferred(findtruncated_svd(values2, strategy)) == 1:3
86+
87+
# notrunc is absorbing for |
88+
@test (notrunc() | truncrank(3)) isa NoTruncation
89+
@test (truncrank(3) | notrunc()) isa NoTruncation
90+
91+
# TruncationUnion flattening
92+
union1 = truncrank(2) | trunctol(; atol = 0.4)
93+
union2 = union1 | truncrank(4)
94+
@test union2 isa TruncationUnion
95+
@test length(union2.components) == 3
6896
end

0 commit comments

Comments
 (0)