diff --git a/pyqmri/solver.py b/pyqmri/solver.py index 07c0feb..081e07c 100644 --- a/pyqmri/solver.py +++ b/pyqmri/solver.py @@ -165,8 +165,8 @@ def run(self, data, iters=30, lambd=1e-5, tol=1e-8, guess=None, for i in range(iters): self._operator_lhs(Ax, p) Ax = Ax + lambd*p - alpha = (clarray.vdot(res, res) / - (clarray.vdot(p, Ax))).real.get() + alpha = float((clarray.vdot(res, res) / + (clarray.vdot(p, Ax))).real.get()) x = (x + alpha*p) res_new = res - alpha*Ax delta = np.linalg.norm(res_new.get())**2 /\ @@ -177,8 +177,8 @@ def run(self, data, iters=30, lambd=1e-5, tol=1e-8, guess=None, del Ax, \ b, res, p, data, res_new return np.squeeze(x.get()) - beta = (clarray.vdot(res_new, res_new) / - clarray.vdot(res, res)).real.get() + beta = float((clarray.vdot(res_new, res_new) / + clarray.vdot(res, res)).real.get()) p = res_new + beta * p (res, res_new) = (res_new, res) del Ax, b, res, p, data, res_new