Skip to content

Commit 80bd0bc

Browse files
authored
Promote storagetypes (#370)
* add `storagetype` implementations for more different tensor types * remove braidingtensor storagetype specialization * introduce storagetype promotion system * use storagetype promotion system * bypass storagetype promotion system for braidingtensor * immediately use storagetype * add implementation for n-ary promotion * update Changelog * make Aqua happy
1 parent 91fa268 commit 80bd0bc

7 files changed

Lines changed: 108 additions & 22 deletions

File tree

docs/src/Changelog.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ When releasing a new version, move the "Unreleased" changes to a new version sec
2222

2323
### Added
2424

25+
- A more robust promotion system for `storagetype`s to better handle working with unions and other abstract tensor map types ([#370](https://github.com/QuantumKitHub/TensorKit.jl/pull/370)).
2526

2627
### Changed
2728

src/tensors/abstracttensor.jl

Lines changed: 89 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,30 @@ function InnerProductStyle(::Type{TT}) where {TT <: AbstractTensorMap}
3838
return InnerProductStyle(spacetype(TT))
3939
end
4040

41+
# storage types and promotion system
42+
# ----------------------------------
4143
@doc """
4244
storagetype(t::AbstractTensorMap) -> Type{A<:AbstractVector}
4345
storagetype(T::Type{<:AbstractTensorMap}) -> Type{A<:AbstractVector}
4446
4547
Return the type of vector that stores the data of a tensor.
48+
If this is not overloaded for a given tensor type, the default value of `storagetype(scalartype(t))` is returned.
49+
50+
See also [`similarstoragetype`](@ref).
4651
""" storagetype
52+
storagetype(t) = storagetype(typeof(t))
53+
function storagetype(::Type{T}) where {T <: AbstractTensorMap}
54+
if T isa Union
55+
# attempt to be slightly more specific by promoting unions
56+
Ma = storagetype(T.a)
57+
Mb = storagetype(T.b)
58+
return promote_storagetype(Ma, Mb)
59+
else
60+
# fallback definition by using scalartype
61+
return similarstoragetype(scalartype(T))
62+
end
63+
end
64+
storagetype(T::Type) = throw(MethodError(storagetype, T))
4765

4866
# storage type determination and promotion - hooks for specializing
4967
# the default implementation tries to leverarge inference and `similar`
@@ -69,6 +87,8 @@ appropriate storage types. Additionally this registers the default storage type
6987
used in constructor-like calls, and therefore will return the exact same type for a `DenseVector`
7088
input. The latter is used in `similar`-like calls, and therefore will return the type of calling
7189
`similar` on the given `DenseVector`, which need not coincide with the original type.
90+
91+
See also [`promote_storagetype`](@ref).
7292
""" similarstoragetype
7393

7494
# implement in type domain
@@ -102,6 +122,74 @@ similarstoragetype(::Type{D}, ::Type{T}) where {D <: AbstractDict{<:Sector, <:Ab
102122
# default storage type for numbers
103123
similarstoragetype(::Type{T}) where {T <: Number} = Vector{T}
104124

125+
@doc """
126+
promote_storagetype([T], A, B, C...)
127+
promote_storagetype([T], TA, TB, TC...)
128+
129+
Determine an appropriate storage type for the combination of tensors `A` and `B`, or tensors of type `TA` and `TB`.
130+
Optionally, a scalartype `T` for the destination can be supplied that might differ from the inputs.
131+
""" promote_storagetype
132+
133+
@inline promote_storagetype(A::AbstractTensorMap, B::AbstractTensorMap, Cs::AbstractTensorMap...) =
134+
promote_storagetype(storagetype(A), storagetype(B), map(storagetype, Cs)...)
135+
@inline promote_storagetype(::Type{T}, A::AbstractTensorMap, B::AbstractTensorMap, Cs::AbstractTensorMap...) where {T <: Number} =
136+
promote_storagetype(similarstoragetype(A, T), similarstoragetype(B, T), map(Base.Fix2(similarstoragetype, T), Cs)...)
137+
138+
@inline function promote_storagetype(
139+
::Type{A}, ::Type{B}, Cs::Type{<:AbstractTensorMap}...
140+
) where {A <: AbstractTensorMap, B <: AbstractTensorMap}
141+
return promote_storagetype(storagetype(A), storagetype(B), map(storagetype, Cs)...)
142+
end
143+
@inline function promote_storagetype(
144+
::Type{T}, ::Type{A}, ::Type{B}, Cs::Type{<:AbstractTensorMap}...
145+
) where {T <: Number, A <: AbstractTensorMap, B <: AbstractTensorMap}
146+
return promote_storagetype(similarstoragetype(A, T), similarstoragetype(B, T), map(Base.Fix2(similarstoragetype, T), Cs)...)
147+
end
148+
149+
# promotion system in the same spirit as base/promotion.jl
150+
promote_storagetype(::Type{Base.Bottom}, ::Type{Base.Bottom}) = Base.Bottom
151+
promote_storagetype(::Type{T}, ::Type{T}) where {T} = T
152+
promote_storagetype(::Type{T}, ::Type{Base.Bottom}) where {T} = T
153+
promote_storagetype(::Type{Base.Bottom}, ::Type{T}) where {T} = T
154+
155+
function promote_storagetype(::Type{T}, ::Type{S}) where {T, S}
156+
@inline
157+
# Try promote_storage_rule in both orders. Typically only one is defined,
158+
# and there is a fallback returning Bottom below, so the common case is
159+
# promote_storagetype(T, S) =>
160+
# promote_storage_result(T, S, result, Bottom) =>
161+
# typejoin(result, Bottom) => result
162+
return promote_storage_result(T, S, promote_storage_rule(T, S), promote_storage_rule(S, T))
163+
end
164+
165+
@inline promote_storagetype(T, S, U) = promote_storagetype(promote_storagetype(T, S), U)
166+
@inline promote_storagetype(T, S, U, V...) = promote_storagetype(promote_storagetype(T, S), U, V...)
167+
168+
@doc """
169+
promote_storage_rule(type1, type2)
170+
171+
Specifies what type should be used by [`promote_storagetype`](@ref) when given values of types `type1` and
172+
`type2`. This function should not be called directly, but should have definitions added to
173+
it for new types as appropriate.
174+
""" promote_storage_rule
175+
176+
promote_storage_rule(::Type, ::Type) = Base.Bottom
177+
# Define some methods to avoid needing to enumerate unrelated possibilities when presented
178+
# with Type{<:T}, and return a value in general accordance with the result given by promote_type
179+
promote_storage_rule(::Type{Base.Bottom}, slurp...) = Base.Bottom
180+
promote_storage_rule(::Type{Base.Bottom}, ::Type{Base.Bottom}, slurp...) = Base.Bottom # not strictly necessary, since the next method would match unambiguously anyways
181+
promote_storage_rule(::Type{Base.Bottom}, ::Type{T}, slurp...) where {T} = T
182+
promote_storage_rule(::Type{T}, ::Type{Base.Bottom}, slurp...) where {T} = T
183+
184+
promote_storage_result(::Type, ::Type, ::Type{T}, ::Type{S}) where {T, S} = (@inline; promote_storagetype(T, S))
185+
# If no promote_storage_rule is defined, both directions give Bottom => error
186+
promote_storage_result(T::Type, S::Type, ::Type{Base.Bottom}, ::Type{Base.Bottom}) =
187+
throw(ArgumentError("No promotion rule defined for storagetype `$T` and `$S`"))
188+
189+
# promotion rules for common vector types
190+
promote_storage_rule(::Type{T}, ::Type{S}) where {T <: DenseVector, S <: DenseVector} =
191+
T === S ? T : throw(ArgumentError("No promotion rule defined for storagetype `$T` and `$S`"))
192+
105193
# tensor characteristics: space and index information
106194
#-----------------------------------------------------
107195
"""
@@ -224,8 +312,7 @@ end
224312
# tensor characteristics: work on instances and pass to type
225313
#------------------------------------------------------------
226314
InnerProductStyle(t::AbstractTensorMap) = InnerProductStyle(typeof(t))
227-
storagetype(t) = storagetype(typeof(t))
228-
storagetype(T::Type) = throw(MethodError(storagetype, T))
315+
229316
blocktype(t::AbstractTensorMap) = blocktype(typeof(t))
230317

231318
numout(t::AbstractTensorMap) = numout(typeof(t))

src/tensors/braidingtensor.jl

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,17 @@ end
5757

5858
space(b::BraidingTensor) = b.adjoint ? b.V1 b.V2 b.V2 b.V1 : b.V2 b.V1 b.V1 b.V2
5959

60-
# TODO: this will probably give issues with GPUs, so we should try to avoid
61-
# calling this method alltogether
62-
storagetype(::Type{BraidingTensor{T, S}}) where {T, S} = Vector{T}
60+
# specializations to ignore the storagetype of BraidingTensor
61+
promote_storagetype(::Type{A}, ::Type{B}) where {A <: BraidingTensor, B <: AbstractTensorMap} = storagetype(B)
62+
promote_storagetype(::Type{A}, ::Type{B}) where {A <: AbstractTensorMap, B <: BraidingTensor} = storagetype(A)
63+
promote_storagetype(::Type{A}, ::Type{B}) where {A <: BraidingTensor, B <: BraidingTensor} = storagetype(A)
64+
65+
promote_storagetype(::Type{T}, ::Type{A}, ::Type{B}) where {T <: Number, A <: BraidingTensor, B <: AbstractTensorMap} =
66+
similarstoragetype(B, T)
67+
promote_storagetype(::Type{T}, ::Type{A}, ::Type{B}) where {T <: Number, A <: AbstractTensorMap, B <: BraidingTensor} =
68+
similarstoragetype(A, T)
69+
promote_storagetype(::Type{T}, ::Type{A}, ::Type{B}) where {T <: Number, A <: BraidingTensor, B <: BraidingTensor} =
70+
similarstoragetype(A, T)
6371

6472
function Base.getindex(b::BraidingTensor)
6573
sectortype(b) === Trivial || throw(SectorMismatch())

src/tensors/diagonal.jl

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -272,9 +272,8 @@ function TO.tensorcontract_type(
272272
::Index2Tuple{1, 1}
273273
)
274274
S = check_spacetype(A, B)
275-
TC′ = promote_permute(TC, sectortype(S))
276-
M = promote_storagetype(similarstoragetype(A, TC′), similarstoragetype(B, TC′))
277-
return DiagonalTensorMap{TC, S, M}
275+
M = promote_storagetype(promote_permute(TC, sectortype(S)), A, B)
276+
return DiagonalTensorMap{scalartype(M), S, M}
278277
end
279278

280279
function TO.tensoralloc(
@@ -303,9 +302,8 @@ end
303302

304303
function compose_dest(A::DiagonalTensorMap, B::DiagonalTensorMap)
305304
S = check_spacetype(A, B)
306-
TC = TO.promote_contract(scalartype(A), scalartype(B), One)
307-
M = promote_storagetype(similarstoragetype(A, TC), similarstoragetype(B, TC))
308-
TTC = DiagonalTensorMap{TC, S, M}
305+
M = promote_storagetype(TO.promote_contract(scalartype(A), scalartype(B), One), A, B)
306+
TTC = DiagonalTensorMap{scalartype(M), S, M}
309307
structure = codomain(A) domain(B)
310308
return TO.tensoralloc(TTC, structure, Val(false))
311309
end

src/tensors/linalg.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,7 @@ LinearAlgebra.normalize(t::AbstractTensorMap, p::Real = 2) = scale(t, inv(norm(t
2323
# permutations, which might require complex scalartypes even if the inputs are real.
2424
function compose_dest(A::AbstractTensorMap, B::AbstractTensorMap)
2525
S = check_spacetype(A, B)
26-
TC = TO.promote_contract(scalartype(A), scalartype(B), One)
27-
M = promote_storagetype(similarstoragetype(A, TC), similarstoragetype(B, TC))
26+
M = promote_storagetype(TO.promote_contract(scalartype(A), scalartype(B), One), A, B)
2827
TTC = tensormaptype(S, numout(A), numin(B), M)
2928
structure = codomain(A) domain(B)
3029
return TO.tensoralloc(TTC, structure, Val(false))

src/tensors/tensor.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -560,7 +560,6 @@ end
560560
function Base.promote_rule(
561561
::Type{<:TT₁}, ::Type{<:TT₂}
562562
) where {S, N₁, N₂, TT₁ <: TensorMap{<:Any, S, N₁, N₂}, TT₂ <: TensorMap{<:Any, S, N₁, N₂}}
563-
T = VectorInterface.promote_add(scalartype(TT₁), scalartype(TT₂))
564-
A = promote_storagetype(similarstoragetype(TT₁, T), similarstoragetype(TT₂, T))
563+
A = promote_storagetype(VectorInterface.promote_add(scalartype(TT₁), scalartype(TT₂)), TT₁, TT₂)
565564
return tensormaptype(S, N₁, N₂, A)
566565
end

src/tensors/tensoroperations.jl

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -153,16 +153,10 @@ function TO.tensorcontract_type(
153153
::Index2Tuple{N₁, N₂}
154154
) where {N₁, N₂}
155155
S = check_spacetype(A, B)
156-
TC′ = promote_permute(TC, sectortype(S))
157-
M = promote_storagetype(similarstoragetype(A, TC′), similarstoragetype(B, TC′))
156+
M = promote_storagetype(promote_permute(TC, sectortype(S)), A, B)
158157
return tensormaptype(S, N₁, N₂, M)
159158
end
160159

161-
# TODO: handle actual promotion rule system
162-
function promote_storagetype(::Type{M₁}, ::Type{M₂}) where {M₁, M₂}
163-
return M₁ === M₂ ? M₁ : throw(ArgumentError("Cannot determine storage type for combining `$M₁` and `$M₂`"))
164-
end
165-
166160
function TO.tensorcontract_structure(
167161
A::AbstractTensorMap, pA::Index2Tuple, conjA::Bool,
168162
B::AbstractTensorMap, pB::Index2Tuple, conjB::Bool,

0 commit comments

Comments
 (0)