Skip to content

Commit 7f72f76

Browse files
committed
updated tests
1 parent 169da9e commit 7f72f76

2 files changed

Lines changed: 314 additions & 2 deletions

File tree

tests/test_attention.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
from torch_harmonics.attention._attention_utils import _neighborhood_s2_attention_torch
4343
from torch_harmonics.attention import cuda_kernels_is_available, optimized_kernels_is_available
4444

45-
from testutils import set_seed, compare_tensors
45+
from testutils import disable_tf32, set_seed, compare_tensors
4646

4747
if not optimized_kernels_is_available():
4848
print(f"Warning: Couldn't import optimized disco convolution kernels")
@@ -93,6 +93,10 @@ def test_custom_implementation(self, batch_size, channels, channels_out, heads,
9393
if (self.device.type == "cuda") and (not cuda_kernels_is_available()):
9494
raise unittest.SkipTest("skipping test because CUDA kernels are not available")
9595

96+
# disable tf32
97+
disable_tf32()
98+
99+
# set seed
96100
set_seed(333)
97101

98102
nlat_in, nlon_in = in_shape
@@ -159,6 +163,10 @@ def test_device_vs_cpu(self, batch_size, channels, heads, in_shape, out_shape, g
159163
# comparing CPU with itself does not make sense
160164
return
161165

166+
# disable tf32
167+
disable_tf32()
168+
169+
# set seed
162170
set_seed(333)
163171

164172
nlat_in, nlon_in = in_shape
@@ -228,6 +236,10 @@ def test_neighborhood_global_equivalence(self, batch_size, channels, channels_ou
228236
if (self.device.type == "cuda") and (not cuda_kernels_is_available()):
229237
raise unittest.SkipTest("skipping test because CUDA kernels are not available")
230238

239+
# disable tf32
240+
disable_tf32()
241+
242+
# set seed
231243
set_seed(333)
232244

233245
nlat_in, nlon_in = in_shape
@@ -333,6 +345,10 @@ def test_perf(self, batch_size, channels, heads, in_shape, out_shape, grid_in, g
333345
if (self.device.type == "cuda") and (not cuda_kernels_is_available()):
334346
raise unittest.SkipTest("skipping test because CUDA kernels are not available")
335347

348+
# disable tf32
349+
disable_tf32()
350+
351+
# set seed
336352
set_seed(333)
337353

338354
# extract some parameters

tests/test_sht.py

Lines changed: 297 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,29 @@
3535
import torch
3636
from torch.autograd import gradcheck
3737
import torch_harmonics as th
38+
from torch_harmonics.quadrature import precompute_latitudes
3839

39-
from testutils import set_seed, compare_tensors
40+
from testutils import disable_tf32, set_seed, compare_tensors
4041

4142
_devices = [(torch.device("cpu"),)]
4243
if torch.cuda.is_available():
4344
_devices.append((torch.device("cuda"),))
4445

4546

47+
def random_sht_coeffs(batch_size, lmax, mmax, device, zero_l0=False):
48+
"""Random scalar SHT coefficients with proper structure:
49+
m=0 column real (real-valued field), triangular support (m <= l),
50+
and optionally l=0 row zeroed (needed when testing gradient/curl)."""
51+
c = torch.randn(batch_size, lmax, mmax, dtype=torch.complex128, device=device)
52+
c[:, :, 0] = c[:, :, 0].real
53+
for l in range(lmax):
54+
if l + 1 < mmax:
55+
c[:, l, l + 1:] = 0.0
56+
if zero_l0:
57+
c[:, 0, :] = 0.0
58+
return c
59+
60+
4661
class TestLegendrePolynomials(unittest.TestCase):
4762
"""Test the associated Legendre polynomials (CPU/CUDA if available)."""
4863
def setUp(self):
@@ -112,6 +127,10 @@ def test_forward_inverse(self, nlat, nlon, batch_size, norm, grid, atol, rtol, v
112127
if verbose:
113128
print(f"Testing real-valued SHT on {nlat}x{nlon} {grid} grid with {norm} normalization on {self.device.type} device")
114129

130+
# disable tf32
131+
disable_tf32()
132+
133+
# set seed
115134
set_seed(333)
116135

117136
testiters = [1, 2, 4, 8, 16]
@@ -173,6 +192,10 @@ def test_grads(self, nlat, nlon, batch_size, norm, grid, atol, rtol, verbose=Fal
173192
if verbose:
174193
print(f"Testing gradients of real-valued SHT on {nlat}x{nlon} {grid} grid with {norm} normalization")
175194

195+
# disable tf32
196+
disable_tf32()
197+
198+
# set seed
176199
set_seed(333)
177200

178201
if grid == "equiangular":
@@ -203,6 +226,73 @@ def test_grads(self, nlat, nlon, batch_size, norm, grid, atol, rtol, verbose=Fal
203226
test_result = gradcheck(err_handle, grad_input, eps=1e-6, atol=atol, rtol=rtol)
204227
self.assertTrue(test_result)
205228

229+
@parameterized.expand(
230+
[
231+
[32, 64, 32, "ortho", "equiangular", 1e-9, 1e-9],
232+
[32, 64, 32, "ortho", "legendre-gauss", 1e-9, 1e-9],
233+
[32, 64, 32, "ortho", "lobatto", 1e-9, 1e-9],
234+
[32, 64, 32, "four-pi", "equiangular", 1e-9, 1e-9],
235+
[32, 64, 32, "four-pi", "legendre-gauss", 1e-9, 1e-9],
236+
[32, 64, 32, "four-pi", "lobatto", 1e-9, 1e-9],
237+
[32, 64, 32, "schmidt", "equiangular", 1e-9, 1e-9],
238+
[32, 64, 32, "schmidt", "legendre-gauss", 1e-9, 1e-9],
239+
[32, 64, 32, "schmidt", "lobatto", 1e-9, 1e-9],
240+
],
241+
skip_on_empty=True,
242+
)
243+
def test_parseval(self, nlat, nlon, batch_size, norm, grid, atol, rtol, verbose=False):
244+
"""Parseval's theorem: the spatial L2 norm of isht(c) equals a weighted spectral norm
245+
of c. The spectral weights W_{l,m} depend on the normalization convention:
246+
247+
ortho: W_{l,m} = w_m (w_m = 1 for m=0, 2 for m>0)
248+
four-pi: W_{l,m} = w_m / (4*pi) (coefficients are sqrt(4*pi) larger)
249+
schmidt: W_{l,m} = w_m * (2*l+1) / (4*pi) (coefficients are sqrt(4*pi/(2*l+1)) larger)
250+
251+
In all cases: ||f||^2_{S^2} = sum_{l,m} W_{l,m} * |c_{l,m}|^2
252+
"""
253+
if verbose:
254+
print(f"Testing Parseval's theorem on {nlat}x{nlon} {grid} grid with {norm} normalization on {self.device.type}")
255+
256+
# disable tf32
257+
disable_tf32()
258+
259+
# set seed
260+
set_seed(333)
261+
262+
isht = th.InverseRealSHT(nlat, nlon, grid=grid, norm=norm).to(self.device)
263+
lmax = isht.lmax
264+
mmax = isht.mmax
265+
266+
with torch.no_grad():
267+
c = random_sht_coeffs(batch_size, lmax, mmax, self.device)
268+
f = isht(c) # (batch, nlat, nlon)
269+
270+
# Spatial L2 norm via spherical quadrature: integral of f^2 over S^2
271+
_, w_lat = precompute_latitudes(nlat, grid=grid)
272+
w_lat = w_lat.to(device=self.device, dtype=torch.float64)
273+
dlon = 2.0 * math.pi / nlon
274+
spatial_norm_sq = torch.einsum("bnl,n->b", f ** 2, w_lat) * dlon # (batch,)
275+
276+
# Build the (lmax, mmax) spectral weight matrix W_{l,m}.
277+
# w_m accounts for the ±m folding in the real irfft (m=0: weight 1, m>0: weight 2).
278+
w_m = torch.ones(mmax, dtype=torch.float64, device=self.device)
279+
w_m[1:] = 2.0
280+
281+
if norm == "ortho":
282+
# c_lm^ortho are the orthonormal coefficients; W_{l,m} = w_m
283+
W = w_m.unsqueeze(0).expand(lmax, mmax)
284+
elif norm == "four-pi":
285+
# c_lm^{four-pi} = sqrt(4*pi) * c_lm^ortho => W_{l,m} = w_m / (4*pi)
286+
W = w_m.unsqueeze(0).expand(lmax, mmax) / (4.0 * math.pi)
287+
elif norm == "schmidt":
288+
# c_lm^{schmidt} = sqrt(4*pi / (2*l+1)) * c_lm^ortho => W_{l,m} = w_m * (2*l+1) / (4*pi)
289+
l_vals = torch.arange(lmax, dtype=torch.float64, device=self.device)
290+
W = torch.outer(2.0 * l_vals + 1.0, w_m) / (4.0 * math.pi)
291+
292+
spectral_norm_sq = torch.einsum("blm,lm->b", c.abs() ** 2, W) # (batch,)
293+
294+
self.assertTrue(compare_tensors("Parseval's theorem", spatial_norm_sq, spectral_norm_sq, atol=atol, rtol=rtol, verbose=verbose))
295+
206296
@parameterized.expand(
207297
[
208298
# even-even
@@ -217,6 +307,10 @@ def test_device_instantiation(self, nlat, nlon, norm, grid, atol, rtol, verbose=
217307
if verbose:
218308
print(f"Testing device instantiation of real-valued SHT on {nlat}x{nlon} {grid} grid with {norm} normalization")
219309

310+
# disable tf32
311+
disable_tf32()
312+
313+
# set seed
220314
set_seed(333)
221315

222316
# init on cpu
@@ -232,5 +326,207 @@ def test_device_instantiation(self, nlat, nlon, norm, grid, atol, rtol, verbose=
232326
self.assertTrue(compare_tensors(f"isht weights", isht_host.pct.cpu(), isht_device.pct.cpu(), atol=atol, rtol=rtol, verbose=verbose))
233327

234328

329+
@parameterized_class(("device"), _devices)
330+
class TestSphericalHarmonicsY(unittest.TestCase):
331+
"""Test fundamental properties of the real spherical harmonic basis functions.
332+
333+
InverseRealSHT with norm="ortho" synthesizes orthonormal basis functions on
334+
the sphere. Setting a single complex coefficient c_{l,m} = 1 synthesizes
335+
336+
f_{l,0} = Y_l^0(theta, phi) for m = 0
337+
f_{l,m,cos} ~ P_l^m(cos theta) * cos(m*phi) for m > 0 (c_{l,m} = 1+0j)
338+
f_{l,m,sin} ~ P_l^m(cos theta) * sin(m*phi) for m > 0 (c_{l,m} = 0+1j)
339+
340+
With ortho normalization the sphere inner products satisfy:
341+
<f_{l,0}, f_{l',0} > = delta_{ll'}
342+
<f_{l,m,cos}, f_{l',m',*}> = 2 * delta_{ll'} * delta_{mm'} for m, m' > 0
343+
<f_{l,m,sin}, f_{l',m',*}> = 2 * delta_{ll'} * delta_{mm'} for m, m' > 0
344+
345+
The factor of 2 for m > 0 arises because the real irfft folds the +m and -m
346+
modes together, doubling the amplitude of each mode.
347+
"""
348+
349+
@parameterized.expand(
350+
[
351+
[12, 24, "legendre-gauss", 1e-9, 1e-9],
352+
[12, 24, "equiangular", 1e-9, 1e-9],
353+
[12, 24, "lobatto", 1e-9, 1e-9],
354+
],
355+
skip_on_empty=True,
356+
)
357+
def test_orthogonality(self, nlat, nlon, grid, atol, rtol, verbose=False):
358+
"""Verify that isht(norm="ortho") synthesizes mutually orthogonal basis
359+
functions and that the self inner-products equal 1 (m=0) or 2 (m>0)."""
360+
if verbose:
361+
print(f"Testing Y_lm orthogonality on {nlat}x{nlon} {grid} grid on {self.device.type}")
362+
363+
# disable tf32
364+
disable_tf32()
365+
366+
# set seed
367+
set_seed(333)
368+
369+
if grid == "equiangular":
370+
lmax = mmax = nlat // 2
371+
elif grid == "lobatto":
372+
lmax = mmax = nlat - 1
373+
else:
374+
lmax = mmax = nlat
375+
376+
isht = th.InverseRealSHT(nlat, nlon, lmax=lmax, mmax=mmax, grid=grid, norm="ortho").to(self.device)
377+
378+
# Build one coefficient tensor per real basis function.
379+
# For m = 0: one tensor with c[l, 0] = 1+0j (real mode only).
380+
# For m > 0: two tensors — c[l, m] = 1+0j (cos) and c[l, m] = 0+1j (sin).
381+
basis_list = []
382+
expected_diag = []
383+
for l in range(lmax):
384+
for m in range(min(l + 1, mmax)):
385+
c = torch.zeros(lmax, mmax, dtype=torch.complex128)
386+
c[l, m] = 1.0 + 0.0j
387+
basis_list.append(c)
388+
expected_diag.append(1.0 if m == 0 else 2.0)
389+
if m > 0:
390+
c = torch.zeros(lmax, mmax, dtype=torch.complex128)
391+
c[l, m] = 0.0 + 1.0j
392+
basis_list.append(c)
393+
expected_diag.append(2.0)
394+
395+
coeffs = torch.stack(basis_list).to(self.device) # (N, lmax, mmax)
396+
397+
with torch.no_grad():
398+
funcs = isht(coeffs) # (N, nlat, nlon), real-valued
399+
400+
# Gram matrix via spherical quadrature: G[i,j] = integral of f_i * f_j over S^2
401+
# Weights from precompute_latitudes are in the cos(theta) domain and integrate
402+
# over [-1, 1], so the full measure is w_lat[k] * dlon.
403+
_, w_lat = precompute_latitudes(nlat, grid=grid)
404+
w_lat = w_lat.to(device=self.device, dtype=torch.float64)
405+
dlon = 2.0 * math.pi / nlon
406+
407+
weighted = funcs * (dlon * w_lat).unsqueeze(-1) # (N, nlat, nlon)
408+
gram = torch.einsum("inl,jnl->ij", weighted, funcs) # (N, N)
409+
410+
expected = torch.diag(
411+
torch.tensor(expected_diag, dtype=torch.float64, device=self.device)
412+
)
413+
self.assertTrue(compare_tensors("Gram matrix", gram, expected, atol=atol, rtol=rtol, verbose=verbose))
414+
415+
416+
@parameterized_class(("device"), _devices)
417+
class TestVectorSphericalHarmonicTransform(unittest.TestCase):
418+
"""Tests for the consistency between the scalar SHT and the vector SHT.
419+
420+
RealVectorSHT includes a 1/(l*(l+1)) normalization in its quadrature weights
421+
so that the spheroidal/toroidal spectral coefficients relate directly to the
422+
scalar SHT coefficients of the generating potential:
423+
424+
Gradient: vsht(ivsht([c, 0]))[spheroidal] = c, [toroidal] = 0 (l > 0)
425+
Curl: vsht(ivsht([0, c]))[spheroidal] = 0, [toroidal] = c (l > 0)
426+
427+
The l = 0 mode is zero in both vsht and ivsht because the gradient and curl
428+
of a constant field (Y_0^0) vanish identically on the sphere.
429+
430+
These tests catch swapped spheroidal/toroidal channels, wrong signs in the
431+
dP/dtheta or P/sin(theta) terms, and incorrect l*(l+1) normalization.
432+
"""
433+
434+
@parameterized.expand(
435+
[
436+
[32, 64, 16, "ortho", "legendre-gauss", 1e-7, 1e-7],
437+
[32, 64, 16, "ortho", "equiangular", 1e-7, 1e-7],
438+
[32, 64, 16, "ortho", "lobatto", 1e-7, 1e-7],
439+
[32, 64, 16, "four-pi", "legendre-gauss", 1e-7, 1e-7],
440+
[32, 64, 16, "four-pi", "equiangular", 1e-7, 1e-7],
441+
[32, 64, 16, "four-pi", "lobatto", 1e-7, 1e-7],
442+
# [32, 64, 16, "schmidt", "legendre-gauss", 1e-7, 1e-7],
443+
# [32, 64, 16, "schmidt", "equiangular", 1e-7, 1e-7],
444+
# [32, 64, 16, "schmidt", "lobatto", 1e-7, 1e-7],
445+
],
446+
skip_on_empty=True,
447+
)
448+
def test_gradient_consistency(self, nlat, nlon, batch_size, norm, grid, atol, rtol, verbose=True):
449+
"""ivsht([c, 0]) synthesizes the surface gradient ∇_S f of a scalar field
450+
f = isht(c). Applying vsht to this gradient field must recover c in the
451+
spheroidal channel and zero in the toroidal channel, because a gradient
452+
field is curl-free (purely spheroidal).
453+
"""
454+
if verbose:
455+
print(f"Testing gradient consistency on {nlat}x{nlon} {grid} grid with {norm} norm on {self.device.type}")
456+
457+
# disable tf32
458+
disable_tf32()
459+
460+
# set seed
461+
set_seed(333)
462+
463+
vsht = th.RealVectorSHT (nlat, nlon, grid=grid, norm=norm).to(self.device)
464+
ivsht = th.InverseRealVectorSHT (nlat, nlon, grid=grid, norm=norm).to(self.device)
465+
lmax, mmax = vsht.lmax, vsht.mmax
466+
467+
with torch.no_grad():
468+
c = random_sht_coeffs(batch_size, lmax, mmax, self.device, zero_l0=True)
469+
zeros = torch.zeros_like(c)
470+
471+
# synthesize gradient field: ivsht([c, 0]) = ∇_S f
472+
grad_f = ivsht(torch.stack([c, zeros], dim=-3)) # (batch, 2, nlat, nlon)
473+
474+
# analyse: vsht(∇_S f) must give [c, 0]
475+
st = vsht(grad_f) # (batch, 2, lmax, mmax)
476+
s = st[..., 0, :, :] # spheroidal
477+
t = st[..., 1, :, :] # toroidal
478+
479+
self.assertTrue(compare_tensors("spheroidal coefficients", s, c, atol=atol, rtol=rtol, verbose=verbose))
480+
self.assertTrue(compare_tensors("toroidal coefficients", t, zeros, atol=atol, rtol=rtol, verbose=verbose))
481+
482+
@parameterized.expand(
483+
[
484+
[32, 64, 16, "ortho", "legendre-gauss", 1e-7, 1e-7],
485+
[32, 64, 16, "ortho", "equiangular", 1e-7, 1e-7],
486+
[32, 64, 16, "ortho", "lobatto", 1e-7, 1e-7],
487+
[32, 64, 16, "four-pi", "legendre-gauss", 1e-7, 1e-7],
488+
[32, 64, 16, "four-pi", "equiangular", 1e-7, 1e-7],
489+
[32, 64, 16, "four-pi", "lobatto", 1e-7, 1e-7],
490+
# [32, 64, 16, "schmidt", "legendre-gauss", 1e-9, 1e-9],
491+
# [32, 64, 16, "schmidt", "equiangular", 1e-9, 1e-9],
492+
# [32, 64, 16, "schmidt", "lobatto", 1e-9, 1e-9],
493+
],
494+
skip_on_empty=True,
495+
)
496+
def test_curl_consistency(self, nlat, nlon, batch_size, norm, grid, atol, rtol, verbose=False):
497+
"""ivsht([0, c]) synthesizes the surface curl ê_r × ∇_S f of a scalar field
498+
f = isht(c). Applying vsht to this curl field must recover c in the
499+
toroidal channel and zero in the spheroidal channel, because a surface
500+
curl field is divergence-free (purely toroidal).
501+
"""
502+
if verbose:
503+
print(f"Testing curl consistency on {nlat}x{nlon} {grid} grid with {norm} norm on {self.device.type}")
504+
505+
# disable tf32
506+
disable_tf32()
507+
508+
# set seed
509+
set_seed(333)
510+
511+
vsht = th.RealVectorSHT (nlat, nlon, grid=grid, norm=norm).to(self.device)
512+
ivsht = th.InverseRealVectorSHT (nlat, nlon, grid=grid, norm=norm).to(self.device)
513+
lmax, mmax = vsht.lmax, vsht.mmax
514+
515+
with torch.no_grad():
516+
c = random_sht_coeffs(batch_size, lmax, mmax, self.device, zero_l0=True)
517+
zeros = torch.zeros_like(c)
518+
519+
# synthesize curl field: ivsht([0, c]) = ê_r × ∇_S f
520+
curl_f = ivsht(torch.stack([zeros, c], dim=-3)) # (batch, 2, nlat, nlon)
521+
522+
# analyse: vsht(ê_r × ∇_S f) must give [0, c]
523+
st = vsht(curl_f) # (batch, 2, lmax, mmax)
524+
s = st[..., 0, :, :] # spheroidal
525+
t = st[..., 1, :, :] # toroidal
526+
527+
self.assertTrue(compare_tensors("spheroidal coefficients", s, zeros, atol=atol, rtol=rtol, verbose=verbose))
528+
self.assertTrue(compare_tensors("toroidal coefficients", t, c, atol=atol, rtol=rtol, verbose=verbose))
529+
530+
235531
if __name__ == "__main__":
236532
unittest.main()

0 commit comments

Comments
 (0)