Skip to content

Commit 130f031

Browse files
committed
rewrite rule tensorcontract in terms of blas_contract!
1 parent 13a143c commit 130f031

2 files changed

Lines changed: 51 additions & 57 deletions

File tree

ext/TensorKitMooncakeExt/tensoroperations.jl

Lines changed: 39 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -4,72 +4,69 @@ Mooncake.@is_primitive(
44
DefaultCtx,
55
ReverseMode,
66
Tuple{
7-
typeof(TO.tensorcontract!),
7+
typeof(TensorKit.blas_contract!),
88
AbstractTensorMap,
9-
AbstractTensorMap, Index2Tuple, Bool,
10-
AbstractTensorMap, Index2Tuple, Bool,
9+
AbstractTensorMap, Index2Tuple,
10+
AbstractTensorMap, Index2Tuple,
1111
Index2Tuple,
1212
Number, Number,
13-
Vararg{Any},
13+
Any, Any,
1414
}
1515
)
1616

1717
function Mooncake.rrule!!(
18-
::CoDual{typeof(TO.tensorcontract!)},
18+
::CoDual{typeof(TensorKit.blas_contract!)},
1919
C_ΔC::CoDual{<:AbstractTensorMap},
20-
A_ΔA::CoDual{<:AbstractTensorMap}, pA_ΔpA::CoDual{<:Index2Tuple}, conjA_ΔconjA::CoDual{Bool},
21-
B_ΔB::CoDual{<:AbstractTensorMap}, pB_ΔpB::CoDual{<:Index2Tuple}, conjB_ΔconjB::CoDual{Bool},
20+
A_ΔA::CoDual{<:AbstractTensorMap}, pA_ΔpA::CoDual{<:Index2Tuple},
21+
B_ΔB::CoDual{<:AbstractTensorMap}, pB_ΔpB::CoDual{<:Index2Tuple},
2222
pAB_ΔpAB::CoDual{<:Index2Tuple},
2323
α_Δα::CoDual{<:Number}, β_Δβ::CoDual{<:Number},
24-
ba_Δba::CoDual...,
24+
backend_Δbackend::CoDual, allocator_Δallocator::CoDual
2525
)
2626
# prepare arguments
2727
(C, ΔC), (A, ΔA), (B, ΔB) = arrayify.((C_ΔC, A_ΔA, B_ΔB))
2828
pA, pB, pAB = primal.((pA_ΔpA, pB_ΔpB, pAB_ΔpAB))
29-
conjA, conjB = primal.((conjA_ΔconjA, conjB_ΔconjB))
3029
α, β = primal.((α_Δα, β_Δβ))
31-
ba = primal.(ba_Δba)
30+
backend, allocator = primal.((backend_Δbackend, allocator_Δallocator))
3231

3332
# primal call
3433
C_cache = copy(C)
35-
TO.tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, ba...)
34+
TensorKit.blas_contract!(C, A, pA, B, pB, pAB, α, β, backend, allocator)
3635

37-
function tensorcontract_pullback(::NoRData)
36+
function blas_contract_pullback(::NoRData)
3837
copy!(C, C_cache)
3938

