Skip to content

Commit fe9d514

Browse files
committed
Fix VUMPS, IDMRG(2)
1 parent c5afadd commit fe9d514

5 files changed

Lines changed: 82 additions & 16 deletions

File tree

examples/J1J2_mpi.jl

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
using Pkg
2+
Pkg.activate("/home/afeuerpfeil/.julia/dev/MPSKitParallel")
3+
using MPSKit
4+
using MPSKitModels
5+
using TensorKit
6+
using MPSKitParallel
7+
using MPIHelper
8+
using MPI
9+
MPSKit.Defaults.set_scheduler!(:serial)
10+
11+
MPI.Init()
12+
mpi_rank() = MPI.Comm_rank(MPI.COMM_WORLD)
13+
mpi_size() = MPI.Comm_size(MPI.COMM_WORLD)
14+
15+
N=4
16+
J2=0.3
17+
H_J1 = @mpoham sum(S_exchange(;spin=1//2){i, j} for (i, j) in nearest_neighbours(InfiniteChain(N)));
18+
H_J2 = @mpoham sum(rmul!(S_exchange(;spin=1//2){i, j}, J2) for (i, j) in next_nearest_neighbours(InfiniteChain(N)));
19+
20+
H_J1J2 = H_J1 + H_J2;
21+
state = InfiniteMPS(fill(2, N), fill(20, N));
22+
state = MPIHelper.bcast(state, MPI.COMM_WORLD)
23+
24+
25+
ψ_inf, envs, delta = find_groundstate(
26+
state, H_J1J2, VUMPS(; maxiter = 20, tol = 1.0e-12, verbosity=1)
27+
);
28+
29+
if mpi_rank() == 0
30+
H_mpi = @mpoham sum(S_exchange(;spin=1//2){i, j} for (i, j) in nearest_neighbours(InfiniteChain(N)));
31+
elseif mpi_rank() == 1
32+
H_mpi = @mpoham sum(rmul!(S_exchange(;spin=1//2){i, j}, J2) for (i, j) in next_nearest_neighbours(InfiniteChain(N)));
33+
else
34+
error("This example only works with 2 MPI processes.")
35+
end
36+
H_mpi = MPIOperator(H_mpi)
37+
38+
39+
println("Hey, I am rank=$(mpi_rank()) out of $(mpi_size()) processes.")
40+
41+
ψ_infmpi, envs_infmpi, delta_infmpi = find_groundstate(state, H_mpi, verbosity=1); ## This tests VUMPS and GradientGrassmann
42+
43+
println("Hey, I am rank=$(mpi_rank()) out of $(mpi_size()) processes. abs(dot(ψ_inf, ψ_infmpi)) = $(abs(dot(ψ_inf, ψ_infmpi)))")
44+
45+
ψ_infmpi, envs_infmpi, delta_infmpi = find_groundstate(state, H_mpi, IDMRG2(; maxiter = 20, tol = 1.0e-12, verbosity=1, trscheme=truncrank(50)));
46+
47+
println("Hey, I am rank=$(mpi_rank()) out of $(mpi_size()) processes. abs(dot(ψ_inf, ψ_infmpi)) = $(abs(dot(ψ_inf, ψ_infmpi)))")
48+
49+
ψ_infmpi, envs_infmpi, delta_infmpi = find_groundstate(state, H_mpi, IDMRG(; maxiter = 20, tol = 1.0e-12, verbosity=1));
50+
51+
println("Hey, I am rank=$(mpi_rank()) out of $(mpi_size()) processes. abs(dot(ψ_inf, ψ_infmpi)) = $(abs(dot(ψ_inf, ψ_infmpi)))")
52+

src/MPIOperator/mpioperator.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ end
1717

1818
function (Op::MPIOperator{O})(x::S) where {O,S}
1919
y_per_rank = parent(Op)(x)
20-
y = MPIHelper.allreduce(y_per_rank, +, MPI.COMM_WORLD)
20+
y = MPIHelper.allreduce(y_per_rank, Base.:+, MPI.COMM_WORLD)
2121
return y
2222
end
2323

src/MPSKitParallel.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,9 @@ import MPSKit: environments, AbstractMPSEnvironments, InfiniteEnvironments
2424
import MPSKit: C_hamiltonian, AC_hamiltonian, AC2_hamiltonian, C_projection, AC_projection, AC2_projection
2525
import MPSKit: exact_diagonalization
2626

27-
using MPSKit: IterativeSolver, VUMPSState, AbstractMPS, Multiline, eachsite, fixedpoint, regauge!, left_orth, left_orth!, transfer_leftenv!, transfer_rightenv!, svd_trunc!
28-
using MPSKit: AC, AC2, _transpose_front, _transpose_tail, _mul_tail, AC_hamiltonian, AC2_hamiltonian
27+
using MPSKit: IterativeSolver, VUMPSState, AbstractMPS, Multiline, eachsite, fixedpoint, regauge!, left_orth, left_orth!, right_orth, right_orth!, transfer_leftenv!, transfer_rightenv!, svd_trunc!
28+
using MPSKit: AC2, _transpose_front, _transpose_tail, _mul_front, _mul_tail, AC_hamiltonian, AC2_hamiltonian, _firstspace
29+
using MPSKit: _mul_front
2930
using MPSKit.DynamicTols: updatetol
3031
using Base.Threads: @spawn, @sync
3132

src/algorithms/groundstate/idmrg.jl

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,43 @@
1+
function check_state(psi::InfiniteMPS,pos=-10,dir=0)
2+
psibc=MPIHelper.bcast(psi, MPI.COMM_WORLD)
3+
if !mpi_is_root()
4+
for typ in [:AL,:AR,:C]
5+
psi_c=getfield(psibc, typ)
6+
psi_l=getfield(psi, typ)
7+
for i in eachindex(psi_c)
8+
pc=psi_c[i]
9+
pl=psi_l[i]
10+
@assert norm(pc-pl)<1e-10 "Wrong MPS on rank $(mpi_rank()) for tensor $typ at site $i: norm difference=$(norm(pc-pl)) and dir=$dir pos=$pos"
11+
end
12+
end
13+
end
14+
end
15+
116
function MPSKit._localupdate_sweep_idmrg!::AbstractMPS, H::MPIOperator, envs, alg_eigsolve, ::IDMRG)
17+
E=0
218
C_old = ψ.C[0]
319
# left to right sweep
420
for pos in 1:length(ψ)
521
h = AC_hamiltonian(pos, ψ, H, ψ, envs)
6-
_, ψ.AC[pos] = fixedpoint(h, ψ.AC[pos], :SR, alg_eigsolve)
7-
if pos == length(ψ)
8-
# AC needed in next sweep
9-
ψ.AL[pos], ψ.C[pos] = mpi_execute_on_root_and_bcast(left_orth,ψ.AC[pos])
10-
else
11-
ψ.AL[pos], ψ.C[pos] = mpi_execute_on_root_and_bcast(left_orth!,ψ.AC[pos])
12-
end
22+
23+
ψ.AC[pos] = MPIHelper.bcast(fixedpoint(h, ψ.AC[pos], :SR, alg_eigsolve)[2], MPI.COMM_WORLD)
24+
ψ.AL[pos], ψ.C[pos] = mpi_execute_on_root_and_bcast(left_orth,ψ.AC[pos])
25+
1326
transfer_leftenv!(envs, ψ, H, ψ, pos + 1)
1427
end
1528

1629
# right to left sweep
1730
for pos in length(ψ):-1:1
1831
h = AC_hamiltonian(pos, ψ, H, ψ, envs)
19-
_, ψ.AC[pos] = fixedpoint(h, ψ.AC[pos], :SR, alg_eigsolve)
2032

21-
ψ.C[pos - 1], temp = mpi_right_orth!(_transpose_tail.AC[pos]; copy = (pos == 1)))
33+
E, ψ.AC[pos] = fixedpoint(h, ψ.AC[pos], :SR, alg_eigsolve)
34+
ψ.AC[pos] = MPIHelper.bcast.AC[pos], MPI.COMM_WORLD)
35+
ψ.C[pos - 1], temp = mpi_execute_on_root_and_bcast(right_orth!,_transpose_tail.AC[pos]; copy = (pos == 1)))
2236
ψ.AR[pos] = _transpose_front(temp)
2337

2438
transfer_rightenv!(envs, ψ, H, ψ, pos - 1)
2539
end
26-
return ψ, envs, C_old
40+
return ψ, envs, C_old, E
2741
end
2842

2943
function MPSKit._localupdate_sweep_idmrg!::AbstractMPS, H::MPIOperator, envs, alg_eigsolve, alg::IDMRG2)
@@ -93,7 +107,7 @@ function MPSKit._localupdate_sweep_idmrg!(ψ::AbstractMPS, H::MPIOperator, envs,
93107
ψ.AR[1] = _transpose_front.C[end] \ _transpose_tail.AC[1]))
94108
ac2 = AC2(ψ, 0; kind = :ACAR)
95109
h_ac2 = AC2_hamiltonian(0, ψ, H, ψ, envs)
96-
_, ac2′ = fixedpoint(h_ac2, ac2, :SR, alg_eigsolve)
110+
E, ac2′ = fixedpoint(h_ac2, ac2, :SR, alg_eigsolve)
97111
al, c, ar = mpi_execute_on_root_and_bcast(svd_trunc!, ac2′; trunc = alg.trscheme, alg = alg.alg_svd)
98112
normalize!(c)
99113

@@ -107,5 +121,5 @@ function MPSKit._localupdate_sweep_idmrg!(ψ::AbstractMPS, H::MPIOperator, envs,
107121
transfer_leftenv!(envs, ψ, H, ψ, 1)
108122
transfer_rightenv!(envs, ψ, H, ψ, 0)
109123

110-
return ψ, envs, C_old
124+
return ψ, envs, C_old, E
111125
end

src/algorithms/groundstate/vumps.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@ end
5151

5252
function MPSKit.gauge_step!(it::IterativeSolver{<:VUMPS}, state::VUMPSState{S, MPIOperator{O}, E}, ACs::AbstractVector) where {S, O, E}
5353
alg_gauge = updatetol(it.alg_gauge, state.iter, state.ϵ)
54-
println("Gauging!")
5554
if mpi_is_root()
5655
psi = InfiniteMPS(ACs, state.mps.C[end]; alg_gauge.tol, alg_gauge.maxiter)
5756
else

0 commit comments

Comments
 (0)