Skip to content

Commit 8b788a7

Browse files
committed
Merge branch 'master' of github.com:Hugo-W/pyEEG
2 parents b5d6251 + 9502796 commit 8b788a7

1 file changed

Lines changed: 171 additions & 0 deletions

File tree

solver.py

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
import numpy as np
2+
from scipy.sparse.linalg import spilu
3+
from scipy.sparse import csc_matrix
4+
5+
def svd_solver(A, b, lambda_=0., truncated_svd=False, verbose=False):
6+
"""
7+
Solve the linear system Ax = b using the SVD method.
8+
9+
This method assunes that we are solving the normal equation:
10+
(X^T X + lambda I) x = X^T y
11+
Thus, A = X^T X and b = X^T y.
12+
13+
Parameters:
14+
A : ndarray
15+
Matrix A. Typically of shape (n_features * n_lags, n_features * n_lags) in the context of TRF.
16+
b : ndarray
17+
Right-hand side vector. Typically of shape (n_features * n_lags, n_outputs) in the context of TRF.
18+
lambda_ : float, optional
19+
Regularization parameter.
20+
truncated_svd : bool, optional
21+
Whether to use the truncated SVD method. If True, lambda_ must be between 0 and 1;
22+
it represents the fraction of the total variance to keep.
23+
24+
Returns:
25+
x : ndarray
26+
Solution vector.
27+
"""
28+
# Check symmetricity of A
29+
assert np.allclose(A, A.T), 'Matrix A must be symmetric'
30+
U, s, Vt = np.linalg.svd(A, full_matrices=False, hermitian=True)
31+
if truncated_svd:
32+
assert 0 < lambda_ < 1
33+
n_components = np.sum(np.cumsum(s) / np.sum(s) < lambda_) + 1
34+
35+
if verbose:
36+
print(f'Keeping {n_components} components (out of {len(s)})')
37+
print(f'Variance explained: {s[:n_components].sum() / s.sum()}')
38+
print(f"Singular values: {s[:n_components]}")
39+
U = U[:, :n_components]
40+
s = s[:n_components]
41+
Vt = Vt[:n_components, :]
42+
lambda_ = 0.
43+
s_inv = np.diag(1 / (s + lambda_))
44+
return Vt.T @ s_inv @ U.T @ b
45+
46+
def incomplete_cholesky_preconditioner(A):
47+
"""
48+
Compute the Incomplete Cholesky preconditioner for matrix A.
49+
50+
Parameters:
51+
A : ndarray
52+
Symmetric positive-definite matrix.
53+
54+
Returns:
55+
M_inv : function
56+
Function that applies the preconditioner.
57+
"""
58+
A_sparse = csc_matrix(A)
59+
ilu = spilu(A_sparse)
60+
M_inv = lambda x: ilu.solve(x)
61+
return M_inv
62+
63+
def diagonal_preconditioner(A):
64+
"""
65+
Compute the Diagonal preconditioner for matrix A.
66+
67+
Parameters:
68+
A : ndarray
69+
Symmetric positive-definite matrix.
70+
71+
Returns:
72+
M_inv : function
73+
Function that applies the preconditioner.
74+
"""
75+
diag = np.diag(A)
76+
M_inv = lambda x: x / diag
77+
return M_inv
78+
79+
def conjugate_gradient(A, b, x0=None, tol=1e-10, max_iter=None, lambda_=0., preconditioner=None, verbose=False):
80+
"""
81+
Solve the linear system Ax = b using the Conjugate Gradient method. A must be square, symmetric and positive-definite.
82+
83+
Parameters:
84+
A : ndarray
85+
Symmetric positive-definite matrix.
86+
b : ndarray
87+
Right-hand side vector.
88+
x0 : ndarray, optional
89+
Initial guess for the solution.
90+
tol : float, optional
91+
Tolerance for convergence.
92+
max_iter : int, optional
93+
Maximum number of iterations.
94+
lambda_ : float, optional
95+
Regularization parameter (Tikhonov regularization).
96+
preconditioner : function, optional
97+
Function that applies the preconditioner (e.g. Incomplete Cholesky or Diagonal).
98+
The function must take a vector as input and return the preconditioned vector.
99+
100+
Returns:
101+
x : ndarray
102+
Solution vector.
103+
104+
Note:
105+
The Conjugate Gradient method is an iterative method that solves the linear system Ax = b. If A is not a square matrix
106+
we request the user to fall back on the normal equation (X^T X + lambda I) x = X^T y, where A = X^T X and b = X^T y,
107+
which is then solvable using the CG method.
108+
"""
109+
assert A.shape[0] == A.shape[1], 'Matrix A must be square, please use the normal equation (X^T X) beta = X^T y, with A = X^T X and b = X^T y'
110+
n = len(b)
111+
if x0 is None:
112+
x0 = np.zeros(n)
113+
if max_iter is None:
114+
max_iter = n
115+
116+
if lambda_ > 0:
117+
A = A + lambda_ * np.eye(n) # Tikhonov regularization
118+
119+
# Preconditioner
120+
if preconditioner is not None:
121+
M_inv = preconditioner(A)
122+
else:
123+
M_inv = lambda x: x
124+
125+
x = x0
126+
r = b - A @ x
127+
z = M_inv(r)
128+
p = z
129+
rs_old = np.dot(r, z)
130+
131+
for i in range(max_iter):
132+
Ap = A @ p
133+
alpha = rs_old / np.dot(p, Ap)
134+
x = x + alpha * p
135+
r = r - alpha * Ap
136+
z = M_inv(r)
137+
rs_new = np.dot(r, z)
138+
139+
if np.sqrt(rs_new) < tol:
140+
if verbose: print(f'Converged in {i+1} iterations')
141+
return x
142+
143+
p = z + (rs_new / rs_old) * p
144+
rs_old = rs_new
145+
146+
if verbose: print(f'Did not converge; reached max iterations ({max_iter})')
147+
148+
return x
149+
150+
# Example usage
151+
if __name__ == "__main__":
152+
A = np.array([[4, 1], [1, 3]])
153+
b = np.array([1, 2])
154+
A = np.random.rand(5, 5)
155+
# A bit of multilinerity in A, slightly rank deficient:
156+
A[0] = A[2] * 0.1 + np.random.rand(5)
157+
A = A @ A.T
158+
b = np.random.rand(5)
159+
x0 = np.zeros_like(b)
160+
# Compare with pseudo-inverse solution
161+
cg_solution = conjugate_gradient(A, b, x0, lambda_=.00001)
162+
pseudo_inverse_solution = np.linalg.pinv(A) @ b
163+
svd_solution = svd_solver(A, b, lambda_=0.00001)
164+
svd_truncated_solution = svd_solver(A, b, lambda_=1-1e-8, truncated_svd=True, verbose=True)
165+
print("CG Solution: \t", cg_solution)
166+
print("Pseudo-inverse Solution:\t", pseudo_inverse_solution)
167+
print("SVD Solution: \t", svd_solution)
168+
print("SVD Truncated Solution: \t", svd_truncated_solution)
169+
# They should be equal
170+
171+

0 commit comments

Comments
 (0)