Skip to content

Commit f576930

Browse files
committed
Test SVD and Eig(h): allow benign sign change
1 parent e65ecfe commit f576930

File tree

1 file changed

+37
-11
lines changed

1 file changed

+37
-11
lines changed

tests/tensor/test_nlinalg.py

Lines changed: 37 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -165,10 +165,33 @@ def test_svd(self, core_shape, full_matrix, compute_uv, batched, test_imag):
165165
pt_outputs = fn(a)
166166

167167
np_outputs = np_outputs if isinstance(np_outputs, tuple) else [np_outputs]
168+
if compute_uv:
169+
# In this case we sometimes get a sign flip on some columns in one impl and not the thore
170+
# The results are both correct, and we test that by reconstructing the original input
171+
U, S, Vh = pt_outputs
172+
S_diag = np.expand_dims(S, -2) * np.eye(S.shape[-1])
173+
174+
diff = a.shape[-2] - a.shape[-1]
175+
if full_matrix:
176+
if diff > 0:
177+
# tall
178+
S_diag = np.pad(S_diag, [(0, 0), (0, diff), (0, 0)][-a.ndim :])
179+
elif diff < 0:
180+
# wide
181+
S_diag = np.pad(S_diag, [(0, 0), (0, 0), (0, -diff)][-a.ndim :])
182+
183+
a_r = U @ S_diag @ Vh
184+
rtol = 1e-3 if config.floatX == "float32" else 1e-7
185+
np.testing.assert_allclose(a_r, a, rtol=rtol)
186+
187+
for np_val, pt_val in zip(np_outputs, pt_outputs, strict=True):
188+
# Check values are equivalent up to sign change
189+
np.testing.assert_allclose(np.abs(np_val), np.abs(pt_val), rtol=rtol)
168190

169-
rtol = 1e-5 if config.floatX == "float32" else 1e-7
170-
for np_val, pt_val in zip(np_outputs, pt_outputs, strict=True):
171-
np.testing.assert_allclose(np_val, pt_val, rtol=rtol)
191+
else:
192+
rtol = 1e-5 if config.floatX == "float32" else 1e-7
193+
for np_val, pt_val in zip(np_outputs, pt_outputs, strict=True):
194+
np.testing.assert_allclose(np_val, pt_val, rtol=rtol)
172195

173196
def test_svd_infer_shape(self):
174197
self.validate_shape((4, 4), full_matrices=True, compute_uv=True)
@@ -428,8 +451,8 @@ def test_eval(self):
428451

429452
w, v = fn(A_val)
430453
w_np, v_np = np.linalg.eig(A_val)
431-
np.testing.assert_allclose(w, w_np)
432-
np.testing.assert_allclose(v, v_np)
454+
np.testing.assert_allclose(np.abs(w), np.abs(w_np))
455+
np.testing.assert_allclose(np.abs(v), np.abs(v_np))
433456
assert_array_almost_equal(np.dot(A_val, v), w * v)
434457

435458
# Asymmetric input (real eigenvalues)
@@ -438,16 +461,16 @@ def test_eval(self):
438461

439462
w, v = fn(A_val)
440463
w_np, v_np = np.linalg.eig(A_val)
441-
np.testing.assert_allclose(w, w_np)
442-
np.testing.assert_allclose(v, v_np)
464+
np.testing.assert_allclose(np.abs(w), np.abs(w_np))
465+
np.testing.assert_allclose(np.abs(v), np.abs(v_np))
443466
assert_array_almost_equal(np.dot(A_val, v), w * v)
444467

445468
# Asymmetric input (complex eigenvalues)
446469
A_val = self.rng.normal(size=(5, 5))
447470
w, v = fn(A_val)
448471
w_np, v_np = np.linalg.eig(A_val)
449-
np.testing.assert_allclose(w, w_np)
450-
np.testing.assert_allclose(v, v_np)
472+
np.testing.assert_allclose(np.abs(w), np.abs(w_np))
473+
np.testing.assert_allclose(np.abs(v), np.abs(v_np))
451474
assert_array_almost_equal(np.dot(A_val, v), w * v)
452475

453476

@@ -464,11 +487,14 @@ def test_eval(self):
464487

465488
w, v = fn(A_val)
466489
w_np, v_np = np.linalg.eigh(A_val)
467-
np.testing.assert_allclose(w, w_np)
468-
np.testing.assert_allclose(v, v_np)
490+
# There are multiple valid results up to some sign changes
491+
# Check we can reconstruct input
469492
rtol = 1e-2 if self.dtype == "float32" else 1e-7
470493
np.testing.assert_allclose(np.dot(A_val, v), w * v, rtol=rtol)
471494

495+
np.testing.assert_allclose(np.abs(w), np.abs(w_np), rtol=rtol)
496+
np.testing.assert_allclose(np.abs(v), np.abs(v_np), rtol=rtol)
497+
472498
def test_uplo(self):
473499
S = self.S
474500
a = matrix(dtype=self.dtype)

0 commit comments

Comments
 (0)