Skip to content

Commit 033f150

Browse files
authored
Define TensorAlgebra.matricize to return SparseArrays.SparseMatrixCSC (#75)
1 parent 78dece8 commit 033f150

File tree

10 files changed

+124
-2
lines changed

10 files changed

+124
-2
lines changed

Project.toml

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "SparseArraysBase"
22
uuid = "0d5efcca-f356-4864-8770-e1ed8d78f208"
33
authors = ["ITensor developers <support@itensor.org> and contributors"]
4-
version = "0.7.7"
4+
version = "0.7.8"
55

66
[deps]
77
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
@@ -14,8 +14,15 @@ GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
1414
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1515
MapBroadcast = "ebd9b9da-f48d-417c-9660-449667d60261"
1616
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
17+
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1718
TypeParameterAccessors = "7e5a90cf-f82e-492e-a09b-e3e26432c138"
1819

20+
[weakdeps]
21+
TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
22+
23+
[extensions]
24+
SparseArraysBaseTensorAlgebraExt = ["TensorAlgebra", "SparseArrays"]
25+
1926
[compat]
2027
Accessors = "0.1.41"
2128
Adapt = "4.3.0"
@@ -29,11 +36,16 @@ LinearAlgebra = "1.10"
2936
MapBroadcast = "0.1.5"
3037
Random = "1.10.0"
3138
SafeTestsets = "0.1"
39+
SparseArrays = "1.10"
3240
Suppressor = "0.2"
41+
TensorAlgebra = "0.6.2"
3342
Test = "1.10"
3443
TypeParameterAccessors = "0.4.3"
3544
julia = "1.10"
3645

46+
[workspace]
47+
projects = ["benchmark", "dev", "docs", "examples", "test"]
48+
3749
[extras]
3850
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
3951
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"

docs/Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
44
Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
55
SparseArraysBase = "0d5efcca-f356-4864-8770-e1ed8d78f208"
66

7+
[sources]
8+
SparseArraysBase = {path = ".."}
9+
710
[compat]
811
Dictionaries = "0.4.4"
912
Documenter = "1.8.1"

examples/Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@ Dictionaries = "85a47980-9c8c-11e8-2b9f-f7ca1fa99fb4"
33
SparseArraysBase = "0d5efcca-f356-4864-8770-e1ed8d78f208"
44
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
55

6+
[sources]
7+
SparseArraysBase = {path = ".."}
8+
69
[compat]
710
Dictionaries = "0.4.4"
811
SparseArraysBase = "0.7.0"
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
module SparseArraysBaseTensorAlgebraExt
2+
3+
using SparseArrays: SparseMatrixCSC
4+
using SparseArraysBase: AnyAbstractSparseArray, AnyAbstractSparseMatrix, SparseArrayDOK
5+
using TensorAlgebra: TensorAlgebra, BlockedTrivialPermutation, BlockedTuple, FusionStyle,
6+
ReshapeFusion, matricize, unmatricize
7+
8+
struct SparseArrayFusion <: FusionStyle end
9+
TensorAlgebra.FusionStyle(::Type{<:AnyAbstractSparseArray}) = SparseArrayFusion()
10+
11+
function TensorAlgebra.matricize(
12+
style::SparseArrayFusion, a::AbstractArray, length_codomain::Val
13+
)
14+
m = matricize(ReshapeFusion(), a, length_codomain)
15+
return convert(SparseMatrixCSC, m)
16+
end
17+
function TensorAlgebra.unmatricize(
18+
style::SparseArrayFusion,
19+
m::AbstractMatrix,
20+
axes_codomain::Tuple{Vararg{AbstractUnitRange}},
21+
axes_domain::Tuple{Vararg{AbstractUnitRange}},
22+
)
23+
a = unmatricize(ReshapeFusion(), m, axes_codomain, axes_domain)
24+
# TODO: Use `similar_type(m)` instead of hardcoding to `SparseArrayDOK`?
25+
return convert(SparseArrayDOK, a)
26+
end
27+
28+
end

src/SparseArraysBase.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,5 +25,6 @@ include("wrappers.jl")
2525
include("abstractsparsearray.jl")
2626
include("sparsearraydok.jl")
2727
include("oneelementarray.jl")
28+
include("sparsearrays.jl")
2829

2930
end

src/abstractsparsearray.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ using Dictionaries: AbstractDictionary
22

33
abstract type AbstractSparseArray{T, N} <: AbstractArray{T, N} end
44

5+
Base.convert(T::Type{<:AbstractSparseArray}, a::AbstractArray) = a isa T ? a : T(a)
6+
57
using DerivableInterfaces: @array_aliases
68
# Define AbstractSparseVector, AnyAbstractSparseArray, etc.
79
@array_aliases AbstractSparseArray

src/sparsearrays.jl

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
using SparseArrays: SparseArrays, AbstractSparseMatrixCSC, SparseMatrixCSC, findnz
2+
3+
function eachstoredindex(m::AbstractSparseMatrixCSC)
4+
I, J, V = findnz(m)
5+
# TODO: This loses the compile time element type, is there a better lazy way?
6+
return Iterators.map(CartesianIndex, zip(I, J))
7+
end
8+
function eachstoredindex(a::Base.ReshapedArray{<:Any, <:Any, <:AbstractSparseMatrixCSC})
9+
return @interface SparseArrayInterface() eachstoredindex(a)
10+
end
11+
12+
function SparseArrays.SparseMatrixCSC{Tv, Ti}(m::AnyAbstractSparseMatrix) where {Tv, Ti}
13+
m′ = SparseMatrixCSC{Tv, Ti}(undef, size(m))
14+
for I in eachstoredindex(m)
15+
m′[I] = m[I]
16+
end
17+
return m′
18+
end
19+
20+
function SparseArrayDOK(a::Base.ReshapedArray{<:Any, <:Any, <:AbstractSparseMatrixCSC})
21+
return SparseArrayDOK{eltype(a), ndims(a)}(a)
22+
end
23+
function SparseArrayDOK{T}(
24+
a::Base.ReshapedArray{<:Any, <:Any, <:AbstractSparseMatrixCSC}
25+
) where {T}
26+
return SparseArrayDOK{T, ndims(a)}(a)
27+
end
28+
function SparseArrayDOK{T, N}(
29+
a::Base.ReshapedArray{<:Any, N, <:AbstractSparseMatrixCSC}
30+
) where {T, N}
31+
a′ = SparseArrayDOK{T, N}(undef, size(a))
32+
for I in eachstoredindex(a)
33+
a′[I] = a[I]
34+
end
35+
return a′
36+
end

test/Project.toml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,16 @@ JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
88
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
99
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1010
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
11+
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1112
SparseArraysBase = "0d5efcca-f356-4864-8770-e1ed8d78f208"
1213
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
1314
Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb"
15+
TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
1416
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1517

18+
[sources]
19+
SparseArraysBase = {path = ".."}
20+
1621
[compat]
1722
Adapt = "4.2.0"
1823
Aqua = "0.8.11"
@@ -23,7 +28,9 @@ JLArrays = "0.2.0, 0.3"
2328
LinearAlgebra = "<0.0.1, 1"
2429
Random = "<0.0.1, 1"
2530
SafeTestsets = "0.1.0"
31+
SparseArrays = "1.10"
2632
SparseArraysBase = "0.7.0"
2733
StableRNGs = "1.0.2"
2834
Suppressor = "0.2.8"
35+
TensorAlgebra = "0.6"
2936
Test = "<0.0.1, 1"

test/test_dense.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@ arrayts = (Array, JLArray)
3434
dev(elt[2, 4]), Dict([CartesianIndex(1, 2) => 1, CartesianIndex(3, 4) => 2]), (3, 4)
3535
)
3636
d = dense(s)
37-
@show typeof(d)
3837
@test d isa arrayt{elt, 2}
3938
@test d == dev(elt[0 2 0 0; 0 0 0 0; 0 0 0 4])
4039
end

