@@ -4,72 +4,69 @@ Mooncake.@is_primitive(
44 DefaultCtx,
55 ReverseMode,
66 Tuple{
7- typeof (TO . tensorcontract !),
7+ typeof (TensorKit . blas_contract !),
88 AbstractTensorMap,
9- AbstractTensorMap, Index2Tuple, Bool,
10- AbstractTensorMap, Index2Tuple, Bool,
9+ AbstractTensorMap, Index2Tuple,
10+ AbstractTensorMap, Index2Tuple,
1111 Index2Tuple,
1212 Number, Number,
13- Vararg{ Any} ,
13+ Any, Any ,
1414 }
1515)
1616
1717function Mooncake. rrule!! (
18- :: CoDual{typeof(TO.tensorcontract !)} ,
18+ :: CoDual{typeof(TensorKit.blas_contract !)} ,
1919 C_ΔC:: CoDual{<:AbstractTensorMap} ,
20- A_ΔA:: CoDual{<:AbstractTensorMap} , pA_ΔpA:: CoDual{<:Index2Tuple} , conjA_ΔconjA :: CoDual{Bool} ,
21- B_ΔB:: CoDual{<:AbstractTensorMap} , pB_ΔpB:: CoDual{<:Index2Tuple} , conjB_ΔconjB :: CoDual{Bool} ,
20+ A_ΔA:: CoDual{<:AbstractTensorMap} , pA_ΔpA:: CoDual{<:Index2Tuple} ,
21+ B_ΔB:: CoDual{<:AbstractTensorMap} , pB_ΔpB:: CoDual{<:Index2Tuple} ,
2222 pAB_ΔpAB:: CoDual{<:Index2Tuple} ,
2323 α_Δα:: CoDual{<:Number} , β_Δβ:: CoDual{<:Number} ,
24- ba_Δba :: CoDual... ,
24+ backend_Δbackend :: CoDual , allocator_Δallocator :: CoDual
2525 )
2626 # prepare arguments
2727 (C, ΔC), (A, ΔA), (B, ΔB) = arrayify .((C_ΔC, A_ΔA, B_ΔB))
2828 pA, pB, pAB = primal .((pA_ΔpA, pB_ΔpB, pAB_ΔpAB))
29- conjA, conjB = primal .((conjA_ΔconjA, conjB_ΔconjB))
3029 α, β = primal .((α_Δα, β_Δβ))
31- ba = primal .(ba_Δba )
30+ backend, allocator = primal .((backend_Δbackend, allocator_Δallocator) )
3231
3332 # primal call
3433 C_cache = copy (C)
35- TO . tensorcontract ! (C, A, pA, conjA, B, pB, conjB, pAB, α, β, ba ... )
34+ TensorKit . blas_contract ! (C, A, pA, B, pB, pAB, α, β, backend, allocator )
3635
37- function tensorcontract_pullback (:: NoRData )
36+ function blas_contract_pullback (:: NoRData )
3837 copy! (C, C_cache)
3938
40- ΔCr = tensorcontract_pullback_ΔC! (ΔC, β)
41- ΔAr = tensorcontract_pullback_ΔA! (
42- ΔA, ΔC, A, pA, conjA, B, pB, conjB, pAB, α, ba...
39+ ΔAr = blas_contract_pullback_ΔA! (
40+ ΔA, ΔC, A, pA, B, pB, pAB, α, backend, allocator
4341 )
44- ΔBr = tensorcontract_pullback_ΔB ! (
45- ΔB, ΔC, A, pA, conjA, B, pB, conjB, pAB, α, ba ...
42+ ΔBr = blas_contract_pullback_ΔB ! (
43+ ΔB, ΔC, A, pA, B, pB, pAB, α, backend, allocator
4644 )
47- Δαr = tensorcontract_pullback_Δα (
48- ΔC, A, pA, conjA, B, pB, conjB, pAB, α, ba ...
45+ Δαr = blas_contract_pullback_Δα (
46+ ΔC, A, pA, B, pB, pAB, α, backend, allocator
4947 )
50- Δβr = tensorcontract_pullback_Δβ (ΔC, C, β)
48+ Δβr = blas_contract_pullback_Δβ (ΔC, C, β)
49+ ΔCr = blas_contract_pullback_ΔC! (ΔC, β)
5150
5251 return NoRData (), ΔCr,
53- ΔAr, NoRData (), NoRData (),
54- ΔBr, NoRData (), NoRData (),
52+ ΔAr, NoRData (),
53+ ΔBr, NoRData (),
5554 NoRData (),
5655 Δαr, Δβr,
57- map (ba_ -> NoRData (), ba) ...
56+ NoRData (), NoRData ()
5857 end
5958
60- return C_ΔC, tensorcontract_pullback
59+ return C_ΔC, blas_contract_pullback
6160end
6261
63- tensorcontract_pullback_ΔC ! (ΔC, β) = (scale! (ΔC, conj (β)); NoRData ())
62+ blas_contract_pullback_ΔC ! (ΔC, β) = (scale! (ΔC, conj (β)); NoRData ())
6463
65- function tensorcontract_pullback_ΔA ! (
66- ΔA, ΔC, A, pA, conjA, B, pB, conjB, pAB, α, ba ...
64+ function blas_contract_pullback_ΔA ! (
65+ ΔA, ΔC, A, pA, B, pB, pAB, α, backend, allocator
6766 )
6867 ipAB = invperm (linearize (pAB))
6968 pΔC = _repartition (ipAB, TO. numout (pA))
7069 ipA = _repartition (invperm (linearize (pA)), A)
71- conjΔC = conjA
72- conjB′ = conjA ? conjB : ! conjB
7370
7471 tB = twist (
7572 B,
@@ -81,24 +78,22 @@ function tensorcontract_pullback_ΔA!(
8178
8279 TO. tensorcontract! (
8380 ΔA,
84- ΔC, pΔC, conjΔC ,
85- tB, reverse (pB), conjB′ ,
81+ ΔC, pΔC, false ,
82+ tB, reverse (pB), true ,
8683 ipA,
87- conjA ? α : conj (α), Zero (),
88- ba ...
84+ conj (α), Zero (),
85+ backend, allocator
8986 )
9087
9188 return NoRData ()
9289end
9390
94- function tensorcontract_pullback_ΔB ! (
95- ΔB, ΔC, A, pA, conjA, B, pB, conjB, pAB, α, ba ...
91+ function blas_contract_pullback_ΔB ! (
92+ ΔB, ΔC, A, pA, B, pB, pAB, α, backend, allocator
9693 )
9794 ipAB = invperm (linearize (pAB))
9895 pΔC = _repartition (ipAB, TO. numout (pA))
9996 ipB = _repartition (invperm (linearize (pB)), B)
100- conjΔC = conjB
101- conjA′ = conjB ? conjA : ! conjA
10297
10398 tA = twist (
10499 A,
@@ -110,27 +105,27 @@ function tensorcontract_pullback_ΔB!(
110105
111106 TO. tensorcontract! (
112107 ΔB,
113- tA, reverse (pA), conjA′ ,
114- ΔC, pΔC, conjΔC ,
108+ tA, reverse (pA), true ,
109+ ΔC, pΔC, false ,
115110 ipB,
116- conjB ? α : conj (α), Zero (), ba ...
111+ conj (α), Zero (), backend, allocator
117112 )
118113
119114 return NoRData ()
120115end
121116
122- function tensorcontract_pullback_Δα (
123- ΔC, A, pA, conjA, B, pB, conjB, pAB, α, ba ...
117+ function blas_contract_pullback_Δα (
118+ ΔC, A, pA, B, pB, pAB, α, backend, allocator
124119 )
125120 Tdα = Mooncake. rdata_type (Mooncake. tangent_type (typeof (α)))
126121 Tdα === NoRData && return NoRData ()
127122
128- AB = TO. tensorcontract (A, pA, conjA , B, pB, conjB , pAB, One (), ba ... )
123+ AB = TO. tensorcontract (A, pA, false , B, pB, false , pAB, One (), backend, allocator )
129124 Δα = inner (AB, ΔC)
130125 return Mooncake. _rdata (Δα)
131126end
132127
133- function tensorcontract_pullback_Δβ (ΔC, C, β)
128+ function blas_contract_pullback_Δβ (ΔC, C, β)
134129 Tdβ = Mooncake. rdata_type (Mooncake. tangent_type (typeof (β)))
135130 Tdβ === NoRData && return NoRData ()
136131
0 commit comments