Skip to content

Commit 902cebe

Browse files
committed
Implement Arnoldi-based version of fixed point gradient
1 parent c9fb181 commit 902cebe

File tree

1 file changed

+90
-21
lines changed

1 file changed

+90
-21
lines changed

varipeps/ctmrg/routine.py

Lines changed: 90 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
from functools import partial
22
import enum
33

4+
import numpy as np
5+
from scipy.sparse.linalg import LinearOperator, eigs
6+
47
import jax.numpy as jnp
58
from jax import jit, custom_vjp, vjp, tree_util
69
from jax.lax import cond, while_loop
@@ -945,6 +948,51 @@ def _ctmrg_rev_while_body(carry):
945948
return vjp_env, initial_bar, bar_fixed_point, converged, count, config, state
946949

947950

951+
def _ctmrg_rev_arnoldi(vjp_operator, initial_v):
952+
v, v_treedev = jax.tree.flatten(initial_v)
953+
v_flat = np.concatenate([i.reshape(-1) for i in v])
954+
955+
def matvec(vec):
956+
new_vec = [None] * len(v)
957+
for i in range(len(v)):
958+
i_start = sum(j.size for j in v[:i])
959+
elems = slice(i_start, i_start + v[i].size)
960+
new_vec[i] = vec[elems].astype(v[i].dtype).reshape(v[i].shape)
961+
new_vec = jax.tree.unflatten(v_treedev, new_vec)
962+
963+
new_vec = vjp_operator((new_vec, jnp.array(0, dtype=jnp.float64)))[0]
964+
965+
new_vec, _ = jax.tree.flatten(new_vec)
966+
new_vec = np.concatenate([i.reshape(-1) for i in new_vec])
967+
968+
return np.append(new_vec + vec[-1] * v_flat, vec[-1])
969+
970+
lin_op = LinearOperator(
971+
(v_flat.shape[0] + 1, v_flat.shape[0] + 1),
972+
matvec=matvec,
973+
)
974+
975+
_, vec = eigs(
976+
lin_op, k=1, v0=np.append(v_flat, np.array(1, dtype=v_flat.dtype)), which="LM"
977+
)
978+
979+
vec = vec.reshape(-1)
980+
981+
if np.abs(vec[-1]) >= 1e-10:
982+
vec /= vec[-1]
983+
984+
result = [None] * len(v)
985+
for i in range(len(v)):
986+
i_start = sum(j.size for j in v[:i])
987+
elems = slice(i_start, i_start + v[i].size)
988+
result[i] = vec[elems].astype(v[i].dtype).reshape(v[i].shape)
989+
990+
if np.abs(vec[-1]) < 1e-10:
991+
return jax.tree.unflatten(v_treedev, result), False
992+
993+
return jax.tree.unflatten(v_treedev, result), True
994+
995+
948996
@jit
949997
def _ctmrg_rev_workhorse(peps_tensors, new_unitcell, new_unitcell_bar, config, state):
950998
if new_unitcell.is_triangular_peps():
@@ -988,6 +1036,7 @@ def _ctmrg_rev_workhorse(peps_tensors, new_unitcell, new_unitcell_bar, config, s
9881036
old_method = False
9891037

9901038
if old_method:
1039+
9911040
def cond_func(carry):
9921041
_, _, _, converged, count, config, state = carry
9931042

@@ -999,30 +1048,50 @@ def cond_func(carry):
9991048
(vjp_env, new_unitcell_bar, new_unitcell_bar, False, 0, config, state),
10001049
)
10011050
else:
1002-
def f(w):
1003-
new_w = vjp_env((w, jnp.array(0, dtype=jnp.float64)))[0]
1004-
1005-
new_w = new_w.replace_unique_tensors(
1006-
[
1007-
t_old.__sub__(t_new, checks=False)
1008-
for t_old, t_new in zip(
1009-
w.get_unique_tensors(),
1010-
new_w.get_unique_tensors(),
1011-
strict=True,
1012-
)
1013-
]
1014-
)
1051+
env_fixed_point, arnoldi_worked = jax.pure_callback(
1052+
_ctmrg_rev_arnoldi,
1053+
jax.eval_shape(lambda x: (x, True), new_unitcell_bar),
1054+
vjp(
1055+
lambda u: do_absorption_step(peps_tensors, u, config, state),
1056+
new_unitcell,
1057+
)[1],
1058+
new_unitcell_bar,
1059+
)
1060+
end_count = 0
1061+
1062+
debug_print("Arnoldi: {}", arnoldi_worked)
1063+
1064+
def run_gmres(v, e):
1065+
def f_gmres(w):
1066+
new_w = vjp_env((w, jnp.array(0, dtype=jnp.float64)))[0]
1067+
1068+
new_w = new_w.replace_unique_tensors(
1069+
[
1070+
t_old.__sub__(t_new, checks=False)
1071+
for t_old, t_new in zip(
1072+
w.get_unique_tensors(),
1073+
new_w.get_unique_tensors(),
1074+
strict=True,
1075+
)
1076+
]
1077+
)
10151078

1016-
return new_w
1079+
return new_w
10171080

1018-
is_gpu = jax.default_backend() == "gpu"
1081+
is_gpu = jax.default_backend() == "gpu"
10191082

1020-
env_fixed_point, end_count = jax.scipy.sparse.linalg.gmres(
1021-
f,
1022-
new_unitcell_bar,
1023-
new_unitcell_bar,
1024-
solve_method="batched" if is_gpu else "incremental",
1025-
atol=config.ad_custom_convergence_eps,
1083+
v, e = jax.scipy.sparse.linalg.gmres(
1084+
f_gmres,
1085+
new_unitcell_bar,
1086+
new_unitcell_bar,
1087+
solve_method="batched" if is_gpu else "incremental",
1088+
atol=config.ad_custom_convergence_eps,
1089+
)
1090+
1091+
return v, e
1092+
1093+
env_fixed_point, end_count = jax.lax.cond(
1094+
arnoldi_worked, lambda x, e: (x, e), run_gmres, env_fixed_point, end_count
10261095
)
10271096

10281097
converged = True

0 commit comments

Comments
 (0)