@@ -18,6 +18,8 @@ trivtuple(N) = ntuple(identity, N)
1818@non_differentiable TensorOperations. tensorcontract_structure (args... )
1919@non_differentiable TensorOperations. tensorcontract_type (args... )
2020@non_differentiable TensorOperations. tensoralloc_contract (args... )
21+ @non_differentiable TensorOperations. promote_contract (args... )
22+ @non_differentiable TensorOperations. promote_add (args... )
2123
2224# Cannot free intermediate tensors when using AD
2325# Thus we change the forward passes: `istemp=false` and `tensorfree!` is a no-op
@@ -38,11 +40,11 @@ function ChainRulesCore.rrule(
3840 return output, tensoralloc_pullback
3941end
4042
41- # TODO : possibly use the non-inplace functions, to avoid depending on Base.copy
4243function ChainRulesCore. rrule (:: typeof (tensorscalar), C)
44+ projectC = ProjectTo (C)
4345 function tensorscalar_pullback (Δc)
44- ΔC = TensorOperations . tensoralloc ( typeof (C), TensorOperations . tensorstructure (C) )
45- return NoTangent (), fill! (ΔC, unthunk (Δc) )
46+ _Δc = unthunk (Δc )
47+ return NoTangent (), projectC (_Δc )
4648 end
4749 return tensorscalar (C), tensorscalar_pullback
4850end
@@ -95,7 +97,7 @@ function _rrule_tensoradd!(C, A, pA, conjA, α, β, ba)
9597 ipA = invperm (linearize (pA))
9698 _dA = zerovector (A, VectorInterface. promote_add (ΔC, α))
9799 _dA = tensoradd! (_dA, ΔC, (ipA, ()), conjA, conjA ? α : conj (α), Zero (), ba... )
98- return projectA (_dA)
100+ projectA (_dA)
99101 end
100102 dα = @thunk let
101103 _dα = tensorscalar (
@@ -105,7 +107,7 @@ function _rrule_tensoradd!(C, A, pA, conjA, α, β, ba)
105107 ((), ()), One (), ba...
106108 )
107109 )
108- return projectα (_dα)
110+ projectα (_dα)
109111 end
110112 dβ = @thunk let
111113 # TODO : consider using `inner`
@@ -116,7 +118,7 @@ function _rrule_tensoradd!(C, A, pA, conjA, α, β, ba)
116118 ((), ()), One (), ba...
117119 )
118120 )
119- return projectβ (_dβ)
121+ projectβ (_dβ)
120122 end
121123 dba = map (_ -> NoTangent (), ba)
122124 return NoTangent (), dC, dA, NoTangent (), NoTangent (), dα, dβ, dba...
@@ -194,7 +196,7 @@ function _rrule_tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, ba)
194196 ipA,
195197 conjA ? α : conj (α), Zero (), ba...
196198 )
197- return projectA (_dA)
199+ projectA (_dA)
198200 end
199201 dB = @thunk let
200202 ipB = (invperm (linearize (pB)), ())
@@ -208,7 +210,7 @@ function _rrule_tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, ba)
208210 ipB,
209211 conjB ? α : conj (α), Zero (), ba...
210212 )
211- return projectB (_dB)
213+ projectB (_dB)
212214 end
213215 dα = @thunk let
214216 C_αβ = tensorcontract (A, pA, conjA, B, pB, conjB, pAB, One (), ba... )
@@ -220,7 +222,7 @@ function _rrule_tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, ba)
220222 ((), ()), One (), ba...
221223 )
222224 )
223- return projectα (_dα)
225+ projectα (_dα)
224226 end
225227 dβ = @thunk let
226228 # TODO : consider using `inner`
@@ -231,7 +233,7 @@ function _rrule_tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, ba)
231233 ((), ()), One (), ba...
232234 )
233235 )
234- return projectβ (_dβ)
236+ projectβ (_dβ)
235237 end
236238 dba = map (_ -> NoTangent (), ba)
237239 return NoTangent (), dC,
@@ -283,7 +285,7 @@ function _rrule_tensortrace!(C, A, p, q, conjA, α, β, ba)
283285 dA = @thunk let
284286 ip = invperm ((linearize (p)... , q[1 ]. .. , q[2 ]. .. ))
285287 Es = map (q[1 ], q[2 ]) do i1, i2
286- return one (
288+ one (
287289 TensorOperations. tensoralloc_add (
288290 scalartype (A), A, ((i1,), (i2,)), conjA
289291 )
@@ -297,7 +299,7 @@ function _rrule_tensortrace!(C, A, p, q, conjA, α, β, ba)
297299 (ip, ()),
298300 conjA ? α : conj (α), Zero (), ba...
299301 )
300- return projectA (_dA)
302+ projectA (_dA)
301303 end
302304 dα = @thunk let
303305 C_αβ = tensortrace (A, p, q, false , One (), ba... )
@@ -309,7 +311,7 @@ function _rrule_tensortrace!(C, A, p, q, conjA, α, β, ba)
309311 ((), ()), One (), ba...
310312 )
311313 )
312- return projectα (_dα)
314+ projectα (_dα)
313315 end
314316 dβ = @thunk let
315317 _dβ = tensorscalar (
@@ -319,7 +321,7 @@ function _rrule_tensortrace!(C, A, p, q, conjA, α, β, ba)
319321 ((), ()), One (), ba...
320322 )
321323 )
322- return projectβ (_dβ)
324+ projectβ (_dβ)
323325 end
324326 dba = map (_ -> NoTangent (), ba)
325327 return NoTangent (), dC, dA, NoTangent (), NoTangent (), NoTangent (), dα, dβ, dba...
0 commit comments