Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 16 additions & 14 deletions ext/TensorOperationsChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ trivtuple(N) = ntuple(identity, N)
@non_differentiable TensorOperations.tensorcontract_structure(args...)
@non_differentiable TensorOperations.tensorcontract_type(args...)
@non_differentiable TensorOperations.tensoralloc_contract(args...)
@non_differentiable TensorOperations.promote_contract(args...)
@non_differentiable TensorOperations.promote_add(args...)

# Cannot free intermediate tensors when using AD
# Thus we change the forward passes: `istemp=false` and `tensorfree!` is a no-op
Expand All @@ -38,11 +40,11 @@ function ChainRulesCore.rrule(
return output, tensoralloc_pullback
end

# TODO: possibly use the non-inplace functions, to avoid depending on Base.copy
function ChainRulesCore.rrule(::typeof(tensorscalar), C)
projectC = ProjectTo(C)
function tensorscalar_pullback(Δc)
ΔC = TensorOperations.tensoralloc(typeof(C), TensorOperations.tensorstructure(C))
return NoTangent(), fill!(ΔC, unthunk(Δc))
_Δc = unthunk(Δc)
return NoTangent(), projectC(_Δc)
end
return tensorscalar(C), tensorscalar_pullback
end
Expand Down Expand Up @@ -95,7 +97,7 @@ function _rrule_tensoradd!(C, A, pA, conjA, α, β, ba)
ipA = invperm(linearize(pA))
_dA = zerovector(A, VectorInterface.promote_add(ΔC, α))
_dA = tensoradd!(_dA, ΔC, (ipA, ()), conjA, conjA ? α : conj(α), Zero(), ba...)
return projectA(_dA)
projectA(_dA)
end
dα = @thunk let
_dα = tensorscalar(
Expand All @@ -105,7 +107,7 @@ function _rrule_tensoradd!(C, A, pA, conjA, α, β, ba)
((), ()), One(), ba...
)
)
return projectα(_dα)
projectα(_dα)
end
dβ = @thunk let
# TODO: consider using `inner`
Expand All @@ -116,7 +118,7 @@ function _rrule_tensoradd!(C, A, pA, conjA, α, β, ba)
((), ()), One(), ba...
)
)
return projectβ(_dβ)
projectβ(_dβ)
end
dba = map(_ -> NoTangent(), ba)
return NoTangent(), dC, dA, NoTangent(), NoTangent(), dα, dβ, dba...
Expand Down Expand Up @@ -194,7 +196,7 @@ function _rrule_tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, ba)
ipA,
conjA ? α : conj(α), Zero(), ba...
)
return projectA(_dA)
projectA(_dA)
end
dB = @thunk let
ipB = (invperm(linearize(pB)), ())
Expand All @@ -208,7 +210,7 @@ function _rrule_tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, ba)
ipB,
conjB ? α : conj(α), Zero(), ba...
)
return projectB(_dB)
projectB(_dB)
end
dα = @thunk let
C_αβ = tensorcontract(A, pA, conjA, B, pB, conjB, pAB, One(), ba...)
Expand All @@ -220,7 +222,7 @@ function _rrule_tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, ba)
((), ()), One(), ba...
)
)
return projectα(_dα)
projectα(_dα)
end
dβ = @thunk let
# TODO: consider using `inner`
Expand All @@ -231,7 +233,7 @@ function _rrule_tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, ba)
((), ()), One(), ba...
)
)
return projectβ(_dβ)
projectβ(_dβ)
end
dba = map(_ -> NoTangent(), ba)
return NoTangent(), dC,
Expand Down Expand Up @@ -283,7 +285,7 @@ function _rrule_tensortrace!(C, A, p, q, conjA, α, β, ba)
dA = @thunk let
ip = invperm((linearize(p)..., q[1]..., q[2]...))
Es = map(q[1], q[2]) do i1, i2
return one(
one(
TensorOperations.tensoralloc_add(
scalartype(A), A, ((i1,), (i2,)), conjA
)
Expand All @@ -297,7 +299,7 @@ function _rrule_tensortrace!(C, A, p, q, conjA, α, β, ba)
(ip, ()),
conjA ? α : conj(α), Zero(), ba...
)
return projectA(_dA)
projectA(_dA)
end
dα = @thunk let
C_αβ = tensortrace(A, p, q, false, One(), ba...)
Expand All @@ -309,7 +311,7 @@ function _rrule_tensortrace!(C, A, p, q, conjA, α, β, ba)
((), ()), One(), ba...
)
)
return projectα(_dα)
projectα(_dα)
end
dβ = @thunk let
_dβ = tensorscalar(
Expand All @@ -319,7 +321,7 @@ function _rrule_tensortrace!(C, A, p, q, conjA, α, β, ba)
((), ()), One(), ba...
)
)
return projectβ(_dβ)
projectβ(_dβ)
end
dba = map(_ -> NoTangent(), ba)
return NoTangent(), dC, dA, NoTangent(), NoTangent(), NoTangent(), dα, dβ, dba...
Expand Down
Loading