From d0b24a5f57a3c11eba1cc1554733c0d39a512dd9 Mon Sep 17 00:00:00 2001 From: Yue Zhengyuan Date: Thu, 19 Mar 2026 20:31:04 +0800 Subject: [PATCH 1/2] Simple update refactoring --- src/algorithms/time_evolution/apply_gate.jl | 6 +- src/algorithms/time_evolution/simpleupdate.jl | 71 ++++++++----------- .../time_evolution/simpleupdate3site.jl | 14 ++-- 3 files changed, 41 insertions(+), 50 deletions(-) diff --git a/src/algorithms/time_evolution/apply_gate.jl b/src/algorithms/time_evolution/apply_gate.jl index 171f327c2..363b23f8d 100644 --- a/src/algorithms/time_evolution/apply_gate.jl +++ b/src/algorithms/time_evolution/apply_gate.jl @@ -48,7 +48,7 @@ When `A`, `B` are PEPOTensors, 5 1 4 1 4 1 ``` """ -function _qr_bond(A::PT, B::PT; gate_ax::Int = 1) where {PT <: Union{PEPSTensor, PEPOTensor}} +function _qr_bond(A::PT, B::PT; gate_ax::Int = 1, kwargs...) where {PT <: Union{PEPSTensor, PEPOTensor}} @assert 1 <= gate_ax <= numout(A) permA, permB, permX, permY = if A isa PEPSTensor ((2, 4, 5), (1, 3)), ((2, 3, 4), (1, 5)), (1, 4, 2, 3), Tuple(1:4) @@ -59,8 +59,8 @@ function _qr_bond(A::PT, B::PT; gate_ax::Int = 1) where {PT <: Union{PEPSTensor, ((1, 3, 5, 6), (2, 4)), ((1, 3, 4, 5), (2, 6)), (1, 2, 5, 3, 4), Tuple(1:5) end end - X, a = left_orth!(permute(A, permA; copy = true); positive = true) - Y, b = left_orth!(permute(B, permB; copy = true); positive = true) + X, a = left_orth!(permute(A, permA; copy = true); kwargs...) + Y, b = left_orth!(permute(B, permB; copy = true); kwargs...) X, Y = permute(X, permX), permute(Y, permY) b = permute(b, ((3, 2), (1,))) return X, a, b, Y diff --git a/src/algorithms/time_evolution/simpleupdate.jl b/src/algorithms/time_evolution/simpleupdate.jl index 1a06fcbcc..f0446d8a1 100644 --- a/src/algorithms/time_evolution/simpleupdate.jl +++ b/src/algorithms/time_evolution/simpleupdate.jl @@ -66,50 +66,44 @@ function TimeEvolver( return TimeEvolver(alg, dt, nstep, gate, state) end -""" -Optimized simple update of nearest neighbor bonds utilizing -reduced bond tensors without decomposing the gate into a 2-site MPO. +function _bond_rotation(x, bonddir::Int, rev::Bool; inv::Bool = false) + return if bonddir == 1 # x-bond + rev ? rot180(x) : x + elseif bonddir == 2 # y-bond + rev ? (inv ? rotr90(x) : rotl90(x)) : (inv ? rotl90(x) : rotr90(x)) + else + error("`bonddir` must be 1 (for x-bonds) or 2 (for y-bonds).") + end +end -When `purified = true`, `gate` acts on the codomain physical legs of `state`. -Otherwise, `gate` acts on both the codomain and the domain physical legs of `state`. +""" +Simple update optimized for nearest neighbor gates +utilizing reduced bond tensors with the physical leg. """ function _su_iter!( state::InfiniteState, gate::NNGate, env::SUWeight, - sites::Vector{CartesianIndex{2}}, truncs::Vector{E}; - purified::Bool = true - ) where {E <: TruncationStrategy} + sites::Vector{CartesianIndex{2}}, alg::SimpleUpdate + ) Nr, Nc = size(state) + truncs = _get_cluster_trunc(alg.trunc, sites, (Nr, Nc)) @assert length(sites) == 2 && length(truncs) == 1 Ms, open_vaxs, = _get_cluster(state, sites, env; permute = false) normalize!.(Ms, Inf) # rotate bond, rev = _nn_bondrev(sites..., (Nr, Nc)) - A, B = if bond[1] == 1 # x-bond - rev ? map(rot180, Ms) : Ms - else # y-bond - rev ? map(rotl90, Ms) : map(rotr90, Ms) - end + A, B = _bond_rotation.(Ms, bond[1], rev; inv = false) # apply gate ϵ, s = 0.0, nothing - gate_axs = purified ? (1:1) : (1:2) + gate_axs = alg.purified ? (1:1) : (1:2) for gate_ax in gate_axs - X, a, b, Y = _qr_bond(A, B; gate_ax) + X, a, b, Y = _qr_bond(A, B; gate_ax, positive = true) a, s, b, ϵ′ = _apply_gate(a, b, gate, truncs[1]) ϵ = max(ϵ, ϵ′) A, B = _qr_bond_undo(X, a, b, Y) end # rotate back - if bond[1] == 1 # x-bond - if rev - A, B = rot180(A), rot180(B) - end - else # y-bond - if rev - A, B = rotr90(A), rotr90(B) - else - A, B = rotl90(A), rotl90(B) - end - end + A = _bond_rotation(A, bond[1], rev; inv = true) + B = _bond_rotation(B, bond[1], rev; inv = true) # remove environment weights siteA, siteB = map(sites) do site return CartesianIndex(mod1(site[1], Nr), mod1(site[2], Nc)) @@ -120,8 +114,8 @@ function _su_iter!( normalize!(A, Inf) normalize!(B, Inf) normalize!(s, Inf) - state.A[siteA], state.A[siteB] = A, B - env.data[bond...] = s + state[siteA], state[siteB] = A, B + env[bond...] = s return ϵ end @@ -134,40 +128,37 @@ function su_iter( ) Nr, Nc, = size(state) state2, env2, ϵ = deepcopy(state), deepcopy(env), 0.0 - purified = alg.purified for (sites, gate) in gates.terms if length(sites) == 1 # 1-site gate # TODO: special treatment for bipartite state site = sites[1] r, c = mod1(site[1], Nr), mod1(site[2], Nc) - state2.A[r, c] = _apply_sitegate(state2.A[r, c], gate; purified) + state2[r, c] = _apply_sitegate(state2[r, c], gate; alg.purified) elseif length(sites) == 2 (d, r, c), = _nn_bondrev(sites..., (Nr, Nc)) if alg.bipartite length(sites) > 2 && error("Multi-site MPO gates are not compatible with bipartite states.") r > 1 && continue end - truncs = _get_cluster_trunc(alg.trunc, sites, size(state)[1:2]) - ϵ′ = _su_iter!(state2, gate, env2, sites, truncs; purified) + ϵ′ = _su_iter!(state2, gate, env2, sites, alg) ϵ = max(ϵ, ϵ′) (!alg.bipartite) && continue if d == 1 rp1, cp1 = _next(r, Nr), _next(c, Nc) - state2.A[rp1, cp1] = deepcopy(state2.A[r, c]) - state2.A[rp1, c] = deepcopy(state2.A[r, cp1]) - env2.data[1, rp1, cp1] = deepcopy(env2.data[1, r, c]) + state2[rp1, cp1] = deepcopy(state2[r, c]) + state2[rp1, c] = deepcopy(state2[r, cp1]) + env2[1, rp1, cp1] = deepcopy(env2[1, r, c]) else rm1, cm1 = _prev(r, Nr), _prev(c, Nc) - state2.A[rm1, cm1] = deepcopy(state2.A[r, c]) - state2.A[r, cm1] = deepcopy(state2.A[rm1, c]) - env2.data[2, rm1, cm1] = deepcopy(env2.data[2, r, c]) + state2[rm1, cm1] = deepcopy(state2[r, c]) + state2[r, cm1] = deepcopy(state2[rm1, c]) + env2[2, rm1, cm1] = deepcopy(env2[2, r, c]) end else # N-site MPO gate (N ≥ 2) alg.bipartite && error("Multi-site MPO gates are not compatible with bipartite states.") - truncs = _get_cluster_trunc(alg.trunc, sites, size(state)[1:2]) - ϵ′ = _su_iter!(state2, gate, env2, sites, truncs; purified) + ϵ′ = _su_iter!(state2, gate, env2, sites, alg) ϵ = max(ϵ, ϵ′) end end diff --git a/src/algorithms/time_evolution/simpleupdate3site.jl b/src/algorithms/time_evolution/simpleupdate3site.jl index c7792cca4..4b091c830 100644 --- a/src/algorithms/time_evolution/simpleupdate3site.jl +++ b/src/algorithms/time_evolution/simpleupdate3site.jl @@ -150,9 +150,9 @@ function _get_cluster( Ms = map(zip(sites, open_vaxs, perms)) do (site, vaxs, perm) s = CartesianIndex(mod1(site[1], Nr), mod1(site[2], Nc)) M = if env === nothing - state.A[s] + state[s] else - absorb_weight(state.A[s], env, s[1], s[2], vaxs) + absorb_weight(state[s], env, s[1], s[2], vaxs) end return permute ? TensorKit.permute(M, perm) : M end @@ -164,10 +164,10 @@ Simple update with an N-site MPO `gate` (N ≥ 2). """ function _su_iter!( state::InfiniteState, gate::Vector{T}, env::SUWeight, - sites::Vector{CartesianIndex{2}}, truncs::Vector{E}; - purified::Bool = true - ) where {T <: AbstractTensorMap, E <: TruncationStrategy} + sites::Vector{CartesianIndex{2}}, alg::SimpleUpdate + ) where {T <: AbstractTensorMap} Nr, Nc = size(state) + truncs = _get_cluster_trunc(alg.trunc, sites, (Nr, Nc)) Ms, open_vaxs, invperms = _get_cluster(state, sites, env) flips = [isdual(space(M, 1)) for M in Ms[2:end]] Vphys = [codomain(M, 2) for M in Ms] @@ -175,7 +175,7 @@ function _su_iter!( # flip virtual arrows in `Ms` to ← _flip_virtuals!(Ms, flips) # apply gate MPOs and truncate - gate_axs = purified ? (1:1) : (1:2) + gate_axs = alg.purified ? (1:1) : (1:2) wts, ϵs = nothing, nothing for gate_ax in gate_axs _apply_gatempo!(Ms, gate; gate_ax) @@ -206,7 +206,7 @@ function _su_iter!( # remove weights on open axes of the cluster M = absorb_weight(M, env, s′[1], s′[2], vaxs; inv = true) # update state tensors - state.A[s′] = normalize(M, Inf) + state[s′] = normalize(M, Inf) end return maximum(ϵs) end From bfc29c4c522434ad84e583748f4277a9e4ad1621 Mon Sep 17 00:00:00 2001 From: Yue Zhengyuan Date: Fri, 20 Mar 2026 08:39:47 +0800 Subject: [PATCH 2/2] Unwrap nested ternary `?` --- src/algorithms/time_evolution/simpleupdate.jl | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/algorithms/time_evolution/simpleupdate.jl b/src/algorithms/time_evolution/simpleupdate.jl index f0446d8a1..5e10b2097 100644 --- a/src/algorithms/time_evolution/simpleupdate.jl +++ b/src/algorithms/time_evolution/simpleupdate.jl @@ -70,7 +70,11 @@ function _bond_rotation(x, bonddir::Int, rev::Bool; inv::Bool = false) return if bonddir == 1 # x-bond rev ? rot180(x) : x elseif bonddir == 2 # y-bond - rev ? (inv ? rotr90(x) : rotl90(x)) : (inv ? rotl90(x) : rotr90(x)) + if rev + inv ? rotr90(x) : rotl90(x) + else + inv ? rotl90(x) : rotr90(x) + end else error("`bonddir` must be 1 (for x-bonds) or 2 (for y-bonds).") end