diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 8880dfcf1..d8ab9b0b2 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -31,6 +31,7 @@ jobs: - tensors - other - mooncake + - enzyme - chainrules os: - ubuntu-latest @@ -57,6 +58,7 @@ jobs: - tensors - other - mooncake + - enzyme - chainrules os: - ubuntu-latest diff --git a/Project.toml b/Project.toml index 7e82876c5..a3754950b 100644 --- a/Project.toml +++ b/Project.toml @@ -21,6 +21,7 @@ VectorInterface = "409d34a3-91d5-4945-b6ec-7529ddf182d8" Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1" @@ -29,6 +30,7 @@ cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1" TensorKitAdaptExt = "Adapt" TensorKitCUDAExt = ["CUDA", "cuTENSOR"] TensorKitChainRulesCoreExt = "ChainRulesCore" +TensorKitEnzymeExt = "Enzyme" TensorKitFiniteDifferencesExt = "FiniteDifferences" TensorKitMooncakeExt = "Mooncake" @@ -41,6 +43,8 @@ CUDA = "5.9" ChainRulesCore = "1" ChainRulesTestUtils = "1" Combinatorics = "1" +Enzyme = "0.13.131" +EnzymeTestUtils = "0.2.5" FiniteDifferences = "0.12" GPUArrays = "11.3.1" JET = "0.9, 0.10, 0.11" @@ -53,7 +57,7 @@ Printf = "1" Random = "1" SafeTestsets = "0.1" ScopedValues = "1.3.0" -Strided = "2" +Strided = "=2.3.3" TensorKitSectors = "0.3.5" TensorOperations = "5.1" Test = "1" @@ -73,6 +77,8 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +EnzymeTestUtils = "12d8515a-0907-448a-8884-5fe00fdf1c5a" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" @@ -86,4 +92,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1" [targets] -test = ["ArgParse", "Adapt", "Aqua", "AllocCheck", "Combinatorics", "CUDA", "cuTENSOR", "GPUArrays", "LinearAlgebra", "SafeTestsets", "TensorOperations", "Test", "TestExtras", "ChainRulesCore", "ChainRulesTestUtils", "FiniteDifferences", "Zygote", "Mooncake", "JET"] +test = ["ArgParse", "Adapt", "Aqua", "AllocCheck", "Combinatorics", "CUDA", "cuTENSOR", "GPUArrays", "LinearAlgebra", "SafeTestsets", "TensorOperations", "Test", "TestExtras", "ChainRulesCore", "ChainRulesTestUtils", "FiniteDifferences", "Zygote", "Mooncake", "Enzyme", "EnzymeTestUtils", "JET"] + +[sources] +TensorOperations = {url = "https://github.com/quantumkithub/tensoroperations.jl", rev = "ksh/enzyme_update"} diff --git a/ext/TensorKitEnzymeExt/TensorKitEnzymeExt.jl b/ext/TensorKitEnzymeExt/TensorKitEnzymeExt.jl new file mode 100644 index 000000000..ab3061795 --- /dev/null +++ b/ext/TensorKitEnzymeExt/TensorKitEnzymeExt.jl @@ -0,0 +1,21 @@ +module TensorKitEnzymeExt + +using Enzyme +using TensorKit +import TensorKit as TK +using VectorInterface +using TensorOperations: TensorOperations, IndexTuple, Index2Tuple, linearize +import TensorOperations as TO +using MatrixAlgebraKit +using TupleTools +using Random: AbstractRNG + +include("utility.jl") +include("linalg.jl") +include("vectorinterface.jl") +include("tensoroperations.jl") +include("factorizations.jl") +include("indexmanipulations.jl") +#include("planaroperations.jl") + +end diff --git a/ext/TensorKitEnzymeExt/factorizations.jl b/ext/TensorKitEnzymeExt/factorizations.jl new file mode 100644 index 000000000..4314f7900 --- /dev/null +++ b/ext/TensorKitEnzymeExt/factorizations.jl @@ -0,0 +1,134 @@ +function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(MatrixAlgebraKit.copy_input)}, + ::Type{RT}, + cache, + f::Annotation, + A::Annotation{<:AbstractTensorMap} + ) where {RT} + copy_shadow = cache + if !isa(A, Const) && !isnothing(copy_shadow) + add!(A.dval, copy_shadow) + end + return (nothing, nothing) +end + +for (f, pb) in ( + (:eig_full, :(MatrixAlgebraKit.eig_pullback!)), + (:eigh_full, :(MatrixAlgebraKit.eigh_pullback!)), + (:lq_compact, :(MatrixAlgebraKit.lq_pullback!)), + (:qr_compact, :(MatrixAlgebraKit.qr_pullback!)), + ) + @eval begin + function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof($f)}, + ::Type{RT}, + A::Annotation{<:AbstractTensorMap}, + alg::Const, + ) where {RT} + ret = $f(A.val, alg.val) + dret = make_zero(ret) + cache = (ret, dret) + return EnzymeRules.AugmentedReturn(ret, dret, cache) + end + function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof($f)}, + ::Type{RT}, + cache, + A::Annotation{<:AbstractTensorMap}, + alg::Const, + ) where {RT} + ret, dret = cache + $pb(A.dval, A.val, ret, dret) + return (nothing, nothing) + end + end +end + +for f in (:svd_compact, :svd_full) + @eval begin + function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof($f)}, + ::Type{RT}, + A::Annotation{<:AbstractTensorMap}, + alg::Const, + ) where {RT} + USVᴴ = $f(A.val, alg.val) + dUSVᴴ = make_zero(USVᴴ) + cache = (USVᴴ, dUSVᴴ) + return EnzymeRules.AugmentedReturn(USVᴴ, dUSVᴴ, cache) + end + function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof($f)}, + ::Type{RT}, + cache, + A::Annotation{<:AbstractTensorMap}, + alg::Const, + ) where {RT} + USVᴴ, dUSVᴴ = cache + MatrixAlgebraKit.svd_pullback!(A.dval, A.val, USVᴴ, dUSVᴴ) + return (nothing, nothing) + end + end + + # mutating version is not guaranteed to actually mutate + # so we can simply use the non-mutating version instead + f! = Symbol(f, :!) + #=@eval begin + function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof($f!)}, + ::Type{RT}, + A::Annotation{<:AbstractTensorMap}, + USVᴴ::Annotation, + alg::Const, + ) where {RT} + EnzymeRules.augmented_primal(func, RT, A, alg) + end + function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof($f!)}, + ::Type{RT}, + cache, + A::Annotation{<:AbstractTensorMap}, + USVᴴ::Annotation, + alg::Const, + ) where {RT} + EnzymeRules.reverse(func, RT, A, alg) + end + end=# #hmmmm +end + +# TODO +#= +function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(svd_trunc)}, + ::Type{RT}, + A::Annotation{<:AbstractTensorMap}, + alg::Const, + ) where {RT} + + USVᴴ = svd_compact(A.val, alg.val.alg) + USVᴴtrunc, ind = MatrixAlgebraKit.truncate(svd_trunc!, USVᴴ, alg.val.trunc) + ϵ = MatrixAlgebraKit.truncation_error(diagview(USVᴴ[2]), ind) + dUSVᴴtrunc = make_zero(USVᴴtrunc) + cache = (USVᴴtrunc, dUSVᴴtrunc) + return EnzymeRules.AugmentedReturn(USVᴴtrunc, dUSVᴴtrunc, cache) +end +function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(svd_trunc)}, + ::Type{RT}, + cache, + A::Annotation{<:AbstractTensorMap}, + alg::Const, +) where {RT} + USVᴴ, dUSVᴴ = cache + MatrixAlgebraKit.svd_pullback!(A.dval, A.val, USVᴴ, dUSVᴴ) + return (nothing, nothing) +end=# diff --git a/ext/TensorKitEnzymeExt/indexmanipulations.jl b/ext/TensorKitEnzymeExt/indexmanipulations.jl new file mode 100644 index 000000000..5e2c5f580 --- /dev/null +++ b/ext/TensorKitEnzymeExt/indexmanipulations.jl @@ -0,0 +1,308 @@ +for transform in (:permute, :transpose) + add_transform! = Symbol(:add_, transform, :!) + @eval function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(TK.$add_transform!)}, + ::Type{RT}, + C::Annotation{<:AbstractTensorMap}, + A::Annotation{<:AbstractTensorMap}, + p::Const{<:Index2Tuple}, + α::Annotation{<:Number}, + β::Annotation{<:Number}, + ba::Const... + ) where {RT} + C_cache = !isa(β, Const) ? copy(C.val) : nothing + A_cache = EnzymeRules.overwritten(config)[3] ? copy(A.val) : nothing + # if we need to compute Δa, it is faster to allocate an intermediate permuted A + # and store that instead of repeating the permutation in the pullback each time. + # effectively, we replace `add_permute` by `add ∘ permute`. + Ap = if !isa(α, Const) + Ap = $transform(A.val, p.val) + add!(C.val, Ap, α.val, β.val) + Ap + else + bavs = map(a -> a.val, ba) + TK.$add_transform!(C.val, A.val, p.val, α.val, β.val, bavs...) + nothing + end + cache = (C_cache, A_cache, Ap) + primal = EnzymeRules.needs_primal(config) ? C.val : nothing + shadow = EnzymeRules.needs_shadow(config) ? C.dval : nothing + return EnzymeRules.AugmentedReturn(primal, shadow, cache) + end + @eval function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(TK.$add_transform!)}, + ::Type{RT}, + cache, + C::Annotation{<:AbstractTensorMap}, + A::Annotation{<:AbstractTensorMap}, + p::Const{<:Index2Tuple}, + α::Annotation{<:Number}, + β::Annotation{<:Number}, + ba::Const... + ) where {RT} + C_cache, A_cache, Ap = cache + Cval = something(C_cache, C.val) + Aval = something(A_cache, A.val) + # ΔA + if !isa(A, Const) && !isa(C, Const) + ip = invperm(linearize(p.val)) + pΔA = _repartition(ip, Aval) + TC = VectorInterface.promote_scale(C.val, α.val) + bavs = map(a -> a.val, ba) + if scalartype(A.dval) <: Real && !(TC <: Real) + ΔAc = TO.tensoralloc_add(TC, C.dval, pΔA, false, Val(false)) + TK.$add_transform!(ΔAc, C.dval, pΔA, conj(α.val), Zero(), bavs...) + add!(A.dval, real(ΔAc)) + else + TK.$add_transform!(A.dval, C.dval, pΔA, conj(α.val), One(), bavs...) + end + end + Δαr = if !isnothing(Ap) && !isa(C, Const) + project_scalar(α.val, inner(Ap, C.dval)) + elseif !isnothing(Ap) + zero(α.val) + else + nothing + end + Δβr = if !isa(C, Const) && !isa(β, Const) + pullback_dβ(C.dval, Cval, β) + elseif !isa(β, Const) + zero(β.val) + else + nothing + end + !isa(C, Const) && pullback_dC!(C.dval, β.val) + return nothing, nothing, nothing, Δαr, Δβr, map(Returns(nothing), ba)... + end +end + +function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(TK.add_braid!)}, + ::Type{RT}, + C::Annotation{<:AbstractTensorMap}, + A::Annotation{<:AbstractTensorMap}, + p::Const{<:Index2Tuple}, + levels::Const{<:IndexTuple}, + α::Annotation{<:Number}, + β::Annotation{<:Number}, + ba::Const... + ) where {RT} + C_cache = !isa(β, Const) ? copy(C.val) : nothing + A_cache = EnzymeRules.overwritten(config)[3] ? copy(A.val) : nothing + # if we need to compute Δa, it is faster to allocate an intermediate braided A + # and store that instead of repeating the permutation in the pullback each time. + # effectively, we replace `add_permute` by `add ∘ permute`. + Ap = if !isa(α, Const) + Ap = braid(A.val, p.val, levels.val) + add!(C.val, Ap, α.val, β.val) + Ap + else + bavs = map(a -> a.val, ba) + TK.add_braid!(C.val, A.val, p.val, levels.val, α.val, β.val, bavs...) + nothing + end + cache = (C_cache, A_cache, Ap) + primal = EnzymeRules.needs_primal(config) ? C.val : nothing + shadow = EnzymeRules.needs_shadow(config) ? C.dval : nothing + return EnzymeRules.AugmentedReturn(primal, shadow, cache) +end +function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(TK.add_braid!)}, + ::Type{RT}, + cache, + C::Annotation{<:AbstractTensorMap}, + A::Annotation{<:AbstractTensorMap}, + p::Const{<:Index2Tuple}, + levels::Const{<:IndexTuple}, + α::Annotation{<:Number}, + β::Annotation{<:Number}, + ba::Const... + ) where {RT} + C_cache, A_cache, Ap = cache + Cval = something(C_cache, C.val) + Aval = something(A_cache, A.val) + # ΔA + if !isa(A, Const) && !isa(C, Const) + ip = invperm(linearize(p.val)) + pΔA = _repartition(ip, Aval) + ilevels = TupleTools.permute(levels.val, linearize(p.val)) + TC = VectorInterface.promote_scale(C.dval, α.val) + bavs = map(a -> a.val, ba) + if scalartype(A.dval) <: Real && !(TC <: Real) + ΔAc = TO.tensoralloc_add(TC, C.dval, pΔA, false, Val(false)) + TK.add_braid!(ΔAc, C.dval, pΔA, ilevels, conj(α.val), Zero(), bavs...) + add!(A.dval, real(ΔAc)) + else + TK.add_braid!(A.dval, C.dval, pΔA, ilevels, conj(α.val), One(), bavs...) + end + end + Δαr = if !isnothing(Ap) && !isa(C, Const) + project_scalar(α.val, inner(Ap, C.dval)) + elseif !isnothing(Ap) + zero(α.val) + else + nothing + end + Δβr = if !isa(C, Const) && !isa(β, Const) + pullback_dβ(C.dval, Cval, β) + elseif !isa(β, Const) + zero(β.val) + else + nothing + end + !isa(C, Const) && pullback_dC!(C.dval, β.val) + return nothing, nothing, nothing, nothing, Δαr, Δβr, map(Returns(nothing), ba)... +end + +function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(twist!)}, + ::Type{RT}, + t::Annotation{<:AbstractTensorMap}, + inds::Const; + inv::Bool = false + ) where {RT} + twist!(t.val, inds.val; inv) + primal = EnzymeRules.needs_primal(config) ? t.val : nothing + shadow = EnzymeRules.needs_shadow(config) ? t.dval : nothing + return EnzymeRules.AugmentedReturn(primal, shadow, nothing) +end + +function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(twist!)}, + ::Type{RT}, + cache, + t::Annotation{<:AbstractTensorMap}, + inds::Const; + inv::Bool = false + ) where {RT} + !isa(t, Const) && twist!(t.dval, inds.val; inv = !inv) + return (nothing, nothing) +end + +function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(flip)}, + ::Type{RT}, + t::Annotation{<:AbstractTensorMap}, + inds::Const; + inv::Bool = false + ) where {RT} + t′ = flip(t.val, inds.val; inv) + dt′ = make_zero(t′) + cache = dt′ + primal = EnzymeRules.needs_primal(config) ? t′ : nothing + shadow = EnzymeRules.needs_shadow(config) ? dt′ : nothing + return EnzymeRules.AugmentedReturn(primal, shadow, cache) +end + +function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(flip)}, + ::Type{RT}, + cache, + t::Annotation{<:AbstractTensorMap}, + inds::Const; + inv::Bool = false, + ) where {RT} + dt′ = cache + if !isa(t, Const) + dt′′ = flip(dt′, inds.val; inv = !inv) + add!(t.dval, scalartype(t.dval) <: Real ? real(dt′′) : dt′′) + end + return (nothing, nothing) +end + +for insertunit in (:insertleftunit, :insertrightunit) + @eval begin + function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof($insertunit)}, + ::Type{RT}, + tsrc::Annotation{<:AbstractTensorMap}, + ival::Const{<:Val}; + kwargs... + ) where {RT} + if tsrc.val isa TensorMap && !get(kwargs, :copy, false) && !isa(tsrc, Const) + tsrc_cache = copy(tsrc.val) + tdst = $insertunit(tsrc.val, ival.val; kwargs...) + Δtdst = $insertunit(tsrc.dval, ival.val; kwargs...) + else + tsrc_cache = nothing + tdst = $insertunit(tsrc.val, ival.val; kwargs...) + Δtdst = make_zero(tdst) + end + primal = EnzymeRules.needs_primal(config) ? tdst : nothing + shadow = EnzymeRules.needs_shadow(config) ? Δtdst : nothing + cache = (tsrc_cache, tdst, Δtdst) + return EnzymeRules.AugmentedReturn(primal, shadow, cache) + end + function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof($insertunit)}, + ::Type{RT}, + cache, + tsrc::Annotation{<:AbstractTensorMap}, + ival::Const{<:Val}; + kwargs... + ) where {RT} + tsrc_cache, tdst, Δtdst = cache + # note: since data is already shared for <:TensorMap, don't have to do anything here! + if isnothing(tsrc_cache) && !isa(tsrc, Const) + for (c, b) in blocks(Δtdst) + add!(block(tsrc.dval, c), b) + end + end + return (nothing, nothing) + end + end +end + +function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(removeunit)}, + ::Type{RT}, + tsrc::Annotation{<:AbstractTensorMap}, + ival::Const{<:Val}; + kwargs... + ) where {RT} + # tdst shares data with tsrc if <:TensorMap & copy=false, in this case we have to deal with correctly + # sharing address spaces + if tsrc.val isa TensorMap && !get(kwargs, :copy, false) && !isa(tsrc, Const) + tsrc_cache = copy(tsrc.val) + tdst = removeunit(tsrc.val, ival.val; kwargs...) + Δtdst = removeunit(tsrc.dval, ival.val) + else + tsrc_cache = nothing + tdst = removeunit(tsrc.val, ival.val; kwargs...) + Δtdst = make_zero(tdst) + end + primal = EnzymeRules.needs_primal(config) ? tdst : nothing + shadow = EnzymeRules.needs_shadow(config) ? Δtdst : nothing + cache = (tsrc_cache, tdst, Δtdst) + return EnzymeRules.AugmentedReturn(primal, shadow, cache) +end + +function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(removeunit)}, + ::Type{RT}, + cache, + tsrc::Annotation{<:AbstractTensorMap}, + ival::Const{<:Val}; + kwargs... + ) where {RT} + tsrc_cache, tdst, Δtdst = cache + # note: since data for <: TensorMap is already shared, don't have to do anything here! + if isnothing(tsrc_cache) && !isa(tsrc, Const) + for (c, b) in blocks(Δtdst) + add!(block(tsrc.dval, c), b) + end + end + return (nothing, nothing) +end diff --git a/ext/TensorKitEnzymeExt/linalg.jl b/ext/TensorKitEnzymeExt/linalg.jl new file mode 100644 index 000000000..a6848bb9a --- /dev/null +++ b/ext/TensorKitEnzymeExt/linalg.jl @@ -0,0 +1,137 @@ +# Shared +# ------ +pullback_dC!(ΔC, β) = scale!(ΔC, conj(β)) +pullback_dβ(ΔC, C, β) = !isa(β, Const) ? project_scalar(β.val, inner(C, ΔC)) : nothing + +# Can Enzyme do this itself? Apparently not... +function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(mul!)}, + ::Type{RT}, + C::Annotation{<:AbstractTensorMap}, + A::Annotation{<:AbstractTensorMap}, + B::Annotation{<:AbstractTensorMap}, + α::Annotation, + β::Annotation, + ) where {RT} + cacheC = !isa(β, Const) && copy(C.val) + cacheA = !isa(B, Const) && EnzymeRules.overwritten(config)[3] ? copy(A.val) : nothing + cacheB = !isa(A, Const) && EnzymeRules.overwritten(config)[4] ? copy(B.val) : nothing + AB = if !isa(α, Const) + AB = A.val * B.val + add!(C.val, AB, α.val, β.val) + AB + else + mul!(C.val, A.val, B.val, α.val, β.val) + nothing + end + primal = EnzymeRules.needs_primal(config) ? C.val : nothing + shadow = EnzymeRules.needs_shadow(config) ? C.dval : nothing + cache = (cacheC, cacheA, cacheB, AB) + return EnzymeRules.AugmentedReturn(primal, shadow, cache) +end + +function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(mul!)}, + ::Type{RT}, + cache, + C::Annotation{<:AbstractTensorMap}, + A::Annotation{<:AbstractTensorMap}, + B::Annotation{<:AbstractTensorMap}, + α::Annotation{<:Number}, + β::Annotation{<:Number}, + ) where {RT} + if RT <: Const + Δα = isa(α, Const) ? nothing : zero(α.val) + Δβ = isa(β, Const) ? nothing : zero(β.val) + return (nothing, nothing, nothing, Δα, Δβ) + end + cacheC, cacheA, cacheB, AB = cache + Cval = something(cacheC, C.val) + Aval = something(cacheA, A.val) + Bval = something(cacheB, B.val) + + !isa(A, Const) && !isa(C, Const) && project_mul!(A.dval, C.dval, Bval', conj(α.val)) + !isa(B, Const) && !isa(C, Const) && project_mul!(B.dval, Aval', C.dval, conj(α.val)) + Δαr = if !isnothing(AB) && !isa(C, Const) + project_scalar(α.val, inner(AB, C.dval)) + elseif !isnothing(AB) + zero(α.val) + else + nothing + end + Δβr = if !isa(β, Const) && !isa(C, Const) + pullback_dβ(C.dval, Cval, β) + elseif !isa(β, Const) + zero(β.val) + else + nothing + end + !isa(C, Const) && pullback_dC!(C.dval, β.val) + + return (nothing, nothing, nothing, Δαr, Δβr) +end + +function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(tr)}, + ::Type{RT}, + A::Annotation{<:AbstractTensorMap}, + ) where {RT} + ret = func.val(A.val) + primal = EnzymeRules.needs_primal(config) ? ret : nothing + shadow = EnzymeRules.needs_shadow(config) ? zero(ret) : nothing + cache = EnzymeRules.overwritten(config)[2] ? copy(A.val) : nothing + return EnzymeRules.AugmentedReturn(primal, shadow, cache) +end +function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(tr)}, + dret::Active, + cache, + A::Annotation{<:AbstractTensorMap}, + ) + Aval = something(cache, A.val) + Δtrace = dret.val + if !isa(A, Const) + for (_, b) in blocks(A.dval) + TensorKit.diagview(b) .+= Δtrace + end + end + return (nothing,) +end +function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(tr)}, + ::Type{<:Const}, + cache, + A::Annotation{<:AbstractTensorMap}, + ) + return (nothing,) +end + +function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(inv)}, + ::Type{RT}, + A::Annotation{<:AbstractTensorMap}, + ) where {RT} + ret = inv(A.val) + primal = EnzymeRules.needs_primal(config) ? ret : nothing + shadow = EnzymeRules.needs_shadow(config) ? zero(ret) : nothing + cache = (ret, shadow) + return EnzymeRules.AugmentedReturn(primal, shadow, cache) +end + +function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(inv)}, + ::Type{RT}, + cache, + A::Annotation{<:AbstractTensorMap}, + ) where {RT} + Ainv, ΔAinv = cache + !isa(A, Const) && mul!(A.dval, Ainv' * ΔAinv, Ainv', -1, One()) + return (nothing,) +end diff --git a/ext/TensorKitEnzymeExt/planaroperations.jl b/ext/TensorKitEnzymeExt/planaroperations.jl new file mode 100644 index 000000000..ef5c24019 --- /dev/null +++ b/ext/TensorKitEnzymeExt/planaroperations.jl @@ -0,0 +1,101 @@ +# planartrace! +# ------------ +# TODO: Fix planartrace pullback +# This implementation is slightly more involved than its non-planar counterpart +# this is because we lack a general `pAB` argument in `planarcontract`, and need +# to keep things planar along the way. +# In particular, we can't simply tensor product with multiple identities in one go +# if they aren't "contiguous", e.g. p = ((1, 4, 5), ()), q = ((2, 6), (3, 7)) + +function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + ::Const{typeof(TensorKit.planartrace!)}, + ::Type{RT}, + C::Annotation{<:AbstractTensorMap}, + A::Annotation{<:AbstractTensorMap}, + p::Const{<:Index2Tuple}, + q::Const{<:Index2Tuple}, + α::Annotation{<:Number}, + β::Annotation{<:Number}, + backend::Const, allocator::Const + ) + cacheC = !isa(β, Const) && copy(C.val) + cacheA = EnzymeRules.overwritten(config)[3] ? copy(A.val) : nothing + + TensorKit.planartrace!(C.val, A.val, p.val, q.val, α.val, β.val, backend.val, allocator.val) + primal = EnzymeRules.needs_primal(config) ? C.val : nothing + shadow = EnzymeRules.needs_shadow(config) ? C.dval : nothing + return EnzymeRules.AugmentedReturn(primal, shadow, (cacheC, cacheA)) +end +function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + ::Const{typeof(TensorKit.planartrace!)}, + ::Type{RT}, + cache, + C::Annotation{<:AbstractTensorMap}, + A::Annotation{<:AbstractTensorMap}, + p::Const{<:Index2Tuple}, + q::Const{<:Index2Tuple}, + α::Annotation{<:Number}, + β::Annotation{<:Number}, + backend::Const, allocator::Const + ) + cacheC, cacheA = cache + Cval = something(cacheC, C.val) + Aval = something(cacheA, A.val) + + if !isa(A, Const) && !isa(C, Const) + planartrace_pullback_ΔA!(A.dval, C.dval, Aval, p.val, q.val, α.val, backend.val, allocator.val) + end + Δαr = if !isa(α, Const) && !isa(C, Const) + planartrace_pullback_Δα(C.dval, A.val, p.val, q.val, α.val, backend.val, allocator.val) + elseif !isa(α, Const) + zero(α.val) + else + nothing + end + Δβr = if !isa(β, Const) && !isa(C, Const) + pullback_dβ(C.dval, C.val, β) + elseif !isa(β, Const) + zero(β.val) + else + nothing + end + !isa(C, Const) && pullback_dC!(C.dval, β.val) + + return nothing, nothing, nothing, nothing, Δαr, Δβr, nothing, nothing +end + +function planartrace_pullback_dA!( + ΔA, ΔC, A, p, q, α, backend, allocator + ) + if length(q[1]) == 0 + ip = invperm(linearize(p)) + pΔA = _repartition(ip, A) + TK.add_transpose!(ΔA, ΔC, pΔA, conj(α), One(), backend, allocator) + return nothing + end + if length(q[1]) == 1 + ip = invperm((p[1]..., q[2]..., p[2]..., q[1]...)) + pdA = _repartition(ip, A) + E = one!(TO.tensoralloc_add(scalartype(A), A, q, false)) + twist!(E, filter(x -> !isdual(space(E, x)), codomainind(E))) + pE = ((), trivtuple(TO.numind(q))) + pΔC = (trivtuple(TO.numind(p)), ()) + TensorKit.planaradd!(ΔA, ΔC ⊗ E, pdA, conj(α), One(), backend, allocator) + return nothing + end + error("The reverse rule for `planartrace` is not yet implemented") +end + +function planartrace_pullback_dα( + ΔC, A, p, q, α, backend, allocator + ) + # TODO: this result might be easier to compute as: + # C′ = βC + α * trace(A) ⟹ At = (C′ - βC) / α + At = TO.tensoralloc_add(scalartype(A), A, p, false, Val(true), allocator) + TensorKit.planartrace!(At, A, p, q, One(), Zero(), backend, allocator) + Δα = project_scalar(α, inner(At, ΔC)) + TO.tensorfree!(At, allocator) + return Δα +end diff --git a/ext/TensorKitEnzymeExt/tensoroperations.jl b/ext/TensorKitEnzymeExt/tensoroperations.jl new file mode 100644 index 000000000..661fc661e --- /dev/null +++ b/ext/TensorKitEnzymeExt/tensoroperations.jl @@ -0,0 +1,213 @@ +# tensorcontract! +# --------------- +# TODO: it might be beneficial to compare here if it would make sense to simply compute the +# rrule of permute-permute-gemm-permute, rather than using the contractions directly. +# This could possibly out save some permutations being carried out twice, at the cost of having +# to store some more intermediate objects. +# For example, the combination `ΔC, pΔC, false` appears in the pullback for ΔA and ΔB, so effectively +# this permutation is done multiple times. + +function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(TensorKit.blas_contract!)}, + ::Type{RT}, + C::Annotation{<:AbstractTensorMap}, + A::Annotation{<:AbstractTensorMap}, + pA::Const{<:Index2Tuple}, + B::Annotation{<:AbstractTensorMap}, + pB::Const{<:Index2Tuple}, + pAB::Const{<:Index2Tuple}, + α::Annotation{<:Number}, + β::Annotation{<:Number}, + backend::Const, + allocator::Const + ) where {RT} + Ccache = isa(β, Const) ? nothing : copy(C.val) + A_needs_cache = EnzymeRules.overwritten(config)[3] && !(typeof(B) <: Const) && !(typeof(C) <: Const) + Acache = A_needs_cache ? copy(A.val) : nothing + B_needs_cache = EnzymeRules.overwritten(config)[5] && !(typeof(A) <: Const) && !(typeof(C) <: Const) + Bcache = B_needs_cache ? copy(B.val) : nothing + AB = if !isa(α, Const) + AB = TO.tensorcontract(A.val, pA.val, false, B.val, pB.val, false, pAB.val, One(), backend.val, allocator.val) + add!(C.val, AB, α.val, β.val) + AB + else + TensorKit.blas_contract!(C.val, A.val, pA.val, B.val, pB.val, pAB.val, α.val, β.val, backend.val, allocator.val) + nothing + end + primal = EnzymeRules.needs_primal(config) ? C.val : nothing + shadow = EnzymeRules.needs_shadow(config) ? C.dval : nothing + cache = (Ccache, Acache, Bcache, AB) + return EnzymeRules.AugmentedReturn(primal, shadow, cache) +end + +function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(TensorKit.blas_contract!)}, + ::Type{RT}, + cache, + C::Annotation{<:AbstractTensorMap}, + A::Annotation{<:AbstractTensorMap}, + pA::Const{<:Index2Tuple}, + B::Annotation{<:AbstractTensorMap}, + pB::Const{<:Index2Tuple}, + pAB::Const{<:Index2Tuple}, + α::Annotation{<:Number}, + β::Annotation{<:Number}, + backend::Const, + allocator::Const + ) where {RT} + cacheC, cacheA, cacheB, AB = cache + Cval = cacheC + Aval = something(cacheA, A.val) + Bval = something(cacheB, B.val) + + Δα = isnothing(AB) ? nothing : project_scalar(α.val, inner(AB, C.dval)) + Δβ = isa(β, Const) ? nothing : pullback_dβ(C.dval, Cval, β) + + if !isa(A, Const) + blas_contract_pullback_ΔA!( + A.dval, C.dval, Aval, pA.val, Bval, pB.val, pAB.val, α.val, backend.val, allocator.val + ) # this typically returns nothing + end + if !isa(B, Const) + blas_contract_pullback_ΔB!( + B.dval, C.dval, Aval, pA.val, Bval, pB.val, pAB.val, α.val, backend.val, allocator.val + ) # this typically returns nothing + end + pullback_dC!(C.dval, β.val) # this typically returns nothing + return nothing, nothing, nothing, nothing, nothing, nothing, Δα, Δβ, nothing, nothing +end + +function blas_contract_pullback_ΔA!( + ΔA, ΔC, A, pA, B, pB, pAB, α, backend, allocator + ) + ipAB = invperm(linearize(pAB)) + pΔC = _repartition(ipAB, TO.numout(pA)) + ipA = _repartition(invperm(linearize(pA)), A) + + tB = twist( + B, + TupleTools.vcat( + filter(x -> !isdual(space(B, x)), pB[1]), + filter(x -> isdual(space(B, x)), pB[2]) + ); copy = false + ) + + project_contract!( + ΔA, + ΔC, pΔC, false, + tB, reverse(pB), true, + ipA, conj(α), backend, allocator + ) + + return nothing +end + +function blas_contract_pullback_ΔB!( + ΔB, ΔC, A, pA, B, pB, pAB, α, backend, allocator + ) + ipAB = invperm(linearize(pAB)) + pΔC = _repartition(ipAB, TO.numout(pA)) + ipB = _repartition(invperm(linearize(pB)), B) + + tA = twist( + A, + TupleTools.vcat( + filter(x -> isdual(space(A, x)), pA[1]), + filter(x -> !isdual(space(A, x)), pA[2]) + ); copy = false + ) + + project_contract!( + ΔB, + tA, reverse(pA), true, + ΔC, pΔC, false, + ipB, conj(α), backend, allocator + ) + + return nothing +end + + +# tensortrace! +# ------------ + +function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(TensorKit.trace_permute!)}, + ::Type{RT}, + C::Annotation{<:AbstractTensorMap}, + A::Annotation{<:AbstractTensorMap}, + p::Const{<:Index2Tuple}, + q::Const{<:Index2Tuple}, + α::Annotation{<:Number}, + β::Annotation{<:Number}, + backend::Const, + ) where {RT} + C_cache = !isa(β, Const) ? copy(C.val) : nothing + A_cache = EnzymeRules.overwritten(config)[3] ? copy(A.val) : nothing + At = if !isa(α, Const) + At = TO.tensortrace(A.val, p.val, q.val, false, One(), backend.val) + add!(C.val, At, α.val, β.val) + At + else + TensorKit.trace_permute!(C.val, A.val, p.val, q.val, α.val, β.val, backend.val) + nothing + end + primal = EnzymeRules.needs_primal(config) ? C.val : nothing + shadow = EnzymeRules.needs_shadow(config) ? C.dval : nothing + cache = (C_cache, A_cache, At) + return EnzymeRules.AugmentedReturn(primal, shadow, cache) +end + + +function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(TensorKit.trace_permute!)}, + ::Type{RT}, + cache, + C::Annotation{<:AbstractTensorMap}, + A::Annotation{<:AbstractTensorMap}, + p::Const{<:Index2Tuple}, + q::Const{<:Index2Tuple}, + α::Annotation{<:Number}, + β::Annotation{<:Number}, + backend::Const, + ) where {RT} + C_cache, A_cache, At = cache + Aval = something(A_cache, A.val) + Cval = something(C_cache, C.val) + !isa(A, Const) && !isa(C, Const) && trace_permute_pullback_ΔA!(A.dval, C.dval, Aval, p.val, q.val, α.val, backend.val) + Δαr = if !isa(C, Const) && !isnothing(At) + project_scalar(α.val, inner(At, C.dval)) + elseif !isnothing(At) + zero(α.val) + else + nothing + end + Δβr = if !isa(β, Const) && !isa(C, Const) + pullback_dβ(C.dval, Cval, β) + elseif !isa(β, Const) + zero(β.val) + else + nothing + end + !isa(C, Const) && pullback_dC!(C.dval, β.val) + return nothing, nothing, nothing, nothing, Δαr, Δβr, nothing +end + +function trace_permute_pullback_ΔA!( + ΔA, ΔC, A, p, q, α, backend + ) + ip = invperm((linearize(p)..., q[1]..., q[2]...)) + pdA = _repartition(ip, A) + E = one!(TO.tensoralloc_add(scalartype(A), A, q, false)) + twist!(E, filter(x -> !isdual(space(E, x)), codomainind(E))) + pE = ((), trivtuple(TO.numind(q))) + pΔC = (trivtuple(TO.numind(p)), ()) + TO.tensorproduct!( + ΔA, ΔC, pΔC, false, E, pE, false, pdA, conj(α), One(), backend + ) + return nothing +end diff --git a/ext/TensorKitEnzymeExt/utility.jl b/ext/TensorKitEnzymeExt/utility.jl new file mode 100644 index 000000000..733108456 --- /dev/null +++ b/ext/TensorKitEnzymeExt/utility.jl @@ -0,0 +1,63 @@ +# Projection +# ---------- +""" + project_scalar(x::Number, dx::Number) + +Project a computed tangent `dx` onto the correct tangent type for `x`. +For example, we might compute a complex `dx` but only require the real part. +""" +project_scalar(x::Number, dx::Number) = oftype(x, dx) +project_scalar(x::Real, dx::Complex) = project_scalar(x, real(dx)) + +# in-place multiplication and accumulation which might project to (real) +# TODO: this could probably be done without allocating +function project_mul!(C, A, B, α) + TC = TO.promote_contract(scalartype(A), scalartype(B), scalartype(α)) + return if !(TC <: Real) && scalartype(C) <: Real + add!(C, real(mul!(zerovector(C, TC), A, B, α))) + else + mul!(C, A, B, α, One()) + end +end +function project_contract!(C, A, pA, conjA, B, pB, conjB, pAB, α, backend, allocator) + TA = TensorKit.promote_permute(A) + TB = TensorKit.promote_permute(B) + TC = TO.promote_contract(TA, TB, scalartype(α)) + + return if scalartype(C) <: Real && !(TC <: Real) + add!(C, real(TO.tensorcontract!(zerovector(C, TC), A, pA, conjA, B, pB, conjB, pAB, α, Zero(), backend, allocator))) + else + TO.tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, One(), backend, allocator) + end +end + +# IndexTuple utility +# ------------------ +trivtuple(N) = ntuple(identity, N) + +Base.@constprop :aggressive function _repartition(p::IndexTuple, N₁::Int) + length(p) >= N₁ || + throw(ArgumentError("cannot repartition $(typeof(p)) to $N₁, $(length(p) - N₁)")) + return TupleTools.getindices(p, trivtuple(N₁)), + TupleTools.getindices(p, trivtuple(length(p) - N₁) .+ N₁) +end +Base.@constprop :aggressive function _repartition(p::Index2Tuple, N₁::Int) + return _repartition(linearize(p), N₁) +end +function _repartition(p::Union{IndexTuple, Index2Tuple}, ::Index2Tuple{N₁}) where {N₁} + return _repartition(p, N₁) +end +function _repartition(p::Union{IndexTuple, Index2Tuple}, t::AbstractTensorMap) + return _repartition(p, TensorKit.numout(t)) +end + +# Ignore derivatives +# ------------------ + +@inline EnzymeRules.inactive(::typeof(TensorKit.fusionblockstructure), arg) = nothing +@inline EnzymeRules.inactive(::typeof(TensorKit.select), s::HomSpace, i::Index2Tuple) = nothing +@inline EnzymeRules.inactive(::typeof(TensorKit.flip), s::HomSpace, i::Any) = nothing +@inline EnzymeRules.inactive(::typeof(TensorKit.permute), s::HomSpace, i::Index2Tuple) = nothing +@inline EnzymeRules.inactive(::typeof(TensorKit.braid), s::HomSpace, i::Index2Tuple, ::IndexTuple) = nothing +@inline EnzymeRules.inactive(::typeof(TensorKit.compose), s1::HomSpace, s2::HomSpace) = nothing +@inline EnzymeRules.inactive(::typeof(TensorOperations.tensorcontract), c::HomSpace, p::Index2Tuple, α::Bool, b::HomSpace, q::Index2Tuple, β::Bool, pq::Index2Tuple) = nothing diff --git a/ext/TensorKitEnzymeExt/vectorinterface.jl b/ext/TensorKitEnzymeExt/vectorinterface.jl new file mode 100644 index 000000000..c300e1209 --- /dev/null +++ b/ext/TensorKitEnzymeExt/vectorinterface.jl @@ -0,0 +1,165 @@ +function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(scale!)}, + ::Type{RT}, + C::Annotation{<:AbstractTensorMap}, + α::Annotation{<:Number}, + ) where {RT} + C_cache = !isa(α, Const) ? copy(C.val) : nothing + scale!(C.val, α.val) + primal = EnzymeRules.needs_primal(config) ? C.val : nothing + shadow = EnzymeRules.needs_shadow(config) ? C.dval : nothing + return EnzymeRules.AugmentedReturn(primal, shadow, C_cache) +end + +function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(scale!)}, + ::Type{RT}, + cache, + C::Annotation{<:AbstractTensorMap}, + α::Annotation{<:Number}, + ) where {RT} + Cval = something(cache, C.val) + Δα = if !isa(α, Const) && !isa(C, Const) + project_scalar(α.val, inner(Cval, C.dval)) + elseif !isa(α, Const) + zero(α.val) + else + nothing + end + !isa(C, Const) && scale!(C.dval, conj(α.val)) + return (nothing, Δα) +end + +function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(scale!)}, + ::Type{RT}, + C::Annotation{<:AbstractTensorMap}, + A::Annotation{<:AbstractTensorMap}, + α::Annotation{<:Number}, + ) where {RT} + A_cache = !isa(α, Const) && EnzymeRules.overwritten(config)[3] ? copy(A.val) : nothing + scale!(C.val, A.val, α.val) + primal = EnzymeRules.needs_primal(config) ? C.val : nothing + shadow = EnzymeRules.needs_shadow(config) ? C.dval : nothing + cache = A_cache + return EnzymeRules.AugmentedReturn(primal, shadow, cache) +end + +function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(scale!)}, + ::Type{RT}, + cache, + C::Annotation{<:AbstractTensorMap}, + A::Annotation{<:AbstractTensorMap}, + α::Annotation{<:Number}, + ) where {RT} + Aval = something(cache, A.val) + Δα = if !isa(α, Const) && !isa(C, Const) + project_scalar(α.val, inner(Aval, C.dval)) + elseif !isa(α, Const) + zero(α.val) + else + nothing + end + !isa(A, Const) && !isa(C, Const) && add!(A.dval, C.dval, conj(α.val)) + !isa(C, Const) && zerovector!(C.dval) + return (nothing, nothing, Δα) +end + +function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(add!)}, + ::Type{RT}, + C::Annotation{<:AbstractTensorMap}, + A::Annotation{<:AbstractTensorMap}, + α::Annotation{<:Number}, + β::Annotation{<:Number}, + ) where {RT} + A_cache = !isa(α, Const) && EnzymeRules.overwritten(config)[3] ? copy(A.val) : nothing + C_cache = !isa(β, Const) ? copy(C.val) : nothing + add!(C.val, A.val, α.val, β.val) + primal = EnzymeRules.needs_primal(config) ? C.val : nothing + shadow = EnzymeRules.needs_shadow(config) ? C.dval : nothing + cache = (A_cache, C_cache) + return EnzymeRules.AugmentedReturn(primal, shadow, cache) +end + +function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(add!)}, + ::Type{RT}, + cache, + C::Annotation{<:AbstractTensorMap}, + A::Annotation{<:AbstractTensorMap}, + α::Annotation{<:Number}, + β::Annotation{<:Number}, + ) where {RT} + A_cache, C_cache = cache + Aval = something(A_cache, A.val) + Cval = something(C_cache, C.val) + Δα = if !isa(α, Const) && !isa(C, Const) + project_scalar(α.val, inner(Aval, C.dval)) + elseif !isa(α, Const) + zero(α.val) + else + nothing + end + Δβ = if !isa(β, Const) && !isa(C, Const) + project_scalar(β.val, inner(Cval, C.dval)) + elseif !isa(β, Const) + zero(β.val) + else + nothing + end + !isa(A, Const) && !isa(C, Const) && add!(A.dval, C.dval, conj(α.val)) + !isa(C, Const) && scale!(C.dval, conj(β.val)) + return (nothing, nothing, Δα, Δβ) +end + +function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(inner)}, + ::Type{RT}, + A::Annotation{<:AbstractTensorMap}, + B::Annotation{<:AbstractTensorMap}, + ) where {RT} + A_cache = !isa(B, Const) && EnzymeRules.overwritten(config)[2] ? copy(A.val) : nothing + B_cache = !isa(A, Const) && EnzymeRules.overwritten(config)[3] ? copy(B.val) : nothing + ret = inner(A.val, B.val) + primal = EnzymeRules.needs_primal(config) ? ret : nothing + shadow = EnzymeRules.needs_shadow(config) ? zero(ret) : nothing + cache = (A_cache, B_cache) + return EnzymeRules.AugmentedReturn(primal, shadow, cache) +end + +function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(inner)}, + dret::Active, + cache, + A::Annotation{<:AbstractTensorMap}, + B::Annotation{<:AbstractTensorMap}, + ) + A_cache, B_cache = cache + Aval = something(A_cache, A.val) + Bval = something(B_cache, B.val) + Δs = dret.val + !isa(A, Const) && add!(A.dval, Bval, conj(Δs)) + !isa(B, Const) && add!(B.dval, Aval, Δs) + return (nothing, nothing) +end + +function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(inner)}, + ::Type{<:Const}, + cache, + A::Annotation{<:AbstractTensorMap}, + B::Annotation{<:AbstractTensorMap}, + ) + return (nothing, nothing) +end diff --git a/src/factorizations/adjoint.jl b/src/factorizations/adjoint.jl index eae8989ce..8b349bf43 100644 --- a/src/factorizations/adjoint.jl +++ b/src/factorizations/adjoint.jl @@ -81,6 +81,23 @@ for (left_f, right_f) in zip( end end +# 2-arg functions +for (left_f, right_f) in zip( + (:qr_full, :qr_compact), + (:lq_full, :lq_compact) + ) + left_f! = Symbol(left_f, :!) + right_f! = Symbol(right_f, :!) + @eval function MAK.$left_f!(t::AdjointTensorMap, F, alg::MAK.Algorithm{:Householder}) + F′ = $right_f!(adjoint(t), reverse(adjoint.(F)), _adjoint(alg)) + return reverse(adjoint.(F′)) + end + @eval function MAK.$right_f!(t::AdjointTensorMap, F, alg::MAK.Algorithm{:Householder}) + F′ = $left_f!(adjoint(t), reverse(adjoint.(F)), _adjoint(alg)) + return reverse(adjoint.(F′)) + end +end + # 3-arg functions for f in (:svd_full, :svd_compact, :svd_trunc) f! = Symbol(f, :!) diff --git a/src/factorizations/matrixalgebrakit.jl b/src/factorizations/matrixalgebrakit.jl index b44f20653..d742860da 100644 --- a/src/factorizations/matrixalgebrakit.jl +++ b/src/factorizations/matrixalgebrakit.jl @@ -43,6 +43,20 @@ for f! in ( end end +for f! in (:qr_compact!, :qr_full!, :lq_compact!, :lq_full!) + @eval function MAK.$f!(t::AbstractTensorMap, F, alg::MAK.Algorithm{:Householder}) + foreachblock(t, F...) do _, (tblock, Fblocks...) + Fblocks′ = $f!(tblock, Fblocks, alg) + # deal with the case where the output is not in-place + for (b′, b) in zip(Fblocks′, Fblocks) + b === b′ || copy!(b, b′) + end + return nothing + end + return F + end +end + # Handle these separately because single output instead of tuple for f! in ( :qr_null!, :lq_null!, diff --git a/test/enzyme/factorizations/eig.jl b/test/enzyme/factorizations/eig.jl new file mode 100644 index 000000000..994793b0f --- /dev/null +++ b/test/enzyme/factorizations/eig.jl @@ -0,0 +1,62 @@ +using Test, TestExtras +using TensorKit +using TensorOperations +using VectorInterface: Zero, One +using MatrixAlgebraKit +using Enzyme, EnzymeTestUtils +using Random + +@isdefined(TestSetup) || include("../../setup.jl") +using .TestSetup + +spacelist = ( + (ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'), + ( + Vect[FermionParity](0 => 1, 1 => 1), + Vect[FermionParity](0 => 1, 1 => 2)', + Vect[FermionParity](0 => 2, 1 => 1)', + Vect[FermionParity](0 => 2, 1 => 3), + Vect[FermionParity](0 => 2, 1 => 2), + ), + ( + Vect[SU2Irrep](0 => 2, 1 // 2 => 1), + Vect[SU2Irrep](0 => 1, 1 => 1), + Vect[SU2Irrep](1 // 2 => 1, 1 => 1)', + Vect[SU2Irrep](1 // 2 => 2), + Vect[SU2Irrep](0 => 1, 1 // 2 => 1, 3 // 2 => 1)', + ), + ( + Vect[FibonacciAnyon](:I => 2, :τ => 1), + Vect[FibonacciAnyon](:I => 1, :τ => 2)', + Vect[FibonacciAnyon](:I => 2, :τ => 2)', + Vect[FibonacciAnyon](:I => 2, :τ => 3), + Vect[FibonacciAnyon](:I => 2, :τ => 2), + ), +) +eltypes = (Float64, ComplexF64) + +function remove_eiggauge_dependence!( + ΔV, D, V; degeneracy_atol = MatrixAlgebraKit.default_pullback_degeneracy_atol(D) + ) + gaugepart = V' * ΔV + for (c, b) in blocks(gaugepart) + Dc = diagview(block(D, c)) + # for some reason this fails only on tests, and I cannot reproduce it in an + # interactive session. + # b[abs.(transpose(diagview(Dc)) .- diagview(Dc)) .>= degeneracy_atol] .= 0 + for j in axes(b, 2), i in axes(b, 1) + abs(Dc[i] - Dc[j]) >= degeneracy_atol && (b[i, j] = 0) + end + end + mul!(ΔV, V / (V' * V), gaugepart, -1, 1) + return ΔV +end + +@timedtestset "Enzyme - Factorizations (EIG): $(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes, t in (randn(T, V[1] ← V[1]), rand(T, V[1] ⊗ V[2] ← V[1] ⊗ V[2])) + atol = default_tol(T) + rtol = default_tol(T) + DV = eig_full(t) + ΔDV = (DiagonalTensorMap(randn!(similar(DV[1].data)), space(DV[1])), randn!(similar(DV[2]))) + remove_eiggauge_dependence!(ΔDV[2], DV...) + EnzymeTestUtils.test_reverse(eig_full, Duplicated, (t, Duplicated); output_tangent = ΔDV, atol, rtol) +end diff --git a/test/enzyme/factorizations/eigh.jl b/test/enzyme/factorizations/eigh.jl new file mode 100644 index 000000000..cea6f9a4d --- /dev/null +++ b/test/enzyme/factorizations/eigh.jl @@ -0,0 +1,63 @@ +using Test, TestExtras +using TensorKit +using TensorOperations +using VectorInterface: Zero, One +using MatrixAlgebraKit +using Enzyme, EnzymeTestUtils +using Random + +@isdefined(TestSetup) || include("../setup.jl") +using .TestSetup + +spacelist = ( + (ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'), + ( + Vect[FermionParity](0 => 1, 1 => 1), + Vect[FermionParity](0 => 1, 1 => 2)', + Vect[FermionParity](0 => 2, 1 => 1)', + Vect[FermionParity](0 => 2, 1 => 3), + Vect[FermionParity](0 => 2, 1 => 2), + ), + ( + Vect[SU2Irrep](0 => 2, 1 // 2 => 1), + Vect[SU2Irrep](0 => 1, 1 => 1), + Vect[SU2Irrep](1 // 2 => 1, 1 => 1)', + Vect[SU2Irrep](1 // 2 => 2), + Vect[SU2Irrep](0 => 1, 1 // 2 => 1, 3 // 2 => 1)', + ), + ( + Vect[FibonacciAnyon](:I => 2, :τ => 1), + Vect[FibonacciAnyon](:I => 1, :τ => 2)', + Vect[FibonacciAnyon](:I => 2, :τ => 2)', + Vect[FibonacciAnyon](:I => 2, :τ => 3), + Vect[FibonacciAnyon](:I => 2, :τ => 2), + ), +) +eltypes = (Float64, ComplexF64) + +function remove_eighgauge_dependence!( + ΔV, D, V; degeneracy_atol = MatrixAlgebraKit.default_pullback_degeneracy_atol(D) + ) + gaugepart = project_antihermitian!(V' * ΔV) + for (c, b) in blocks(gaugepart) + Dc = diagview(block(D, c)) + # for some reason this fails only on tests, and I cannot reproduce it in an + # interactive session. + # b[abs.(transpose(diagview(Dc)) .- diagview(Dc)) .>= degeneracy_atol] .= 0 + for j in axes(b, 2), i in axes(b, 1) + abs(Dc[i] - Dc[j]) >= degeneracy_atol && (b[i, j] = 0) + end + end + mul!(ΔV, V, gaugepart, -1, 1) + return ΔV +end + +@timedtestset "Enzyme - Factorizations (EIGH): $(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes, t in (randn(T, V[1] ← V[1]), rand(T, V[1] ⊗ V[2] ← V[1] ⊗ V[2])) + atol = default_tol(T) + rtol = default_tol(T) + th = project_hermitian(t) + DV = eigh_full(th) + ΔDV = (DiagonalTensorMap(randn!(similar(DV[1].data)), space(DV[1])), randn!(similar(DV[2]))) + remove_eighgauge_dependence!(ΔDV[2], DV...) + EnzymeTestUtils.test_reverse(eigh_full ∘ project_hermitian, Duplicated, (t, Duplicated); output_tangent = ΔDV, atol, rtol) +end diff --git a/test/enzyme/factorizations/lq.jl b/test/enzyme/factorizations/lq.jl new file mode 100644 index 000000000..12538a388 --- /dev/null +++ b/test/enzyme/factorizations/lq.jl @@ -0,0 +1,77 @@ +using Test, TestExtras +using TensorKit +using TensorOperations +using VectorInterface: Zero, One +using MatrixAlgebraKit +using Enzyme, EnzymeTestUtils +using Random + +@isdefined(TestSetup) || include("../../setup.jl") +using .TestSetup + +spacelist = ( + (ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'), + ( + Vect[FermionParity](0 => 1, 1 => 1), + Vect[FermionParity](0 => 1, 1 => 2)', + Vect[FermionParity](0 => 2, 1 => 1)', + Vect[FermionParity](0 => 2, 1 => 3), + Vect[FermionParity](0 => 2, 1 => 2), + ), + ( + Vect[SU2Irrep](0 => 2, 1 // 2 => 1), + Vect[SU2Irrep](0 => 1, 1 => 1), + Vect[SU2Irrep](1 // 2 => 1, 1 => 1)', + Vect[SU2Irrep](1 // 2 => 2), + Vect[SU2Irrep](0 => 1, 1 // 2 => 1, 3 // 2 => 1)', + ), + ( + Vect[FibonacciAnyon](:I => 2, :τ => 1), + Vect[FibonacciAnyon](:I => 1, :τ => 2)', + Vect[FibonacciAnyon](:I => 2, :τ => 2)', + Vect[FibonacciAnyon](:I => 2, :τ => 3), + Vect[FibonacciAnyon](:I => 2, :τ => 2), + ), +) +eltypes = (Float64, ComplexF64) + +function remove_lqgauge_dependence!(ΔQ, t, Q) + for (c, b) in blocks(ΔQ) + m, n = size(block(t, c)) + minmn = min(m, n) + Qc = block(Q, c) + Q1 = view(Qc, 1:minmn, 1:n) + ΔQ2 = view(b, (minmn + 1):n, :) + mul!(ΔQ2, ΔQ2 * Q1', Q1) + end + return ΔQ +end + +@timedtestset "Enzyme - Factorizations (LQ): $(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes + atol = default_tol(T) + rtol = default_tol(T) + + A = randn(T, V[1] ⊗ V[2] ← V[1] ⊗ V[2]) + + A = randn(T, V[1] ⊗ V[2] ← V[1] ⊗ V[2]) + + EnzymeTestUtils.test_reverse(lq_compact, Duplicated, (A, Duplicated); atol, rtol) + + # lq_full/lq_null requires being careful with gauges + LQ = lq_full(A) + ΔLQ = randn!.(similar.(LQ)) + remove_lqgauge_dependence!(ΔLQ[2], A, LQ[2]) + EnzymeTestUtils.test_reverse(lq_full, Duplicated, (A, Duplicated); output_tangent = ΔLQ, atol, rtol) + #EnzymeTestUtils.test_reverse(lq_null, Duplicated, (A, Duplicated); atol, rtol) + + A = randn(T, V[1] ⊗ V[2] ← V[1]) + + EnzymeTestUtils.test_reverse(lq_compact, Duplicated, (A, Duplicated); atol, rtol) + + # lq_full/lq_null requires being careful with gauges + LQ = lq_full(A) + ΔLQ = randn!.(similar.(LQ)) + remove_lqgauge_dependence!(ΔLQ[2], A, LQ[2]) + EnzymeTestUtils.test_reverse(lq_full, Duplicated, (A, Duplicated); output_tangent = ΔLQ, atol, rtol) + #EnzymeTestUtils.test_reverse(lq_null, Duplicated, (A, Duplicated); atol, rtol) +end diff --git a/test/enzyme/factorizations/qr.jl b/test/enzyme/factorizations/qr.jl new file mode 100644 index 000000000..80412a9bc --- /dev/null +++ b/test/enzyme/factorizations/qr.jl @@ -0,0 +1,75 @@ +using Test, TestExtras +using TensorKit +using TensorOperations +using VectorInterface: Zero, One +using MatrixAlgebraKit +using Enzyme, EnzymeTestUtils +using Random + +@isdefined(TestSetup) || include("../../setup.jl") +using .TestSetup + +spacelist = ( + (ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'), + ( + Vect[FermionParity](0 => 1, 1 => 1), + Vect[FermionParity](0 => 1, 1 => 2)', + Vect[FermionParity](0 => 2, 1 => 1)', + Vect[FermionParity](0 => 2, 1 => 3), + Vect[FermionParity](0 => 2, 1 => 2), + ), + ( + Vect[SU2Irrep](0 => 2, 1 // 2 => 1), + Vect[SU2Irrep](0 => 1, 1 => 1), + Vect[SU2Irrep](1 // 2 => 1, 1 => 1)', + Vect[SU2Irrep](1 // 2 => 2), + Vect[SU2Irrep](0 => 1, 1 // 2 => 1, 3 // 2 => 1)', + ), + ( + Vect[FibonacciAnyon](:I => 2, :τ => 1), + Vect[FibonacciAnyon](:I => 1, :τ => 2)', + Vect[FibonacciAnyon](:I => 2, :τ => 2)', + Vect[FibonacciAnyon](:I => 2, :τ => 3), + Vect[FibonacciAnyon](:I => 2, :τ => 2), + ), +) +eltypes = (Float64, ComplexF64) + +function remove_qrgauge_dependence!(ΔQ, t, Q) + for (c, b) in blocks(ΔQ) + m, n = size(block(t, c)) + minmn = min(m, n) + Qc = block(Q, c) + Q1 = view(Qc, 1:m, 1:minmn) + ΔQ2 = view(b, :, (minmn + 1):m) + mul!(ΔQ2, Q1, Q1' * ΔQ2) + end + return ΔQ +end + +@timedtestset "Enzyme - Factorizations (QR): $(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes + atol = default_tol(T) + rtol = default_tol(T) + + A = randn(T, V[1] ⊗ V[2] ← V[1] ⊗ V[2]) + + EnzymeTestUtils.test_reverse(qr_compact, Duplicated, (A, Duplicated); atol, rtol) + + # qr_full/qr_null requires being careful with gauges + QR = qr_full(A) + ΔQR = randn!.(similar.(QR)) + remove_qrgauge_dependence!(ΔQR[1], A, QR[1]) + EnzymeTestUtils.test_reverse(qr_full, Duplicated, (A, Duplicated); output_tangent = ΔQR, atol, rtol) + #EnzymeTestUtils.test_reverse(qr_null, Duplicated, (A, Duplicated); atol, rtol) + + A = randn(T, V[1] ⊗ V[2] ← V[1]) + + EnzymeTestUtils.test_reverse(qr_compact, Duplicated, (A, Duplicated); atol, rtol) + + # qr_full/qr_null requires being careful with gauges + QR = qr_full(A) + ΔQR = randn!.(similar.(QR)) + remove_qrgauge_dependence!(ΔQR[1], A, QR[1]) + EnzymeTestUtils.test_reverse(qr_full, Duplicated, (A, Duplicated); output_tangent = ΔQR, atol, rtol) + #EnzymeTestUtils.test_reverse(qr_null, Duplicated, (A, Duplicated); atol, rtol) +end diff --git a/test/enzyme/factorizations/svd.jl b/test/enzyme/factorizations/svd.jl new file mode 100644 index 000000000..d25febfa3 --- /dev/null +++ b/test/enzyme/factorizations/svd.jl @@ -0,0 +1,77 @@ +using Test, TestExtras +using TensorKit +using TensorOperations +using VectorInterface: Zero, One +using MatrixAlgebraKit +using Enzyme, EnzymeTestUtils +using Random + +@isdefined(TestSetup) || include("../setup.jl") +using .TestSetup + +spacelist = ( + (ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'), + ( + Vect[FermionParity](0 => 1, 1 => 1), + Vect[FermionParity](0 => 1, 1 => 2)', + Vect[FermionParity](0 => 2, 1 => 1)', + Vect[FermionParity](0 => 2, 1 => 3), + Vect[FermionParity](0 => 2, 1 => 2), + ), + ( + Vect[SU2Irrep](0 => 2, 1 // 2 => 1), + Vect[SU2Irrep](0 => 1, 1 => 1), + Vect[SU2Irrep](1 // 2 => 1, 1 => 1)', + Vect[SU2Irrep](1 // 2 => 2), + Vect[SU2Irrep](0 => 1, 1 // 2 => 1, 3 // 2 => 1)', + ), + ( + Vect[FibonacciAnyon](:I => 2, :τ => 1), + Vect[FibonacciAnyon](:I => 1, :τ => 2)', + Vect[FibonacciAnyon](:I => 2, :τ => 2)', + Vect[FibonacciAnyon](:I => 2, :τ => 3), + Vect[FibonacciAnyon](:I => 2, :τ => 2), + ), +) +eltypes = (Float64, ComplexF64) + +function remove_svdgauge_dependence!( + ΔU, ΔVᴴ, U, S, Vᴴ; degeneracy_atol = MatrixAlgebraKit.default_pullback_degeneracy_atol(S) + ) + UdU = U' * ΔU + VdV = Vᴴ * ΔVᴴ' + gaugepart = project_antihermitian!(UdU + VdV) + for (c, b) in blocks(gaugepart) + Sd = diagview(block(S, c)) + # for some reason this fails only on tests, and I cannot reproduce it in an + # interactive session. + # b[abs.(transpose(diagview(Sc)) .- diagview(Sc)) .>= degeneracy_atol] .= 0 + for j in axes(b, 2), i in axes(b, 1) + abs(Sd[i] - Sd[j]) >= degeneracy_atol && (b[i, j] = 0) + end + end + mul!(ΔU, U, gaugepart, -1, 1) + return ΔU, ΔVᴴ +end + +@timedtestset "Enzyme - Factorizations (SVD): $(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes, t in (randn(T, V[1] ← V[1]), randn(T, V[1] ⊗ V[2] ← V[3] ⊗ V[4])) + atol = default_tol(T) + rtol = default_tol(T) + USVᴴ = svd_compact(t) + ΔUSVᴴ = (TensorMap(randn!(similar(USVᴴ[1].data)), space(USVᴴ[1])), DiagonalTensorMap(randn!(similar(USVᴴ[2].data)), space(USVᴴ[2], 1)), TensorMap(randn!(similar(USVᴴ[3].data)), space(USVᴴ[3]))) + remove_svdgauge_dependence!(ΔUSVᴴ[1], ΔUSVᴴ[3], USVᴴ...) + EnzymeTestUtils.test_reverse(svd_compact, Duplicated, (t, Duplicated); output_tangent = ΔUSVᴴ, atol, rtol) + + #=USVᴴ = svd_full(t) + ΔUSVᴴ = (TensorMap(randn!(similar(USVᴴ[1].data)), space(USVᴴ[1])), TensorMap(randn!(similar(USVᴴ[2].data)), space(USVᴴ[2])), TensorMap(randn!(similar(USVᴴ[3].data)), space(USVᴴ[3]))) + remove_svdgauge_dependence!(ΔUSVᴴ[1], ΔUSVᴴ[3], USVᴴ...) + EnzymeTestUtils.test_reverse(svd_full, Duplicated, (t, Duplicated); output_tangent = ΔUSVᴴ, atol, rtol)=# + + V_trunc = spacetype(t)(c => min(size(b)...) ÷ 2 for (c, b) in blocks(t)) + trunc = truncspace(V_trunc) + alg = MatrixAlgebraKit.select_algorithm(svd_trunc_no_error, t, nothing; trunc) + USVᴴtrunc = svd_trunc(t, alg) + ΔUSVᴴtrunc = randn!(similar.(USVᴴtrunc)) + remove_svdgauge_dependence!(ΔUSVᴴtrunc[1], ΔUSVᴴtrunc[3], USVᴴtrunc...) + EnzymeTestUtils.test_reverse(svd_trunc_no_error, Duplicated, (t, Duplicated), (alg, Const); output_tangent = ΔUSVᴴtrunc, atol, rtol) +end diff --git a/test/enzyme/indexmanipulations/add_braid.jl b/test/enzyme/indexmanipulations/add_braid.jl new file mode 100644 index 000000000..2b7b46386 --- /dev/null +++ b/test/enzyme/indexmanipulations/add_braid.jl @@ -0,0 +1,64 @@ +using Test, TestExtras +using TensorKit +using TensorOperations +using VectorInterface: Zero, One +using Enzyme, EnzymeTestUtils +using Random + +@isdefined(TestSetup) || include("../../setup.jl") +using .TestSetup + +spacelist = ( + (ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'), + ( + Vect[FermionParity](0 => 1, 1 => 1), + Vect[FermionParity](0 => 1, 1 => 2)', + Vect[FermionParity](0 => 2, 1 => 1)', + Vect[FermionParity](0 => 2, 1 => 3), + Vect[FermionParity](0 => 2, 1 => 2), + ), + ( + Vect[U1Irrep](0 => 2, 1 => 1, -1 => 1), + Vect[U1Irrep](0 => 2, 1 => 1, -1 => 1), + Vect[U1Irrep](0 => 2, 1 => 2, -1 => 1)', + Vect[U1Irrep](0 => 1, 1 => 1, -1 => 2), + Vect[U1Irrep](0 => 1, 1 => 2, -1 => 1)', + ), + ( + Vect[SU2Irrep](0 => 2, 1 // 2 => 1), + Vect[SU2Irrep](0 => 1, 1 => 1), + Vect[SU2Irrep](1 // 2 => 1, 1 => 1)', + Vect[SU2Irrep](1 // 2 => 2), + Vect[SU2Irrep](0 => 1, 1 // 2 => 1, 3 // 2 => 1)', + ), + ( + Vect[FibonacciAnyon](:I => 2, :τ => 1), + Vect[FibonacciAnyon](:I => 1, :τ => 2)', + Vect[FibonacciAnyon](:I => 2, :τ => 2)', + Vect[FibonacciAnyon](:I => 2, :τ => 3), + Vect[FibonacciAnyon](:I => 2, :τ => 2), + ), +) +eltypes = (Float64, ComplexF64) + +@timedtestset "Enzyme - Index Manipulations (add_braid!):" begin + @timedtestset "$(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes + atol = default_tol(T) + rtol = default_tol(T) + Vstr = TensorKit.type_repr(sectortype(eltype(V))) + @timedtestset "add_braid! Tα $Tα Tβ $Tβ" for Tα in (Active, Const), Tβ in (Active, Const) + A = randn(T, V[1] ⊗ V[2] ← V[4] ⊗ V[5]) + α = randn(T) + β = randn(T) + p = randcircshift(numout(A), numin(A)) + levels = Tuple(randperm(numind(A))) + C = randn!(transpose(A, p)) + EnzymeTestUtils.test_reverse(TensorKit.add_braid!, Duplicated, (C, Duplicated), (A, Duplicated), (p, Const), (levels, Const), (α, Tα), (β, Tβ); atol, rtol, testset_name = "add_braid! V $Vstr Tα $Tα Tβ $Tβ") + if !(T <: Real) + EnzymeTestUtils.test_reverse(TensorKit.add_braid!, Duplicated, (C, Duplicated), (real(A), Duplicated), (p, Const), (levels, Const), (α, Tα), (β, Tβ); atol, rtol, testset_name = "add_braid! V $Vstr Tα $Tα Tβ $Tβ") + EnzymeTestUtils.test_reverse(TensorKit.add_braid!, Duplicated, (C, Duplicated), (A, Duplicated), (p, Const), (levels, Const), (real(α), Tα), (β, Tβ); atol, rtol, testset_name = "add_braid! V $Vstr Tα $Tα Tβ $Tβ") + EnzymeTestUtils.test_reverse(TensorKit.add_braid!, Duplicated, (C, Duplicated), (A, Duplicated), (p, Const), (levels, Const), (real(α), Tα), (real(β), Tβ); atol, rtol, testset_name = "add_braid! V $Vstr Tα $Tα Tβ $Tβ") + end + end + end +end diff --git a/test/enzyme/indexmanipulations/add_permute.jl b/test/enzyme/indexmanipulations/add_permute.jl new file mode 100644 index 000000000..022f6d5e4 --- /dev/null +++ b/test/enzyme/indexmanipulations/add_permute.jl @@ -0,0 +1,64 @@ +using Test, TestExtras +using TensorKit +using TensorOperations +using VectorInterface: Zero, One +using Enzyme, EnzymeTestUtils +using Random + +@isdefined(TestSetup) || include("../../setup.jl") +using .TestSetup + +spacelist = ( + (ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'), + ( + Vect[FermionParity](0 => 1, 1 => 1), + Vect[FermionParity](0 => 1, 1 => 2)', + Vect[FermionParity](0 => 2, 1 => 1)', + Vect[FermionParity](0 => 2, 1 => 3), + Vect[FermionParity](0 => 2, 1 => 2), + ), + ( + Vect[U1Irrep](0 => 2, 1 => 1, -1 => 1), + Vect[U1Irrep](0 => 2, 1 => 1, -1 => 1), + Vect[U1Irrep](0 => 2, 1 => 2, -1 => 1)', + Vect[U1Irrep](0 => 1, 1 => 1, -1 => 2), + Vect[U1Irrep](0 => 1, 1 => 2, -1 => 1)', + ), + ( + Vect[SU2Irrep](0 => 2, 1 // 2 => 1), + Vect[SU2Irrep](0 => 1, 1 => 1), + Vect[SU2Irrep](1 // 2 => 1, 1 => 1)', + Vect[SU2Irrep](1 // 2 => 2), + Vect[SU2Irrep](0 => 1, 1 // 2 => 1, 3 // 2 => 1)', + ), + ( + Vect[FibonacciAnyon](:I => 2, :τ => 1), + Vect[FibonacciAnyon](:I => 1, :τ => 2)', + Vect[FibonacciAnyon](:I => 2, :τ => 2)', + Vect[FibonacciAnyon](:I => 2, :τ => 3), + Vect[FibonacciAnyon](:I => 2, :τ => 2), + ), +) +eltypes = (Float64, ComplexF64) + +@timedtestset "Enzyme - Index Manipulations (add_permute!):" begin + @timedtestset "$(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes + atol = default_tol(T) + rtol = default_tol(T) + symmetricbraiding = BraidingStyle(sectortype(eltype(V))) isa SymmetricBraiding + + symmetricbraiding && @timedtestset "add_permute!" begin + # repeat a couple times to get some distribution of arrows + for ri in 1:5 + @testset for Tα in (Const, Active), Tβ in (Const, Active) + A = randn(T, V[1] ⊗ V[2] ← V[4] ⊗ V[5]) + α = randn(T) + β = randn(T) + p = randindextuple(numind(A)) + C = randn!(permute(A, p)) + EnzymeTestUtils.test_reverse(TensorKit.add_permute!, Duplicated, (C, Duplicated), (A, Duplicated), (p, Const), (α, Tα), (β, Tβ); atol, rtol) + end + end + end + end +end diff --git a/test/enzyme/indexmanipulations/add_transpose.jl b/test/enzyme/indexmanipulations/add_transpose.jl new file mode 100644 index 000000000..f2e2421bd --- /dev/null +++ b/test/enzyme/indexmanipulations/add_transpose.jl @@ -0,0 +1,69 @@ +using Test, TestExtras +using TensorKit +using TensorOperations +using VectorInterface: Zero, One +using Enzyme, EnzymeTestUtils +using Random + +@isdefined(TestSetup) || include("../../setup.jl") +using .TestSetup + +spacelist = ( + (ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'), + ( + Vect[FermionParity](0 => 1, 1 => 1), + Vect[FermionParity](0 => 1, 1 => 2)', + Vect[FermionParity](0 => 2, 1 => 1)', + Vect[FermionParity](0 => 2, 1 => 3), + Vect[FermionParity](0 => 2, 1 => 2), + ), + ( + Vect[U1Irrep](0 => 2, 1 => 1, -1 => 1), + Vect[U1Irrep](0 => 2, 1 => 1, -1 => 1), + Vect[U1Irrep](0 => 2, 1 => 2, -1 => 1)', + Vect[U1Irrep](0 => 1, 1 => 1, -1 => 2), + Vect[U1Irrep](0 => 1, 1 => 2, -1 => 1)', + ), + ( + Vect[SU2Irrep](0 => 2, 1 // 2 => 1), + Vect[SU2Irrep](0 => 1, 1 => 1), + Vect[SU2Irrep](1 // 2 => 1, 1 => 1)', + Vect[SU2Irrep](1 // 2 => 2), + Vect[SU2Irrep](0 => 1, 1 // 2 => 1, 3 // 2 => 1)', + ), + ( + Vect[FibonacciAnyon](:I => 2, :τ => 1), + Vect[FibonacciAnyon](:I => 1, :τ => 2)', + Vect[FibonacciAnyon](:I => 2, :τ => 2)', + Vect[FibonacciAnyon](:I => 2, :τ => 3), + Vect[FibonacciAnyon](:I => 2, :τ => 2), + ), +) +eltypes = (Float64, ComplexF64) + +@timedtestset "Enzyme - Index Manipulations (add_transpose!):" begin + @timedtestset "$(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes + atol = default_tol(T) + rtol = default_tol(T) + + @timedtestset "add_transpose! Tα $Tα Tβ $Tβ" for Tα in (Const, Active), Tβ in (Const, Active) + A = randn(T, V[1] ⊗ V[2] ← V[4] ⊗ V[5]) + α = randn(T) + β = randn(T) + + # repeat a couple times to get some distribution of arrows + @testet for ri in 1:2 + p = randcircshift(numout(A), numin(A)) + C = randn!(transpose(A, p)) + EnzymeTestUtils.test_reverse(TensorKit.add_transpose!, Duplicated, (C, Duplicated), (A, Duplicated), (p, Const), (One(), Const), (Zero(), Const); atol, rtol) + EnzymeTestUtils.test_reverse(TensorKit.add_transpose!, Duplicated, (C, Duplicated), (A, Duplicated), (p, Const), (α, Tα), (β, Tβ); atol, rtol) + if !(T <: Real) + EnzymeTestUtils.test_reverse(TensorKit.add_transpose!, Duplicated, (C, Duplicated), (real(A), Duplicated), (p, Const), (α, Tα), (β, Tβ); atol, rtol) + EnzymeTestUtils.test_reverse(TensorKit.add_transpose!, Duplicated, (C, Duplicated), (A, Duplicated), (p, Const), (real(α), Tα), (β, Tβ); atol, rtol) + EnzymeTestUtils.test_reverse(TensorKit.add_transpose!, Duplicated, (C, Duplicated), (real(A), Duplicated), (p, Const), (real(α), Tα), (β, Tβ); atol, rtol) + end + A = C + end + end + end +end diff --git a/test/enzyme/indexmanipulations/flip.jl b/test/enzyme/indexmanipulations/flip.jl new file mode 100644 index 000000000..363199827 --- /dev/null +++ b/test/enzyme/indexmanipulations/flip.jl @@ -0,0 +1,56 @@ +using Test, TestExtras +using TensorKit +using TensorOperations +using VectorInterface: Zero, One +using Enzyme, EnzymeTestUtils +using Random + +@isdefined(TestSetup) || include("../../setup.jl") +using .TestSetup + +spacelist = ( + (ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'), + ( + Vect[FermionParity](0 => 1, 1 => 1), + Vect[FermionParity](0 => 1, 1 => 2)', + Vect[FermionParity](0 => 2, 1 => 1)', + Vect[FermionParity](0 => 2, 1 => 3), + Vect[FermionParity](0 => 2, 1 => 2), + ), + ( + Vect[U1Irrep](0 => 2, 1 => 1, -1 => 1), + Vect[U1Irrep](0 => 2, 1 => 1, -1 => 1), + Vect[U1Irrep](0 => 2, 1 => 2, -1 => 1)', + Vect[U1Irrep](0 => 1, 1 => 1, -1 => 2), + Vect[U1Irrep](0 => 1, 1 => 2, -1 => 1)', + ), + ( + Vect[SU2Irrep](0 => 2, 1 // 2 => 1), + Vect[SU2Irrep](0 => 1, 1 => 1), + Vect[SU2Irrep](1 // 2 => 1, 1 => 1)', + Vect[SU2Irrep](1 // 2 => 2), + Vect[SU2Irrep](0 => 1, 1 // 2 => 1, 3 // 2 => 1)', + ), + ( + Vect[FibonacciAnyon](:I => 2, :τ => 1), + Vect[FibonacciAnyon](:I => 1, :τ => 2)', + Vect[FibonacciAnyon](:I => 2, :τ => 2)', + Vect[FibonacciAnyon](:I => 2, :τ => 3), + Vect[FibonacciAnyon](:I => 2, :τ => 2), + ), +) +eltypes = (Float64, ComplexF64) + +@timedtestset "Enzyme - Index Manipulations (flip):" begin + @timedtestset "$(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes + atol = default_tol(T) + rtol = default_tol(T) + A = randn(T, V[1] ⊗ V[2] ← V[4] ⊗ V[5]) + for TA in (Duplicated,) + EnzymeTestUtils.test_reverse(flip, TA, (A, TA), (1, Const); atol, rtol, fkwargs = (inv = false,)) + EnzymeTestUtils.test_reverse(flip, TA, (A, TA), [1, 3]; atol, rtol, fkwargs = (inv = true,)) + EnzymeTestUtils.test_reverse(flip, TA, (A, TA), (1, Const); atol, rtol) + EnzymeTestUtils.test_reverse(flip, TA, (A, TA), ([1, 3], Const); atol, rtol) + end + end +end diff --git a/test/enzyme/indexmanipulations/insertunit.jl b/test/enzyme/indexmanipulations/insertunit.jl new file mode 100644 index 000000000..4f1136627 --- /dev/null +++ b/test/enzyme/indexmanipulations/insertunit.jl @@ -0,0 +1,59 @@ +using Test, TestExtras +using TensorKit +using TensorOperations +using VectorInterface: Zero, One +using Enzyme, EnzymeTestUtils +using Random + +@isdefined(TestSetup) || include("../../setup.jl") +using .TestSetup + +spacelist = ( + (ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'), + ( + Vect[FermionParity](0 => 1, 1 => 1), + Vect[FermionParity](0 => 1, 1 => 2)', + Vect[FermionParity](0 => 2, 1 => 1)', + Vect[FermionParity](0 => 2, 1 => 3), + Vect[FermionParity](0 => 2, 1 => 2), + ), + ( + Vect[U1Irrep](0 => 2, 1 => 1, -1 => 1), + Vect[U1Irrep](0 => 2, 1 => 1, -1 => 1), + Vect[U1Irrep](0 => 2, 1 => 2, -1 => 1)', + Vect[U1Irrep](0 => 1, 1 => 1, -1 => 2), + Vect[U1Irrep](0 => 1, 1 => 2, -1 => 1)', + ), + ( + Vect[SU2Irrep](0 => 2, 1 // 2 => 1), + Vect[SU2Irrep](0 => 1, 1 => 1), + Vect[SU2Irrep](1 // 2 => 1, 1 => 1)', + Vect[SU2Irrep](1 // 2 => 2), + Vect[SU2Irrep](0 => 1, 1 // 2 => 1, 3 // 2 => 1)', + ), + ( + Vect[FibonacciAnyon](:I => 2, :τ => 1), + Vect[FibonacciAnyon](:I => 1, :τ => 2)', + Vect[FibonacciAnyon](:I => 2, :τ => 2)', + Vect[FibonacciAnyon](:I => 2, :τ => 3), + Vect[FibonacciAnyon](:I => 2, :τ => 2), + ), +) +eltypes = (Float64, ComplexF64) + +@timedtestset verbose = true "Enzyme - Index Manipulations (insertunit):" begin + @timedtestset verbose = true "$(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes + atol = default_tol(T) + rtol = default_tol(T) + A = randn(T, V[1] ⊗ V[2] ← V[4] ⊗ V[5]) + @testset for insertunit in (insertleftunit, insertrightunit), TA in (Duplicated,) + EnzymeTestUtils.test_reverse(insertunit, TA, (A, TA), (Val(1), Const); atol, rtol) + EnzymeTestUtils.test_reverse(insertunit, TA, (A, TA), (Val(4), Const); atol, rtol) + EnzymeTestUtils.test_reverse(insertunit, TA, (A', TA), (Val(2), Const); atol, rtol) + EnzymeTestUtils.test_reverse(insertunit, TA, (A, TA), (Val(1), Const); atol, rtol, fkwargs = (copy = false,)) + EnzymeTestUtils.test_reverse(insertunit, TA, (A, TA), (Val(2), Const); atol, rtol, fkwargs = (copy = true,)) + EnzymeTestUtils.test_reverse(insertunit, TA, (A, TA), (Val(3), Const); atol, rtol, fkwargs = (copy = false, dual = true, conj = true)) + EnzymeTestUtils.test_reverse(insertunit, TA, (A', TA), (Val(3), Const); atol, rtol, fkwargs = (copy = false, dual = true, conj = true)) + end + end +end diff --git a/test/enzyme/indexmanipulations/removeunit.jl b/test/enzyme/indexmanipulations/removeunit.jl new file mode 100644 index 000000000..cb4ed253e --- /dev/null +++ b/test/enzyme/indexmanipulations/removeunit.jl @@ -0,0 +1,56 @@ +using Test, TestExtras +using TensorKit +using TensorOperations +using VectorInterface: Zero, One +using Enzyme, EnzymeTestUtils +using Random + +@isdefined(TestSetup) || include("../../setup.jl") +using .TestSetup + +spacelist = ( + (ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'), + ( + Vect[FermionParity](0 => 1, 1 => 1), + Vect[FermionParity](0 => 1, 1 => 2)', + Vect[FermionParity](0 => 2, 1 => 1)', + Vect[FermionParity](0 => 2, 1 => 3), + Vect[FermionParity](0 => 2, 1 => 2), + ), + ( + Vect[U1Irrep](0 => 2, 1 => 1, -1 => 1), + Vect[U1Irrep](0 => 2, 1 => 1, -1 => 1), + Vect[U1Irrep](0 => 2, 1 => 2, -1 => 1)', + Vect[U1Irrep](0 => 1, 1 => 1, -1 => 2), + Vect[U1Irrep](0 => 1, 1 => 2, -1 => 1)', + ), + ( + Vect[SU2Irrep](0 => 2, 1 // 2 => 1), + Vect[SU2Irrep](0 => 1, 1 => 1), + Vect[SU2Irrep](1 // 2 => 1, 1 => 1)', + Vect[SU2Irrep](1 // 2 => 2), + Vect[SU2Irrep](0 => 1, 1 // 2 => 1, 3 // 2 => 1)', + ), + ( + Vect[FibonacciAnyon](:I => 2, :τ => 1), + Vect[FibonacciAnyon](:I => 1, :τ => 2)', + Vect[FibonacciAnyon](:I => 2, :τ => 2)', + Vect[FibonacciAnyon](:I => 2, :τ => 3), + Vect[FibonacciAnyon](:I => 2, :τ => 2), + ), +) +eltypes = (Float64, ComplexF64) + +@timedtestset verbose = true "Enzyme - Index Manipulations:" begin + @timedtestset verbose = true "$(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes + atol = default_tol(T) + rtol = default_tol(T) + A = randn(T, V[1] ⊗ V[2] ← V[4] ⊗ V[5]) + for TB in (Duplicated,), i in 1:2 + B = insertleftunit(A, i; dual = rand(Bool)) + EnzymeTestUtils.test_reverse(removeunit, TB, (B, TB), (Val(i), Const); atol, rtol) + EnzymeTestUtils.test_reverse(removeunit, TB, (B, TB), (Val(i), Const); atol, rtol, fkwargs = (copy = false,)) + EnzymeTestUtils.test_reverse(removeunit, TB, (B, TB), (Val(i), Const); atol, rtol, fkwargs = (copy = true,)) + end + end +end diff --git a/test/enzyme/indexmanipulations/twist.jl b/test/enzyme/indexmanipulations/twist.jl new file mode 100644 index 000000000..6726089e6 --- /dev/null +++ b/test/enzyme/indexmanipulations/twist.jl @@ -0,0 +1,58 @@ +using Test, TestExtras +using TensorKit +using TensorOperations +using VectorInterface: Zero, One +using Enzyme, EnzymeTestUtils +using Random + +@isdefined(TestSetup) || include("../../setup.jl") +using .TestSetup + +spacelist = ( + (ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'), + ( + Vect[FermionParity](0 => 1, 1 => 1), + Vect[FermionParity](0 => 1, 1 => 2)', + Vect[FermionParity](0 => 2, 1 => 1)', + Vect[FermionParity](0 => 2, 1 => 3), + Vect[FermionParity](0 => 2, 1 => 2), + ), + ( + Vect[U1Irrep](0 => 2, 1 => 1, -1 => 1), + Vect[U1Irrep](0 => 2, 1 => 1, -1 => 1), + Vect[U1Irrep](0 => 2, 1 => 2, -1 => 1)', + Vect[U1Irrep](0 => 1, 1 => 1, -1 => 2), + Vect[U1Irrep](0 => 1, 1 => 2, -1 => 1)', + ), + ( + Vect[SU2Irrep](0 => 2, 1 // 2 => 1), + Vect[SU2Irrep](0 => 1, 1 => 1), + Vect[SU2Irrep](1 // 2 => 1, 1 => 1)', + Vect[SU2Irrep](1 // 2 => 2), + Vect[SU2Irrep](0 => 1, 1 // 2 => 1, 3 // 2 => 1)', + ), + ( + Vect[FibonacciAnyon](:I => 2, :τ => 1), + Vect[FibonacciAnyon](:I => 1, :τ => 2)', + Vect[FibonacciAnyon](:I => 2, :τ => 2)', + Vect[FibonacciAnyon](:I => 2, :τ => 3), + Vect[FibonacciAnyon](:I => 2, :τ => 2), + ), +) +eltypes = (Float64, ComplexF64) + +@timedtestset verbose = true "Enzyme - Index Manipulations (twist):" begin + @timedtestset verbose = true "$(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes + atol = default_tol(T) + rtol = default_tol(T) + A = randn(T, V[1] ⊗ V[2] ← V[4] ⊗ V[5]) + if !(T <: Real && !(sectorscalartype(sectortype(A)) <: Real)) + for TA in (Duplicated,) + EnzymeTestUtils.test_reverse(twist!, TA, (A, TA), (1, Const); atol, rtol, fkwargs = (inv = false,)) + EnzymeTestUtils.test_reverse(twist!, TA, (A, TA), ([1, 3], Const); atol, rtol, fkwargs = (inv = true,)) + EnzymeTestUtils.test_reverse(twist!, TA, (A, TA), (1, Const); atol, rtol) + EnzymeTestUtils.test_reverse(twist!, TA, (A, TA), ([1, 3], Const); atol, rtol) + end + end + end +end diff --git a/test/enzyme/linalg/inv.jl b/test/enzyme/linalg/inv.jl new file mode 100644 index 000000000..3f063c456 --- /dev/null +++ b/test/enzyme/linalg/inv.jl @@ -0,0 +1,48 @@ +using Test, TestExtras +using TensorKit +using Enzyme, EnzymeTestUtils +using Random + +@isdefined(TestSetup) || include("../../setup.jl") +using .TestSetup + +spacelist = ( + (ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'), + ( + Vect[FermionParity](0 => 1, 1 => 1), + Vect[FermionParity](0 => 1, 1 => 2)', + Vect[FermionParity](0 => 2, 1 => 1)', + Vect[FermionParity](0 => 2, 1 => 3), + Vect[FermionParity](0 => 2, 1 => 2), + ), + ( + Vect[SU2Irrep](0 => 2, 1 // 2 => 1), + Vect[SU2Irrep](0 => 1, 1 => 1), + Vect[SU2Irrep](1 // 2 => 1, 1 => 1)', + Vect[SU2Irrep](1 // 2 => 2), + Vect[SU2Irrep](0 => 1, 1 // 2 => 1, 3 // 2 => 1)', + ), + ( + Vect[FibonacciAnyon](:I => 2, :τ => 1), + Vect[FibonacciAnyon](:I => 1, :τ => 2)', + Vect[FibonacciAnyon](:I => 2, :τ => 2)', + Vect[FibonacciAnyon](:I => 2, :τ => 3), + Vect[FibonacciAnyon](:I => 2, :τ => 2), + ), +) +eltypes = (Float64, ComplexF64) + +@timedtestset "Enzyme - LinearAlgebra (inv):" begin + @timedtestset "$(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes + atol = default_tol(T) + rtol = default_tol(T) + D1 = randn(T, V[1] ← V[1]) + D2 = randn(T, V[1] ⊗ V[2] ← V[1] ⊗ V[2]) + D3 = randn(T, V[1] ⊗ V[2] ⊗ V[3] ← V[1] ⊗ V[2] ⊗ V[3]) + @testset "inv: TD $TD" for TD in (Const, Duplicated) + EnzymeTestUtils.test_reverse(inv, TD, (D1, TD); atol, rtol) + EnzymeTestUtils.test_reverse(inv, TD, (D2, TD); atol, rtol) + EnzymeTestUtils.test_reverse(inv, TD, (D3, TD); atol, rtol) + end + end +end diff --git a/test/enzyme/linalg/mul.jl b/test/enzyme/linalg/mul.jl new file mode 100644 index 000000000..7ac0cfbf4 --- /dev/null +++ b/test/enzyme/linalg/mul.jl @@ -0,0 +1,53 @@ +using Test, TestExtras +using TensorKit +using Enzyme, EnzymeTestUtils +using Random + +@isdefined(TestSetup) || include("../../setup.jl") +using .TestSetup + +spacelist = ( + (ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'), + ( + Vect[FermionParity](0 => 1, 1 => 1), + Vect[FermionParity](0 => 1, 1 => 2)', + Vect[FermionParity](0 => 2, 1 => 1)', + Vect[FermionParity](0 => 2, 1 => 3), + Vect[FermionParity](0 => 2, 1 => 2), + ), + ( + Vect[SU2Irrep](0 => 2, 1 // 2 => 1), + Vect[SU2Irrep](0 => 1, 1 => 1), + Vect[SU2Irrep](1 // 2 => 1, 1 => 1)', + Vect[SU2Irrep](1 // 2 => 2), + Vect[SU2Irrep](0 => 1, 1 // 2 => 1, 3 // 2 => 1)', + ), + ( + Vect[FibonacciAnyon](:I => 2, :τ => 1), + Vect[FibonacciAnyon](:I => 1, :τ => 2)', + Vect[FibonacciAnyon](:I => 2, :τ => 2)', + Vect[FibonacciAnyon](:I => 2, :τ => 3), + Vect[FibonacciAnyon](:I => 2, :τ => 2), + ), +) +eltypes = (Float64, ComplexF64) + +@timedtestset verbose = true "Enzyme - LinearAlgebra (mul):" begin + @timedtestset "$(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes + atol = default_tol(T) + rtol = default_tol(T) + + C = randn(T, V[1] ⊗ V[2] ← V[5]) + A = randn(T, codomain(C) ← V[3] ⊗ V[4]) + B = randn(T, domain(A) ← domain(C)) + α = randn(T) + β = randn(T) + + @testset "mul: TC $TC, TA $TA, TB $TB" for TC in (Const, Duplicated), TA in (Const, Duplicated), TB in (Const, Duplicated) + @testset "Tα $Tα, Tβ $Tβ" for Tα in (Active, Const), Tβ in (Active, Const) + EnzymeTestUtils.test_reverse(mul!, TC, (C, TC), (A, TA), (B, TB), (α, Tα), (β, Tβ); atol, rtol) + end + EnzymeTestUtils.test_reverse(mul!, TC, (C, TC), (A, TA), (B, TB); atol, rtol) + end + end +end diff --git a/test/enzyme/linalg/norm.jl b/test/enzyme/linalg/norm.jl new file mode 100644 index 000000000..d90a6101d --- /dev/null +++ b/test/enzyme/linalg/norm.jl @@ -0,0 +1,45 @@ +using Test, TestExtras +using TensorKit +using Enzyme, EnzymeTestUtils +using Random + +@isdefined(TestSetup) || include("../../setup.jl") +using .TestSetup + +spacelist = ( + (ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'), + ( + Vect[FermionParity](0 => 1, 1 => 1), + Vect[FermionParity](0 => 1, 1 => 2)', + Vect[FermionParity](0 => 2, 1 => 1)', + Vect[FermionParity](0 => 2, 1 => 3), + Vect[FermionParity](0 => 2, 1 => 2), + ), + ( + Vect[SU2Irrep](0 => 2, 1 // 2 => 1), + Vect[SU2Irrep](0 => 1, 1 => 1), + Vect[SU2Irrep](1 // 2 => 1, 1 => 1)', + Vect[SU2Irrep](1 // 2 => 2), + Vect[SU2Irrep](0 => 1, 1 // 2 => 1, 3 // 2 => 1)', + ), + ( + Vect[FibonacciAnyon](:I => 2, :τ => 1), + Vect[FibonacciAnyon](:I => 1, :τ => 2)', + Vect[FibonacciAnyon](:I => 2, :τ => 2)', + Vect[FibonacciAnyon](:I => 2, :τ => 3), + Vect[FibonacciAnyon](:I => 2, :τ => 2), + ), +) +eltypes = (Float64, ComplexF64) + +@timedtestset verbose = true "Enzyme - LinearAlgebra (norm):" begin + @timedtestset "$(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes + atol = default_tol(T) + rtol = default_tol(T) + C = randn(T, V[1] ⊗ V[2] ← V[5]) + @testset "norm: RT $RT, TC $TC" for RT in (Const, Active), TC in (Const, Duplicated) + EnzymeTestUtils.test_reverse(norm, RT, (C, TC), (2, Const); atol, rtol) + EnzymeTestUtils.test_reverse(norm, RT, (C', TC), (2, Const); atol, rtol) + end + end +end diff --git a/test/enzyme/linalg/tr.jl b/test/enzyme/linalg/tr.jl new file mode 100644 index 000000000..0f5a3e87a --- /dev/null +++ b/test/enzyme/linalg/tr.jl @@ -0,0 +1,48 @@ +using Test, TestExtras +using TensorKit +using Enzyme, EnzymeTestUtils +using Random + +@isdefined(TestSetup) || include("../../setup.jl") +using .TestSetup + +spacelist = ( + (ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'), + ( + Vect[FermionParity](0 => 1, 1 => 1), + Vect[FermionParity](0 => 1, 1 => 2)', + Vect[FermionParity](0 => 2, 1 => 1)', + Vect[FermionParity](0 => 2, 1 => 3), + Vect[FermionParity](0 => 2, 1 => 2), + ), + ( + Vect[SU2Irrep](0 => 2, 1 // 2 => 1), + Vect[SU2Irrep](0 => 1, 1 => 1), + Vect[SU2Irrep](1 // 2 => 1, 1 => 1)', + Vect[SU2Irrep](1 // 2 => 2), + Vect[SU2Irrep](0 => 1, 1 // 2 => 1, 3 // 2 => 1)', + ), + ( + Vect[FibonacciAnyon](:I => 2, :τ => 1), + Vect[FibonacciAnyon](:I => 1, :τ => 2)', + Vect[FibonacciAnyon](:I => 2, :τ => 2)', + Vect[FibonacciAnyon](:I => 2, :τ => 3), + Vect[FibonacciAnyon](:I => 2, :τ => 2), + ), +) +eltypes = (Float64, ComplexF64) + +@timedtestset verbose = true "Enzyme - LinearAlgebra (tr):" begin + @timedtestset "$(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes + atol = default_tol(T) + rtol = default_tol(T) + D1 = randn(T, V[1] ← V[1]) + D2 = randn(T, V[1] ⊗ V[2] ← V[1] ⊗ V[2]) + D3 = randn(T, V[1] ⊗ V[2] ⊗ V[3] ← V[1] ⊗ V[2] ⊗ V[3]) + @testset "tr: RT $RT, TD $TD" for RT in (Const, Active), TD in (Const, Duplicated) + EnzymeTestUtils.test_reverse(tr, RT, (D1, TD); atol, rtol) + EnzymeTestUtils.test_reverse(tr, RT, (D2, TD); atol, rtol) + EnzymeTestUtils.test_reverse(tr, RT, (D3, TD); atol, rtol) + end + end +end diff --git a/test/enzyme/planaroperations/planarcontract.jl b/test/enzyme/planaroperations/planarcontract.jl new file mode 100644 index 000000000..53db41b2f --- /dev/null +++ b/test/enzyme/planaroperations/planarcontract.jl @@ -0,0 +1,85 @@ +using Test, TestExtras +using TensorKit +using TensorOperations +using VectorInterface: Zero, One +using Enzyme, EnzymeTestUtils +using Random + +@isdefined(TestSetup) || include("../../setup.jl") +using .TestSetup +using .TestSetup: _repartition + +rng = Random.default_rng() + +spacelist = ( + (ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'), + ( + Vect[FermionParity](0 => 1, 1 => 1), + Vect[FermionParity](0 => 1, 1 => 2)', + Vect[FermionParity](0 => 2, 1 => 1)', + Vect[FermionParity](0 => 2, 1 => 3), + Vect[FermionParity](0 => 2, 1 => 2), + ), + ( + Vect[SU2Irrep](0 => 2, 1 // 2 => 1), + Vect[SU2Irrep](0 => 1, 1 => 1), + Vect[SU2Irrep](1 // 2 => 1, 1 => 1)', + Vect[SU2Irrep](1 // 2 => 2), + Vect[SU2Irrep](0 => 1, 1 // 2 => 1, 3 // 2 => 1)', + ), + ( + Vect[FibonacciAnyon](:I => 2, :τ => 1), + Vect[FibonacciAnyon](:I => 1, :τ => 2)', + Vect[FibonacciAnyon](:I => 2, :τ => 2)', + Vect[FibonacciAnyon](:I => 2, :τ => 3), + Vect[FibonacciAnyon](:I => 2, :τ => 2), + ), +) +eltypes = (Float64, ComplexF64) + +@timedtestset "Enzyme - PlanarOperations (planarcontract): $(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes + atol = default_tol(T) + rtol = default_tol(T) + for _ in 1:5 + d = 0 + local V1, V2, V3, k1, k2, k3 + # retry a couple times to make sure there are at least some nonzero elements + for _ in 1:10 + k1 = rand(0:3) + k2 = rand(0:2) + k3 = rand(0:2) + V1 = prod(v -> rand(Bool) ? v' : v, rand(V, k1); init = one(V[1])) + V2 = prod(v -> rand(Bool) ? v' : v, rand(V, k2); init = one(V[1])) + V3 = prod(v -> rand(Bool) ? v' : v, rand(V, k3); init = one(V[1])) + d = min(dim(V1 ← V2), dim(V1' ← V2), dim(V2 ← V3), dim(V2' ← V3)) + d > 1 && break + end + k′ = rand(0:(k1 + k2)) + pA = randcircshift(k′, k1 + k2 - k′, k1) + ipA = _repartition(invperm(linearize(pA)), k′) + k′ = rand(0:(k2 + k3)) + pB = randcircshift(k′, k2 + k3 - k′, k2) + ipB = _repartition(invperm(linearize(pB)), k′) + # TODO: primal value already is broken for this? + # pAB = randcircshift(k1, k3) + pAB = _repartition(tuple((1:(k1 + k3))...), k1) + + α = randn(T) + β = randn(T) + + A = randn(T, permute(V1 ← V2, ipA)) + B = randn(T, permute(V2 ← V3, ipB)) + C = randn!( + TensorOperations.tensoralloc_contract( + T, A, pA, false, B, pB, false, pAB, Val(false) + ) + ) + for TC in (Duplicated,), TA in (Duplicated,), TB in (Duplicated,) + EnzymeTestUtils.test_reverse(TensorKit.planarcontract!, TC, (C, TC), (A, TA), (pA, Const), (B, TB), (pB, Const), (pAB, Const), (One(), Const), (Zero(), Const); atol, rtol) + EnzymeTestUtils.test_reverse(TensorKit.planarcontract!, TC, (C, TC), (A, TA), (pA, Const), (B, TB), (pB, Const), (pAB, Const), (α, Const), (β, Const); atol, rtol) + EnzymeTestUtils.test_reverse(TensorKit.planarcontract!, TC, (C, TC), (A, TA), (pA, Const), (B, TB), (pB, Const), (pAB, Const), (α, Const), (β, Active); atol, rtol) + EnzymeTestUtils.test_reverse(TensorKit.planarcontract!, TC, (C, TC), (A, TA), (pA, Const), (B, TB), (pB, Const), (pAB, Const), (α, Active), (β, Const); atol, rtol) + EnzymeTestUtils.test_reverse(TensorKit.planarcontract!, TC, (C, TC), (A, TA), (pA, Const), (B, TB), (pB, Const), (pAB, Const), (α, Active), (β, Active); atol, rtol) + end + end +end diff --git a/test/enzyme/planaroperations/planartrace.jl b/test/enzyme/planaroperations/planartrace.jl new file mode 100644 index 000000000..ebf0ffbd6 --- /dev/null +++ b/test/enzyme/planaroperations/planartrace.jl @@ -0,0 +1,66 @@ +using Test, TestExtras +using TensorKit +using TensorOperations +using VectorInterface: Zero, One +using Enzyme, EnzymeTestUtils +using Random + +@isdefined(TestSetup) || include("../../setup.jl") +using .TestSetup +using .TestSetup: _repartition + +rng = Random.default_rng() + +spacelist = ( + (ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'), + ( + Vect[FermionParity](0 => 1, 1 => 1), + Vect[FermionParity](0 => 1, 1 => 2)', + Vect[FermionParity](0 => 2, 1 => 1)', + Vect[FermionParity](0 => 2, 1 => 3), + Vect[FermionParity](0 => 2, 1 => 2), + ), + ( + Vect[SU2Irrep](0 => 2, 1 // 2 => 1), + Vect[SU2Irrep](0 => 1, 1 => 1), + Vect[SU2Irrep](1 // 2 => 1, 1 => 1)', + Vect[SU2Irrep](1 // 2 => 2), + Vect[SU2Irrep](0 => 1, 1 // 2 => 1, 3 // 2 => 1)', + ), + ( + Vect[FibonacciAnyon](:I => 2, :τ => 1), + Vect[FibonacciAnyon](:I => 1, :τ => 2)', + Vect[FibonacciAnyon](:I => 2, :τ => 2)', + Vect[FibonacciAnyon](:I => 2, :τ => 3), + Vect[FibonacciAnyon](:I => 2, :τ => 2), + ), +) +eltypes = (Float64, ComplexF64) + +@timedtestset "Enzyme - PlanarOperations (planartrace): $(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes + atol = default_tol(T) + rtol = default_tol(T) + for _ in 1:5 + k1 = rand(0:2) + k2 = rand(0:1) + V1 = map(v -> rand(Bool) ? v' : v, rand(V, k1)) + V2 = map(v -> rand(Bool) ? v' : v, rand(V, k2)) + V3 = prod(x -> x ⊗ x', V2[1:k2]; init = one(V[1])) + V4 = prod(x -> x ⊗ x', V2[(k2 + 1):end]; init = one(V[1])) + + k′ = rand(0:(k1 + 2k2)) + (_p, _q) = randcircshift(k′, k1 + 2k2 - k′, k1) + p = _repartition(_p, rand(0:k1)) + q = (tuple(_q[1:2:end]...), tuple(_q[2:2:end]...)) + ip = _repartition(invperm(linearize((_p, _q))), k′) + A = randn(T, permute(prod(V1) ⊗ V3 ← V4, ip)) + + α = randn(T) + β = randn(T) + C = randn!(TensorOperations.tensoralloc_add(T, A, p, false, Val(false))) + EnyzmeTestUtils.test_reverse(TensorKit.planartrace!, Active, (C, Duplicated), (A, Duplicated), (p, Const), (q, Const), (α, Const), (β, Const), (TensorOperations.DefaultBackend(), Const), (TensorOperations.DefaultAllocator(), Const); atol, rtol) + EnyzmeTestUtils.test_reverse(TensorKit.planartrace!, Active, (C, Duplicated), (A, Duplicated), (p, Const), (q, Const), (α, Active), (β, Const), (TensorOperations.DefaultBackend(), Const), (TensorOperations.DefaultAllocator(), Const); atol, rtol) + EnyzmeTestUtils.test_reverse(TensorKit.planartrace!, Active, (C, Duplicated), (A, Duplicated), (p, Const), (q, Const), (α, Const), (β, Active), (TensorOperations.DefaultBackend(), Const), (TensorOperations.DefaultAllocator(), Const); atol, rtol) + EnyzmeTestUtils.test_reverse(TensorKit.planartrace!, Active, (C, Duplicated), (A, Duplicated), (p, Const), (q, Const), (α, Active), (β, Active), (TensorOperations.DefaultBackend(), Const), (TensorOperations.DefaultAllocator(), Const); atol, rtol) + end +end diff --git a/test/enzyme/tensoroperations/contract.jl b/test/enzyme/tensoroperations/contract.jl new file mode 100644 index 000000000..e83b073e3 --- /dev/null +++ b/test/enzyme/tensoroperations/contract.jl @@ -0,0 +1,127 @@ +using Test, TestExtras +using TensorKit +using TensorOperations +using VectorInterface: One, Zero +using Enzyme, EnzymeTestUtils +Enzyme.Compiler.VERBOSE_ERRORS[] = true + +@isdefined(TestSetup) || include("../../setup.jl") +using .TestSetup + +spacelist = ( + (ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'), + ( + Vect[FermionParity](0 => 1, 1 => 1), + Vect[FermionParity](0 => 1, 1 => 2)', + Vect[FermionParity](0 => 2, 1 => 1)', + Vect[FermionParity](0 => 2, 1 => 3), + Vect[FermionParity](0 => 2, 1 => 2), + ), + ( + Vect[U1Irrep](0 => 2, 1 => 1, -1 => 1), + Vect[U1Irrep](0 => 2, 1 => 1, -1 => 1), + Vect[U1Irrep](0 => 2, 1 => 2, -1 => 1)', + Vect[U1Irrep](0 => 1, 1 => 1, -1 => 2), + Vect[U1Irrep](0 => 1, 1 => 2, -1 => 1)', + ), + ( + Vect[SU2Irrep](0 => 2, 1 // 2 => 1), + Vect[SU2Irrep](0 => 1, 1 => 1), + Vect[SU2Irrep](1 // 2 => 1, 1 => 1)', + Vect[SU2Irrep](1 // 2 => 2), + Vect[SU2Irrep](0 => 1, 1 // 2 => 1, 3 // 2 => 1)', + ), + ( + Vect[FibonacciAnyon](:I => 2, :τ => 1), + Vect[FibonacciAnyon](:I => 1, :τ => 2)', + Vect[FibonacciAnyon](:I => 2, :τ => 2)', + Vect[FibonacciAnyon](:I => 2, :τ => 3), + Vect[FibonacciAnyon](:I => 2, :τ => 2), + ), +) +eltypes = (Float64, ComplexF64) + +@timedtestset "Enzyme - TensorOperations" begin + @timedtestset verbose = true "$(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes + atol = default_tol(T) + rtol = default_tol(T) + symmetricbraiding = BraidingStyle(sectortype(eltype(V))) isa SymmetricBraiding + symmetricbraiding && @timedtestset "tensorcontract!" begin + d = 0 + local V1, V2, V3 + # retry a couple times to make sure there are at least some nonzero elements + for _ in 1:10 + k1 = rand(0:3) + k2 = rand(0:2) + k3 = rand(0:2) + V1 = prod(v -> rand(Bool) ? v' : v, rand(V, k1); init = one(V[1])) + V2 = prod(v -> rand(Bool) ? v' : v, rand(V, k2); init = one(V[1])) + V3 = prod(v -> rand(Bool) ? v' : v, rand(V, k3); init = one(V[1])) + d = min(dim(V1 ← V2), dim(V1' ← V2), dim(V2 ← V3), dim(V2' ← V3)) + d > 0 && break + end + ipA = randindextuple(length(V1) + length(V2)) + pA = _repartition(invperm(linearize(ipA)), length(V1)) + ipB = randindextuple(length(V2) + length(V3)) + pB = _repartition(invperm(linearize(ipB)), length(V2)) + pAB = randindextuple(length(V1) + length(V3)) + + α = randn(T) + β = randn(T) + V2_conj = prod(conj, V2; init = one(V[1])) + A = randn(T, permute(V1 ← V2, ipA)) + B = randn(T, permute(V2 ← V3, ipB)) + C = randn!( + TensorOperations.tensoralloc_contract( + T, A, pA, false, B, pB, false, pAB, Val(false) + ) + ) + + for α_ in ((One(), Const), (α, Const), (α, Active)), + β_ in ((Zero(), Const), (β, Const), (β, Active)) + EnzymeTestUtils.test_reverse( + TensorKit.blas_contract!, Duplicated, + (copy(C), Duplicated), (A, Duplicated), (pA, Const), + (B, Duplicated), (pB, Const), (pAB, Const), + α_, β_, + (TensorOperations.DefaultBackend(), Const), + (TensorOperations.DefaultAllocator(), Const); + atol, rtol, + testset_name = "blas_contract! α $α_ β $β_", + ) + end + if !(T <: Real) + EnzymeTestUtils.test_reverse( + TensorKit.blas_contract!, Duplicated, + (copy(C), Duplicated), (A, Duplicated), (pA, Const), + (B, Duplicated), (pB, Const), (pAB, Const), + (real(α), Active), (real(β), Active), + (TensorOperations.DefaultBackend(), Const), + (TensorOperations.DefaultAllocator(), Const); + atol, rtol, + testset_name = "blas_contract! real(α) real(β)", + ) + EnzymeTestUtils.test_reverse( + TensorKit.blas_contract!, Duplicated, + (copy(C), Duplicated), (real(A), Duplicated), (pA, Const), + (B, Duplicated), (pB, Const), (pAB, Const), + (real(α), Active), (real(β), Active), + (TensorOperations.DefaultBackend(), Const), + (TensorOperations.DefaultAllocator(), Const); + atol, rtol, + testset_name = "blas_contract! real(A) real(α) real(β)", + ) + EnzymeTestUtils.test_reverse( + TensorKit.blas_contract!, Duplicated, + (copy(C), Duplicated), (A, Duplicated), (pA, Const), + (real(B), Duplicated), (pB, Const), (pAB, Const), + (real(α), Active), (real(β), Active), + (TensorOperations.DefaultBackend(), Const), + (TensorOperations.DefaultAllocator(), Const); + atol, rtol, + testset_name = "blas_contract! real(B) real(α) real(β)", + ) + end + end + end +end diff --git a/test/enzyme/tensoroperations/trace.jl b/test/enzyme/tensoroperations/trace.jl new file mode 100644 index 000000000..a428ec1e9 --- /dev/null +++ b/test/enzyme/tensoroperations/trace.jl @@ -0,0 +1,75 @@ +using Test, TestExtras +using TensorKit +using TensorOperations +using VectorInterface: One, Zero +using Enzyme, EnzymeTestUtils +Enzyme.Compiler.VERBOSE_ERRORS[] = true + +@isdefined(TestSetup) || include("../../setup.jl") +using .TestSetup + +spacelist = ( + (ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'), + ( + Vect[FermionParity](0 => 1, 1 => 1), + Vect[FermionParity](0 => 1, 1 => 2)', + Vect[FermionParity](0 => 2, 1 => 1)', + Vect[FermionParity](0 => 2, 1 => 3), + Vect[FermionParity](0 => 2, 1 => 2), + ), + ( + Vect[U1Irrep](0 => 2, 1 => 1, -1 => 1), + Vect[U1Irrep](0 => 2, 1 => 1, -1 => 1), + Vect[U1Irrep](0 => 2, 1 => 2, -1 => 1)', + Vect[U1Irrep](0 => 1, 1 => 1, -1 => 2), + Vect[U1Irrep](0 => 1, 1 => 2, -1 => 1)', + ), + ( + Vect[SU2Irrep](0 => 2, 1 // 2 => 1), + Vect[SU2Irrep](0 => 1, 1 => 1), + Vect[SU2Irrep](1 // 2 => 1, 1 => 1)', + Vect[SU2Irrep](1 // 2 => 2), + Vect[SU2Irrep](0 => 1, 1 // 2 => 1, 3 // 2 => 1)', + ), + ( + Vect[FibonacciAnyon](:I => 2, :τ => 1), + Vect[FibonacciAnyon](:I => 1, :τ => 2)', + Vect[FibonacciAnyon](:I => 2, :τ => 2)', + Vect[FibonacciAnyon](:I => 2, :τ => 3), + Vect[FibonacciAnyon](:I => 2, :τ => 2), + ), +) +eltypes = (Float64, ComplexF64) + +@timedtestset "Enzyme - TensorOperations (trace)" begin + @timedtestset verbose = true "$(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes + atol = default_tol(T) + rtol = default_tol(T) + symmetricbraiding = BraidingStyle(sectortype(eltype(V))) isa SymmetricBraiding + symmetricbraiding && @timedtestset "trace_permute!" begin + k1 = rand(0:2) + k2 = rand(1:2) + V1 = map(v -> rand(Bool) ? v' : v, rand(V, k1)) + V2 = map(v -> rand(Bool) ? v' : v, rand(V, k2)) + + (_p, _q) = randindextuple(k1 + 2 * k2, k1) + p = _repartition(_p, rand(0:k1)) + q = _repartition(_q, k2) + ip = _repartition(invperm(linearize((_p, _q))), rand(0:(k1 + 2 * k2))) + A = randn(T, permute(prod(V1) ⊗ prod(V2) ← prod(V2), ip)) + + α = randn(T) + β = randn(T) + C = randn!(TensorOperations.tensoralloc_add(T, A, p, false, Val(false))) + for TC in (Const, Duplicated), TA in (Const, Duplicated), Tα in (Const, Active), Tβ in (Const, Active) + EnzymeTestUtils.test_reverse( + TensorKit.trace_permute!, TC, + (copy(C), TC), (A, TA), (p, Const), (q, Const), + (α, Tα), (β, Tβ), (TensorOperations.DefaultBackend(), Const); + atol, rtol, + testset_name = "trace_permute! TC $TC TA $TA Tα $Tα Tβ $Tβ", + ) + end + end + end +end diff --git a/test/enzyme/vectorinterface/add.jl b/test/enzyme/vectorinterface/add.jl new file mode 100644 index 000000000..87366b0cb --- /dev/null +++ b/test/enzyme/vectorinterface/add.jl @@ -0,0 +1,56 @@ +using Test, TestExtras +using TensorKit +using TensorOperations +using Enzyme, EnzymeTestUtils +using Random, FiniteDifferences + +@isdefined(TestSetup) || include("../../setup.jl") +using .TestSetup + +spacelist = ( + (ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'), + ( + Vect[FermionParity](0 => 1, 1 => 1), + Vect[FermionParity](0 => 1, 1 => 2)', + Vect[FermionParity](0 => 2, 1 => 1)', + Vect[FermionParity](0 => 2, 1 => 3), + Vect[FermionParity](0 => 2, 1 => 2), + ), + ( + Vect[SU2Irrep](0 => 2, 1 // 2 => 1), + Vect[SU2Irrep](0 => 1, 1 => 1), + Vect[SU2Irrep](1 // 2 => 1, 1 => 1)', + Vect[SU2Irrep](1 // 2 => 2), + Vect[SU2Irrep](0 => 1, 1 // 2 => 1, 3 // 2 => 1)', + ), + ( + Vect[FibonacciAnyon](:I => 2, :τ => 1), + Vect[FibonacciAnyon](:I => 1, :τ => 2)', + Vect[FibonacciAnyon](:I => 2, :τ => 2)', + Vect[FibonacciAnyon](:I => 2, :τ => 3), + Vect[FibonacciAnyon](:I => 2, :τ => 2), + ), +) +eltypes = (Float64, ComplexF64) + +@testset "Enzyme - VectorInterface (add!)" begin + @timedtestset "$(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes + atol = default_tol(T) + rtol = default_tol(T) + + C = randn(T, V[1] ⊗ V[2] ← V[3] ⊗ V[4] ⊗ V[5]) + A = randn(T, V[1] ⊗ V[2] ← V[3] ⊗ V[4] ⊗ V[5]) + α = randn(T) + β = randn(T) + + @testset for TC in (Duplicated, Const), TA in (Duplicated, Const) + EnzymeTestUtils.test_reverse(add!, TC, (copy(C), TC), (A, TA); atol, rtol) + @testset for Tα in (Active, Const) + EnzymeTestUtils.test_reverse(add!, TC, (copy(C), TC), (A, TA), (α, Tα); atol, rtol) + @testset for Tβ in (Active, Const) + EnzymeTestUtils.test_reverse(add!, TC, (copy(C), TC), (A, TA), (α, Tα), (β, Tβ); atol, rtol) + end + end + end + end +end diff --git a/test/enzyme/vectorinterface/inner.jl b/test/enzyme/vectorinterface/inner.jl new file mode 100644 index 000000000..98b74ee34 --- /dev/null +++ b/test/enzyme/vectorinterface/inner.jl @@ -0,0 +1,47 @@ +using Test, TestExtras +using TensorKit +using TensorOperations +using Enzyme, EnzymeTestUtils +using Random, FiniteDifferences + +@isdefined(TestSetup) || include("../../setup.jl") +using .TestSetup + +spacelist = ( + (ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'), + ( + Vect[FermionParity](0 => 1, 1 => 1), + Vect[FermionParity](0 => 1, 1 => 2)', + Vect[FermionParity](0 => 2, 1 => 1)', + Vect[FermionParity](0 => 2, 1 => 3), + Vect[FermionParity](0 => 2, 1 => 2), + ), + ( + Vect[SU2Irrep](0 => 2, 1 // 2 => 1), + Vect[SU2Irrep](0 => 1, 1 => 1), + Vect[SU2Irrep](1 // 2 => 1, 1 => 1)', + Vect[SU2Irrep](1 // 2 => 2), + Vect[SU2Irrep](0 => 1, 1 // 2 => 1, 3 // 2 => 1)', + ), + ( + Vect[FibonacciAnyon](:I => 2, :τ => 1), + Vect[FibonacciAnyon](:I => 1, :τ => 2)', + Vect[FibonacciAnyon](:I => 2, :τ => 2)', + Vect[FibonacciAnyon](:I => 2, :τ => 3), + Vect[FibonacciAnyon](:I => 2, :τ => 2), + ), +) +eltypes = (Float64, ComplexF64) + +@testset "Enzyme - VectorInterface" begin + @timedtestset "$(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes + atol = default_tol(T) + rtol = default_tol(T) + C = randn(T, V[1] ⊗ V[2] ← V[3] ⊗ V[4] ⊗ V[5]) + A = randn(T, V[1] ⊗ V[2] ← V[3] ⊗ V[4] ⊗ V[5]) + @testset for RT in (Active,), TC in (Duplicated, Const), TA in (Duplicated, Const) + EnzymeTestUtils.test_reverse(inner, RT, (C, TC), (A, TA); atol, rtol) + EnzymeTestUtils.test_reverse(inner, RT, (C', TC), (A', TA); atol, rtol) + end + end +end diff --git a/test/enzyme/vectorinterface/scale.jl b/test/enzyme/vectorinterface/scale.jl new file mode 100644 index 000000000..7d6d17fd6 --- /dev/null +++ b/test/enzyme/vectorinterface/scale.jl @@ -0,0 +1,54 @@ +using Test, TestExtras +using TensorKit +using TensorOperations +using Enzyme, EnzymeTestUtils +using Random, FiniteDifferences + +@isdefined(TestSetup) || include("../../setup.jl") +using .TestSetup + +spacelist = ( + (ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'), + ( + Vect[FermionParity](0 => 1, 1 => 1), + Vect[FermionParity](0 => 1, 1 => 2)', + Vect[FermionParity](0 => 2, 1 => 1)', + Vect[FermionParity](0 => 2, 1 => 3), + Vect[FermionParity](0 => 2, 1 => 2), + ), + ( + Vect[SU2Irrep](0 => 2, 1 // 2 => 1), + Vect[SU2Irrep](0 => 1, 1 => 1), + Vect[SU2Irrep](1 // 2 => 1, 1 => 1)', + Vect[SU2Irrep](1 // 2 => 2), + Vect[SU2Irrep](0 => 1, 1 // 2 => 1, 3 // 2 => 1)', + ), + ( + Vect[FibonacciAnyon](:I => 2, :τ => 1), + Vect[FibonacciAnyon](:I => 1, :τ => 2)', + Vect[FibonacciAnyon](:I => 2, :τ => 2)', + Vect[FibonacciAnyon](:I => 2, :τ => 3), + Vect[FibonacciAnyon](:I => 2, :τ => 2), + ), +) +eltypes = (Float64, ComplexF64) + +@testset "Enzyme - VectorInterface (scale!)" begin + @timedtestset "$(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes + atol = default_tol(T) + rtol = default_tol(T) + C = randn(T, V[1] ⊗ V[2] ← V[3] ⊗ V[4] ⊗ V[5]) + A = randn(T, V[1] ⊗ V[2] ← V[3] ⊗ V[4] ⊗ V[5]) + α = randn(T) + @testset for TC in (Duplicated, Const), Tα in (Active, Const) + EnzymeTestUtils.test_reverse(scale!, TC, (C, TC), (α, Tα); atol, rtol) + EnzymeTestUtils.test_reverse(scale!, TC, (C', TC), (α, Tα); atol, rtol) + @testset for TA in (Duplicated, Const) + EnzymeTestUtils.test_reverse(scale!, TC, (C, TC), (A, TA), (α, Tα); atol, rtol) + EnzymeTestUtils.test_reverse(scale!, TC, (C', TC), (A', TA), (α, Tα); atol, rtol) + EnzymeTestUtils.test_reverse(scale!, TC, (copy(C'), TC), (A', TA), (α, Tα); atol, rtol) + EnzymeTestUtils.test_reverse(scale!, TC, (C', TC), (copy(A'), TA), (α, Tα); atol, rtol) + end + end + end +end diff --git a/test/runtests.jl b/test/runtests.jl index ad7b4006e..87dc8274d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -26,9 +26,9 @@ else groups = settings[:groups] end -checktestgroup(group) = isdir(joinpath(@__DIR__, group)) || +#=checktestgroup(group) = isdir(joinpath(@__DIR__, group)) || throw(ArgumentError("Invalid group ($group), no such folder")) -foreach(checktestgroup, groups) +foreach(checktestgroup, groups)=# @info "Loaded test groups:" groups @@ -57,7 +57,7 @@ istestfile(fn) = endswith(fn, ".jl") && !contains(fn, "setup") # somehow AD tests are unreasonably slow on Apple CI # and ChainRulesTestUtils doesn't like prereleases - if group == "chainrules" || group == "mooncake" + if group == "chainrules" || group == "mooncake" || group == "enzyme" Sys.isapple() && get(ENV, "CI", "false") == "true" && continue isempty(VERSION.prerelease) || continue end