@@ -165,10 +165,33 @@ def test_svd(self, core_shape, full_matrix, compute_uv, batched, test_imag):
165165 pt_outputs = fn (a )
166166
167167 np_outputs = np_outputs if isinstance (np_outputs , tuple ) else [np_outputs ]
168+ if compute_uv :
169+ # In this case we sometimes get a sign flip on some columns in one impl and not the thore
170+ # The results are both correct, and we test that by reconstructing the original input
171+ U , S , Vh = pt_outputs
172+ S_diag = np .expand_dims (S , - 2 ) * np .eye (S .shape [- 1 ])
173+
174+ diff = a .shape [- 2 ] - a .shape [- 1 ]
175+ if full_matrix :
176+ if diff > 0 :
177+ # tall
178+ S_diag = np .pad (S_diag , [(0 , 0 ), (0 , diff ), (0 , 0 )][- a .ndim :])
179+ elif diff < 0 :
180+ # wide
181+ S_diag = np .pad (S_diag , [(0 , 0 ), (0 , 0 ), (0 , - diff )][- a .ndim :])
182+
183+ a_r = U @ S_diag @ Vh
184+ rtol = 1e-3 if config .floatX == "float32" else 1e-7
185+ np .testing .assert_allclose (a_r , a , rtol = rtol )
186+
187+ for np_val , pt_val in zip (np_outputs , pt_outputs , strict = True ):
188+ # Check values are equivalent up to sign change
189+ np .testing .assert_allclose (np .abs (np_val ), np .abs (pt_val ), rtol = rtol )
168190
169- rtol = 1e-5 if config .floatX == "float32" else 1e-7
170- for np_val , pt_val in zip (np_outputs , pt_outputs , strict = True ):
171- np .testing .assert_allclose (np_val , pt_val , rtol = rtol )
191+ else :
192+ rtol = 1e-5 if config .floatX == "float32" else 1e-7
193+ for np_val , pt_val in zip (np_outputs , pt_outputs , strict = True ):
194+ np .testing .assert_allclose (np_val , pt_val , rtol = rtol )
172195
173196 def test_svd_infer_shape (self ):
174197 self .validate_shape ((4 , 4 ), full_matrices = True , compute_uv = True )
@@ -428,8 +451,8 @@ def test_eval(self):
428451
429452 w , v = fn (A_val )
430453 w_np , v_np = np .linalg .eig (A_val )
431- np .testing .assert_allclose (w , w_np )
432- np .testing .assert_allclose (v , v_np )
454+ np .testing .assert_allclose (np . abs ( w ), np . abs ( w_np ) )
455+ np .testing .assert_allclose (np . abs ( v ), np . abs ( v_np ) )
433456 assert_array_almost_equal (np .dot (A_val , v ), w * v )
434457
435458 # Asymmetric input (real eigenvalues)
@@ -438,16 +461,16 @@ def test_eval(self):
438461
439462 w , v = fn (A_val )
440463 w_np , v_np = np .linalg .eig (A_val )
441- np .testing .assert_allclose (w , w_np )
442- np .testing .assert_allclose (v , v_np )
464+ np .testing .assert_allclose (np . abs ( w ), np . abs ( w_np ) )
465+ np .testing .assert_allclose (np . abs ( v ), np . abs ( v_np ) )
443466 assert_array_almost_equal (np .dot (A_val , v ), w * v )
444467
445468 # Asymmetric input (complex eigenvalues)
446469 A_val = self .rng .normal (size = (5 , 5 ))
447470 w , v = fn (A_val )
448471 w_np , v_np = np .linalg .eig (A_val )
449- np .testing .assert_allclose (w , w_np )
450- np .testing .assert_allclose (v , v_np )
472+ np .testing .assert_allclose (np . abs ( w ), np . abs ( w_np ) )
473+ np .testing .assert_allclose (np . abs ( v ), np . abs ( v_np ) )
451474 assert_array_almost_equal (np .dot (A_val , v ), w * v )
452475
453476
@@ -464,11 +487,14 @@ def test_eval(self):
464487
465488 w , v = fn (A_val )
466489 w_np , v_np = np .linalg .eigh (A_val )
467- np . testing . assert_allclose ( w , w_np )
468- np . testing . assert_allclose ( v , v_np )
490+ # There are multiple valid results up to some sign changes
491+ # Check we can reconstruct input
469492 rtol = 1e-2 if self .dtype == "float32" else 1e-7
470493 np .testing .assert_allclose (np .dot (A_val , v ), w * v , rtol = rtol )
471494
495+ np .testing .assert_allclose (np .abs (w ), np .abs (w_np ), rtol = rtol )
496+ np .testing .assert_allclose (np .abs (v ), np .abs (v_np ), rtol = rtol )
497+
472498 def test_uplo (self ):
473499 S = self .S
474500 a = matrix (dtype = self .dtype )
0 commit comments