From 3344b45c93be111f4f63dda842bd427923578310 Mon Sep 17 00:00:00 2001 From: yitan1 Date: Thu, 4 Dec 2025 10:48:42 +0800 Subject: [PATCH 1/4] Fix rrule of `tensorscalar` to avoid inplace function and delete unnecessary `return` in rrule --- ext/TensorOperationsChainRulesCoreExt.jl | 36 +++++++++++++++--------- 1 file changed, 23 insertions(+), 13 deletions(-) diff --git a/ext/TensorOperationsChainRulesCoreExt.jl b/ext/TensorOperationsChainRulesCoreExt.jl index 3d7fbd4..913198b 100644 --- a/ext/TensorOperationsChainRulesCoreExt.jl +++ b/ext/TensorOperationsChainRulesCoreExt.jl @@ -18,6 +18,7 @@ trivtuple(N) = ntuple(identity, N) @non_differentiable TensorOperations.tensorcontract_structure(args...) @non_differentiable TensorOperations.tensorcontract_type(args...) @non_differentiable TensorOperations.tensoralloc_contract(args...) +@non_differentiable Base.promote_op(args...) # Cannot free intermediate tensors when using AD # Thus we change the forward passes: `istemp=false` and `tensorfree!` is a no-op @@ -40,9 +41,18 @@ end # TODO: possibly use the non-inplace functions, to avoid depending on Base.copy function ChainRulesCore.rrule(::typeof(tensorscalar), C) + projectC = ProjectTo(C) function tensorscalar_pullback(Δc) - ΔC = TensorOperations.tensoralloc(typeof(C), TensorOperations.tensorstructure(C)) - return NoTangent(), fill!(ΔC, unthunk(Δc)) + # ΔC = TensorOperations.tensoralloc(typeof(C), TensorOperations.tensorstructure(C)) + _Δc = unthunk(Δc) + # @show _Δc + # _Δc_scalar = + # _Δc isa Number ? Δc : + # _Δc isa AbstractArray ? only(Δc) : + # throw(ArgumentError("unexpected Δc: $(typeof(_Δc))")) + # @show projectC(_Δc_scalar) + return NoTangent(), projectC(_Δc) + # return NoTangent(), fill!(ΔC, unthunk(Δc)) end return tensorscalar(C), tensorscalar_pullback end @@ -95,7 +105,7 @@ function _rrule_tensoradd!(C, A, pA, conjA, α, β, ba) ipA = invperm(linearize(pA)) _dA = zerovector(A, VectorInterface.promote_add(ΔC, α)) _dA = tensoradd!(_dA, ΔC, (ipA, ()), conjA, conjA ? α : conj(α), Zero(), ba...) - return projectA(_dA) + projectA(_dA) end dα = @thunk let _dα = tensorscalar( @@ -105,7 +115,7 @@ function _rrule_tensoradd!(C, A, pA, conjA, α, β, ba) ((), ()), One(), ba... ) ) - return projectα(_dα) + projectα(_dα) end dβ = @thunk let # TODO: consider using `inner` @@ -116,7 +126,7 @@ function _rrule_tensoradd!(C, A, pA, conjA, α, β, ba) ((), ()), One(), ba... ) ) - return projectβ(_dβ) + projectβ(_dβ) end dba = map(_ -> NoTangent(), ba) return NoTangent(), dC, dA, NoTangent(), NoTangent(), dα, dβ, dba... @@ -194,7 +204,7 @@ function _rrule_tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, ba) ipA, conjA ? α : conj(α), Zero(), ba... ) - return projectA(_dA) + projectA(_dA) end dB = @thunk let ipB = (invperm(linearize(pB)), ()) @@ -208,7 +218,7 @@ function _rrule_tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, ba) ipB, conjB ? α : conj(α), Zero(), ba... ) - return projectB(_dB) + projectB(_dB) end dα = @thunk let C_αβ = tensorcontract(A, pA, conjA, B, pB, conjB, pAB, One(), ba...) @@ -220,7 +230,7 @@ function _rrule_tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, ba) ((), ()), One(), ba... ) ) - return projectα(_dα) + projectα(_dα) end dβ = @thunk let # TODO: consider using `inner` @@ -231,7 +241,7 @@ function _rrule_tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, ba) ((), ()), One(), ba... ) ) - return projectβ(_dβ) + projectβ(_dβ) end dba = map(_ -> NoTangent(), ba) return NoTangent(), dC, @@ -283,7 +293,7 @@ function _rrule_tensortrace!(C, A, p, q, conjA, α, β, ba) dA = @thunk let ip = invperm((linearize(p)..., q[1]..., q[2]...)) Es = map(q[1], q[2]) do i1, i2 - return one( + one( TensorOperations.tensoralloc_add( scalartype(A), A, ((i1,), (i2,)), conjA ) @@ -297,7 +307,7 @@ function _rrule_tensortrace!(C, A, p, q, conjA, α, β, ba) (ip, ()), conjA ? α : conj(α), Zero(), ba... ) - return projectA(_dA) + projectA(_dA) end dα = @thunk let C_αβ = tensortrace(A, p, q, false, One(), ba...) @@ -309,7 +319,7 @@ function _rrule_tensortrace!(C, A, p, q, conjA, α, β, ba) ((), ()), One(), ba... ) ) - return projectα(_dα) + projectα(_dα) end dβ = @thunk let _dβ = tensorscalar( @@ -319,7 +329,7 @@ function _rrule_tensortrace!(C, A, p, q, conjA, α, β, ba) ((), ()), One(), ba... ) ) - return projectβ(_dβ) + projectβ(_dβ) end dba = map(_ -> NoTangent(), ba) return NoTangent(), dC, dA, NoTangent(), NoTangent(), NoTangent(), dα, dβ, dba... From 14d823f52e1f44b4401ccd8efccb1f7c69cbb52d Mon Sep 17 00:00:00 2001 From: yitan1 Date: Thu, 4 Dec 2025 11:03:53 +0800 Subject: [PATCH 2/4] remove TODO of tensorscalar --- ext/TensorOperationsChainRulesCoreExt.jl | 9 --------- 1 file changed, 9 deletions(-) diff --git a/ext/TensorOperationsChainRulesCoreExt.jl b/ext/TensorOperationsChainRulesCoreExt.jl index 913198b..6b7c090 100644 --- a/ext/TensorOperationsChainRulesCoreExt.jl +++ b/ext/TensorOperationsChainRulesCoreExt.jl @@ -39,20 +39,11 @@ function ChainRulesCore.rrule( return output, tensoralloc_pullback end -# TODO: possibly use the non-inplace functions, to avoid depending on Base.copy function ChainRulesCore.rrule(::typeof(tensorscalar), C) projectC = ProjectTo(C) function tensorscalar_pullback(Δc) - # ΔC = TensorOperations.tensoralloc(typeof(C), TensorOperations.tensorstructure(C)) _Δc = unthunk(Δc) - # @show _Δc - # _Δc_scalar = - # _Δc isa Number ? Δc : - # _Δc isa AbstractArray ? only(Δc) : - # throw(ArgumentError("unexpected Δc: $(typeof(_Δc))")) - # @show projectC(_Δc_scalar) return NoTangent(), projectC(_Δc) - # return NoTangent(), fill!(ΔC, unthunk(Δc)) end return tensorscalar(C), tensorscalar_pullback end From e509e2b00be92bdc42db706578a844a7ee58051c Mon Sep 17 00:00:00 2001 From: yitan1 Date: Thu, 4 Dec 2025 11:26:12 +0800 Subject: [PATCH 3/4] style: apply Runic formatting --- ext/TensorOperationsChainRulesCoreExt.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ext/TensorOperationsChainRulesCoreExt.jl b/ext/TensorOperationsChainRulesCoreExt.jl index 6b7c090..6ddd5dc 100644 --- a/ext/TensorOperationsChainRulesCoreExt.jl +++ b/ext/TensorOperationsChainRulesCoreExt.jl @@ -40,7 +40,7 @@ function ChainRulesCore.rrule( end function ChainRulesCore.rrule(::typeof(tensorscalar), C) - projectC = ProjectTo(C) + projectC = ProjectTo(C) function tensorscalar_pullback(Δc) _Δc = unthunk(Δc) return NoTangent(), projectC(_Δc) @@ -284,7 +284,7 @@ function _rrule_tensortrace!(C, A, p, q, conjA, α, β, ba) dA = @thunk let ip = invperm((linearize(p)..., q[1]..., q[2]...)) Es = map(q[1], q[2]) do i1, i2 - one( + one( TensorOperations.tensoralloc_add( scalartype(A), A, ((i1,), (i2,)), conjA ) From a7070733e371e42423ed3cd64d665f0905faf275 Mon Sep 17 00:00:00 2001 From: yitan1 Date: Mon, 15 Dec 2025 13:06:19 +0800 Subject: [PATCH 4/4] fix: update promote_op to use TensorOperations functions for consistency --- ext/TensorOperationsChainRulesCoreExt.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ext/TensorOperationsChainRulesCoreExt.jl b/ext/TensorOperationsChainRulesCoreExt.jl index 6ddd5dc..a838917 100644 --- a/ext/TensorOperationsChainRulesCoreExt.jl +++ b/ext/TensorOperationsChainRulesCoreExt.jl @@ -18,7 +18,8 @@ trivtuple(N) = ntuple(identity, N) @non_differentiable TensorOperations.tensorcontract_structure(args...) @non_differentiable TensorOperations.tensorcontract_type(args...) @non_differentiable TensorOperations.tensoralloc_contract(args...) -@non_differentiable Base.promote_op(args...) +@non_differentiable TensorOperations.promote_contract(args...) +@non_differentiable TensorOperations.promote_add(args...) # Cannot free intermediate tensors when using AD # Thus we change the forward passes: `istemp=false` and `tensorfree!` is a no-op