diff --git a/ext/TensorOperationsChainRulesCoreExt.jl b/ext/TensorOperationsChainRulesCoreExt.jl index 3d7fbd4..a838917 100644 --- a/ext/TensorOperationsChainRulesCoreExt.jl +++ b/ext/TensorOperationsChainRulesCoreExt.jl @@ -18,6 +18,8 @@ trivtuple(N) = ntuple(identity, N) @non_differentiable TensorOperations.tensorcontract_structure(args...) @non_differentiable TensorOperations.tensorcontract_type(args...) @non_differentiable TensorOperations.tensoralloc_contract(args...) +@non_differentiable TensorOperations.promote_contract(args...) +@non_differentiable TensorOperations.promote_add(args...) # Cannot free intermediate tensors when using AD # Thus we change the forward passes: `istemp=false` and `tensorfree!` is a no-op @@ -38,11 +40,11 @@ function ChainRulesCore.rrule( return output, tensoralloc_pullback end -# TODO: possibly use the non-inplace functions, to avoid depending on Base.copy function ChainRulesCore.rrule(::typeof(tensorscalar), C) + projectC = ProjectTo(C) function tensorscalar_pullback(Δc) - ΔC = TensorOperations.tensoralloc(typeof(C), TensorOperations.tensorstructure(C)) - return NoTangent(), fill!(ΔC, unthunk(Δc)) + _Δc = unthunk(Δc) + return NoTangent(), projectC(_Δc) end return tensorscalar(C), tensorscalar_pullback end @@ -95,7 +97,7 @@ function _rrule_tensoradd!(C, A, pA, conjA, α, β, ba) ipA = invperm(linearize(pA)) _dA = zerovector(A, VectorInterface.promote_add(ΔC, α)) _dA = tensoradd!(_dA, ΔC, (ipA, ()), conjA, conjA ? α : conj(α), Zero(), ba...) - return projectA(_dA) + projectA(_dA) end dα = @thunk let _dα = tensorscalar( @@ -105,7 +107,7 @@ function _rrule_tensoradd!(C, A, pA, conjA, α, β, ba) ((), ()), One(), ba... ) ) - return projectα(_dα) + projectα(_dα) end dβ = @thunk let # TODO: consider using `inner` @@ -116,7 +118,7 @@ function _rrule_tensoradd!(C, A, pA, conjA, α, β, ba) ((), ()), One(), ba... ) ) - return projectβ(_dβ) + projectβ(_dβ) end dba = map(_ -> NoTangent(), ba) return NoTangent(), dC, dA, NoTangent(), NoTangent(), dα, dβ, dba... @@ -194,7 +196,7 @@ function _rrule_tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, ba) ipA, conjA ? α : conj(α), Zero(), ba... ) - return projectA(_dA) + projectA(_dA) end dB = @thunk let ipB = (invperm(linearize(pB)), ()) @@ -208,7 +210,7 @@ function _rrule_tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, ba) ipB, conjB ? α : conj(α), Zero(), ba... ) - return projectB(_dB) + projectB(_dB) end dα = @thunk let C_αβ = tensorcontract(A, pA, conjA, B, pB, conjB, pAB, One(), ba...) @@ -220,7 +222,7 @@ function _rrule_tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, ba) ((), ()), One(), ba... ) ) - return projectα(_dα) + projectα(_dα) end dβ = @thunk let # TODO: consider using `inner` @@ -231,7 +233,7 @@ function _rrule_tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, ba) ((), ()), One(), ba... ) ) - return projectβ(_dβ) + projectβ(_dβ) end dba = map(_ -> NoTangent(), ba) return NoTangent(), dC, @@ -283,7 +285,7 @@ function _rrule_tensortrace!(C, A, p, q, conjA, α, β, ba) dA = @thunk let ip = invperm((linearize(p)..., q[1]..., q[2]...)) Es = map(q[1], q[2]) do i1, i2 - return one( + one( TensorOperations.tensoralloc_add( scalartype(A), A, ((i1,), (i2,)), conjA ) @@ -297,7 +299,7 @@ function _rrule_tensortrace!(C, A, p, q, conjA, α, β, ba) (ip, ()), conjA ? α : conj(α), Zero(), ba... ) - return projectA(_dA) + projectA(_dA) end dα = @thunk let C_αβ = tensortrace(A, p, q, false, One(), ba...) @@ -309,7 +311,7 @@ function _rrule_tensortrace!(C, A, p, q, conjA, α, β, ba) ((), ()), One(), ba... ) ) - return projectα(_dα) + projectα(_dα) end dβ = @thunk let _dβ = tensorscalar( @@ -319,7 +321,7 @@ function _rrule_tensortrace!(C, A, p, q, conjA, α, β, ba) ((), ()), One(), ba... ) ) - return projectβ(_dβ) + projectβ(_dβ) end dba = map(_ -> NoTangent(), ba) return NoTangent(), dC, dA, NoTangent(), NoTangent(), NoTangent(), dα, dβ, dba...