Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/algorithms/time_evolution/apply_gate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ When `A`, `B` are PEPOTensors,
5 1 4 1 4 1
```
"""
function _qr_bond(A::PT, B::PT; gate_ax::Int = 1) where {PT <: Union{PEPSTensor, PEPOTensor}}
function _qr_bond(A::PT, B::PT; gate_ax::Int = 1, kwargs...) where {PT <: Union{PEPSTensor, PEPOTensor}}
@assert 1 <= gate_ax <= numout(A)
permA, permB, permX, permY = if A isa PEPSTensor
((2, 4, 5), (1, 3)), ((2, 3, 4), (1, 5)), (1, 4, 2, 3), Tuple(1:4)
Expand All @@ -59,8 +59,8 @@ function _qr_bond(A::PT, B::PT; gate_ax::Int = 1) where {PT <: Union{PEPSTensor,
((1, 3, 5, 6), (2, 4)), ((1, 3, 4, 5), (2, 6)), (1, 2, 5, 3, 4), Tuple(1:5)
end
end
X, a = left_orth!(permute(A, permA; copy = true); positive = true)
Y, b = left_orth!(permute(B, permB; copy = true); positive = true)
X, a = left_orth!(permute(A, permA; copy = true); kwargs...)
Y, b = left_orth!(permute(B, permB; copy = true); kwargs...)
X, Y = permute(X, permX), permute(Y, permY)
b = permute(b, ((3, 2), (1,)))
return X, a, b, Y
Expand Down
75 changes: 35 additions & 40 deletions src/algorithms/time_evolution/simpleupdate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,50 +66,48 @@ function TimeEvolver(
return TimeEvolver(alg, dt, nstep, gate, state)
end

"""
Optimized simple update of nearest neighbor bonds utilizing
reduced bond tensors without decomposing the gate into a 2-site MPO.
function _bond_rotation(x, bonddir::Int, rev::Bool; inv::Bool = false)
return if bonddir == 1 # x-bond
rev ? rot180(x) : x
elseif bonddir == 2 # y-bond
if rev
inv ? rotr90(x) : rotl90(x)
else
inv ? rotl90(x) : rotr90(x)
end
else
error("`bonddir` must be 1 (for x-bonds) or 2 (for y-bonds).")
end
end

When `purified = true`, `gate` acts on the codomain physical legs of `state`.
Otherwise, `gate` acts on both the codomain and the domain physical legs of `state`.
"""
Simple update optimized for nearest neighbor gates
utilizing reduced bond tensors with the physical leg.
"""
function _su_iter!(
state::InfiniteState, gate::NNGate, env::SUWeight,
sites::Vector{CartesianIndex{2}}, truncs::Vector{E};
purified::Bool = true
) where {E <: TruncationStrategy}
sites::Vector{CartesianIndex{2}}, alg::SimpleUpdate
)
Nr, Nc = size(state)
truncs = _get_cluster_trunc(alg.trunc, sites, (Nr, Nc))
@assert length(sites) == 2 && length(truncs) == 1
Ms, open_vaxs, = _get_cluster(state, sites, env; permute = false)
normalize!.(Ms, Inf)
# rotate
bond, rev = _nn_bondrev(sites..., (Nr, Nc))
A, B = if bond[1] == 1 # x-bond
rev ? map(rot180, Ms) : Ms
else # y-bond
rev ? map(rotl90, Ms) : map(rotr90, Ms)
end
A, B = _bond_rotation.(Ms, bond[1], rev; inv = false)
# apply gate
ϵ, s = 0.0, nothing
gate_axs = purified ? (1:1) : (1:2)
gate_axs = alg.purified ? (1:1) : (1:2)
for gate_ax in gate_axs
X, a, b, Y = _qr_bond(A, B; gate_ax)
X, a, b, Y = _qr_bond(A, B; gate_ax, positive = true)
a, s, b, ϵ′ = _apply_gate(a, b, gate, truncs[1])
ϵ = max(ϵ, ϵ′)
A, B = _qr_bond_undo(X, a, b, Y)
end
# rotate back
if bond[1] == 1 # x-bond
if rev
A, B = rot180(A), rot180(B)
end
else # y-bond
if rev
A, B = rotr90(A), rotr90(B)
else
A, B = rotl90(A), rotl90(B)
end
end
A = _bond_rotation(A, bond[1], rev; inv = true)
B = _bond_rotation(B, bond[1], rev; inv = true)
# remove environment weights
siteA, siteB = map(sites) do site
return CartesianIndex(mod1(site[1], Nr), mod1(site[2], Nc))
Expand All @@ -120,8 +118,8 @@ function _su_iter!(
normalize!(A, Inf)
normalize!(B, Inf)
normalize!(s, Inf)
state.A[siteA], state.A[siteB] = A, B
env.data[bond...] = s
state[siteA], state[siteB] = A, B
env[bond...] = s
return ϵ
end

Expand All @@ -134,40 +132,37 @@ function su_iter(
)
Nr, Nc, = size(state)
state2, env2, ϵ = deepcopy(state), deepcopy(env), 0.0
purified = alg.purified
for (sites, gate) in gates.terms
if length(sites) == 1
# 1-site gate
# TODO: special treatment for bipartite state
site = sites[1]
r, c = mod1(site[1], Nr), mod1(site[2], Nc)
state2.A[r, c] = _apply_sitegate(state2.A[r, c], gate; purified)
state2[r, c] = _apply_sitegate(state2[r, c], gate; alg.purified)
elseif length(sites) == 2
(d, r, c), = _nn_bondrev(sites..., (Nr, Nc))
if alg.bipartite
length(sites) > 2 && error("Multi-site MPO gates are not compatible with bipartite states.")
r > 1 && continue
end
truncs = _get_cluster_trunc(alg.trunc, sites, size(state)[1:2])
ϵ′ = _su_iter!(state2, gate, env2, sites, truncs; purified)
ϵ′ = _su_iter!(state2, gate, env2, sites, alg)
ϵ = max(ϵ, ϵ′)
(!alg.bipartite) && continue
if d == 1
rp1, cp1 = _next(r, Nr), _next(c, Nc)
state2.A[rp1, cp1] = deepcopy(state2.A[r, c])
state2.A[rp1, c] = deepcopy(state2.A[r, cp1])
env2.data[1, rp1, cp1] = deepcopy(env2.data[1, r, c])
state2[rp1, cp1] = deepcopy(state2[r, c])
state2[rp1, c] = deepcopy(state2[r, cp1])
env2[1, rp1, cp1] = deepcopy(env2[1, r, c])
else
rm1, cm1 = _prev(r, Nr), _prev(c, Nc)
state2.A[rm1, cm1] = deepcopy(state2.A[r, c])
state2.A[r, cm1] = deepcopy(state2.A[rm1, c])
env2.data[2, rm1, cm1] = deepcopy(env2.data[2, r, c])
state2[rm1, cm1] = deepcopy(state2[r, c])
state2[r, cm1] = deepcopy(state2[rm1, c])
env2[2, rm1, cm1] = deepcopy(env2[2, r, c])
end
else
# N-site MPO gate (N ≥ 2)
alg.bipartite && error("Multi-site MPO gates are not compatible with bipartite states.")
truncs = _get_cluster_trunc(alg.trunc, sites, size(state)[1:2])
ϵ′ = _su_iter!(state2, gate, env2, sites, truncs; purified)
ϵ′ = _su_iter!(state2, gate, env2, sites, alg)
ϵ = max(ϵ, ϵ′)
end
end
Expand Down
14 changes: 7 additions & 7 deletions src/algorithms/time_evolution/simpleupdate3site.jl
Original file line number Diff line number Diff line change
Expand Up @@ -150,9 +150,9 @@ function _get_cluster(
Ms = map(zip(sites, open_vaxs, perms)) do (site, vaxs, perm)
s = CartesianIndex(mod1(site[1], Nr), mod1(site[2], Nc))
M = if env === nothing
state.A[s]
state[s]
else
absorb_weight(state.A[s], env, s[1], s[2], vaxs)
absorb_weight(state[s], env, s[1], s[2], vaxs)
end
return permute ? TensorKit.permute(M, perm) : M
end
Expand All @@ -164,18 +164,18 @@ Simple update with an N-site MPO `gate` (N ≥ 2).
"""
function _su_iter!(
state::InfiniteState, gate::Vector{T}, env::SUWeight,
sites::Vector{CartesianIndex{2}}, truncs::Vector{E};
purified::Bool = true
) where {T <: AbstractTensorMap, E <: TruncationStrategy}
sites::Vector{CartesianIndex{2}}, alg::SimpleUpdate
) where {T <: AbstractTensorMap}
Nr, Nc = size(state)
truncs = _get_cluster_trunc(alg.trunc, sites, (Nr, Nc))
Ms, open_vaxs, invperms = _get_cluster(state, sites, env)
flips = [isdual(space(M, 1)) for M in Ms[2:end]]
Vphys = [codomain(M, 2) for M in Ms]
normalize!.(Ms, Inf)
# flip virtual arrows in `Ms` to ←
_flip_virtuals!(Ms, flips)
# apply gate MPOs and truncate
gate_axs = purified ? (1:1) : (1:2)
gate_axs = alg.purified ? (1:1) : (1:2)
wts, ϵs = nothing, nothing
for gate_ax in gate_axs
_apply_gatempo!(Ms, gate; gate_ax)
Expand Down Expand Up @@ -206,7 +206,7 @@ function _su_iter!(
# remove weights on open axes of the cluster
M = absorb_weight(M, env, s′[1], s′[2], vaxs; inv = true)
# update state tensors
state.A[s′] = normalize(M, Inf)
state[s′] = normalize(M, Inf)
end
return maximum(ϵs)
end
Expand Down
Loading