3535import torch
3636from torch .autograd import gradcheck
3737import torch_harmonics as th
38+ from torch_harmonics .quadrature import precompute_latitudes
3839
39- from testutils import set_seed , compare_tensors
40+ from testutils import disable_tf32 , set_seed , compare_tensors
4041
4142_devices = [(torch .device ("cpu" ),)]
4243if torch .cuda .is_available ():
4344 _devices .append ((torch .device ("cuda" ),))
4445
4546
47+ def random_sht_coeffs (batch_size , lmax , mmax , device , zero_l0 = False ):
48+ """Random scalar SHT coefficients with proper structure:
49+ m=0 column real (real-valued field), triangular support (m <= l),
50+ and optionally l=0 row zeroed (needed when testing gradient/curl)."""
51+ c = torch .randn (batch_size , lmax , mmax , dtype = torch .complex128 , device = device )
52+ c [:, :, 0 ] = c [:, :, 0 ].real
53+ for l in range (lmax ):
54+ if l + 1 < mmax :
55+ c [:, l , l + 1 :] = 0.0
56+ if zero_l0 :
57+ c [:, 0 , :] = 0.0
58+ return c
59+
60+
4661class TestLegendrePolynomials (unittest .TestCase ):
4762 """Test the associated Legendre polynomials (CPU/CUDA if available)."""
4863 def setUp (self ):
@@ -112,6 +127,10 @@ def test_forward_inverse(self, nlat, nlon, batch_size, norm, grid, atol, rtol, v
112127 if verbose :
113128 print (f"Testing real-valued SHT on { nlat } x{ nlon } { grid } grid with { norm } normalization on { self .device .type } device" )
114129
130+ # disable tf32
131+ disable_tf32 ()
132+
133+ # set seed
115134 set_seed (333 )
116135
117136 testiters = [1 , 2 , 4 , 8 , 16 ]
@@ -173,6 +192,10 @@ def test_grads(self, nlat, nlon, batch_size, norm, grid, atol, rtol, verbose=Fal
173192 if verbose :
174193 print (f"Testing gradients of real-valued SHT on { nlat } x{ nlon } { grid } grid with { norm } normalization" )
175194
195+ # disable tf32
196+ disable_tf32 ()
197+
198+ # set seed
176199 set_seed (333 )
177200
178201 if grid == "equiangular" :
@@ -203,6 +226,73 @@ def test_grads(self, nlat, nlon, batch_size, norm, grid, atol, rtol, verbose=Fal
203226 test_result = gradcheck (err_handle , grad_input , eps = 1e-6 , atol = atol , rtol = rtol )
204227 self .assertTrue (test_result )
205228
229+ @parameterized .expand (
230+ [
231+ [32 , 64 , 32 , "ortho" , "equiangular" , 1e-9 , 1e-9 ],
232+ [32 , 64 , 32 , "ortho" , "legendre-gauss" , 1e-9 , 1e-9 ],
233+ [32 , 64 , 32 , "ortho" , "lobatto" , 1e-9 , 1e-9 ],
234+ [32 , 64 , 32 , "four-pi" , "equiangular" , 1e-9 , 1e-9 ],
235+ [32 , 64 , 32 , "four-pi" , "legendre-gauss" , 1e-9 , 1e-9 ],
236+ [32 , 64 , 32 , "four-pi" , "lobatto" , 1e-9 , 1e-9 ],
237+ [32 , 64 , 32 , "schmidt" , "equiangular" , 1e-9 , 1e-9 ],
238+ [32 , 64 , 32 , "schmidt" , "legendre-gauss" , 1e-9 , 1e-9 ],
239+ [32 , 64 , 32 , "schmidt" , "lobatto" , 1e-9 , 1e-9 ],
240+ ],
241+ skip_on_empty = True ,
242+ )
243+ def test_parseval (self , nlat , nlon , batch_size , norm , grid , atol , rtol , verbose = False ):
244+ """Parseval's theorem: the spatial L2 norm of isht(c) equals a weighted spectral norm
245+ of c. The spectral weights W_{l,m} depend on the normalization convention:
246+
247+ ortho: W_{l,m} = w_m (w_m = 1 for m=0, 2 for m>0)
248+ four-pi: W_{l,m} = w_m / (4*pi) (coefficients are sqrt(4*pi) larger)
249+ schmidt: W_{l,m} = w_m * (2*l+1) / (4*pi) (coefficients are sqrt(4*pi/(2*l+1)) larger)
250+
251+ In all cases: ||f||^2_{S^2} = sum_{l,m} W_{l,m} * |c_{l,m}|^2
252+ """
253+ if verbose :
254+ print (f"Testing Parseval's theorem on { nlat } x{ nlon } { grid } grid with { norm } normalization on { self .device .type } " )
255+
256+ # disable tf32
257+ disable_tf32 ()
258+
259+ # set seed
260+ set_seed (333 )
261+
262+ isht = th .InverseRealSHT (nlat , nlon , grid = grid , norm = norm ).to (self .device )
263+ lmax = isht .lmax
264+ mmax = isht .mmax
265+
266+ with torch .no_grad ():
267+ c = random_sht_coeffs (batch_size , lmax , mmax , self .device )
268+ f = isht (c ) # (batch, nlat, nlon)
269+
270+ # Spatial L2 norm via spherical quadrature: integral of f^2 over S^2
271+ _ , w_lat = precompute_latitudes (nlat , grid = grid )
272+ w_lat = w_lat .to (device = self .device , dtype = torch .float64 )
273+ dlon = 2.0 * math .pi / nlon
274+ spatial_norm_sq = torch .einsum ("bnl,n->b" , f ** 2 , w_lat ) * dlon # (batch,)
275+
276+ # Build the (lmax, mmax) spectral weight matrix W_{l,m}.
277+ # w_m accounts for the ±m folding in the real irfft (m=0: weight 1, m>0: weight 2).
278+ w_m = torch .ones (mmax , dtype = torch .float64 , device = self .device )
279+ w_m [1 :] = 2.0
280+
281+ if norm == "ortho" :
282+ # c_lm^ortho are the orthonormal coefficients; W_{l,m} = w_m
283+ W = w_m .unsqueeze (0 ).expand (lmax , mmax )
284+ elif norm == "four-pi" :
285+ # c_lm^{four-pi} = sqrt(4*pi) * c_lm^ortho => W_{l,m} = w_m / (4*pi)
286+ W = w_m .unsqueeze (0 ).expand (lmax , mmax ) / (4.0 * math .pi )
287+ elif norm == "schmidt" :
288+ # c_lm^{schmidt} = sqrt(4*pi / (2*l+1)) * c_lm^ortho => W_{l,m} = w_m * (2*l+1) / (4*pi)
289+ l_vals = torch .arange (lmax , dtype = torch .float64 , device = self .device )
290+ W = torch .outer (2.0 * l_vals + 1.0 , w_m ) / (4.0 * math .pi )
291+
292+ spectral_norm_sq = torch .einsum ("blm,lm->b" , c .abs () ** 2 , W ) # (batch,)
293+
294+ self .assertTrue (compare_tensors ("Parseval's theorem" , spatial_norm_sq , spectral_norm_sq , atol = atol , rtol = rtol , verbose = verbose ))
295+
206296 @parameterized .expand (
207297 [
208298 # even-even
@@ -217,6 +307,10 @@ def test_device_instantiation(self, nlat, nlon, norm, grid, atol, rtol, verbose=
217307 if verbose :
218308 print (f"Testing device instantiation of real-valued SHT on { nlat } x{ nlon } { grid } grid with { norm } normalization" )
219309
310+ # disable tf32
311+ disable_tf32 ()
312+
313+ # set seed
220314 set_seed (333 )
221315
222316 # init on cpu
@@ -232,5 +326,207 @@ def test_device_instantiation(self, nlat, nlon, norm, grid, atol, rtol, verbose=
232326 self .assertTrue (compare_tensors (f"isht weights" , isht_host .pct .cpu (), isht_device .pct .cpu (), atol = atol , rtol = rtol , verbose = verbose ))
233327
234328
329+ @parameterized_class (("device" ), _devices )
330+ class TestSphericalHarmonicsY (unittest .TestCase ):
331+ """Test fundamental properties of the real spherical harmonic basis functions.
332+
333+ InverseRealSHT with norm="ortho" synthesizes orthonormal basis functions on
334+ the sphere. Setting a single complex coefficient c_{l,m} = 1 synthesizes
335+
336+ f_{l,0} = Y_l^0(theta, phi) for m = 0
337+ f_{l,m,cos} ~ P_l^m(cos theta) * cos(m*phi) for m > 0 (c_{l,m} = 1+0j)
338+ f_{l,m,sin} ~ P_l^m(cos theta) * sin(m*phi) for m > 0 (c_{l,m} = 0+1j)
339+
340+ With ortho normalization the sphere inner products satisfy:
341+ <f_{l,0}, f_{l',0} > = delta_{ll'}
342+ <f_{l,m,cos}, f_{l',m',*}> = 2 * delta_{ll'} * delta_{mm'} for m, m' > 0
343+ <f_{l,m,sin}, f_{l',m',*}> = 2 * delta_{ll'} * delta_{mm'} for m, m' > 0
344+
345+ The factor of 2 for m > 0 arises because the real irfft folds the +m and -m
346+ modes together, doubling the amplitude of each mode.
347+ """
348+
349+ @parameterized .expand (
350+ [
351+ [12 , 24 , "legendre-gauss" , 1e-9 , 1e-9 ],
352+ [12 , 24 , "equiangular" , 1e-9 , 1e-9 ],
353+ [12 , 24 , "lobatto" , 1e-9 , 1e-9 ],
354+ ],
355+ skip_on_empty = True ,
356+ )
357+ def test_orthogonality (self , nlat , nlon , grid , atol , rtol , verbose = False ):
358+ """Verify that isht(norm="ortho") synthesizes mutually orthogonal basis
359+ functions and that the self inner-products equal 1 (m=0) or 2 (m>0)."""
360+ if verbose :
361+ print (f"Testing Y_lm orthogonality on { nlat } x{ nlon } { grid } grid on { self .device .type } " )
362+
363+ # disable tf32
364+ disable_tf32 ()
365+
366+ # set seed
367+ set_seed (333 )
368+
369+ if grid == "equiangular" :
370+ lmax = mmax = nlat // 2
371+ elif grid == "lobatto" :
372+ lmax = mmax = nlat - 1
373+ else :
374+ lmax = mmax = nlat
375+
376+ isht = th .InverseRealSHT (nlat , nlon , lmax = lmax , mmax = mmax , grid = grid , norm = "ortho" ).to (self .device )
377+
378+ # Build one coefficient tensor per real basis function.
379+ # For m = 0: one tensor with c[l, 0] = 1+0j (real mode only).
380+ # For m > 0: two tensors — c[l, m] = 1+0j (cos) and c[l, m] = 0+1j (sin).
381+ basis_list = []
382+ expected_diag = []
383+ for l in range (lmax ):
384+ for m in range (min (l + 1 , mmax )):
385+ c = torch .zeros (lmax , mmax , dtype = torch .complex128 )
386+ c [l , m ] = 1.0 + 0.0j
387+ basis_list .append (c )
388+ expected_diag .append (1.0 if m == 0 else 2.0 )
389+ if m > 0 :
390+ c = torch .zeros (lmax , mmax , dtype = torch .complex128 )
391+ c [l , m ] = 0.0 + 1.0j
392+ basis_list .append (c )
393+ expected_diag .append (2.0 )
394+
395+ coeffs = torch .stack (basis_list ).to (self .device ) # (N, lmax, mmax)
396+
397+ with torch .no_grad ():
398+ funcs = isht (coeffs ) # (N, nlat, nlon), real-valued
399+
400+ # Gram matrix via spherical quadrature: G[i,j] = integral of f_i * f_j over S^2
401+ # Weights from precompute_latitudes are in the cos(theta) domain and integrate
402+ # over [-1, 1], so the full measure is w_lat[k] * dlon.
403+ _ , w_lat = precompute_latitudes (nlat , grid = grid )
404+ w_lat = w_lat .to (device = self .device , dtype = torch .float64 )
405+ dlon = 2.0 * math .pi / nlon
406+
407+ weighted = funcs * (dlon * w_lat ).unsqueeze (- 1 ) # (N, nlat, nlon)
408+ gram = torch .einsum ("inl,jnl->ij" , weighted , funcs ) # (N, N)
409+
410+ expected = torch .diag (
411+ torch .tensor (expected_diag , dtype = torch .float64 , device = self .device )
412+ )
413+ self .assertTrue (compare_tensors ("Gram matrix" , gram , expected , atol = atol , rtol = rtol , verbose = verbose ))
414+
415+
416+ @parameterized_class (("device" ), _devices )
417+ class TestVectorSphericalHarmonicTransform (unittest .TestCase ):
418+ """Tests for the consistency between the scalar SHT and the vector SHT.
419+
420+ RealVectorSHT includes a 1/(l*(l+1)) normalization in its quadrature weights
421+ so that the spheroidal/toroidal spectral coefficients relate directly to the
422+ scalar SHT coefficients of the generating potential:
423+
424+ Gradient: vsht(ivsht([c, 0]))[spheroidal] = c, [toroidal] = 0 (l > 0)
425+ Curl: vsht(ivsht([0, c]))[spheroidal] = 0, [toroidal] = c (l > 0)
426+
427+ The l = 0 mode is zero in both vsht and ivsht because the gradient and curl
428+ of a constant field (Y_0^0) vanish identically on the sphere.
429+
430+ These tests catch swapped spheroidal/toroidal channels, wrong signs in the
431+ dP/dtheta or P/sin(theta) terms, and incorrect l*(l+1) normalization.
432+ """
433+
434+ @parameterized .expand (
435+ [
436+ [32 , 64 , 16 , "ortho" , "legendre-gauss" , 1e-7 , 1e-7 ],
437+ [32 , 64 , 16 , "ortho" , "equiangular" , 1e-7 , 1e-7 ],
438+ [32 , 64 , 16 , "ortho" , "lobatto" , 1e-7 , 1e-7 ],
439+ [32 , 64 , 16 , "four-pi" , "legendre-gauss" , 1e-7 , 1e-7 ],
440+ [32 , 64 , 16 , "four-pi" , "equiangular" , 1e-7 , 1e-7 ],
441+ [32 , 64 , 16 , "four-pi" , "lobatto" , 1e-7 , 1e-7 ],
442+ # [32, 64, 16, "schmidt", "legendre-gauss", 1e-7, 1e-7],
443+ # [32, 64, 16, "schmidt", "equiangular", 1e-7, 1e-7],
444+ # [32, 64, 16, "schmidt", "lobatto", 1e-7, 1e-7],
445+ ],
446+ skip_on_empty = True ,
447+ )
448+ def test_gradient_consistency (self , nlat , nlon , batch_size , norm , grid , atol , rtol , verbose = True ):
449+ """ivsht([c, 0]) synthesizes the surface gradient ∇_S f of a scalar field
450+ f = isht(c). Applying vsht to this gradient field must recover c in the
451+ spheroidal channel and zero in the toroidal channel, because a gradient
452+ field is curl-free (purely spheroidal).
453+ """
454+ if verbose :
455+ print (f"Testing gradient consistency on { nlat } x{ nlon } { grid } grid with { norm } norm on { self .device .type } " )
456+
457+ # disable tf32
458+ disable_tf32 ()
459+
460+ # set seed
461+ set_seed (333 )
462+
463+ vsht = th .RealVectorSHT (nlat , nlon , grid = grid , norm = norm ).to (self .device )
464+ ivsht = th .InverseRealVectorSHT (nlat , nlon , grid = grid , norm = norm ).to (self .device )
465+ lmax , mmax = vsht .lmax , vsht .mmax
466+
467+ with torch .no_grad ():
468+ c = random_sht_coeffs (batch_size , lmax , mmax , self .device , zero_l0 = True )
469+ zeros = torch .zeros_like (c )
470+
471+ # synthesize gradient field: ivsht([c, 0]) = ∇_S f
472+ grad_f = ivsht (torch .stack ([c , zeros ], dim = - 3 )) # (batch, 2, nlat, nlon)
473+
474+ # analyse: vsht(∇_S f) must give [c, 0]
475+ st = vsht (grad_f ) # (batch, 2, lmax, mmax)
476+ s = st [..., 0 , :, :] # spheroidal
477+ t = st [..., 1 , :, :] # toroidal
478+
479+ self .assertTrue (compare_tensors ("spheroidal coefficients" , s , c , atol = atol , rtol = rtol , verbose = verbose ))
480+ self .assertTrue (compare_tensors ("toroidal coefficients" , t , zeros , atol = atol , rtol = rtol , verbose = verbose ))
481+
482+ @parameterized .expand (
483+ [
484+ [32 , 64 , 16 , "ortho" , "legendre-gauss" , 1e-7 , 1e-7 ],
485+ [32 , 64 , 16 , "ortho" , "equiangular" , 1e-7 , 1e-7 ],
486+ [32 , 64 , 16 , "ortho" , "lobatto" , 1e-7 , 1e-7 ],
487+ [32 , 64 , 16 , "four-pi" , "legendre-gauss" , 1e-7 , 1e-7 ],
488+ [32 , 64 , 16 , "four-pi" , "equiangular" , 1e-7 , 1e-7 ],
489+ [32 , 64 , 16 , "four-pi" , "lobatto" , 1e-7 , 1e-7 ],
490+ # [32, 64, 16, "schmidt", "legendre-gauss", 1e-9, 1e-9],
491+ # [32, 64, 16, "schmidt", "equiangular", 1e-9, 1e-9],
492+ # [32, 64, 16, "schmidt", "lobatto", 1e-9, 1e-9],
493+ ],
494+ skip_on_empty = True ,
495+ )
496+ def test_curl_consistency (self , nlat , nlon , batch_size , norm , grid , atol , rtol , verbose = False ):
497+ """ivsht([0, c]) synthesizes the surface curl ê_r × ∇_S f of a scalar field
498+ f = isht(c). Applying vsht to this curl field must recover c in the
499+ toroidal channel and zero in the spheroidal channel, because a surface
500+ curl field is divergence-free (purely toroidal).
501+ """
502+ if verbose :
503+ print (f"Testing curl consistency on { nlat } x{ nlon } { grid } grid with { norm } norm on { self .device .type } " )
504+
505+ # disable tf32
506+ disable_tf32 ()
507+
508+ # set seed
509+ set_seed (333 )
510+
511+ vsht = th .RealVectorSHT (nlat , nlon , grid = grid , norm = norm ).to (self .device )
512+ ivsht = th .InverseRealVectorSHT (nlat , nlon , grid = grid , norm = norm ).to (self .device )
513+ lmax , mmax = vsht .lmax , vsht .mmax
514+
515+ with torch .no_grad ():
516+ c = random_sht_coeffs (batch_size , lmax , mmax , self .device , zero_l0 = True )
517+ zeros = torch .zeros_like (c )
518+
519+ # synthesize curl field: ivsht([0, c]) = ê_r × ∇_S f
520+ curl_f = ivsht (torch .stack ([zeros , c ], dim = - 3 )) # (batch, 2, nlat, nlon)
521+
522+ # analyse: vsht(ê_r × ∇_S f) must give [0, c]
523+ st = vsht (curl_f ) # (batch, 2, lmax, mmax)
524+ s = st [..., 0 , :, :] # spheroidal
525+ t = st [..., 1 , :, :] # toroidal
526+
527+ self .assertTrue (compare_tensors ("spheroidal coefficients" , s , zeros , atol = atol , rtol = rtol , verbose = verbose ))
528+ self .assertTrue (compare_tensors ("toroidal coefficients" , t , c , atol = atol , rtol = rtol , verbose = verbose ))
529+
530+
235531if __name__ == "__main__" :
236532 unittest .main ()
0 commit comments