test/test_tensoralgebraext.jl

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
using SparseArrays: SparseMatrixCSC, findnz, nnz
2+
using SparseArraysBase: SparseMatrixDOK, eachstoredindex, isstored, sparsezeros,
3+
storedlength
4+
using TensorAlgebra: contract, matricize
5+
using Test: @testset, @test
6+
7+
@testset "TensorAlgebraExt (eltype = $elt)" for elt in (Float32, ComplexF64)
8+
a = sparsezeros(elt, (2, 2, 2))
9+
a[1, 1, 1] = 1
10+
a[2, 1, 2] = 2
11+
12+
# matricize
13+
m = matricize(a, (1, 3), (2,))
14+
@test m isa SparseMatrixCSC{elt}
15+
@test nnz(m) == 2
16+
@test isstored(m, 1, 1)
17+
@test m[1, 1] elt(1)
18+
@test isstored(m, 4, 1)
19+
@test m[4, 1] elt(2)
20+
@test issetequal(eachstoredindex(m), [CartesianIndex(1, 1), CartesianIndex(4, 1)])
21+
for I in setdiff(CartesianIndices(m), [CartesianIndex(1, 1), CartesianIndex(4, 1)])
22+
@test m[I] zero(elt)
23+
end
24+
25+
# contract
26+
b, l = contract(a, ("i", "j", "k"), a, ("j", "k", "l"))
27+
@test b isa SparseMatrixDOK{elt}
28+
@test storedlength(b) == 1
29+
@test only(eachstoredindex(b)) == CartesianIndex(1, 1)
30+
@test b[1, 1] elt(1)
31+
end

0 commit comments

Comments
 (0)