Skip to content

Commit f1465fe

Browse files
committed
Add config setting for the grad fixed point method
1 parent 902cebe commit f1465fe

File tree

2 files changed

+34
-14
lines changed

2 files changed

+34
-14
lines changed

varipeps/config.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,17 @@ class Wavevector_Type(IntEnum):
4040
TWO_PI_SYMMETRIC = auto() #: Use interval [-2pi, 2pi) for q vectors
4141

4242

43+
@unique
44+
class Grad_Fixed_Point_Method(IntEnum):
45+
ITERATIVE = auto() #: Use iterative method to calculate gradient of CTMRG routine
46+
LINEAR_SOLVER = (
47+
auto()
48+
) #: Use linear solver method to calculate gradient of CTMRG routine
49+
EIGEN_SOLVER = (
50+
auto()
51+
) #: Use eigen solver method to calculate gradient of CTMRG routine
52+
53+
4354
@unique
4455
class Slurm_Restart_Mode(IntEnum):
4556
DISABLED = (
@@ -73,6 +84,9 @@ class VariPEPS_Config:
7384
ad_custom_max_steps (:obj:`int`):
7485
Maximal number of steps for fix-pointer iteration of the custom VJP
7586
function.
87+
ad_custom_fixed_point_method (:obj:`~varipeps.config.Grad_Fixed_Point_Method`):
88+
Select method how the gradient of the CTMRG fixed point routine is
89+
calculated.
7690
checkpointing_ncon (:obj:`bool`):
7791
Enable AD checkpointing for the ncon calls.
7892
checkpointing_projectors (:obj:`bool`):
@@ -242,6 +256,9 @@ class VariPEPS_Config:
242256
ad_custom_verbose_output: bool = False
243257
ad_custom_convergence_eps: float = 1e-7
244258
ad_custom_max_steps: int = 75
259+
ad_custom_fixed_point_method: Grad_Fixed_Point_Method = (
260+
Grad_Fixed_Point_Method.LINEAR_SOLVER
261+
)
245262
checkpointing_ncon: bool = False
246263
checkpointing_projectors: bool = False
247264

@@ -406,6 +423,7 @@ class ConfigModuleWrapper:
406423
"Line_Search_Methods",
407424
"Projector_Method",
408425
"Wavevector_Type",
426+
"Grad_Fixed_Point_Method",
409427
"Slurm_Restart_Mode",
410428
"VariPEPS_Config",
411429
"config",

varipeps/ctmrg/routine.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import jax.debug as jdebug
1111

1212
from varipeps import varipeps_config, varipeps_global_state
13+
from varipeps.config import Grad_Fixed_Point_Method
1314
from varipeps.peps import PEPS_Tensor, PEPS_Tensor_Split_Transfer, PEPS_Unit_Cell
1415
from varipeps.utils.debug_print import debug_print
1516
from .absorption import do_absorption_step, do_absorption_step_split_transfer
@@ -1033,9 +1034,7 @@ def _ctmrg_rev_workhorse(peps_tensors, new_unitcell, new_unitcell_bar, config, s
10331034
)[1]
10341035
)
10351036

1036-
old_method = False
1037-
1038-
if old_method:
1037+
if config.ad_custom_fixed_point_method is Grad_Fixed_Point_Method.ITERATIVE:
10391038

10401039
def cond_func(carry):
10411040
_, _, _, converged, count, config, state = carry
@@ -1048,18 +1047,21 @@ def cond_func(carry):
10481047
(vjp_env, new_unitcell_bar, new_unitcell_bar, False, 0, config, state),
10491048
)
10501049
else:
1051-
env_fixed_point, arnoldi_worked = jax.pure_callback(
1052-
_ctmrg_rev_arnoldi,
1053-
jax.eval_shape(lambda x: (x, True), new_unitcell_bar),
1054-
vjp(
1055-
lambda u: do_absorption_step(peps_tensors, u, config, state),
1056-
new_unitcell,
1057-
)[1],
1058-
new_unitcell_bar,
1059-
)
1060-
end_count = 0
1050+
if config.ad_custom_fixed_point_method is Grad_Fixed_Point_Method.EIGEN_SOLVER:
1051+
env_fixed_point, arnoldi_worked = jax.pure_callback(
1052+
_ctmrg_rev_arnoldi,
1053+
jax.eval_shape(lambda x: (x, True), new_unitcell_bar),
1054+
vjp(
1055+
lambda u: do_absorption_step(peps_tensors, u, config, state),
1056+
new_unitcell,
1057+
)[1],
1058+
new_unitcell_bar,
1059+
)
1060+
else:
1061+
env_fixed_point = new_unitcell_bar
1062+
arnoldi_worked = False
10611063

1062-
debug_print("Arnoldi: {}", arnoldi_worked)
1064+
end_count = 0
10631065

10641066
def run_gmres(v, e):
10651067
def f_gmres(w):

0 commit comments

Comments
 (0)