1010import jax .debug as jdebug
1111
1212from varipeps import varipeps_config , varipeps_global_state
13+ from varipeps .config import Grad_Fixed_Point_Method
1314from varipeps .peps import PEPS_Tensor , PEPS_Tensor_Split_Transfer , PEPS_Unit_Cell
1415from varipeps .utils .debug_print import debug_print
1516from .absorption import do_absorption_step , do_absorption_step_split_transfer
@@ -1033,9 +1034,7 @@ def _ctmrg_rev_workhorse(peps_tensors, new_unitcell, new_unitcell_bar, config, s
10331034 )[1 ]
10341035 )
10351036
1036- old_method = False
1037-
1038- if old_method :
1037+ if config .ad_custom_fixed_point_method is Grad_Fixed_Point_Method .ITERATIVE :
10391038
10401039 def cond_func (carry ):
10411040 _ , _ , _ , converged , count , config , state = carry
@@ -1048,18 +1047,21 @@ def cond_func(carry):
10481047 (vjp_env , new_unitcell_bar , new_unitcell_bar , False , 0 , config , state ),
10491048 )
10501049 else :
1051- env_fixed_point , arnoldi_worked = jax .pure_callback (
1052- _ctmrg_rev_arnoldi ,
1053- jax .eval_shape (lambda x : (x , True ), new_unitcell_bar ),
1054- vjp (
1055- lambda u : do_absorption_step (peps_tensors , u , config , state ),
1056- new_unitcell ,
1057- )[1 ],
1058- new_unitcell_bar ,
1059- )
1060- end_count = 0
1050+ if config .ad_custom_fixed_point_method is Grad_Fixed_Point_Method .EIGEN_SOLVER :
1051+ env_fixed_point , arnoldi_worked = jax .pure_callback (
1052+ _ctmrg_rev_arnoldi ,
1053+ jax .eval_shape (lambda x : (x , True ), new_unitcell_bar ),
1054+ vjp (
1055+ lambda u : do_absorption_step (peps_tensors , u , config , state ),
1056+ new_unitcell ,
1057+ )[1 ],
1058+ new_unitcell_bar ,
1059+ )
1060+ else :
1061+ env_fixed_point = new_unitcell_bar
1062+ arnoldi_worked = False
10611063
1062- debug_print ( "Arnoldi: {}" , arnoldi_worked )
1064+ end_count = 0
10631065
10641066 def run_gmres (v , e ):
10651067 def f_gmres (w ):
0 commit comments