40-
ΔCr = tensorcontract_pullback_ΔC!(ΔC, β)
41-
ΔAr = tensorcontract_pullback_ΔA!(
42-
ΔA, ΔC, A, pA, conjA, B, pB, conjB, pAB, α, ba...
39+
ΔAr = blas_contract_pullback_ΔA!(
40+
ΔA, ΔC, A, pA, B, pB, pAB, α, backend, allocator
4341
)
44-
ΔBr = tensorcontract_pullback_ΔB!(
45-
ΔB, ΔC, A, pA, conjA, B, pB, conjB, pAB, α, ba...
42+
ΔBr = blas_contract_pullback_ΔB!(
43+
ΔB, ΔC, A, pA, B, pB, pAB, α, backend, allocator
4644
)
47-
Δαr = tensorcontract_pullback_Δα(
48-
ΔC, A, pA, conjA, B, pB, conjB, pAB, α, ba...
45+
Δαr = blas_contract_pullback_Δα(
46+
ΔC, A, pA, B, pB, pAB, α, backend, allocator
4947
)
50-
Δβr = tensorcontract_pullback_Δβ(ΔC, C, β)
48+
Δβr = blas_contract_pullback_Δβ(ΔC, C, β)
49+
ΔCr = blas_contract_pullback_ΔC!(ΔC, β)
5150

5251
return NoRData(), ΔCr,
53-
ΔAr, NoRData(), NoRData(),
54-
ΔBr, NoRData(), NoRData(),
52+
ΔAr, NoRData(),
53+
ΔBr, NoRData(),
5554
NoRData(),
5655
Δαr, Δβr,
57-
map(ba_ -> NoRData(), ba)...
56+
NoRData(), NoRData()
5857
end
5958

60-
return C_ΔC, tensorcontract_pullback
59+
return C_ΔC, blas_contract_pullback
6160
end
6261

63-
tensorcontract_pullback_ΔC!(ΔC, β) = (scale!(ΔC, conj(β)); NoRData())
62+
blas_contract_pullback_ΔC!(ΔC, β) = (scale!(ΔC, conj(β)); NoRData())
6463

65-
function tensorcontract_pullback_ΔA!(
66-
ΔA, ΔC, A, pA, conjA, B, pB, conjB, pAB, α, ba...
64+
function blas_contract_pullback_ΔA!(
65+
ΔA, ΔC, A, pA, B, pB, pAB, α, backend, allocator
6766
)
6867
ipAB = invperm(linearize(pAB))
6968
pΔC = _repartition(ipAB, TO.numout(pA))
7069
ipA = _repartition(invperm(linearize(pA)), A)
71-
conjΔC = conjA
72-
conjB′ = conjA ? conjB : !conjB
7370

7471
tB = twist(
7572
B,
@@ -81,24 +78,22 @@ function tensorcontract_pullback_ΔA!(
8178

8279
TO.tensorcontract!(
8380
ΔA,
84-
ΔC, pΔC, conjΔC,
85-
tB, reverse(pB), conjB′,
81+
ΔC, pΔC, false,
82+
tB, reverse(pB), true,
8683
ipA,
87-
conjA ? α : conj(α), Zero(),
88-
ba...
84+
conj(α), Zero(),
85+
backend, allocator
8986
)
9087

9188
return NoRData()
9289
end
9390

94-
function tensorcontract_pullback_ΔB!(
95-
ΔB, ΔC, A, pA, conjA, B, pB, conjB, pAB, α, ba...
91+
function blas_contract_pullback_ΔB!(
92+
ΔB, ΔC, A, pA, B, pB, pAB, α, backend, allocator
9693
)
9794
ipAB = invperm(linearize(pAB))
9895
pΔC = _repartition(ipAB, TO.numout(pA))
9996
ipB = _repartition(invperm(linearize(pB)), B)
100-
conjΔC = conjB
101-
conjA′ = conjB ? conjA : !conjA
10297

10398
tA = twist(
10499
A,
@@ -110,27 +105,27 @@ function tensorcontract_pullback_ΔB!(
110105

111106
TO.tensorcontract!(
112107
ΔB,
113-
tA, reverse(pA), conjA′,
114-
ΔC, pΔC, conjΔC,
108+
tA, reverse(pA), true,
109+
ΔC, pΔC, false,
115110
ipB,
116-
conjB ? α : conj(α), Zero(), ba...
111+
conj(α), Zero(), backend, allocator
117112
)
118113

119114
return NoRData()
120115
end
121116

122-
function tensorcontract_pullback_Δα(
123-
ΔC, A, pA, conjA, B, pB, conjB, pAB, α, ba...
117+
function blas_contract_pullback_Δα(
118+
ΔC, A, pA, B, pB, pAB, α, backend, allocator
124119
)
125120
Tdα = Mooncake.rdata_type(Mooncake.tangent_type(typeof(α)))
126121
Tdα === NoRData && return NoRData()
127122

128-
AB = TO.tensorcontract(A, pA, conjA, B, pB, conjB, pAB, One(), ba...)
123+
AB = TO.tensorcontract(A, pA, false, B, pB, false, pAB, One(), backend, allocator)
129124
Δα = inner(AB, ΔC)
130125
return Mooncake._rdata(Δα)
131126
end
132127

133-
function tensorcontract_pullback_Δβ(ΔC, C, β)
128+
function blas_contract_pullback_Δβ(ΔC, C, β)
134129
Tdβ = Mooncake.rdata_type(Mooncake.tangent_type(typeof(β)))
135130
Tdβ === NoRData && return NoRData()
136131

test/autodiff/mooncake.jl

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -231,20 +231,19 @@ for V in spacelist
231231
β = randn(T)
232232
V2_conj = prod(conj, V2; init = one(V[1]))
233233

234-
for conjA in (false, true), conjB in (false, true)
235-
A = randn(T, permute(V1 (conjA ? V2_conj : V2), ipA))
236-
B = randn(T, permute((conjB ? V2_conj : V2) V3, ipB))
237-
C = randn!(
238-
TensorOperations.tensoralloc_contract(
239-
T, A, pA, conjA, B, pB, conjB, pAB, Val(false)
240-
)
241-
)
242-
Mooncake.TestUtils.test_rule(
243-
rng, tensorcontract!, C, A, pA, conjA, B, pB, conjB, pAB, α, β;
244-
atol, rtol, mode
234+
A = randn(T, permute(V1 V2, ipA))
235+
B = randn(T, permute(V2 V3, ipB))
236+
C = randn!(
237+
TensorOperations.tensoralloc_contract(
238+
T, A, pA, false, B, pB, false, pAB, Val(false)
245239
)
246-
247-
end
240+
)
241+
Mooncake.TestUtils.test_rule(
242+
rng, TensorKit.blas_contract!,
243+
C, A, pA, B, pB, pAB, α, β,
244+
TensorOperations.DefaultBackend(), TensorOperations.DefaultAllocator();
245+
atol, rtol, mode
246+
)
248247
end
249248
end
250249

0 commit comments

Comments
 (0)