Skip to content

Commit fd64bfe

Browse files
authored
Alter chain rules implementation to improve support for higher-order derivatives (#233)
* Fix rrule of `tensorscalar` to avoid inplace function and delete unnecessary `return` in rrule * remove TODO of tensorscalar * style: apply Runic formatting * fix: update promote_op to use TensorOperations functions for consistency
1 parent 34bae64 commit fd64bfe

File tree

1 file changed

+16
-14
lines changed

1 file changed

+16
-14
lines changed

ext/TensorOperationsChainRulesCoreExt.jl

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ trivtuple(N) = ntuple(identity, N)
1818
@non_differentiable TensorOperations.tensorcontract_structure(args...)
1919
@non_differentiable TensorOperations.tensorcontract_type(args...)
2020
@non_differentiable TensorOperations.tensoralloc_contract(args...)
21+
@non_differentiable TensorOperations.promote_contract(args...)
22+
@non_differentiable TensorOperations.promote_add(args...)
2123

2224
# Cannot free intermediate tensors when using AD
2325
# Thus we change the forward passes: `istemp=false` and `tensorfree!` is a no-op
@@ -38,11 +40,11 @@ function ChainRulesCore.rrule(
3840
return output, tensoralloc_pullback
3941
end
4042

41-
# TODO: possibly use the non-inplace functions, to avoid depending on Base.copy
4243
function ChainRulesCore.rrule(::typeof(tensorscalar), C)
44+
projectC = ProjectTo(C)
4345
function tensorscalar_pullback(Δc)
44-
ΔC = TensorOperations.tensoralloc(typeof(C), TensorOperations.tensorstructure(C))
45-
return NoTangent(), fill!(ΔC, unthunk(Δc))
46+
_Δc = unthunk(Δc)
47+
return NoTangent(), projectC(_Δc)
4648
end
4749
return tensorscalar(C), tensorscalar_pullback
4850
end
@@ -95,7 +97,7 @@ function _rrule_tensoradd!(C, A, pA, conjA, α, β, ba)
9597
ipA = invperm(linearize(pA))
9698
_dA = zerovector(A, VectorInterface.promote_add(ΔC, α))
9799
_dA = tensoradd!(_dA, ΔC, (ipA, ()), conjA, conjA ? α : conj(α), Zero(), ba...)
98-
return projectA(_dA)
100+
projectA(_dA)
99101
end
100102
= @thunk let
101103
_dα = tensorscalar(
@@ -105,7 +107,7 @@ function _rrule_tensoradd!(C, A, pA, conjA, α, β, ba)
105107
((), ()), One(), ba...
106108
)
107109
)
108-
return projectα(_dα)
110+
projectα(_dα)
109111
end
110112
= @thunk let
111113
# TODO: consider using `inner`
@@ -116,7 +118,7 @@ function _rrule_tensoradd!(C, A, pA, conjA, α, β, ba)
116118
((), ()), One(), ba...
117119
)
118120
)
119-
return projectβ(_dβ)
121+
projectβ(_dβ)
120122
end
121123
dba = map(_ -> NoTangent(), ba)
122124
return NoTangent(), dC, dA, NoTangent(), NoTangent(), dα, dβ, dba...
@@ -194,7 +196,7 @@ function _rrule_tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, ba)
194196
ipA,
195197
conjA ? α : conj(α), Zero(), ba...
196198
)
197-
return projectA(_dA)
199+
projectA(_dA)
198200
end
199201
dB = @thunk let
200202
ipB = (invperm(linearize(pB)), ())
@@ -208,7 +210,7 @@ function _rrule_tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, ba)
208210
ipB,
209211
conjB ? α : conj(α), Zero(), ba...
210212
)
211-
return projectB(_dB)
213+
projectB(_dB)
212214
end
213215
= @thunk let
214216
C_αβ = tensorcontract(A, pA, conjA, B, pB, conjB, pAB, One(), ba...)
@@ -220,7 +222,7 @@ function _rrule_tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, ba)
220222
((), ()), One(), ba...
221223
)
222224
)
223-
return projectα(_dα)
225+
projectα(_dα)
224226
end
225227
= @thunk let
226228
# TODO: consider using `inner`
@@ -231,7 +233,7 @@ function _rrule_tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, ba)
231233
((), ()), One(), ba...
232234
)
233235
)
234-
return projectβ(_dβ)
236+
projectβ(_dβ)
235237
end
236238
dba = map(_ -> NoTangent(), ba)
237239
return NoTangent(), dC,
@@ -283,7 +285,7 @@ function _rrule_tensortrace!(C, A, p, q, conjA, α, β, ba)
283285
dA = @thunk let
284286
ip = invperm((linearize(p)..., q[1]..., q[2]...))
285287
Es = map(q[1], q[2]) do i1, i2
286-
return one(
288+
one(
287289
TensorOperations.tensoralloc_add(
288290
scalartype(A), A, ((i1,), (i2,)), conjA
289291
)
@@ -297,7 +299,7 @@ function _rrule_tensortrace!(C, A, p, q, conjA, α, β, ba)
297299
(ip, ()),
298300
conjA ? α : conj(α), Zero(), ba...
299301
)
300-
return projectA(_dA)
302+
projectA(_dA)
301303
end
302304
= @thunk let
303305
C_αβ = tensortrace(A, p, q, false, One(), ba...)
@@ -309,7 +311,7 @@ function _rrule_tensortrace!(C, A, p, q, conjA, α, β, ba)
309311
((), ()), One(), ba...
310312
)
311313
)
312-
return projectα(_dα)
314+
projectα(_dα)
313315
end
314316
= @thunk let
315317
_dβ = tensorscalar(
@@ -319,7 +321,7 @@ function _rrule_tensortrace!(C, A, p, q, conjA, α, β, ba)
319321
((), ()), One(), ba...
320322
)
321323
)
322-
return projectβ(_dβ)
324+
projectβ(_dβ)
323325
end
324326
dba = map(_ -> NoTangent(), ba)
325327
return NoTangent(), dC, dA, NoTangent(), NoTangent(), NoTangent(), dα, dβ, dba...

0 commit comments

Comments
 (0)