From 2d663a8e71972b764b386597c65013abb6635012 Mon Sep 17 00:00:00 2001 From: Akshay Subramaniam <6964110+akshaysubr@users.noreply.github.com> Date: Mon, 12 Jan 2026 23:04:48 -0800 Subject: [PATCH 1/4] Refactor tests to pytest with float32/float64 dtype parameterization - Convert all test scripts to pytest format with fixtures and assertions - Add conftest.py with shared fixtures (device, nside, dtype, tolerances) - Parameterize all SHT tests with both float32 and float64 dtypes - Add tolerance helpers: get_impl_tol(), get_roundtrip_tol(), get_bluestein_tol() - Add pytest configuration to pyproject.toml - Move profiling scripts to benchmarks/ and examples/ directories - Remove obsolete test_sht_two_stream_overlap_profiling.py Signed-off-by: Akshay Subramaniam <6964110+akshaysubr@users.noreply.github.com> --- benchmarks/sht_profiling.py | 83 +++++++ benchmarks/stream_overlap_profiling.py | 93 ++++++++ examples/wmap_optimization.py | 130 +++++++++++ pyproject.toml | 30 +++ tests/conftest.py | 138 ++++++++++++ tests/test_batch_remap.py | 64 ++++-- tests/test_data_remapping.py | 94 ++++---- tests/test_differentiability.py | 142 ++++++++++-- tests/test_grad.py | 127 ++++++----- tests/test_harmonic_transform.py | 107 +++++---- tests/test_regridding.py | 205 ++++++++++++------ tests/test_sht_bluestein.py | 90 ++++++-- tests/test_sht_cuda.py | 89 ++++++-- tests/test_sht_cuda_batch.py | 110 +++++++--- tests/test_sht_cuda_stream.py | 110 ++++++---- .../test_sht_two_stream_overlap_profiling.py | 37 ---- 16 files changed, 1265 insertions(+), 384 deletions(-) create mode 100644 benchmarks/sht_profiling.py create mode 100644 benchmarks/stream_overlap_profiling.py create mode 100644 examples/wmap_optimization.py create mode 100644 tests/conftest.py delete mode 100755 tests/test_sht_two_stream_overlap_profiling.py diff --git a/benchmarks/sht_profiling.py b/benchmarks/sht_profiling.py new file mode 100644 index 0000000..bac875d --- /dev/null +++ b/benchmarks/sht_profiling.py @@ -0,0 +1,83 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Benchmark: Profile SHTCUDA batched operations with NVTX markers. + +This script profiles the performance of batched spherical harmonic transforms +using NVTX markers for visualization in NSight Systems or NSight Compute. + +Usage: + # Basic run + python benchmarks/sht_profiling.py + + # With NSight Systems profiling + nsys profile python benchmarks/sht_profiling.py +""" + +import torch + +from cuhpx import SHTCUDA + + +def main(): + """Run SHT profiling benchmark.""" + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + if device.type == "cpu": + print("Error: This benchmark requires CUDA") + return + + # Configuration + nside = int(input("nside: ")) + nbatch = int(input("batch size: ")) + lmax = int(input("lmax: ")) + mmax = lmax + npix = 12 * nside**2 + + print("\nBenchmark configuration:") + print(f" nside: {nside}") + print(f" batch_size: {nbatch}") + print(f" lmax: {lmax}") + print(f" npix: {npix}") + print() + + # Create input signal + signal = torch.randn(nbatch, npix, dtype=torch.float32, device=device) + + # Create SHT transform + sht = SHTCUDA(nside, lmax=lmax, mmax=mmax, quad_weights="ring") + + # Warmup + print("Warming up...") + for _ in range(3): + _ = sht(signal) + torch.cuda.synchronize() + + # Profile with NVTX markers + print("Running profiled iterations...") + n_iterations = 10 + + for i in range(n_iterations): + torch.cuda.nvtx.range_push(f"SHTCUDA batch iter {i}") + _ = sht(signal) + torch.cuda.nvtx.range_pop() + + torch.cuda.synchronize() + print(f"Completed {n_iterations} iterations") + print("Use 'nsys profile' to capture detailed profiling data") + + +if __name__ == "__main__": + main() diff --git a/benchmarks/stream_overlap_profiling.py b/benchmarks/stream_overlap_profiling.py new file mode 100644 index 0000000..b9d9705 --- /dev/null +++ b/benchmarks/stream_overlap_profiling.py @@ -0,0 +1,93 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Benchmark: Profile CUDA stream overlap for SHT operations. + +This script profiles the behavior of SHT operations across multiple CUDA streams +using NVTX markers for visualization in NSight Systems. + +Usage: + # Basic run + python benchmarks/stream_overlap_profiling.py + + # With NSight Systems profiling + nsys profile python benchmarks/stream_overlap_profiling.py +""" + +import torch +import torch.cuda.nvtx as nvtx + +from cuhpx import cuhpx_fft + + +def main(): + """Run stream overlap profiling benchmark.""" + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + if device.type == "cpu": + print("Error: This benchmark requires CUDA") + return + + # Configuration + nside = int(input("nside: ")) + m = int(input("m, the first dim: ")) + n = int(input("n, the second dim: ")) + + npix = 12 * nside**2 + lmax = 2 * nside + 1 + + print("\nBenchmark configuration:") + print(f" nside: {nside}") + print(f" batch dims: ({m}, {n})") + print(f" npix: {npix}") + print(f" lmax: {lmax}") + print() + + # Create input signals + signal1 = torch.randn(m, n, npix, dtype=torch.float32, device=device) + signal2 = signal1.clone() + + # Create two CUDA streams + stream1 = torch.cuda.Stream() + stream2 = torch.cuda.Stream() + + # Profile with NVTX markers + print("Running profiled stream operations...") + + # Perform SHT on stream1 + nvtx.range_push("Stream 1 SHT Operation") + with torch.cuda.stream(stream1): + result1 = cuhpx_fft.healpix_rfft_batch(signal1, nside, nside) + stream1.synchronize() + nvtx.range_pop() + + # Perform SHT on stream2 + nvtx.range_push("Stream 2 SHT Operation") + with torch.cuda.stream(stream2): + result2 = cuhpx_fft.healpix_rfft_batch(signal2, nside, nside) + stream2.synchronize() + nvtx.range_pop() + + # Compare results + nvtx.range_push("Compare Results") + comparison = torch.allclose(result1, result2) + nvtx.range_pop() + + print(f"Results from two streams identical: {comparison}") + print("Use 'nsys profile' to visualize stream overlap") + + +if __name__ == "__main__": + main() diff --git a/examples/wmap_optimization.py b/examples/wmap_optimization.py new file mode 100644 index 0000000..31d9c49 --- /dev/null +++ b/examples/wmap_optimization.py @@ -0,0 +1,130 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Example: Spectral optimization on WMAP data using differentiable SHT. + +This example demonstrates how to use cuHPX's differentiable inverse spherical +harmonic transform to fit spherical harmonic coefficients to WMAP sky map data. + +Usage: + python examples/wmap_optimization.py + +Requirements: + - healpy (pip install healpy) + - CUDA-capable GPU + - Network access to download WMAP data from NASA +""" + +import os +import urllib.request + +import healpy as hp +import torch +import torch.nn as nn + +from cuhpx import iSHTCUDA + + +def download_wmap_data(filename="wmap_band_iqumap_r9_7yr_W_v4.fits"): + """Download WMAP 7-year W-band sky map if not present.""" + if os.path.exists(filename): + print(f"Using existing file: {filename}") + return filename + + url = f"http://lambda.gsfc.nasa.gov/data/map/dr4/skymaps/7yr/raw/{filename}" + print(f"Downloading WMAP data from {url}...") + + try: + urllib.request.urlretrieve(url, filename) # noqa: S310 + print(f"Downloaded: {filename}") + return filename + except Exception as e: + raise RuntimeError(f"Failed to download WMAP data: {e}") + + +class SpectralModel(nn.Module): + """Neural network module with learnable spherical harmonic coefficients. + + This model learns to represent a HEALPix map as spherical harmonic + coefficients. The forward pass applies the inverse SHT to produce + a pixel-space representation. + """ + + def __init__(self, nside, lmax, mmax, device): + super().__init__() + self.coeffs = nn.Parameter(torch.randn(lmax, mmax, dtype=torch.complex128)) + self.isht = iSHTCUDA(nside, lmax=lmax, mmax=mmax).to(device) + + def forward(self): + return self.isht(self.coeffs) + + +def main(): + """Run WMAP spectral optimization example.""" + # Configuration + nside = int(input("Enter the nside value (e.g., 64, 128): ")) + lmax = int(input("Enter the lmax value (e.g., 2*nside+1): ")) + mmax = lmax + n_iterations = 500 + learning_rate = 5e-2 + + # Setup device + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + if device.type == "cpu": + print("Warning: CUDA not available, running on CPU (will be slow)") + + # Download and load WMAP data + fits_file = download_wmap_data() + wmap_map_I = hp.read_map(fits_file) + wmap = hp.ud_grade(wmap_map_I, nside) + + # Convert to torch tensor + signal = torch.from_numpy(wmap).to(device) + print(f"Loaded WMAP data: nside={nside}, npix={len(wmap)}") + + # Initialize model and optimizer + model = SpectralModel(nside, lmax, mmax, device).to(device) + optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) + + # Training loop + print(f"\nOptimizing spherical harmonic coefficients (lmax={lmax})...") + print("-" * 50) + + losses = [] + for iteration in range(n_iterations): + # Forward pass + prediction = model() + loss = (prediction - signal).pow(2).mean() + + # Backward pass + optimizer.zero_grad() + loss.backward() + optimizer.step() + + losses.append(loss.item()) + + if iteration % 50 == 0 or iteration == n_iterations - 1: + print(f"Iteration {iteration:4d}: MSE loss = {loss.item():.6e}") + + # Report results + print("-" * 50) + print(f"Final MSE loss: {losses[-1]:.6e}") + print(f"Loss reduction: {losses[0] / losses[-1]:.1f}x") + print("\nOptimization complete!") + + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml index bf8656e..3496d16 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,6 +31,9 @@ classifiers = [ "Programming Language :: Python :: 3.12", ] +[project.optional-dependencies] +test = ["pytest>=7.0", "pytest-cov>=4.0", "healpy"] + [project.urls] "Homepage" = "https://github.com/NVlabs/cuHPX" @@ -130,3 +133,30 @@ target-version = 'py38' "ARG", # Unused function args -> fixtures nevertheless are functionally relevant... "FBT", # Don't care about booleans as positional arguments in tests, e.g. via @pytest.mark.parametrize() ] + +[tool.pytest.ini_options] +testpaths = ["tests"] +python_files = ["test_*.py"] +python_functions = ["test_*"] +markers = [ + "cuda: marks tests as requiring CUDA GPU", + "slow: marks tests as slow running", +] +filterwarnings = [ + "ignore::SyntaxWarning", +] +addopts = "-v --tb=short" + +[tool.coverage.run] +source = ["cuhpx"] +branch = true +omit = ["*/tests/*", "*/__pycache__/*"] + +[tool.coverage.report] +exclude_lines = [ + "pragma: no cover", + "raise NotImplementedError", + "if __name__ == .__main__.:", + "if TYPE_CHECKING:", +] +show_missing = true diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..3a9784e --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,138 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Pytest configuration and shared fixtures for cuHPX tests.""" + +import pytest +import torch + +# Check if CUDA is available once at module load +CUDA_AVAILABLE = torch.cuda.is_available() + + +def pytest_configure(config): + """Register custom markers.""" + config.addinivalue_line("markers", "cuda: marks tests as requiring CUDA GPU") + config.addinivalue_line("markers", "slow: marks tests as slow running") + + +@pytest.fixture +def device(): + """Provide CUDA device, skip test if unavailable.""" + if not CUDA_AVAILABLE: + pytest.skip("CUDA not available") + return torch.device("cuda") + + +@pytest.fixture(params=[32, 64]) +def nside(request): + """Parametrize nside values for tests.""" + return request.param + + +@pytest.fixture +def nside_small(): + """Return a small nside value for quick tests.""" + return 32 + + +@pytest.fixture +def lmax(nside): + """Return lmax based on nside (healpy default: 3*nside - 1).""" + return 3 * nside - 1 + + +@pytest.fixture +def mmax(lmax): + """Return mmax (same as lmax by default).""" + return lmax + + +@pytest.fixture +def npix(nside): + """Return number of HEALPix pixels for given nside.""" + return 12 * nside**2 + + +@pytest.fixture(params=[torch.float32, torch.float64]) +def dtype(request): + """Parametrize float dtypes for tests.""" + return request.param + + +@pytest.fixture +def complex_dtype(dtype): + """Return corresponding complex dtype for a real dtype.""" + return torch.complex64 if dtype == torch.float32 else torch.complex128 + + +# Tolerance configurations for different test types +# These are based on empirical measurements of algorithm accuracy + +# For CUDA vs PyTorch implementation comparison (same algorithm, different impl) +IMPL_COMPARISON_TOL = { + torch.float32: {"sht": (1e-4, 1e-5), "isht": (1e-3, 1e-2)}, + torch.float64: {"sht": (1e-8, 1e-8), "isht": (1e-5, 1e-5)}, +} + +# For Bluestein vs standard comparison (different algorithms, both float64 internally) +# The Bluestein algorithm has inherent numerical differences from standard FFT +# that are algorithm-limited (~1e-6 for SHT, ~1e-4 for iSHT) not precision-limited +BLUESTEIN_COMPARISON_TOL = { + "sht": (1e-5, 1e-5), + "isht": (1e-4, 1e-4), +} + +# For roundtrip tests (algorithm-limited, not precision-limited) +# These tolerances are the same for both dtypes because error is algorithmic +ROUNDTRIP_TOL = {"rtol": 0.01, "atol": 0.05} + + +def get_impl_tol(dtype, transform_type): + """Get (rtol, atol) for implementation comparison tests. + + Args: + dtype: torch.float32 or torch.float64 + transform_type: "sht" or "isht" + + Returns: + Tuple of (rtol, atol) + """ + return IMPL_COMPARISON_TOL[dtype][transform_type] + + +def get_roundtrip_tol(): + """Get (rtol, atol) for roundtrip tests. + + Roundtrip error is algorithm-limited, not precision-limited, + so same tolerances for float32 and float64. + """ + return ROUNDTRIP_TOL["rtol"], ROUNDTRIP_TOL["atol"] + + +def get_bluestein_tol(transform_type): + """Get (rtol, atol) for Bluestein vs standard comparison tests. + + Bluestein uses a different FFT algorithm (convolution-based) which + has inherent numerical differences from standard FFT. These are + algorithm-limited, not precision-limited. + + Args: + transform_type: "sht" or "isht" + + Returns: + Tuple of (rtol, atol) + """ + return BLUESTEIN_COMPARISON_TOL[transform_type] diff --git a/tests/test_batch_remap.py b/tests/test_batch_remap.py index 959e343..4cfb223 100644 --- a/tests/test_batch_remap.py +++ b/tests/test_batch_remap.py @@ -13,27 +13,65 @@ # See the License for the specific language governing permissions and # limitations under the License. +"""Tests for batched remapping operations.""" + +import pytest import torch import cuhpx -# Check if CUDA is available -device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') -nside = int(input('nside: ')) -m = int(input('m: ')) -n = int(input('n: ')) +@pytest.mark.cuda +def test_batched_ring2nest_matches_single(device, nside): + """Test that batched ring2nest produces same results as individual operations.""" + npix = 12 * nside**2 + m, n = 4, 8 + signal = torch.randn(m, n, npix, dtype=torch.float32, device=device) + + # Batched operation + result_batch = cuhpx.ring2nest(signal, nside) + + # Single operations + result_single = torch.zeros_like(result_batch) + for i in range(m): + for j in range(n): + result_single[i, j, :] = cuhpx.ring2nest(signal[i, j, :], nside) + + assert torch.equal( + result_batch, result_single + ), f"Batched ring2nest doesn't match single operations for nside={nside}" + + +@pytest.mark.cuda +def test_batched_nest2ring_matches_single(device, nside): + """Test that batched nest2ring produces same results as individual operations.""" + npix = 12 * nside**2 + m, n = 4, 8 + signal = torch.randn(m, n, npix, dtype=torch.float32, device=device) + + # Batched operation + result_batch = cuhpx.nest2ring(signal, nside) -npix = 12 * nside**2 -signal = torch.randn((m, n, npix), dtype=torch.float32).to(device) + # Single operations + result_single = torch.zeros_like(result_batch) + for i in range(m): + for j in range(n): + result_single[i, j, :] = cuhpx.nest2ring(signal[i, j, :], nside) -signal_dest = cuhpx.ring2nest(signal, nside) + assert torch.equal( + result_batch, result_single + ), f"Batched nest2ring doesn't match single operations for nside={nside}" -signal_1by1 = torch.zeros((m, n, npix), dtype=torch.float32).to(device) -for i in range(m): - for j in range(n): - signal_1by1[i, j, :] = cuhpx.ring2nest(signal[i, j, :], nside) +@pytest.mark.cuda +@pytest.mark.parametrize("batch_shape", [(2,), (3, 4), (2, 3, 5)]) +def test_batched_remap_various_shapes(device, nside, batch_shape): + """Test batched remapping with various batch dimensions.""" + npix = 12 * nside**2 + signal = torch.randn(*batch_shape, npix, dtype=torch.float32, device=device) + # Round trip should recover original + nested = cuhpx.ring2nest(signal, nside) + recovered = cuhpx.nest2ring(nested, nside) -print("whether batch and one by one the same: ", torch.equal(signal_dest, signal_1by1)) + assert torch.equal(signal, recovered), f"Round-trip failed for batch_shape={batch_shape}, nside={nside}" diff --git a/tests/test_data_remapping.py b/tests/test_data_remapping.py index a11520c..7a196e4 100644 --- a/tests/test_data_remapping.py +++ b/tests/test_data_remapping.py @@ -13,64 +13,76 @@ # See the License for the specific language governing permissions and # limitations under the License. +"""Tests for data remapping operations (ring2nest, nest2ring) against healpy reference.""" + import healpy as hp +import pytest import torch import cuhpx as hpx -# Read the order value from user input -nside = int(input("Enter the nside value: ")) -nelements = 12 * nside**2 -# Check if CUDA is available -device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +@pytest.mark.cuda +@pytest.mark.parametrize("dtype", [torch.int32, torch.float64]) +def test_ring2nest_matches_healpy(device, nside, dtype): + """Test that ring2nest produces identical results to healpy.""" + npix = 12 * nside**2 + tensor_ring = torch.arange(npix, device=device, dtype=dtype) + + # cuHPX ring2nest + result_hpx = hpx.ring2nest(tensor_ring, nside).cpu() + + # healpy reference + result_healpy = torch.tensor( + hp.pixelfunc.reorder(tensor_ring.cpu().numpy(), inp="RING", out="NESTED"), + dtype=dtype, + ) + assert torch.equal(result_hpx, result_healpy), f"ring2nest mismatch for nside={nside}, dtype={dtype}" -# Function to test ring2nest and nest2ring for a given dtype -def test_ring2nest_nest2ring(dtype): - # Generate the input tensor in RING ordering - tensor_in_ring = torch.arange(nelements, device=device, dtype=dtype) - # Print the first five elements of the input tensor for ring2nest - print(f"Input tensor (RING, dtype={dtype}) first 5 elements:", tensor_in_ring[:5]) +@pytest.mark.cuda +@pytest.mark.parametrize("dtype", [torch.int32, torch.float64]) +def test_nest2ring_matches_healpy(device, nside, dtype): + """Test that nest2ring produces identical results to healpy.""" + npix = 12 * nside**2 + tensor_nest = torch.arange(npix, device=device, dtype=dtype) - # Use hpx.ring2nest to convert to NESTED ordering - tensor_in_nest = hpx.ring2nest(tensor_in_ring, nside) - result_tensor_hpx_nest = tensor_in_nest.to('cpu') + # cuHPX nest2ring + result_hpx = hpx.nest2ring(tensor_nest, nside).cpu() - # Use healpy.pixelfunc.reorder to convert to NESTED ordering - map_in_ring = tensor_in_ring.cpu().numpy() - result_tensor_healpy_nest = torch.tensor(hp.pixelfunc.reorder(map_in_ring, inp='RING', out='NESTED'), dtype=dtype) + # healpy reference + result_healpy = torch.tensor( + hp.pixelfunc.reorder(tensor_nest.cpu().numpy(), inp="NESTED", out="RING"), + dtype=dtype, + ) - # Compare the results - comparison_ring2nest = torch.equal(result_tensor_hpx_nest, result_tensor_healpy_nest) - print(f"Are the ring2nest results identical (dtype={dtype})?", comparison_ring2nest) - print(f"HPX ring2nest result (dtype={dtype}) first 5 elements:", result_tensor_hpx_nest[:5]) - print(f"Healpy ring2nest result (dtype={dtype}) first 5 elements:", result_tensor_healpy_nest[:5]) + assert torch.equal(result_hpx, result_healpy), f"nest2ring mismatch for nside={nside}, dtype={dtype}" - # Generate the input tensor in NEST ordering for nest2ring - tensor_in_nest = torch.arange(nelements, device=device, dtype=dtype) - # Print the first five elements of the input tensor for nest2ring - print(f"Input tensor (NEST, dtype={dtype}) first 5 elements:", tensor_in_nest[:5]) +@pytest.mark.cuda +@pytest.mark.parametrize("dtype", [torch.float32, torch.float64]) +def test_ring2nest_nest2ring_roundtrip(device, nside, dtype): + """Test that ring2nest followed by nest2ring recovers the original data.""" + npix = 12 * nside**2 + original = torch.randn(npix, device=device, dtype=dtype) - # Use hpx.nest2ring to convert to RING ordering - tensor_in_ring = hpx.nest2ring(tensor_in_nest, nside) - result_tensor_hpx_ring = tensor_in_ring.to('cpu') + # Round trip: ring -> nest -> ring + nested = hpx.ring2nest(original, nside) + recovered = hpx.nest2ring(nested, nside) - # Use healpy.pixelfunc.reorder to convert to RING ordering - map_in_nest = tensor_in_nest.cpu().numpy() - result_tensor_healpy_ring = torch.tensor(hp.pixelfunc.reorder(map_in_nest, inp='NESTED', out='RING'), dtype=dtype) + assert torch.equal(original, recovered), f"Round-trip failed for nside={nside}, dtype={dtype}" - # Compare the results - comparison_nest2ring = torch.equal(result_tensor_hpx_ring, result_tensor_healpy_ring) - print(f"Are the nest2ring results identical (dtype={dtype})?", comparison_nest2ring) - print(f"HPX nest2ring result (dtype={dtype}) first 5 elements:", result_tensor_hpx_ring[:5]) - print(f"Healpy nest2ring result (dtype={dtype}) first 5 elements:", result_tensor_healpy_ring[:5]) +@pytest.mark.cuda +@pytest.mark.parametrize("dtype", [torch.float32, torch.float64]) +def test_nest2ring_ring2nest_roundtrip(device, nside, dtype): + """Test that nest2ring followed by ring2nest recovers the original data.""" + npix = 12 * nside**2 + original = torch.randn(npix, device=device, dtype=dtype) -# Test for int32 -test_ring2nest_nest2ring(torch.int32) + # Round trip: nest -> ring -> nest + ring = hpx.nest2ring(original, nside) + recovered = hpx.ring2nest(ring, nside) -# Test for float64 (double) -test_ring2nest_nest2ring(torch.float64) + assert torch.equal(original, recovered), f"Round-trip failed for nside={nside}, dtype={dtype}" diff --git a/tests/test_differentiability.py b/tests/test_differentiability.py index 58d0035..e12045a 100644 --- a/tests/test_differentiability.py +++ b/tests/test_differentiability.py @@ -13,29 +13,58 @@ # See the License for the specific language governing permissions and # limitations under the License. +"""Tests for differentiability of SHT operations (gradient-based optimization).""" -import healpy as hp +import pytest import torch import torch.nn as nn -from download import download_file -from cuhpx import iSHTCUDA +from cuhpx import SHT, iSHTCUDA -download_file('http://lambda.gsfc.nasa.gov/data/map/dr4/skymaps/7yr/raw/wmap_band_iqumap_r9_7yr_W_v4.fits') -nside = int(input("Enter the nside value: ")) -lmax = int(input("Enter the lmax value: ")) # lmax = 2*nside+1 -mmax = lmax -device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +def _generate_synthetic_signal(nside, lmax, mmax, device, seed=42, bandwidth_fraction=0.5): + """Generate a synthetic bandlimited signal for testing. -wmap_map_I = hp.read_map("wmap_band_iqumap_r9_7yr_W_v4.fits") -wmap = hp.ud_grade(wmap_map_I, nside) -data = torch.from_numpy(wmap) -signal = data.to(device) + Args: + nside: HEALPix nside parameter. + lmax: Maximum degree l for the transform. + mmax: Maximum order m for the transform. + device: Torch device. + seed: Random seed. + bandwidth_fraction: Fraction of lmax to use for signal bandwidth (0 < frac <= 1). + A value < 1 ensures the signal is strictly bandlimited below lmax. + + Returns: + Signal tensor on device. + """ + torch.manual_seed(seed) + + # Create bandlimited signal via inverse SHT + from cuhpx import iSHT + + isht = iSHT(nside, lmax=lmax, mmax=mmax) + + # Only populate coefficients up to a fraction of lmax to ensure + # the signal is strictly bandlimited below the transform's lmax + max_l = int(lmax * bandwidth_fraction) + max_l = max(1, min(max_l, lmax - 1)) # Ensure at least l=0 and at most lmax-1 + + # Random coefficients with decay for higher l + coeff = torch.zeros((lmax, mmax), dtype=torch.complex128) + for l_idx in range(max_l): + for m_idx in range(min(l_idx + 1, mmax)): + # Decay amplitude with l for realistic signal + amplitude = 1.0 / (1 + l_idx) + coeff[l_idx, m_idx] = amplitude * (torch.randn(1) + 1j * torch.randn(1)) + + signal = isht(coeff) + return signal.to(device) class SpectralModel(nn.Module): - def __init__(self, nside, lmax, mmax): + """Neural network module with learnable spherical harmonic coefficients.""" + + def __init__(self, nside, lmax, mmax, device): super().__init__() self.coeffs = nn.Parameter(torch.randn(lmax, mmax, dtype=torch.complex128)) self.isht = iSHTCUDA(nside, lmax=lmax, mmax=mmax).to(device) @@ -44,20 +73,85 @@ def forward(self): return self.isht(self.coeffs) -sh_model = SpectralModel(nside, lmax, mmax).to(device) +@pytest.mark.cuda +@pytest.mark.slow +class TestDifferentiability: + """Test differentiability through gradient-based optimization.""" + + @pytest.fixture + def nside_opt(self): + """Return nside for optimization tests (smaller for speed).""" + return 32 + + @pytest.fixture + def lmax_opt(self, nside_opt): + """Return lmax for optimization tests.""" + return 2 * nside_opt + 1 + + @pytest.fixture + def target_signal(self, device, nside_opt, lmax_opt): + """Create target signal for optimization.""" + return _generate_synthetic_signal(nside_opt, lmax_opt, lmax_opt, device) + + def test_spectral_model_optimization(self, device, nside_opt, lmax_opt, target_signal): + """Test that spectral model can be optimized via gradient descent.""" + model = SpectralModel(nside_opt, lmax_opt, lmax_opt, device).to(device) + optimizer = torch.optim.Adam(model.parameters(), lr=5e-2) + + initial_loss = None + final_loss = None + + # Run a few optimization steps + n_iterations = 50 + for i in range(n_iterations): + loss = (model() - target_signal).pow(2).mean() + optimizer.zero_grad() + loss.backward() + optimizer.step() + + if i == 0: + initial_loss = loss.item() + final_loss = loss.item() + + # Loss should decrease significantly + assert final_loss < initial_loss * 0.1, ( + f"Optimization failed: initial loss = {initial_loss}, " f"final loss = {final_loss}" + ) + + def test_gradient_exists_and_valid(self, device, nside_opt, lmax_opt, target_signal): + """Test that gradients exist and are valid (no NaN/Inf).""" + model = SpectralModel(nside_opt, lmax_opt, lmax_opt, device).to(device) + + loss = (model() - target_signal).pow(2).mean() + loss.backward() + + assert model.coeffs.grad is not None, "Gradient not computed" + assert not torch.isnan(model.coeffs.grad).any(), "NaN in gradients" + assert not torch.isinf(model.coeffs.grad).any(), "Inf in gradients" + + def test_sht_forward_differentiable(self, device, nside_opt, lmax_opt): + """Test that SHT forward pass is differentiable.""" + npix = 12 * nside_opt**2 + signal = torch.randn(npix, dtype=torch.float32, device=device, requires_grad=True) + + sht = SHT(nside_opt, lmax=lmax_opt, mmax=lmax_opt, quad_weights="ring").to(device) -optimizer = torch.optim.Adam(sh_model.parameters(), lr=5e-2) + coeff = sht(signal) + loss = coeff.abs().sum() + loss.backward() -losses = [] + assert signal.grad is not None, "SHT forward not differentiable" + assert signal.grad.shape == signal.shape, "Gradient shape mismatch" -for iter in range(500): + def test_isht_forward_differentiable(self, device, nside_opt, lmax_opt): + """Test that iSHT forward pass is differentiable.""" + coeff = torch.randn(lmax_opt, lmax_opt, dtype=torch.complex128, device=device, requires_grad=True) - loss = (sh_model() - signal).pow(2).mean() - optimizer.zero_grad() - loss.backward() - optimizer.step() + isht = iSHTCUDA(nside_opt, lmax=lmax_opt, mmax=lmax_opt).to(device) - losses.append(loss.item()) + signal = isht(coeff) + loss = signal.abs().sum() + loss.backward() - if iter % 10 == 0: - print(f'iteration: {iter} loss: {loss.item()}') + assert coeff.grad is not None, "iSHT forward not differentiable" + assert coeff.grad.shape == coeff.shape, "Gradient shape mismatch" diff --git a/tests/test_grad.py b/tests/test_grad.py index 6fcacd3..1bab607 100644 --- a/tests/test_grad.py +++ b/tests/test_grad.py @@ -13,65 +13,78 @@ # See the License for the specific language governing permissions and # limitations under the License. +"""Tests for gradient computation of SHT/iSHT operations.""" + +import pytest import torch from cuhpx import SHT, SHTCUDA, iSHT, iSHTCUDA -# Check if CUDA is available -device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - -nside = 32 -npix = 12 * nside**2 -signal = torch.randn(npix, dtype=torch.float32).to(device) - - -quad_weights = 'ring' -lmax = 2 * nside + 1 -mmax = lmax - -sht = SHT(nside, lmax=lmax, mmax=mmax, quad_weights=quad_weights).to(device) -isht = iSHT(nside, lmax=lmax, mmax=mmax).to(device) - -coeff = sht(signal) - -sht_cuda = SHTCUDA(nside, lmax=lmax, mmax=mmax, quad_weights=quad_weights).to(device) -isht_cuda = iSHTCUDA(nside, lmax=lmax, mmax=mmax).to(device) - -signal1 = torch.clone(signal) -signal2 = torch.clone(signal) - -signal1.requires_grad_(True) -signal2.requires_grad_(True) - -c1 = sht(signal1) -c2 = sht_cuda(signal2) - -c1.backward(torch.clone(c1)) -c2.backward(torch.clone(c1)) - -diff = signal1.grad - signal2.grad - -print('-----------------------') -print('Mean of Autograd of SHT: ', torch.mean(signal1.grad.abs())) -print('Mean of manual of SHT: ', torch.mean(signal2.grad.abs())) -print('diff between the grad of SHT: ', torch.mean(diff.abs())) -print('ratio', torch.mean(diff.abs()) / torch.mean(signal1.grad.abs())) - -coeff1 = torch.clone(coeff) -coeff2 = torch.clone(coeff) - -coeff1.requires_grad_(True) -coeff2.requires_grad_(True) - -s1 = isht(coeff1) -s2 = isht_cuda(coeff2) - -s1.backward(torch.clone(s1)) -s2.backward(torch.clone(s1)) - -diff = coeff1.grad - coeff2.grad -print('Mean of Autograd of iSHT: ', torch.mean(coeff1.grad.abs())) -print('Mean of manual of iSHT: ', torch.mean(coeff2.grad.abs())) -print('diff between the grad of iSHT: ', torch.mean(diff.abs())) -print('ratio', torch.mean(diff.abs()) / torch.mean(coeff1.grad.abs())) +@pytest.mark.cuda +class TestSHTGradients: + """Test gradient computation for spherical harmonic transforms.""" + + @pytest.fixture + def nside_fixed(self): + """Return fixed nside for gradient tests.""" + return 32 + + @pytest.fixture + def lmax_fixed(self, nside_fixed): + """Return lmax for fixed nside.""" + return 2 * nside_fixed + 1 + + @pytest.fixture + def signal(self, device, nside_fixed, dtype): + """Create a random test signal.""" + npix = 12 * nside_fixed**2 + return torch.randn(npix, dtype=dtype, device=device) + + @pytest.fixture + def sht_autograd(self, device, nside_fixed, lmax_fixed): + """Create SHT with autograd support.""" + return SHT(nside_fixed, lmax=lmax_fixed, mmax=lmax_fixed, quad_weights="ring").to(device) + + @pytest.fixture + def isht_autograd(self, device, nside_fixed, lmax_fixed): + """Create iSHT with autograd support.""" + return iSHT(nside_fixed, lmax=lmax_fixed, mmax=lmax_fixed).to(device) + + @pytest.fixture + def sht_cuda(self, device, nside_fixed, lmax_fixed): + """Create SHTCUDA with custom backward.""" + return SHTCUDA(nside_fixed, lmax=lmax_fixed, mmax=lmax_fixed, quad_weights="ring").to(device) + + @pytest.fixture + def isht_cuda(self, device, nside_fixed, lmax_fixed): + """Create iSHTCUDA with custom backward.""" + return iSHTCUDA(nside_fixed, lmax=lmax_fixed, mmax=lmax_fixed).to(device) + + def test_sht_gradient_flow(self, signal, sht_cuda): + """Test that gradients flow through SHTCUDA.""" + signal_grad = signal.clone().requires_grad_(True) + coeff = sht_cuda(signal_grad) + + # Create a scalar loss + loss = coeff.abs().sum() + loss.backward() + + assert signal_grad.grad is not None, "Gradient not computed for SHT" + assert not torch.isnan(signal_grad.grad).any(), "NaN in SHT gradient" + assert not torch.isinf(signal_grad.grad).any(), "Inf in SHT gradient" + + def test_isht_gradient_flow(self, signal, sht_autograd, isht_cuda): + """Test that gradients flow through iSHTCUDA.""" + coeff = sht_autograd(signal) + coeff_grad = coeff.clone().requires_grad_(True) + + signal_out = isht_cuda(coeff_grad) + + # Create a scalar loss + loss = signal_out.abs().sum() + loss.backward() + + assert coeff_grad.grad is not None, "Gradient not computed for iSHT" + assert not torch.isnan(coeff_grad.grad).any(), "NaN in iSHT gradient" + assert not torch.isinf(coeff_grad.grad).any(), "Inf in iSHT gradient" diff --git a/tests/test_harmonic_transform.py b/tests/test_harmonic_transform.py index 8429237..e57aa40 100644 --- a/tests/test_harmonic_transform.py +++ b/tests/test_harmonic_transform.py @@ -13,69 +13,92 @@ # See the License for the specific language governing permissions and # limitations under the License. -import random +"""Tests for spherical harmonic transforms (SHT/iSHT) accuracy.""" +import pytest import torch +from conftest import get_roundtrip_tol from cuhpx import SHT, iSHT -def random_fill_matrix(n, xmax, matrix): - for _ in range(n): - v = random.random() # Generate a random float number between 0 and 1 - x = random.randint(0, xmax - 1) # Generate a random int number, 0 <= x < xmax - y = random.randint(0, x) # Generate a random int number, 0 <= y <= x - matrix[x, y] = v # Fill the matrix at position (x, y) with the value v +@pytest.mark.cuda +def test_sht_isht_bandlimited_roundtrip(device, nside, lmax, mmax, dtype, complex_dtype): + """Test that SHT followed by iSHT recovers a band-limited signal. - return matrix + When a signal is strictly band-limited to lmax, the SHT->iSHT roundtrip + should recover it with high accuracy. + """ + sht = SHT(nside, lmax=lmax, mmax=mmax, quad_weights="ring").to(device) + isht = iSHT(nside, lmax=lmax, mmax=mmax).to(device) + # Create strictly band-limited signal via iSHT of known coefficients + # Only populate coefficients up to a fraction of lmax for strict band-limiting + coeffs = torch.zeros((lmax, mmax), dtype=complex_dtype, device=device) + max_l = min(lmax // 2, nside // 2) # Use half of lmax for safety margin + torch.manual_seed(42) + for l_idx in range(max_l): + for m_idx in range(min(l_idx + 1, mmax)): + coeffs[l_idx, m_idx] = torch.randn(1, dtype=dtype).item() -def generate_xyv(n, xmax, xmin): - v, x, y = [], [], [] - for _ in range(n): - vi = random.random() # Generate a random float number between 0 and 1 - xi = random.randint(xmin, xmax - 1) # Generate a random int number, 0 <= x < xmax - yi = random.randint(xmin, xi) # Generate a random int number, 0 <= y <= x + # Generate signal from coefficients + signal = isht(coeffs) - v.append(vi) - x.append(xi) - y.append(yi) + # Reconstruct via SHT + iSHT + reconstructed = isht(sht(signal)) - return x, y, v + # Roundtrip error is algorithm-limited, not precision-limited + rtol, atol = get_roundtrip_tol() + assert torch.allclose( + reconstructed, signal, rtol=rtol, atol=atol + ), f"Bandlimited roundtrip failed: max diff = {(reconstructed - signal).abs().max():.2e}" -def fill_matrix(x, y, v, matrix): +@pytest.mark.cuda +def test_sht_output_shape(device, nside, lmax, mmax, dtype): + """Test that SHT produces output with correct shape.""" + npix = 12 * nside**2 + sht = SHT(nside, lmax=lmax, mmax=mmax, quad_weights="ring").to(device) - n = len(x) - for i in range(n): - matrix[x[i], y[i]] = v[i] # Fill the matrix at position (x, y) with the value v - return matrix + signal = torch.randn(npix, dtype=dtype, device=device) + coeffs = sht(signal) + assert coeffs.shape == (lmax, mmax), f"Expected shape ({lmax}, {mmax}), got {coeffs.shape}" + assert coeffs.is_complex(), "SHT output should be complex" -xmax = 64 -xmin = 0 -xg, yg, vg = generate_xyv(100, xmax, xmin) -device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +@pytest.mark.cuda +def test_isht_output_shape(device, nside, lmax, mmax, complex_dtype): + """Test that iSHT produces output with correct shape.""" + npix = 12 * nside**2 + isht = iSHT(nside, lmax=lmax, mmax=mmax).to(device) -nside = int(input("Enter the nside value (>=32): ")) -npix = 12 * nside**2 -quad_weights = "ring" + coeffs = torch.randn(lmax, mmax, dtype=complex_dtype, device=device) + signal = isht(coeffs) -lmax = 2 * nside + 1 -lmax = 65 -mmax = lmax + assert signal.shape == (npix,), f"Expected shape ({npix},), got {signal.shape}" -sht = SHT(nside, lmax=lmax, mmax=mmax, quad_weights=quad_weights).to(device) -isht = iSHT(nside, lmax=lmax, mmax=mmax).to(device) -coeff_ori = torch.zeros((lmax, mmax), dtype=torch.complex128) -coeff_ori = fill_matrix(xg, yg, vg, coeff_ori).to(device) +@pytest.mark.cuda +def test_sht_isht_consistency(device, nside, lmax, mmax, dtype, complex_dtype): + """Test that SHT and iSHT are consistent inverses for bandlimited signals.""" + sht = SHT(nside, lmax=lmax, mmax=mmax, quad_weights="ring").to(device) + isht = iSHT(nside, lmax=lmax, mmax=mmax).to(device) -signal_ori = isht(coeff_ori) + # Start with coefficients and verify roundtrip + torch.manual_seed(42) + coeffs = torch.zeros((lmax, mmax), dtype=complex_dtype, device=device) + max_l = min(lmax // 2, nside // 2) + for l_idx in range(max_l): + for m_idx in range(min(l_idx + 1, mmax)): + coeffs[l_idx, m_idx] = torch.randn(2, dtype=dtype).sum() -diff = isht(sht(signal_ori)) - signal_ori -rms = torch.sqrt((diff.abs().pow(2)).mean()) -max_value = torch.max(diff.abs()) + # coeffs -> signal -> coeffs_back + signal = isht(coeffs) + coeffs_back = sht(signal) -print(f'nside={nside}, rms = {rms}, max difference = {max_value}') + # Roundtrip error is algorithm-limited, not precision-limited + rtol, atol = get_roundtrip_tol() + assert torch.allclose( + coeffs_back[:max_l, :max_l], coeffs[:max_l, :max_l], rtol=rtol, atol=atol + ), f"Coefficient roundtrip failed: max diff = {(coeffs_back[:max_l, :max_l] - coeffs[:max_l, :max_l]).abs().max():.2e}" diff --git a/tests/test_regridding.py b/tests/test_regridding.py index b32fde6..1e42c2e 100644 --- a/tests/test_regridding.py +++ b/tests/test_regridding.py @@ -13,79 +13,148 @@ # See the License for the specific language governing permissions and # limitations under the License. +"""Tests for grid regridding operations.""" + import random +import pytest import torch import cuhpx from cuhpx import Grid, Regridding -def random_fill_matrix(n, xmax, matrix): - for _ in range(n): - v = random.random() # Generate a random float number between 0 and 1 - x = random.randint(0, xmax - 1) # Generate a random int number, 0 <= x < xmax - y = random.randint(0, x) # Generate a random int number, 0 <= y <= x - matrix[x, y] = v # Fill the matrix at position (x, y) with the value v - - return matrix - - -def generate_xyv(n, xmax, xmin): - v, x, y = [], [], [] - for _ in range(n): - vi = random.random() # Generate a random float number between 0 and 1 - xi = random.randint(xmin, xmax - 1) # Generate a random int number, 0 <= x < xmax - yi = random.randint(xmin, xi) # Generate a random int number, 0 <= y <= x - - v.append(vi) - x.append(xi) - y.append(yi) - - return x, y, v - - -def fill_matrix(x, y, v, matrix): - - n = len(x) - for i in range(n): - matrix[x[i], y[i]] = v[i] # Fill the matrix at position (x, y) with the value v - return matrix - - -nside = int(input("Enter the nside value: ")) - -xmax = 2 * nside - 1 -xmin = 0 -xg, yg, vg = generate_xyv(100, xmax, xmin) - -device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - -npix = 12 * nside**2 - -lmax = 2 * nside + 1 -mmax = lmax - -# synthetic data with finite bandwidth -sht = cuhpx.SHT(nside, lmax=lmax, mmax=mmax) -isht = cuhpx.iSHT(nside, lmax=lmax, mmax=mmax) - -coeff = torch.zeros((lmax, mmax), dtype=torch.complex128) -coeff = fill_matrix(xg, yg, vg, coeff) -signal_hpx = isht(coeff).to(device) - -# regridding -src_grid = Grid('healpix', nside) -dest_grid = Grid('equiangular', (2 * nside, 4 * nside)) - -hpx2eq = Regridding(src_grid, dest_grid, lmax=lmax, mmax=mmax, device=device) -eq2hpx = Regridding(dest_grid, src_grid, lmax=lmax, mmax=mmax, device=device) - -signal_eq = hpx2eq.execute(signal_hpx) -signal_hpx_back = eq2hpx.execute(signal_eq) - -diff = signal_hpx_back - signal_hpx -rms = torch.sqrt((diff.pow(2)).mean()) -max_value = torch.max(diff.abs()) - -print(f'regridding error: nside={nside}, rms = {rms}, max difference = {max_value}') +def _generate_bandlimited_coeffs( + lmax, mmax, n_nonzero=100, seed=42, bandwidth_fraction=0.5, complex_dtype=torch.complex128 +): + """Generate sparse spherical harmonic coefficients for testing. + + Args: + lmax: Maximum degree l (size of coefficient array). + mmax: Maximum order m (size of coefficient array). + n_nonzero: Number of non-zero coefficients. + seed: Random seed for reproducibility. + bandwidth_fraction: Fraction of lmax/mmax to populate (default 0.5). + Using a fraction < 1.0 ensures the signal is strictly bandlimited + with headroom for perfect roundtrip through SHT. + complex_dtype: Complex dtype for coefficients (default torch.complex128). + + Returns: + Complex tensor of shape (lmax, mmax) with sparse coefficients. + """ + random.seed(seed) + coeff = torch.zeros((lmax, mmax), dtype=complex_dtype) + + # Limit the bandwidth to ensure strict bandlimiting + max_l = max(1, int(lmax * bandwidth_fraction)) + max_m = max(1, int(mmax * bandwidth_fraction)) + + for _ in range(n_nonzero): + v = random.random() + x = random.randint(0, max_l - 1) + y = random.randint(0, min(x, max_m - 1)) + coeff[x, y] = v + + return coeff + + +@pytest.mark.cuda +class TestRegridding: + """Test regridding between different grid types.""" + + def test_healpix_equiangular_roundtrip(self, device, nside, dtype, complex_dtype): + """Test HEALPix -> Equiangular -> HEALPix roundtrip.""" + # Use lmax that is compatible with both grids + # For equiangular grid of size (2*nside, 4*nside), lmax should be < 2*nside + lmax = 2 * nside - 1 + mmax = lmax + + src_grid = Grid("healpix", nside) + dest_grid = Grid("equiangular", (2 * nside, 4 * nside)) + + # Create bandlimited signal + isht = cuhpx.iSHT(nside, lmax=lmax, mmax=mmax) + coeff = _generate_bandlimited_coeffs(lmax, mmax, n_nonzero=50, complex_dtype=complex_dtype) + signal = isht(coeff).to(device).to(dtype) + + hpx2eq = Regridding(src_grid, dest_grid, lmax=lmax, mmax=mmax, device=device) + eq2hpx = Regridding(dest_grid, src_grid, lmax=lmax, mmax=mmax, device=device) + + signal_eq = hpx2eq.execute(signal) + signal_back = eq2hpx.execute(signal_eq) + + # For bandlimited signals, roundtrip should be accurate + # rtol=0.02 (2% relative tolerance), atol=0.05 for regridding which has more numerical error + assert torch.allclose( + signal_back, signal, rtol=0.02, atol=0.05 + ), f"Regridding roundtrip failed: max diff = {(signal_back - signal).abs().max():.2e}" + + def test_regridding_preserves_dtype(self, device, nside_small, dtype, complex_dtype): + """Test that regridding preserves data type.""" + nside = nside_small + lmax = 2 * nside - 1 + mmax = lmax + + isht = cuhpx.iSHT(nside, lmax=lmax, mmax=mmax) + coeff = _generate_bandlimited_coeffs(lmax, mmax, n_nonzero=50, complex_dtype=complex_dtype) + signal = isht(coeff).to(device).to(dtype) + + src_grid = Grid("healpix", nside) + dest_grid = Grid("equiangular", (2 * nside, 4 * nside)) + regrid = Regridding(src_grid, dest_grid, lmax=lmax, mmax=mmax, device=device) + + output = regrid.execute(signal) + assert output.dtype == dtype, f"Expected {dtype}, got {output.dtype}" + + def test_grid_creation(self): + """Test Grid class instantiation.""" + # HEALPix grid + hpx_grid = Grid("healpix", 64) + assert hpx_grid.grid == "healpix" + assert hpx_grid.nside == 64 + + # Equiangular grid + eq_grid = Grid("equiangular", (128, 256)) + assert eq_grid.grid == "equiangular" + assert eq_grid.nlat == 128 + assert eq_grid.nlon == 256 + + def test_regridding_output_shape(self, device, nside_small, dtype): + """Test that regridding produces correct output shapes.""" + nside = nside_small + lmax = 2 * nside - 1 + mmax = lmax + + npix = 12 * nside**2 + signal = torch.randn(npix, dtype=dtype, device=device) + + src_grid = Grid("healpix", nside) + nlat, nlon = 2 * nside, 4 * nside + dest_grid = Grid("equiangular", (nlat, nlon)) + + regrid = Regridding(src_grid, dest_grid, lmax=lmax, mmax=mmax, device=device) + output = regrid.execute(signal) + + expected_shape = (nlat, nlon) + assert output.shape == expected_shape, f"Expected shape {expected_shape}, got {output.shape}" + + def test_healpix_to_equiangular(self, device, nside_small, dtype): + """Test one-way regridding from HEALPix to equiangular grid.""" + nside = nside_small + lmax = 2 * nside - 1 + mmax = lmax + + npix = 12 * nside**2 + torch.manual_seed(42) + signal = torch.randn(npix, dtype=dtype, device=device) + + src_grid = Grid("healpix", nside) + dest_grid = Grid("equiangular", (2 * nside, 4 * nside)) + + regrid = Regridding(src_grid, dest_grid, lmax=lmax, mmax=mmax, device=device) + output = regrid.execute(signal) + + # Basic sanity checks + assert not torch.isnan(output).any(), "NaN in regridding output" + assert not torch.isinf(output).any(), "Inf in regridding output" + assert output.shape == (2 * nside, 4 * nside) diff --git a/tests/test_sht_bluestein.py b/tests/test_sht_bluestein.py index 3dc7679..e27c1c0 100644 --- a/tests/test_sht_bluestein.py +++ b/tests/test_sht_bluestein.py @@ -13,35 +13,91 @@ # See the License for the specific language governing permissions and # limitations under the License. +"""Tests for Bluestein FFT-based SHT implementation.""" + +import pytest import torch +from conftest import get_bluestein_tol, get_roundtrip_tol from cuhpx import SHT, iSHT -# Check if CUDA is available -device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') -nside = int(input('nside: ')) -lmax = int(input('lmax: ')) +@pytest.mark.cuda +class TestSHTBluestein: + """Test Bluestein FFT-based spherical harmonic transforms.""" + + @pytest.fixture + def signal(self, device, nside, dtype): + """Create a random test signal.""" + npix = 12 * nside**2 + return torch.randn(npix, dtype=dtype, device=device) + + @pytest.fixture + def sht_torch(self, device, nside, lmax, mmax): + """Create standard SHT transform.""" + return SHT(nside, lmax=lmax, mmax=mmax, quad_weights="ring").to(device) + + @pytest.fixture + def isht_torch(self, device, nside, lmax, mmax): + """Create standard iSHT transform.""" + return iSHT(nside, lmax=lmax, mmax=mmax).to(device) + + @pytest.fixture + def sht_bluestein(self, device, nside, lmax, mmax): + """Create Bluestein SHT transform.""" + return SHT(nside, lmax=lmax, mmax=mmax, quad_weights="ring", use_bluestein=True).to(device) + + @pytest.fixture + def isht_bluestein(self, device, nside, lmax, mmax): + """Create Bluestein iSHT transform.""" + return iSHT(nside, lmax=lmax, mmax=mmax, use_bluestein=True).to(device) -npix = 12 * nside**2 -signal = torch.randn(npix, dtype=torch.float32).to(device) + def test_sht_bluestein_matches_standard(self, signal, sht_torch, sht_bluestein): + """Test that Bluestein SHT produces same results as standard SHT.""" + coeff_torch = sht_torch(signal.clone()) + coeff_bluestein = sht_bluestein(signal.clone()) + # Cast to same dtype for comparison (standard SHT may use different precision) + coeff_bluestein = coeff_bluestein.to(coeff_torch.dtype) -quad_weights = 'ring' -mmax = lmax + # Use Bluestein-specific tolerances (algorithm-limited, not precision-limited) + rtol, atol = get_bluestein_tol("sht") + assert torch.allclose( + coeff_torch, coeff_bluestein, rtol=rtol, atol=atol + ), f"SHT Bluestein/standard mismatch: max diff = {(coeff_torch - coeff_bluestein).abs().max():.2e}" -sht_torch = SHT(nside, lmax=lmax, mmax=mmax, quad_weights=quad_weights).to(device) -isht_torch = iSHT(nside, lmax=lmax, mmax=mmax).to(device) + def test_isht_bluestein_matches_standard(self, signal, sht_torch, isht_torch, isht_bluestein): + """Test that Bluestein iSHT produces same results as standard iSHT.""" + coeff = sht_torch(signal) -sht_bs = SHT(nside, lmax=lmax, mmax=mmax, quad_weights=quad_weights, use_bluestein=True).to(device) -isht_bs = iSHT(nside, lmax=lmax, mmax=mmax, use_bluestein=True).to(device) + signal_torch = isht_torch(coeff.clone()) + signal_bluestein = isht_bluestein(coeff.clone()) -coeff = sht_torch(signal) + # Cast to same dtype for comparison (standard iSHT may use different precision) + signal_bluestein = signal_bluestein.to(signal_torch.dtype) -diff = sht_torch(torch.clone(signal)) - sht_bs(torch.clone(signal)) + # Use Bluestein-specific tolerances (algorithm-limited, not precision-limited) + rtol, atol = get_bluestein_tol("isht") + assert torch.allclose( + signal_torch, signal_bluestein, rtol=rtol, atol=atol + ), f"iSHT Bluestein/standard mismatch: max diff = {(signal_torch - signal_bluestein).abs().max():.2e}" -print('diff between sht torch and sht bluestein: ', torch.mean(diff.abs())) + def test_bluestein_roundtrip(self, device, nside, lmax, mmax, dtype, complex_dtype, sht_bluestein, isht_bluestein): + """Test SHT -> iSHT roundtrip with Bluestein implementation on bandlimited signal.""" + # Create bandlimited signal + torch.manual_seed(42) + coeffs = torch.zeros((lmax, mmax), dtype=complex_dtype, device=device) + max_l = min(lmax // 2, nside // 2) + for l_idx in range(max_l): + for m_idx in range(min(l_idx + 1, mmax)): + coeffs[l_idx, m_idx] = torch.randn(1, dtype=dtype).item() -diff = isht_torch(torch.clone(coeff)) - isht_bs(torch.clone(coeff)) + signal = isht_bluestein(coeffs) + coeff_back = sht_bluestein(signal) + signal_back = isht_bluestein(coeff_back) -print('diff between isht torch and sht bluestein: ', torch.mean(diff.abs())) + # Roundtrip error is algorithm-limited, not precision-limited + rtol, atol = get_roundtrip_tol() + assert torch.allclose( + signal_back, signal, rtol=rtol, atol=atol + ), f"Bluestein roundtrip failed: max diff = {(signal_back - signal).abs().max():.2e}" diff --git a/tests/test_sht_cuda.py b/tests/test_sht_cuda.py index 49cdc77..1024b38 100644 --- a/tests/test_sht_cuda.py +++ b/tests/test_sht_cuda.py @@ -13,32 +13,87 @@ # See the License for the specific language governing permissions and # limitations under the License. +"""Tests comparing CUDA-accelerated SHT/iSHT against PyTorch reference implementation.""" + +import pytest import torch +from conftest import get_impl_tol, get_roundtrip_tol from cuhpx import SHT, SHTCUDA, iSHT, iSHTCUDA -# Check if CUDA is available -device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') -nside = int(input('nside: ')) -lmax = int(input('lmax: ')) -npix = 12 * nside**2 -signal = torch.randn(npix, dtype=torch.float32).to(device) +@pytest.mark.cuda +def test_shtcuda_matches_sht(device, nside, lmax, mmax, dtype): + """Test that SHTCUDA produces same results as PyTorch SHT.""" + npix = 12 * nside**2 + torch.manual_seed(42) + signal = torch.randn(npix, dtype=dtype, device=device) + + sht = SHT(nside, lmax=lmax, mmax=mmax, quad_weights="ring").to(device) + sht_cuda = SHTCUDA(nside, lmax=lmax, mmax=mmax, quad_weights="ring") + + result_torch = sht(signal) + result_cuda = sht_cuda(signal) + + rtol, atol = get_impl_tol(dtype, "sht") + assert torch.allclose( + result_torch, result_cuda, rtol=rtol, atol=atol + ), f"SHTCUDA differs from SHT: max diff = {(result_torch - result_cuda).abs().max():.2e}" + + +@pytest.mark.cuda +def test_ishtcuda_matches_isht(device, nside, lmax, mmax, dtype, complex_dtype): + """Test that iSHTCUDA produces same results as PyTorch iSHT.""" + torch.manual_seed(42) + coeffs = torch.randn(lmax, mmax, dtype=complex_dtype, device=device) + + isht = iSHT(nside, lmax=lmax, mmax=mmax).to(device) + isht_cuda = iSHTCUDA(nside, lmax=lmax, mmax=mmax) + + result_torch = isht(coeffs.clone()) + result_cuda = isht_cuda(coeffs.clone()) + + rtol, atol = get_impl_tol(dtype, "isht") + assert torch.allclose( + result_torch, result_cuda, rtol=rtol, atol=atol + ), f"iSHTCUDA differs from iSHT: max diff = {(result_torch - result_cuda).abs().max():.2e}" + + +@pytest.mark.cuda +def test_shtcuda_isht_cuda_roundtrip(device, nside, lmax, mmax, dtype, complex_dtype): + """Test SHTCUDA + iSHTCUDA round-trip on bandlimited signals.""" + sht_cuda = SHTCUDA(nside, lmax=lmax, mmax=mmax, quad_weights="ring") + isht_cuda = iSHTCUDA(nside, lmax=lmax, mmax=mmax) + + # Create strictly band-limited signal + torch.manual_seed(42) + coeffs = torch.zeros((lmax, mmax), dtype=complex_dtype, device=device) + max_l = min(lmax // 2, nside // 2) # Use half of lmax for safety margin + for l_idx in range(max_l): + for m_idx in range(min(l_idx + 1, mmax)): + coeffs[l_idx, m_idx] = torch.randn(1, dtype=dtype).item() -quad_weights = 'ring' + signal = isht_cuda(coeffs) + reconstructed = isht_cuda(sht_cuda(signal)) -mmax = lmax + # Roundtrip error is algorithm-limited, not precision-limited + rtol, atol = get_roundtrip_tol() + assert torch.allclose( + reconstructed, signal, rtol=rtol, atol=atol + ), f"CUDA roundtrip failed: max diff = {(reconstructed - signal).abs().max():.2e}" -sht = SHT(nside, lmax=lmax, mmax=mmax, quad_weights=quad_weights).to(device) -isht = iSHT(nside, lmax=lmax, mmax=mmax).to(device) -coeff = sht(signal) +@pytest.mark.cuda +def test_shtcuda_output_consistency(device, nside, lmax, mmax, dtype): + """Test that SHTCUDA produces consistent output across multiple calls.""" + npix = 12 * nside**2 + torch.manual_seed(42) + signal = torch.randn(npix, dtype=dtype, device=device) -sht_cuda = SHTCUDA(nside, lmax=lmax, mmax=mmax, quad_weights=quad_weights) -isht_cuda = iSHTCUDA(nside, lmax=lmax, mmax=mmax) + sht_cuda = SHTCUDA(nside, lmax=lmax, mmax=mmax, quad_weights="ring") -diff = sht(signal) - sht_cuda(signal) -print('diff between pytorch and cuda, sht', torch.sqrt(torch.mean(diff.abs() ** 2))) + result1 = sht_cuda(signal.clone()) + result2 = sht_cuda(signal.clone()) -diff = isht(torch.clone(coeff)) - isht_cuda(torch.clone(coeff)) -print('diff between pytorch and cuda, isht', torch.sqrt(torch.mean(diff.abs() ** 2))) + # Results should be exactly identical for deterministic operations + assert torch.allclose(result1, result2, rtol=0, atol=0), "SHTCUDA produces inconsistent results" diff --git a/tests/test_sht_cuda_batch.py b/tests/test_sht_cuda_batch.py index 4a89ae7..51645e5 100644 --- a/tests/test_sht_cuda_batch.py +++ b/tests/test_sht_cuda_batch.py @@ -13,43 +13,101 @@ # See the License for the specific language governing permissions and # limitations under the License. +"""Tests for batched SHT/iSHT CUDA operations.""" + +import pytest import torch from cuhpx import SHTCUDA, iSHTCUDA -# Check if CUDA is available -device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') -nside = int(input('nside: ')) -m = int(input('m, the first dim: ')) -n = int(input('n, the second dim: ')) +@pytest.mark.cuda +class TestSHTCUDABatch: + """Test batched spherical harmonic transforms on CUDA.""" + + @pytest.fixture + def batch_dims(self): + """Return batch dimensions (m, n) for testing.""" + return (2, 3) + + @pytest.fixture + def batch_signal(self, device, nside, batch_dims, dtype): + """Create a batched random signal on device.""" + m, n = batch_dims + npix = 12 * nside**2 + return torch.randn(m, n, npix, dtype=dtype, device=device) + + @pytest.fixture + def sht(self, nside, lmax, mmax): + """Create SHT transform.""" + return SHTCUDA(nside, lmax=lmax, mmax=mmax, quad_weights="ring") + + @pytest.fixture + def isht(self, nside, lmax, mmax): + """Create iSHT transform.""" + return iSHTCUDA(nside, lmax=lmax, mmax=mmax) + + def test_sht_batch_matches_single(self, device, nside, batch_signal, sht, batch_dims, dtype): + """Test that batched SHT matches element-wise SHT.""" + m, n = batch_dims + + # Batched transform + coeff_batch = sht(batch_signal) + + # Element-wise transform + coeff_single = torch.zeros_like(coeff_batch) + for i in range(m): + for j in range(n): + coeff_single[i, j, :] = sht(batch_signal[i, j, :]) + + # Compare results - should be identical + # Use tighter tolerances for float64 + rtol = 1e-5 if dtype == torch.float32 else 1e-10 + atol = 1e-6 if dtype == torch.float32 else 1e-10 + assert torch.allclose( + coeff_batch, coeff_single, rtol=rtol, atol=atol + ), f"SHT batch/single mismatch: max diff = {(coeff_batch - coeff_single).abs().max():.2e}" + + def test_isht_batch_matches_single(self, device, nside, batch_signal, sht, isht, batch_dims, dtype): + """Test that batched iSHT matches element-wise iSHT.""" + m, n = batch_dims -npix = 12 * nside**2 -signal = torch.randn(m, n, npix, dtype=torch.float32).to(device) + # Get coefficients first + coeff = sht(batch_signal) -quad_weights = 'ring' -lmax = 2 * nside + 1 -mmax = lmax + # Batched inverse transform + signal_batch = isht(coeff) -sht = SHTCUDA(nside, lmax=lmax, mmax=mmax, quad_weights=quad_weights) -isht = iSHTCUDA(nside, lmax=lmax, mmax=mmax) + # Element-wise inverse transform + signal_single = torch.zeros_like(signal_batch) + for i in range(m): + for j in range(n): + signal_single[i, j, :] = isht(coeff[i, j, :]) -coeff = sht(signal) -c = torch.zeros_like(coeff) + # Compare results - should be identical + # Use tighter tolerances for float64 + rtol = 1e-5 if dtype == torch.float32 else 1e-10 + atol = 1e-6 if dtype == torch.float32 else 1e-10 + assert torch.allclose( + signal_batch, signal_single, rtol=rtol, atol=atol + ), f"iSHT batch/single mismatch: max diff = {(signal_batch - signal_single).abs().max():.2e}" -for i in range(m): - for j in range(n): - c[i, j, :] = sht(signal[i, j, :]) + @pytest.mark.parametrize("batch_shape", [(4,), (2, 2), (2, 2, 2)]) + def test_various_batch_shapes(self, device, nside_small, batch_shape, dtype): + """Test batched SHT with various batch dimensions.""" + npix = 12 * nside_small**2 + lmax = 3 * nside_small - 1 + mmax = lmax -diff = (coeff - c).abs() -print('diff between batch and single, sht', torch.sqrt(torch.mean(diff.abs() ** 2))) + signal = torch.randn(*batch_shape, npix, dtype=dtype, device=device) -s1 = isht(coeff) -s2 = torch.zeros_like(s1) + sht = SHTCUDA(nside_small, lmax=lmax, mmax=mmax, quad_weights="ring") + isht = iSHTCUDA(nside_small, lmax=lmax, mmax=mmax) -for i in range(m): - for j in range(n): - s2[i, j, :] = isht(coeff[i, j, :]) + # Forward transform + coeff = sht(signal) + assert coeff.shape[:-2] == batch_shape, f"Expected batch shape {batch_shape}, got {coeff.shape[:-2]}" -diff = s1 - s2 -print('diff between batch and single, isht', torch.sqrt(torch.mean(diff.abs() ** 2))) + # Inverse transform + signal_back = isht(coeff) + assert signal_back.shape == signal.shape, f"Shape mismatch: {signal_back.shape} vs {signal.shape}" diff --git a/tests/test_sht_cuda_stream.py b/tests/test_sht_cuda_stream.py index 9d59044..8968970 100644 --- a/tests/test_sht_cuda_stream.py +++ b/tests/test_sht_cuda_stream.py @@ -13,46 +13,72 @@ # See the License for the specific language governing permissions and # limitations under the License. -import cuhpx_fft +"""Tests for CUDA stream support in SHT operations.""" + +import pytest import torch -import torch.cuda.nvtx as nvtx - -# Check if CUDA is available -device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - -# User input -nside = int(input('nside: ')) -m = int(input('m, the first dim: ')) -n = int(input('n, the second dim: ')) - -npix = 12 * nside**2 -signal1 = torch.randn(m, n, npix, dtype=torch.float32).to(device) -signal2 = torch.clone(signal1) - -quad_weights = 'ring' -lmax = 2 * nside + 1 -mmax = lmax - -# Create two CUDA streams -stream1 = torch.cuda.Stream() -stream2 = torch.cuda.Stream() - -# Perform SHT on stream1 -nvtx.range_push("Stream 1 SHT Operation") -with torch.cuda.stream(stream1): - result1 = cuhpx_fft.healpix_rfft_batch(signal1, nside, nside) -stream1.synchronize() -nvtx.range_pop() - -# Perform SHT on stream2 -nvtx.range_push("Stream 2 SHT Operation") -with torch.cuda.stream(stream2): - result2 = cuhpx_fft.healpix_rfft_batch(signal2, nside, nside) -stream2.synchronize() -nvtx.range_pop() - -# Synchronize streams and compare results -nvtx.range_push("Compare Results") -comparison = torch.allclose(result1, result2) -print("Are the results from the two streams identical?", comparison) -nvtx.range_pop() + +from cuhpx import cuhpx_fft + + +@pytest.mark.cuda +class TestCUDAStreams: + """Test CUDA stream support for SHT operations.""" + + @pytest.fixture + def batch_signal(self, device, nside_small, dtype): + """Create a batched signal for stream tests.""" + m, n = 2, 3 + npix = 12 * nside_small**2 + return torch.randn(m, n, npix, dtype=dtype, device=device) + + def test_different_streams_same_result(self, device, nside_small, batch_signal): + """Test that operations on different streams produce identical results.""" + signal1 = batch_signal.clone() + signal2 = batch_signal.clone() + + nside = nside_small + + # Create two CUDA streams + stream1 = torch.cuda.Stream() + stream2 = torch.cuda.Stream() + + # Perform SHT on stream1 + with torch.cuda.stream(stream1): + result1 = cuhpx_fft.healpix_rfft_batch(signal1, nside, nside) + stream1.synchronize() + + # Perform SHT on stream2 + with torch.cuda.stream(stream2): + result2 = cuhpx_fft.healpix_rfft_batch(signal2, nside, nside) + stream2.synchronize() + + # Results should be identical + assert torch.allclose(result1, result2), "Results from different streams should be identical" + + def test_stream_does_not_corrupt_data(self, device, nside_small, batch_signal): + """Test that stream execution doesn't corrupt input data.""" + original = batch_signal.clone() + signal = batch_signal.clone() + + nside = nside_small + + stream = torch.cuda.Stream() + + with torch.cuda.stream(stream): + _ = cuhpx_fft.healpix_rfft_batch(signal, nside, nside) + stream.synchronize() + + # Input should not be modified + assert torch.allclose(signal, original), "Stream operation corrupted input data" + + def test_default_stream_works(self, device, nside_small, dtype): + """Test that operations work on default stream.""" + npix = 12 * nside_small**2 + signal = torch.randn(npix, dtype=dtype, device=device) + + # Should work without explicit stream + result = cuhpx_fft.healpix_rfft_batch(signal.unsqueeze(0), nside_small, nside_small) + + assert result is not None + assert not torch.isnan(result).any(), "NaN in result" diff --git a/tests/test_sht_two_stream_overlap_profiling.py b/tests/test_sht_two_stream_overlap_profiling.py deleted file mode 100755 index df4d8d2..0000000 --- a/tests/test_sht_two_stream_overlap_profiling.py +++ /dev/null @@ -1,37 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch - -from cuhpx import SHTCUDA - -device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - -nside = int(input('nside: ')) -nbatch = int(input('batch size: ')) -npix = 12 * nside**2 - -signal = torch.randn(nbatch, npix, dtype=torch.float32).to(device) - -quad_weights = 'ring' -lmax = int(input('lmax: ')) -mmax = lmax - -sht = SHTCUDA(nside, lmax=lmax, mmax=mmax, quad_weights=quad_weights) - -for _ in range(10): - torch.cuda.nvtx.range_push("SHTCUDA batch") - coeff = sht(signal) - torch.cuda.nvtx.range_pop() From 7a895af3840d1e1e4a8ef3caa4a3cbcc9b9f5510 Mon Sep 17 00:00:00 2001 From: Akshay Subramaniam <6964110+akshaysubr@users.noreply.github.com> Date: Wed, 21 Jan 2026 17:28:04 -0800 Subject: [PATCH 2/4] Adding fix for CUDA graphs compatibility and tests for that --- cuhpx/hpx_sht.py | 186 +++++++++---------- src/harmonic_transform/hpx_fft.cpp | 82 ++++++--- src/harmonic_transform/hpx_fft.h | 272 +++++++++++++++------------- tests/test_cuda_graphs.py | 279 +++++++++++++++++++++++++++++ 4 files changed, 573 insertions(+), 246 deletions(-) create mode 100644 tests/test_cuda_graphs.py diff --git a/cuhpx/hpx_sht.py b/cuhpx/hpx_sht.py index dc942e8..ef754e6 100644 --- a/cuhpx/hpx_sht.py +++ b/cuhpx/hpx_sht.py @@ -35,7 +35,6 @@ def healpix_rfft_torch(f: torch.tensor, L: int, nside: int) -> torch.tensor: - index = 0 ctype = torch.complex64 if f.dtype == torch.float32 else torch.complex128 ftm = torch.zeros(ftm_shape(L, "healpix", nside), dtype=ctype, device=f.device) @@ -56,13 +55,11 @@ def healpix_rfft_torch(f: torch.tensor, L: int, nside: int) -> torch.tensor: def healpix_irfft_torch(ftm: torch.tensor, L: int, nside: int) -> torch.tensor: - ftype = torch.float if ftm.dtype == torch.complex64 else torch.double f = torch.zeros(f_shape(sampling="healpix", nside=nside), dtype=ftype, device=ftm.device) ntheta = ftm.shape[0] index = 0 for t in range(ntheta): - phi_ring_offset = p2phi_ring(t, 0, nside) phase_shift = torch.exp(1j * torch.arange(L, device=ftm.device) * phi_ring_offset) ftm[t, :] *= phase_shift @@ -77,7 +74,6 @@ def healpix_irfft_torch(ftm: torch.tensor, L: int, nside: int) -> torch.tensor: def healpix_irfft_bluestein(ftm: torch.tensor, L: int, nside: int) -> torch.tensor: - f = torch.zeros(12 * nside**2, dtype=torch.double, device=ftm.device) ntheta = ftm.shape[0] @@ -126,7 +122,6 @@ def healpix_irfft_bluestein(ftm: torch.tensor, L: int, nside: int) -> torch.tens def healpix_rfft_bluestein(f: torch.tensor, L: int, nside: int) -> torch.tensor: - ftm = torch.zeros((4 * nside - 1, L), dtype=torch.complex128, device=f.device) ntheta = ftm.shape[0] @@ -171,7 +166,6 @@ def healpix_rfft_bluestein(f: torch.tensor, L: int, nside: int) -> torch.tensor: class SHT(nn.Module): - def __init__( self, nside, @@ -183,7 +177,6 @@ def __init__( csphase=True, use_bluestein=False, ): - super().__init__() self.nside = nside @@ -208,11 +201,10 @@ def __init__( pct = _precompute_legpoly(self.mmax, self.lmax, tq, norm=self.norm, csphase=self.csphase) pct = torch.from_numpy(pct) - weights = torch.einsum('mlk,k->mlk', pct, weights) - self.register_buffer('weights', weights, persistent=False) + weights = torch.einsum("mlk,k->mlk", pct, weights) + self.register_buffer("weights", weights, persistent=False) def forward(self, x: torch.Tensor): - if torch.is_complex(x): raise ValueError("Input tensor must be real.") @@ -229,17 +221,15 @@ def forward(self, x: torch.Tensor): xout = torch.zeros(out_shape, dtype=x.dtype, device=x.device) # contraction - xout[..., 0] = torch.einsum('...km,mlk->...lm', x[..., : self.mmax, 0], self.weights.to(x.dtype)) - xout[..., 1] = torch.einsum('...km,mlk->...lm', x[..., : self.mmax, 1], self.weights.to(x.dtype)) + xout[..., 0] = torch.einsum("...km,mlk->...lm", x[..., : self.mmax, 0], self.weights.to(x.dtype)) + xout[..., 1] = torch.einsum("...km,mlk->...lm", x[..., : self.mmax, 1], self.weights.to(x.dtype)) x = torch.view_as_complex(xout) return x class iSHT(nn.Module): - def __init__(self, nside, lmax=None, mmax=None, grid="healpix", norm="ortho", csphase=True, use_bluestein=False): - super().__init__() self.nside = nside @@ -252,7 +242,7 @@ def __init__(self, nside, lmax=None, mmax=None, grid="healpix", norm="ortho", cs self.use_bluestein = use_bluestein if self.grid == "healpix": - cost, _ = healpix_weights(nside, 'none') + cost, _ = healpix_weights(nside, "none") self.lmax = lmax or self.nlat else: raise (ValueError("Unknown quadrature mode")) @@ -263,14 +253,13 @@ def __init__(self, nside, lmax=None, mmax=None, grid="healpix", norm="ortho", cs pct = _precompute_legpoly(self.mmax, self.lmax, t, norm=self.norm, inverse=True, csphase=self.csphase) pct = torch.from_numpy(pct) - self.register_buffer('pct', pct, persistent=False) + self.register_buffer("pct", pct, persistent=False) def forward(self, x: torch.Tensor): - x = torch.view_as_real(x) - rl = torch.einsum('...lm, mlk->...km', x[..., 0], self.pct.to(x.dtype)) - im = torch.einsum('...lm, mlk->...km', x[..., 1], self.pct.to(x.dtype)) + rl = torch.einsum("...lm, mlk->...km", x[..., 0], self.pct.to(x.dtype)) + im = torch.einsum("...lm, mlk->...km", x[..., 1], self.pct.to(x.dtype)) xs = torch.stack((rl, im), -1) x = torch.view_as_complex(xs) @@ -284,7 +273,6 @@ def forward(self, x: torch.Tensor): class VectorSHT(nn.Module): - def __init__( self, nside, @@ -296,7 +284,6 @@ def __init__( csphase=True, use_bluestein=False, ): - super().__init__() self.nside = nside @@ -325,14 +312,13 @@ def __init__( l = torch.arange(0, self.lmax) # noqa: E741 norm_factor = 1.0 / l / (l + 1) norm_factor[0] = 1.0 - weights = torch.einsum('dmlk,k,l->dmlk', dpct, weights, norm_factor) + weights = torch.einsum("dmlk,k,l->dmlk", dpct, weights, norm_factor) weights[1] = -1 * weights[1] - self.register_buffer('weights', weights, persistent=False) + self.register_buffer("weights", weights, persistent=False) def forward(self, x: torch.Tensor): - if torch.is_complex(x): raise ValueError("Input tensor must be real.") @@ -353,31 +339,29 @@ def forward(self, x: torch.Tensor): # contraction - spheroidal component # real component xout[..., 0, :, :, 0] = torch.einsum( - '...km,mlk->...lm', x[..., 0, :, : self.mmax, 0], self.weights[0].to(x.dtype) - ) - torch.einsum('...km,mlk->...lm', x[..., 1, :, : self.mmax, 1], self.weights[1].to(x.dtype)) + "...km,mlk->...lm", x[..., 0, :, : self.mmax, 0], self.weights[0].to(x.dtype) + ) - torch.einsum("...km,mlk->...lm", x[..., 1, :, : self.mmax, 1], self.weights[1].to(x.dtype)) # iamg component xout[..., 0, :, :, 1] = torch.einsum( - '...km,mlk->...lm', x[..., 0, :, : self.mmax, 1], self.weights[0].to(x.dtype) - ) + torch.einsum('...km,mlk->...lm', x[..., 1, :, : self.mmax, 0], self.weights[1].to(x.dtype)) + "...km,mlk->...lm", x[..., 0, :, : self.mmax, 1], self.weights[0].to(x.dtype) + ) + torch.einsum("...km,mlk->...lm", x[..., 1, :, : self.mmax, 0], self.weights[1].to(x.dtype)) # contraction - toroidal component # real component xout[..., 1, :, :, 0] = -torch.einsum( - '...km,mlk->...lm', x[..., 0, :, : self.mmax, 1], self.weights[1].to(x.dtype) - ) - torch.einsum('...km,mlk->...lm', x[..., 1, :, : self.mmax, 0], self.weights[0].to(x.dtype)) + "...km,mlk->...lm", x[..., 0, :, : self.mmax, 1], self.weights[1].to(x.dtype) + ) - torch.einsum("...km,mlk->...lm", x[..., 1, :, : self.mmax, 0], self.weights[0].to(x.dtype)) # imag component xout[..., 1, :, :, 1] = torch.einsum( - '...km,mlk->...lm', x[..., 0, :, : self.mmax, 0], self.weights[1].to(x.dtype) - ) - torch.einsum('...km,mlk->...lm', x[..., 1, :, : self.mmax, 1], self.weights[0].to(x.dtype)) + "...km,mlk->...lm", x[..., 0, :, : self.mmax, 0], self.weights[1].to(x.dtype) + ) - torch.einsum("...km,mlk->...lm", x[..., 1, :, : self.mmax, 1], self.weights[0].to(x.dtype)) return torch.view_as_complex(xout) class VectoriSHT(nn.Module): - def __init__(self, nside, lmax=None, mmax=None, grid="healpix", norm="ortho", csphase=True, use_bluestein=False): - super().__init__() self.nside = nside @@ -390,7 +374,7 @@ def __init__(self, nside, lmax=None, mmax=None, grid="healpix", norm="ortho", cs self.use_bluestein = use_bluestein if self.grid == "healpix": - cost, _ = healpix_weights(nside, 'none') + cost, _ = healpix_weights(nside, "none") self.lmax = lmax or self.nlat else: raise (ValueError("Unknown quadrature mode")) @@ -401,30 +385,29 @@ def __init__(self, nside, lmax=None, mmax=None, grid="healpix", norm="ortho", cs dpct = _precompute_dlegpoly(self.mmax, self.lmax, t, norm=self.norm, inverse=True, csphase=self.csphase) dpct = torch.from_numpy(dpct) - self.register_buffer('dpct', dpct, persistent=False) + self.register_buffer("dpct", dpct, persistent=False) def forward(self, x: torch.Tensor): - x = torch.view_as_real(x) # contraction - spheroidal component # real component - srl = torch.einsum('...lm,mlk->...km', x[..., 0, :, :, 0], self.dpct[0].to(x.dtype)) - torch.einsum( - '...lm,mlk->...km', x[..., 1, :, :, 1], self.dpct[1].to(x.dtype) + srl = torch.einsum("...lm,mlk->...km", x[..., 0, :, :, 0], self.dpct[0].to(x.dtype)) - torch.einsum( + "...lm,mlk->...km", x[..., 1, :, :, 1], self.dpct[1].to(x.dtype) ) # iamg component - sim = torch.einsum('...lm,mlk->...km', x[..., 0, :, :, 1], self.dpct[0].to(x.dtype)) + torch.einsum( - '...lm,mlk->...km', x[..., 1, :, :, 0], self.dpct[1].to(x.dtype) + sim = torch.einsum("...lm,mlk->...km", x[..., 0, :, :, 1], self.dpct[0].to(x.dtype)) + torch.einsum( + "...lm,mlk->...km", x[..., 1, :, :, 0], self.dpct[1].to(x.dtype) ) # contraction - toroidal component # real component - trl = -torch.einsum('...lm,mlk->...km', x[..., 0, :, :, 1], self.dpct[1].to(x.dtype)) - torch.einsum( - '...lm,mlk->...km', x[..., 1, :, :, 0], self.dpct[0].to(x.dtype) + trl = -torch.einsum("...lm,mlk->...km", x[..., 0, :, :, 1], self.dpct[1].to(x.dtype)) - torch.einsum( + "...lm,mlk->...km", x[..., 1, :, :, 0], self.dpct[0].to(x.dtype) ) # imag component - tim = torch.einsum('...lm,mlk->...km', x[..., 0, :, :, 0], self.dpct[1].to(x.dtype)) - torch.einsum( - '...lm,mlk->...km', x[..., 1, :, :, 1], self.dpct[0].to(x.dtype) + tim = torch.einsum("...lm,mlk->...km", x[..., 0, :, :, 0], self.dpct[1].to(x.dtype)) - torch.einsum( + "...lm,mlk->...km", x[..., 1, :, :, 1], self.dpct[0].to(x.dtype) ) # reassemble @@ -445,7 +428,6 @@ def forward(self, x: torch.Tensor): def einsum_with_chunking(x, weights, mmax, xout, nchunk, stream1): - device = torch.device("cuda") chunk_size = int(weights.size(1) / nchunk + 1) # Adjust this based on your memory constraints @@ -461,7 +443,6 @@ def einsum_with_chunking(x, weights, mmax, xout, nchunk, stream1): torch.cuda.current_stream().synchronize() for i in range(0, weights.size(1), chunk_size): - start_i, end_i = i, min(i + chunk_size, weights.size(1)) actual_chunk_size = end_i - start_i @@ -478,7 +459,7 @@ def einsum_with_chunking(x, weights, mmax, xout, nchunk, stream1): event_transfer.record(stream1) xout[..., start_j:end_j, :, :] = torch.einsum( - '...kmn,mlk->...lmn', x, current_chunk[:, : end_j - start_j, :].to(x.dtype) + "...kmn,mlk->...lmn", x, current_chunk[:, : end_j - start_j, :].to(x.dtype) ) event_computation.record(torch.cuda.current_stream()) @@ -489,7 +470,7 @@ def einsum_with_chunking(x, weights, mmax, xout, nchunk, stream1): if start_i < weights.size(1): xout[..., start_i:end_i, :, :] = torch.einsum( - '...kmn,mlk->...lmn', x, current_chunk[:, : end_i - start_i, :].to(x.dtype) + "...kmn,mlk->...lmn", x, current_chunk[:, : end_i - start_i, :].to(x.dtype) ) stream1.synchronize() @@ -499,11 +480,10 @@ def einsum_with_chunking(x, weights, mmax, xout, nchunk, stream1): class SHTFunction(Function): - @staticmethod - def forward(ctx, x, weights, pct, W, mmax, lmax, nside): - + def forward(ctx, x, pct_weights, pct, W, mmax, lmax, nside): # Init + # pct_weights is pre-computed (pct * weights) for CUDA graph compatibility ctx.save_for_backward(pct, W) ctx.mmax = mmax ctx.lmax = lmax @@ -517,20 +497,17 @@ def forward(ctx, x, weights, pct, W, mmax, lmax, nside): x = torch.view_as_real(x) - out_shape = list(x.size()) - out_shape[-3] = lmax - out_shape[-2] = mmax - - xout = torch.zeros(out_shape, dtype=x.dtype, device=x.device) - - weights = pct * weights - + # Use einsum directly with pre-computed weights (no allocation for weights) if not pct.is_cuda: + out_shape = list(x.size()) + out_shape[-3] = lmax + out_shape[-2] = mmax + xout = torch.zeros(out_shape, dtype=x.dtype, device=x.device) nchunk = 12 stream1 = torch.cuda.Stream() - xout = einsum_with_chunking(x, weights, mmax, xout, nchunk, stream1) + xout = einsum_with_chunking(x, pct_weights, mmax, xout, nchunk, stream1) else: - xout = torch.einsum('...kmn,mlk->...lmn', x, weights.to(x.dtype)) + xout = torch.einsum("...kmn,mlk->...lmn", x, pct_weights) x = torch.view_as_complex(xout.contiguous()) @@ -538,14 +515,13 @@ def forward(ctx, x, weights, pct, W, mmax, lmax, nside): @staticmethod def backward(ctx, grad_output): - # adjoint iSHT pct, W = ctx.saved_tensors mmax, nside = ctx.mmax, ctx.nside x = torch.view_as_real(grad_output) - xs = torch.einsum('...lmn, mlk->...kmn', x, pct.to(x.dtype)) + xs = torch.einsum("...lmn, mlk->...kmn", x, pct.to(x.dtype)) grad_input = torch.view_as_complex(xs.contiguous()) if grad_input.dim() == 2: @@ -559,18 +535,19 @@ def backward(ctx, grad_output): class iSHTFunction(Function): - @staticmethod - def forward(ctx, x, weights, pct, W, mmax, lmax, nside): - - ctx.save_for_backward(weights, pct, W) + def forward(ctx, x, pct_weights, pct, W, mmax, lmax, nside): + # pct_weights is pre-computed (pct * weights) for backward pass + # pct is pre-computed in correct dtype for CUDA graph compatibility + ctx.save_for_backward(pct_weights, pct, W) ctx.mmax = mmax ctx.lmax = lmax ctx.nside = nside x = torch.view_as_real(x) - xs = torch.einsum('...lmn, mlk->...kmn', x, pct.to(x.dtype)) + # pct is already in correct dtype, no conversion needed + xs = torch.einsum("...lmn, mlk->...kmn", x, pct) x = torch.view_as_complex(xs.contiguous()) @@ -583,10 +560,9 @@ def forward(ctx, x, weights, pct, W, mmax, lmax, nside): @staticmethod def backward(ctx, grad_output): - # adjoint SHT - weights, pct, W = ctx.saved_tensors - mmax, lmax, nside = ctx.mmax, ctx.lmax, ctx.nside + pct_weights, pct, W = ctx.saved_tensors + mmax, _lmax, nside = ctx.mmax, ctx.lmax, ctx.nside x = grad_output / W.to(grad_output.dtype) @@ -597,16 +573,8 @@ def backward(ctx, grad_output): x = torch.view_as_real(x) - out_shape = list(x.size()) - out_shape[-3] = lmax - out_shape[-2] = mmax - - xout = torch.zeros(out_shape, dtype=x.dtype, device=x.device) - - weights = pct * weights - weights = weights.to(x.device) - - xout = torch.einsum('...kmn,mlk->...lmn', x, weights.to(x.dtype)) + # Use pre-computed pct_weights directly (no allocation) + xout = torch.einsum("...kmn,mlk->...lmn", x, pct_weights) grad_input = torch.view_as_complex(xout.contiguous()) return grad_input, None, None, None, None, None, None @@ -614,7 +582,6 @@ def backward(ctx, grad_output): class SHTCUDA(nn.Module): def __init__(self, nside, lmax=None, mmax=None, quad_weights="ring", norm="ortho", csphase=True): - super().__init__() self.nside = nside self.norm = norm @@ -629,7 +596,7 @@ def __init__(self, nside, lmax=None, mmax=None, quad_weights="ring", norm="ortho if not torch.cuda.is_available(): raise RuntimeError("CUDA device is not available. This class requires a GPU.") - self.device = torch.device('cuda') + self.device = torch.device("cuda") # quadrature weights cost, w = healpix_weights(nside, self.quad_weights) @@ -649,23 +616,30 @@ def __init__(self, nside, lmax=None, mmax=None, quad_weights="ring", norm="ortho W = W_helper(w, nside) W = W.to(torch.float).to(self.device) - self.register_buffer('weights', weights, persistent=False) - self.register_buffer('pct', pct, persistent=False) - self.register_buffer('W', W, persistent=False) + self.register_buffer("weights", weights, persistent=False) + self.register_buffer("pct", pct, persistent=False) + self.register_buffer("W", W, persistent=False) - def forward(self, x): + # Pre-compute pct * weights for CUDA graph compatibility (avoids allocation during forward) + pct_weights = pct * weights + pct_weights_f64 = pct_weights.double() + self.register_buffer("pct_weights", pct_weights, persistent=False) + self.register_buffer("pct_weights_f64", pct_weights_f64, persistent=False) + def forward(self, x): if torch.is_complex(x): raise ValueError("Input tensor must be real.") - with torch.cuda.stream(self.stream): - return SHTFunction.apply(x, self.weights, self.pct, self.W, self.mmax, self.lmax, self.nside) + # Use pre-computed weights based on input dtype for CUDA graph compatibility + # Note: No stream context manager - runs on caller's stream (required for CUDA graph capture) + if x.dtype == torch.float64: + return SHTFunction.apply(x, self.pct_weights_f64, self.pct, self.W, self.mmax, self.lmax, self.nside) + else: + return SHTFunction.apply(x, self.pct_weights, self.pct, self.W, self.mmax, self.lmax, self.nside) class iSHTCUDA(nn.Module): - def __init__(self, nside, lmax=None, mmax=None, quad_weights="ring", norm="ortho", csphase=True): - super().__init__() self.nside = nside self.norm = norm @@ -675,12 +649,11 @@ def __init__(self, nside, lmax=None, mmax=None, quad_weights="ring", norm="ortho self.quad_weights = quad_weights self.lmax = lmax or self.nlat self.mmax = mmax or (self.nlon // 2 + 1) - self.stream = torch.cuda.current_stream() if not torch.cuda.is_available(): raise RuntimeError("CUDA device is not available. This class requires a GPU.") - self.device = torch.device('cuda') + self.device = torch.device("cuda") # quadrature weights cost, w = healpix_weights(nside, self.quad_weights) @@ -696,11 +669,24 @@ def __init__(self, nside, lmax=None, mmax=None, quad_weights="ring", norm="ortho W = W_helper(w, nside) W = W.to(torch.float).to(self.device) - self.register_buffer('weights', weights, persistent=False) - self.register_buffer('pct', pct, persistent=False) - self.register_buffer('W', W, persistent=False) + self.register_buffer("weights", weights, persistent=False) + self.register_buffer("pct", pct, persistent=False) + self.register_buffer("W", W, persistent=False) - def forward(self, x): + # Pre-compute pct in both dtypes for CUDA graph compatibility + pct_f64 = pct.double() + self.register_buffer("pct_f64", pct_f64, persistent=False) + + # Pre-compute pct * weights for backward pass (adjoint SHT) + pct_weights = pct * weights + pct_weights_f64 = pct_weights.double() + self.register_buffer("pct_weights", pct_weights, persistent=False) + self.register_buffer("pct_weights_f64", pct_weights_f64, persistent=False) - with torch.cuda.stream(self.stream): - return iSHTFunction.apply(x, self.weights, self.pct, self.W, self.mmax, self.lmax, self.nside) + def forward(self, x): + # Use pre-computed pct based on input dtype for CUDA graph compatibility + # Note: No stream context manager - runs on caller's stream (required for CUDA graph capture) + if x.real.dtype == torch.float64: + return iSHTFunction.apply(x, self.pct_weights_f64, self.pct_f64, self.W, self.mmax, self.lmax, self.nside) + else: + return iSHTFunction.apply(x, self.pct_weights, self.pct, self.W, self.mmax, self.lmax, self.nside) diff --git a/src/harmonic_transform/hpx_fft.cpp b/src/harmonic_transform/hpx_fft.cpp index 9ec28ac..84c94f6 100644 --- a/src/harmonic_transform/hpx_fft.cpp +++ b/src/harmonic_transform/hpx_fft.cpp @@ -47,10 +47,10 @@ torch::Tensor healpix_rfft_batch(torch::Tensor f, int L, int nside) { // Create FFT object and initialize y_pad if not already done static HealpixFFT* fft = nullptr; - if (!fft || fft->needsReconfiguration(ntheta, n, padding, dtype, device)) { + if (!fft || fft->needsReconfiguration(ntheta, n, padding, L, dtype, device)) { delete fft; // Properly deallocate existing object - fft = new HealpixFFT(ntheta, n, padding, dtype, device, stream); + fft = new HealpixFFT(ntheta, n, padding, L, dtype, device, stream); } else { // If reconfiguration is not needed, ensure the stream is up-to-date fft->updateStreamIfNeeded(stream); @@ -65,14 +65,23 @@ torch::Tensor healpix_rfft_batch(torch::Tensor f, int L, int nside) { x_pad_size.push_back(ntheta); x_pad_size.push_back(padding); - auto ftm = torch::zeros(ftm_size, torch::dtype(dtype).device(device)); - auto x_pad = torch::zeros(x_pad_size, torch::dtype(dtype).device(device)); + // Use pre-allocated buffers and zero in-place (CUDA graph compatible) + torch::Tensor& x_pad_storage = fft->getXpad(); + torch::Tensor& ftm_storage = fft->getFtm(); + + // Zero in-place - this is CUDA graph compatible + x_pad_storage.zero_(); + ftm_storage.zero_(); + + // Create views with proper batch shape (slice for smaller batches) + auto x_pad = x_pad_storage.slice(0, 0, n * ntheta).view(x_pad_size); + auto ftm = ftm_storage.slice(0, 0, n * ntheta).view(ftm_size); rfft_pre_process_x_pad_batch_dispatch(x_pad, f, padding, nside, order, stream); fft->execute_forward(x_pad); - x_pad = x_pad * fft->getYpad(); + x_pad.mul_(fft->getYpad()); fft->execute_inverse(x_pad); @@ -118,27 +127,35 @@ torch::Tensor healpix_irfft_batch(torch::Tensor ftm, int L, int nside) { x_pad_size.push_back(ntheta); x_pad_size.push_back(padding); - auto f = torch::zeros(f_size, torch::dtype(ftype).device(device)); - auto x_pad = torch::zeros(x_pad_size, torch::dtype(dtype).device(device)); - // Instantiate FFT object static HealpixIFFT* ifft = nullptr; - if (!ifft || ifft->needsReconfiguration(ntheta, n, padding, dtype, device)) { + if (!ifft || ifft->needsReconfiguration(ntheta, n, padding, nside, dtype, device)) { delete ifft; // Properly deallocate existing object - ifft = new HealpixIFFT(ntheta, n, padding, dtype, device, stream); + ifft = new HealpixIFFT(ntheta, n, padding, nside, dtype, device, stream); } else { // If reconfiguration is not needed, ensure the stream is up-to-date ifft->updateStreamIfNeeded(stream); } ifft->initializeYpad(nside); + // Use pre-allocated buffers and zero in-place (CUDA graph compatible) + torch::Tensor& x_pad_storage = ifft->getXpad(); + torch::Tensor& f_storage = ifft->getF(); + + // Zero in-place - this is CUDA graph compatible + x_pad_storage.zero_(); + f_storage.zero_(); + + // Create views with proper batch shape (slice for smaller batches) + auto x_pad = x_pad_storage.slice(0, 0, n * ntheta).view(x_pad_size); + auto f = f_storage.slice(0, 0, n).view(f_size); + irfft_phase_shift_batch_dispatch(ftm, L, nside, stream); irfft_pre_process_x_pad_batch_dispatch(ftm, x_pad, L, padding, nside, order, stream); ifft->execute_forward(x_pad); - x_pad = x_pad * ifft->getYpad(); - //x_y_pad_conv_batch_dispatch(x_pad, ifft->getYpad(), padding, nside); + x_pad.mul_(ifft->getYpad()); ifft->execute_inverse(x_pad); @@ -164,18 +181,26 @@ torch::Tensor healpix_rfft_class(torch::Tensor f, int L, int nside) { // Create FFT object and initialize y_pad if not already done static HealpixFFT* fft = nullptr; - if (!fft || fft->needsReconfiguration(ntheta, 1, padding, dtype, device)) { + if (!fft || fft->needsReconfiguration(ntheta, 1, padding, L, dtype, device)) { delete fft; // Properly deallocate existing object - fft = new HealpixFFT(ntheta, 1, padding, dtype, device, stream); + fft = new HealpixFFT(ntheta, 1, padding, L, dtype, device, stream); } else { // If reconfiguration is not needed, ensure the stream is up-to-date fft->updateStreamIfNeeded(stream); } fft->initializeYpad(nside); - // Allocate tensors and perform FFT operations - auto ftm = torch::zeros({ntheta, L}, torch::dtype(dtype).device(device)); - auto x_pad = torch::zeros({ntheta, padding}, torch::dtype(dtype).device(device)); + // Use pre-allocated buffers and zero in-place (CUDA graph compatible) + torch::Tensor& x_pad_storage = fft->getXpad(); + torch::Tensor& ftm_storage = fft->getFtm(); + + // Zero in-place - this is CUDA graph compatible + x_pad_storage.zero_(); + ftm_storage.zero_(); + + // Create views with proper shape + auto x_pad = x_pad_storage.view({ntheta, padding}); + auto ftm = ftm_storage.view({ntheta, L}); rfft_pre_process_x_pad_dispatch(x_pad, f, padding, nside, stream); fft->execute_forward(x_pad); @@ -184,8 +209,6 @@ torch::Tensor healpix_rfft_class(torch::Tensor f, int L, int nside) { fft->execute_inverse(x_pad); - // x_pad.div_(padding); - rfft_post_process_dispatch(x_pad, ftm, L, padding, nside, stream); rfft_phase_shift_dispatch(ftm, L, nside, stream); @@ -199,29 +222,36 @@ torch::Tensor healpix_irfft_class(torch::Tensor ftm, int L, int nside) { auto device = ftm.device(); auto dtype = ftm.scalar_type() == torch::kComplexDouble ? torch::kComplexDouble : torch::kComplexFloat; - auto ftype = dtype == torch::kComplexDouble ? torch::kDouble : torch::kFloat; // Retrieve the current CUDA stream at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream(); // Use CUDAStreamGuard to set the current stream (though it's already set, this makes it explicit) at::cuda::CUDAStreamGuard guard(stream); - - auto f = torch::zeros({12 * nside * nside}, torch::dtype(ftype).device(device)); - auto x_pad = torch::zeros({ntheta, padding}, torch::dtype(dtype).device(device)); - // Instantiate FFT object static HealpixIFFT* ifft = nullptr; - if (!ifft || ifft->needsReconfiguration(ntheta, 1, padding, dtype, device)) { + if (!ifft || ifft->needsReconfiguration(ntheta, 1, padding, nside, dtype, device)) { delete ifft; // Properly deallocate existing object - ifft = new HealpixIFFT(ntheta, 1, padding, dtype, device, stream); + ifft = new HealpixIFFT(ntheta, 1, padding, nside, dtype, device, stream); } else { // If reconfiguration is not needed, ensure the stream is up-to-date ifft->updateStreamIfNeeded(stream); } ifft->initializeYpad(nside); + // Use pre-allocated buffers and zero in-place (CUDA graph compatible) + torch::Tensor& x_pad_storage = ifft->getXpad(); + torch::Tensor& f_storage = ifft->getF(); + + // Zero in-place - this is CUDA graph compatible + x_pad_storage.zero_(); + f_storage.zero_(); + + // Create views with proper shape + auto x_pad = x_pad_storage.view({ntheta, padding}); + auto f = f_storage.view({12 * nside * nside}); + irfft_phase_shift_dispatch(ftm, L, nside, stream); irfft_pre_process_x_pad_dispatch(ftm, x_pad, L, padding, nside, stream); diff --git a/src/harmonic_transform/hpx_fft.h b/src/harmonic_transform/hpx_fft.h index 33a3a3c..d79ff96 100644 --- a/src/harmonic_transform/hpx_fft.h +++ b/src/harmonic_transform/hpx_fft.h @@ -66,26 +66,30 @@ template inline int compute_order(I nside); void rfft_pre_process_x_pad_batch_float4_dispatch(torch::Tensor x_pad, torch::Tensor f, int padding, int nside, int order, at::cuda::CUDAStream& stream); -class HealpixFFT{ -public: - // Constructor initializes only the essential variables and FFT plan - HealpixFFT(int ntheta, int n, int padding, torch::Dtype dtype, torch::Device device, at::cuda::CUDAStream& stream) - : ntheta_(ntheta), n_(n), padding_(padding), dtype_(dtype), device_(device), y_pad_initialized_(false), stream_(stream){ - - if (dtype == torch::kComplexDouble) { - checkCuFFTError(cufftPlan1d(&plan_, padding_, CUFFT_Z2Z, ntheta_ * n)); - } else if (dtype == torch::kComplexFloat) { - checkCuFFTError(cufftPlan1d(&plan_, padding_, CUFFT_C2C, ntheta_ * n)); - } else { - throw std::runtime_error("Unsupported data type for FFT plan."); - } - - // Set the CUDA stream for the FFT plan - checkCuFFTError(cufftSetStream(plan_, stream_.stream())); - - // Allocate memory for y_pad tensor - y_pad_ = torch::zeros({ntheta, padding}, torch::dtype(dtype).device(device)); - } +class HealpixFFT{ +public: + // Constructor initializes only the essential variables and FFT plan + HealpixFFT(int ntheta, int n, int padding, int L, torch::Dtype dtype, torch::Device device, at::cuda::CUDAStream stream) + : ntheta_(ntheta), n_(n), padding_(padding), L_(L), dtype_(dtype), device_(device), y_pad_initialized_(false), stream_(stream){ + + if (dtype == torch::kComplexDouble) { + checkCuFFTError(cufftPlan1d(&plan_, padding_, CUFFT_Z2Z, ntheta_ * n)); + } else if (dtype == torch::kComplexFloat) { + checkCuFFTError(cufftPlan1d(&plan_, padding_, CUFFT_C2C, ntheta_ * n)); + } else { + throw std::runtime_error("Unsupported data type for FFT plan."); + } + + // Set the CUDA stream for the FFT plan + checkCuFFTError(cufftSetStream(plan_, stream_.stream())); + + // Allocate memory for y_pad tensor + y_pad_ = torch::zeros({ntheta, padding}, torch::dtype(dtype).device(device)); + + // Pre-allocate workspace buffers for CUDA graph compatibility + x_pad_ = torch::empty({ntheta * n, padding}, torch::dtype(dtype).device(device)); + ftm_ = torch::empty({ntheta * n, L}, torch::dtype(dtype).device(device)); + } ~HealpixFFT() { checkCuFFTError(cufftDestroy(plan_)); @@ -124,63 +128,81 @@ class HealpixFFT{ return y_pad_; } - // Method to check if reconfiguration is needed - bool needsReconfiguration(int ntheta, int n, int padding, torch::Dtype dtype, torch::Device device) const { - return ntheta_ != ntheta || n_ != n || padding_ != padding || dtype_ != dtype || device_ != device; - } - - // Method to check and update the stream (non-const) - void updateStreamIfNeeded(at::cuda::CUDAStream& stream) { - - // Set the plan to the new stream - checkCuFFTError(cufftSetStream(plan_, stream.stream())); - stream_ = stream; - } - - - // Getters for current configuration - int getNtheta() const { return ntheta_; } - int getPadding() const { return padding_; } - torch::Dtype getDtype() const { return dtype_; } - torch::Device getDevice() const { return device_; } - -private: - cufftHandle plan_; - int ntheta_; - int n_; - int padding_; - torch::Dtype dtype_; - torch::Device device_; - torch::Tensor y_pad_; - bool y_pad_initialized_; - at::cuda::CUDAStream& stream_; - - void checkCuFFTError(cufftResult result) { - if (result != CUFFT_SUCCESS) { - throw std::runtime_error("CUFFT error: " + std::to_string(result)); - } - } -}; - - -class HealpixIFFT { -public: - HealpixIFFT(int ntheta, int n, int padding, torch::Dtype dtype, torch::Device device, at::cuda::CUDAStream& stream) - : ntheta_(ntheta), n_(n), padding_(padding), dtype_(dtype), device_(device), y_pad_initialized_(false), stream_(stream){ - - if (dtype == torch::kComplexDouble) { - checkCuFFTError(cufftPlan1d(&plan_, padding_, CUFFT_Z2Z, ntheta_ * n)); - } else if (dtype == torch::kComplexFloat) { - checkCuFFTError(cufftPlan1d(&plan_, padding_, CUFFT_C2C, ntheta_ * n)); - } else { - throw std::runtime_error("Unsupported data type for FFT plan."); - } - - // Set the CUDA stream for the FFT plan - checkCuFFTError(cufftSetStream(plan_, stream_.stream())); - - y_pad_ = torch::zeros({ntheta, padding}, torch::dtype(dtype).device(device)); - } + // Method to check if reconfiguration is needed + // Only reconfigure when batch grows larger (n > n_) or other config changes + bool needsReconfiguration(int ntheta, int n, int padding, int L, torch::Dtype dtype, torch::Device device) const { + return ntheta_ != ntheta || n_ < n || padding_ != padding || L_ != L || dtype_ != dtype || device_ != device; + } + + // Method to check and update the stream (non-const) + void updateStreamIfNeeded(at::cuda::CUDAStream stream) { + + // Set the plan to the new stream + checkCuFFTError(cufftSetStream(plan_, stream.stream())); + stream_ = stream; + } + + + // Getters for current configuration + int getNtheta() const { return ntheta_; } + int getPadding() const { return padding_; } + int getL() const { return L_; } + int getN() const { return n_; } + torch::Dtype getDtype() const { return dtype_; } + torch::Device getDevice() const { return device_; } + + // Accessors for pre-allocated buffers + torch::Tensor& getXpad() { return x_pad_; } + torch::Tensor& getFtm() { return ftm_; } + +private: + cufftHandle plan_; + int ntheta_; + int n_; + int padding_; + int L_; + torch::Dtype dtype_; + torch::Device device_; + torch::Tensor y_pad_; + torch::Tensor x_pad_; + torch::Tensor ftm_; + bool y_pad_initialized_; + at::cuda::CUDAStream stream_; + + void checkCuFFTError(cufftResult result) { + if (result != CUFFT_SUCCESS) { + throw std::runtime_error("CUFFT error: " + std::to_string(result)); + } + } +}; + + +class HealpixIFFT { +public: + HealpixIFFT(int ntheta, int n, int padding, int nside, torch::Dtype dtype, torch::Device device, at::cuda::CUDAStream stream) + : ntheta_(ntheta), n_(n), padding_(padding), nside_(nside), dtype_(dtype), device_(device), y_pad_initialized_(false), stream_(stream){ + + if (dtype == torch::kComplexDouble) { + checkCuFFTError(cufftPlan1d(&plan_, padding_, CUFFT_Z2Z, ntheta_ * n)); + } else if (dtype == torch::kComplexFloat) { + checkCuFFTError(cufftPlan1d(&plan_, padding_, CUFFT_C2C, ntheta_ * n)); + } else { + throw std::runtime_error("Unsupported data type for FFT plan."); + } + + // Set the CUDA stream for the FFT plan + checkCuFFTError(cufftSetStream(plan_, stream_.stream())); + + y_pad_ = torch::zeros({ntheta, padding}, torch::dtype(dtype).device(device)); + + // Pre-allocate workspace buffers for CUDA graph compatibility + x_pad_ = torch::empty({ntheta * n, padding}, torch::dtype(dtype).device(device)); + + // Output tensor type is real (not complex) + auto ftype = dtype == torch::kComplexDouble ? torch::kDouble : torch::kFloat; + int npix = 12 * nside * nside; + f_ = torch::empty({n, npix}, torch::dtype(ftype).device(device)); + } ~HealpixIFFT() { checkCuFFTError(cufftDestroy(plan_)); @@ -214,46 +236,56 @@ class HealpixIFFT { } } - const torch::Tensor& getYpad() const { - return y_pad_; - } - - // Method to check if reconfiguration is needed - bool needsReconfiguration(int ntheta, int n, int padding, torch::Dtype dtype, torch::Device device) const { - return ntheta_ != ntheta || n_ != n || padding_ != padding || dtype_ != dtype || device_ != device; - } - - // Method to check and update the stream (non-const) - void updateStreamIfNeeded(at::cuda::CUDAStream& stream) { - // Set the plan to the new stream - checkCuFFTError(cufftSetStream(plan_, stream.stream())); - stream_ = stream; - } - - // Getters for current configuration - int getNtheta() const { return ntheta_; } - int getPadding() const { return padding_; } - torch::Dtype getDtype() const { return dtype_; } - torch::Device getDevice() const { return device_; } - -private: - cufftHandle plan_; - int ntheta_; - int n_; - int padding_; - torch::Dtype dtype_; - torch::Device device_; - torch::Tensor y_pad_; - bool y_pad_initialized_; - at::cuda::CUDAStream& stream_; - - void checkCuFFTError(cufftResult result) { - if (result != CUFFT_SUCCESS) { - throw std::runtime_error("CUFFT error: " + std::to_string(result)); - } - } -}; - - - -#endif // HPX_FFT_H + const torch::Tensor& getYpad() const { + return y_pad_; + } + + // Method to check if reconfiguration is needed + // Only reconfigure when batch grows larger (n > n_) or other config changes + bool needsReconfiguration(int ntheta, int n, int padding, int nside, torch::Dtype dtype, torch::Device device) const { + return ntheta_ != ntheta || n_ < n || padding_ != padding || nside_ != nside || dtype_ != dtype || device_ != device; + } + + // Method to check and update the stream (non-const) + void updateStreamIfNeeded(at::cuda::CUDAStream stream) { + // Set the plan to the new stream + checkCuFFTError(cufftSetStream(plan_, stream.stream())); + stream_ = stream; + } + + // Getters for current configuration + int getNtheta() const { return ntheta_; } + int getPadding() const { return padding_; } + int getNside() const { return nside_; } + int getN() const { return n_; } + torch::Dtype getDtype() const { return dtype_; } + torch::Device getDevice() const { return device_; } + + // Accessors for pre-allocated buffers + torch::Tensor& getXpad() { return x_pad_; } + torch::Tensor& getF() { return f_; } + +private: + cufftHandle plan_; + int ntheta_; + int n_; + int padding_; + int nside_; + torch::Dtype dtype_; + torch::Device device_; + torch::Tensor y_pad_; + torch::Tensor x_pad_; + torch::Tensor f_; + bool y_pad_initialized_; + at::cuda::CUDAStream stream_; + + void checkCuFFTError(cufftResult result) { + if (result != CUFFT_SUCCESS) { + throw std::runtime_error("CUFFT error: " + std::to_string(result)); + } + } +}; + + + +#endif // HPX_FFT_H diff --git a/tests/test_cuda_graphs.py b/tests/test_cuda_graphs.py new file mode 100644 index 0000000..8ac64b8 --- /dev/null +++ b/tests/test_cuda_graphs.py @@ -0,0 +1,279 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for CUDA graph compatibility of SHTCUDA/iSHTCUDA.""" + +import pytest +import torch + +from cuhpx import SHTCUDA, iSHTCUDA + + +@pytest.mark.cuda +def test_shtcuda_cuda_graph_capture(device, nside_small, dtype): + """Test that SHTCUDA can be captured in a CUDA graph.""" + lmax = 3 * nside_small - 1 + mmax = lmax + npix = 12 * nside_small**2 + + sht = SHTCUDA(nside_small, lmax=lmax, mmax=mmax) + + # Warmup (triggers lazy initialization) + signal = torch.randn(npix, dtype=dtype, device=device) + _ = sht(signal) + + # Prepare input for graph capture + static_input = torch.randn(npix, dtype=dtype, device=device) + + # Warmup the specific input size + _ = sht(static_input.clone()) + torch.cuda.synchronize() + + g = torch.cuda.CUDAGraph() + + # Capture graph - use the default stream, not a separate one + with torch.cuda.graph(g): + static_output = sht(static_input) + + torch.cuda.synchronize() + + # Prepare test input and compute expected result BEFORE replay + new_input = torch.randn(npix, dtype=dtype, device=device) + + # Copy to static input and replay + static_input.copy_(new_input) + g.replay() + torch.cuda.synchronize() + + # Copy result before computing expected (to avoid interference) + graph_result = static_output.clone() + + # Compute expected using direct call + expected = sht(new_input) + + rtol = 1e-4 if dtype == torch.float32 else 1e-8 + atol = 1e-5 if dtype == torch.float32 else 1e-10 + assert torch.allclose( + graph_result, expected, rtol=rtol, atol=atol + ), f"CUDA graph output differs from direct call: max diff = {(graph_result - expected).abs().max():.2e}" + + +@pytest.mark.cuda +def test_ishtcuda_cuda_graph_capture(device, nside_small, dtype, complex_dtype): + """Test that iSHTCUDA can be captured in a CUDA graph.""" + lmax = 3 * nside_small - 1 + mmax = lmax + + isht = iSHTCUDA(nside_small, lmax=lmax, mmax=mmax) + + # Warmup + coeffs = torch.randn(lmax, mmax, dtype=complex_dtype, device=device) + _ = isht(coeffs.clone()) + torch.cuda.synchronize() + + # Capture graph + static_input = torch.randn(lmax, mmax, dtype=complex_dtype, device=device) + + g = torch.cuda.CUDAGraph() + + with torch.cuda.graph(g): + static_output = isht(static_input) + + torch.cuda.synchronize() + + # Replay with new data + new_input = torch.randn(lmax, mmax, dtype=complex_dtype, device=device) + static_input.copy_(new_input) + g.replay() + torch.cuda.synchronize() + + # Copy result before computing expected + graph_result = static_output.clone() + + # Verify correctness + expected = isht(new_input) + rtol = 1e-3 if dtype == torch.float32 else 1e-5 + atol = 1e-2 if dtype == torch.float32 else 1e-5 + assert torch.allclose( + graph_result, expected, rtol=rtol, atol=atol + ), f"CUDA graph output differs from direct call: max diff = {(graph_result - expected).abs().max():.2e}" + + +@pytest.mark.cuda +def test_shtcuda_batch_cuda_graph(device, nside_small, dtype): + """Test SHTCUDA with batched input in CUDA graph.""" + lmax = 3 * nside_small - 1 + mmax = lmax + npix = 12 * nside_small**2 + batch_size = 4 + + sht = SHTCUDA(nside_small, lmax=lmax, mmax=mmax) + + # Warmup with batch + signal = torch.randn(batch_size, npix, dtype=dtype, device=device) + _ = sht(signal) + torch.cuda.synchronize() + + # Capture graph + static_input = torch.randn(batch_size, npix, dtype=dtype, device=device) + + g = torch.cuda.CUDAGraph() + + with torch.cuda.graph(g): + static_output = sht(static_input) + + torch.cuda.synchronize() + + # Replay + new_input = torch.randn(batch_size, npix, dtype=dtype, device=device) + static_input.copy_(new_input) + g.replay() + torch.cuda.synchronize() + + # Copy result before computing expected + graph_result = static_output.clone() + + expected = sht(new_input) + rtol = 1e-4 if dtype == torch.float32 else 1e-8 + atol = 1e-5 if dtype == torch.float32 else 1e-10 + assert torch.allclose( + graph_result, expected, rtol=rtol, atol=atol + ), f"CUDA graph batch output differs: max diff = {(graph_result - expected).abs().max():.2e}" + + +@pytest.mark.cuda +def test_ishtcuda_batch_cuda_graph(device, nside_small, dtype, complex_dtype): + """Test iSHTCUDA with batched input in CUDA graph.""" + lmax = 3 * nside_small - 1 + mmax = lmax + batch_size = 4 + + isht = iSHTCUDA(nside_small, lmax=lmax, mmax=mmax) + + # Warmup with batch + coeffs = torch.randn(batch_size, lmax, mmax, dtype=complex_dtype, device=device) + _ = isht(coeffs.clone()) + torch.cuda.synchronize() + + # Capture graph + static_input = torch.randn(batch_size, lmax, mmax, dtype=complex_dtype, device=device) + + g = torch.cuda.CUDAGraph() + + with torch.cuda.graph(g): + static_output = isht(static_input) + + torch.cuda.synchronize() + + # Replay + new_input = torch.randn(batch_size, lmax, mmax, dtype=complex_dtype, device=device) + static_input.copy_(new_input) + g.replay() + torch.cuda.synchronize() + + # Copy result before computing expected + graph_result = static_output.clone() + + expected = isht(new_input) + rtol = 1e-3 if dtype == torch.float32 else 1e-5 + atol = 1e-2 if dtype == torch.float32 else 1e-5 + assert torch.allclose( + graph_result, expected, rtol=rtol, atol=atol + ), f"CUDA graph batch output differs: max diff = {(graph_result - expected).abs().max():.2e}" + + +@pytest.mark.cuda +def test_roundtrip_cuda_graph(device, nside_small, dtype): + """Test SHT+iSHT roundtrip in CUDA graph.""" + lmax = 3 * nside_small - 1 + mmax = lmax + npix = 12 * nside_small**2 + + sht = SHTCUDA(nside_small, lmax=lmax, mmax=mmax) + isht = iSHTCUDA(nside_small, lmax=lmax, mmax=mmax) + + # Warmup + signal = torch.randn(npix, dtype=dtype, device=device) + _ = isht(sht(signal)) + torch.cuda.synchronize() + + # Capture roundtrip + static_input = torch.randn(npix, dtype=dtype, device=device) + + g = torch.cuda.CUDAGraph() + + with torch.cuda.graph(g): + coeffs = sht(static_input) + static_output = isht(coeffs) + + torch.cuda.synchronize() + + # Replay + new_input = torch.randn(npix, dtype=dtype, device=device) + static_input.copy_(new_input) + g.replay() + torch.cuda.synchronize() + + # Copy result before computing expected + graph_result = static_output.clone() + + expected = isht(sht(new_input)) + rtol = 1e-3 if dtype == torch.float32 else 1e-5 + atol = 1e-2 if dtype == torch.float32 else 1e-5 + assert torch.allclose( + graph_result, expected, rtol=rtol, atol=atol + ), f"CUDA graph roundtrip differs: max diff = {(graph_result - expected).abs().max():.2e}" + + +@pytest.mark.cuda +def test_cuda_graph_multiple_replays(device, nside_small, dtype): + """Test that CUDA graph can be replayed multiple times correctly.""" + lmax = 3 * nside_small - 1 + mmax = lmax + npix = 12 * nside_small**2 + + sht = SHTCUDA(nside_small, lmax=lmax, mmax=mmax) + + # Warmup + _ = sht(torch.randn(npix, dtype=dtype, device=device)) + torch.cuda.synchronize() + + # Capture + static_input = torch.randn(npix, dtype=dtype, device=device) + + g = torch.cuda.CUDAGraph() + + with torch.cuda.graph(g): + static_output = sht(static_input) + + torch.cuda.synchronize() + + # Multiple replays + rtol = 1e-4 if dtype == torch.float32 else 1e-8 + atol = 1e-5 if dtype == torch.float32 else 1e-10 + + for i in range(5): + new_input = torch.randn(npix, dtype=dtype, device=device) + static_input.copy_(new_input) + g.replay() + torch.cuda.synchronize() + + # Copy result before computing expected + graph_result = static_output.clone() + + expected = sht(new_input) + assert torch.allclose( + graph_result, expected, rtol=rtol, atol=atol + ), f"CUDA graph replay {i} failed: max diff = {(graph_result - expected).abs().max():.2e}" From 6a1f247b6a63c96eeb3da37db974f9f7d0dc1915 Mon Sep 17 00:00:00 2001 From: Akshay Subramaniam <6964110+akshaysubr@users.noreply.github.com> Date: Tue, 24 Feb 2026 21:30:04 -0800 Subject: [PATCH 3/4] Updating error tolerances to pass on multiple system and GPU architectures --- tests/conftest.py | 2 +- tests/test_sht_cuda_batch.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 3a9784e..2276605 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -83,7 +83,7 @@ def complex_dtype(dtype): # For CUDA vs PyTorch implementation comparison (same algorithm, different impl) IMPL_COMPARISON_TOL = { - torch.float32: {"sht": (1e-4, 1e-5), "isht": (1e-3, 1e-2)}, + torch.float32: {"sht": (1e-4, 2e-5), "isht": (1e-3, 1e-1)}, torch.float64: {"sht": (1e-8, 1e-8), "isht": (1e-5, 1e-5)}, } diff --git a/tests/test_sht_cuda_batch.py b/tests/test_sht_cuda_batch.py index 51645e5..a6f67c1 100644 --- a/tests/test_sht_cuda_batch.py +++ b/tests/test_sht_cuda_batch.py @@ -63,7 +63,7 @@ def test_sht_batch_matches_single(self, device, nside, batch_signal, sht, batch_ # Compare results - should be identical # Use tighter tolerances for float64 rtol = 1e-5 if dtype == torch.float32 else 1e-10 - atol = 1e-6 if dtype == torch.float32 else 1e-10 + atol = 1e-4 if dtype == torch.float32 else 1e-10 assert torch.allclose( coeff_batch, coeff_single, rtol=rtol, atol=atol ), f"SHT batch/single mismatch: max diff = {(coeff_batch - coeff_single).abs().max():.2e}" From 00c3958a7c2b68b3f206d3201b2db60da8821495 Mon Sep 17 00:00:00 2001 From: Akshay Subramaniam <6964110+akshaysubr@users.noreply.github.com> Date: Thu, 5 Mar 2026 16:58:25 -0800 Subject: [PATCH 4/4] Addressing review comments --- cuhpx/hpx_sht.py | 6 +++++- src/harmonic_transform/hpx_fft.h | 2 +- tests/test_sht_bluestein.py | 2 +- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/cuhpx/hpx_sht.py b/cuhpx/hpx_sht.py index ef754e6..193f59d 100644 --- a/cuhpx/hpx_sht.py +++ b/cuhpx/hpx_sht.py @@ -620,6 +620,10 @@ def __init__(self, nside, lmax=None, mmax=None, quad_weights="ring", norm="ortho self.register_buffer("pct", pct, persistent=False) self.register_buffer("W", W, persistent=False) + # Pre-compute pct in both dtypes for CUDA graph compatibility + pct_f64 = pct.double() + self.register_buffer("pct_f64", pct_f64, persistent=False) + # Pre-compute pct * weights for CUDA graph compatibility (avoids allocation during forward) pct_weights = pct * weights pct_weights_f64 = pct_weights.double() @@ -633,7 +637,7 @@ def forward(self, x): # Use pre-computed weights based on input dtype for CUDA graph compatibility # Note: No stream context manager - runs on caller's stream (required for CUDA graph capture) if x.dtype == torch.float64: - return SHTFunction.apply(x, self.pct_weights_f64, self.pct, self.W, self.mmax, self.lmax, self.nside) + return SHTFunction.apply(x, self.pct_weights_f64, self.pct_f64, self.W, self.mmax, self.lmax, self.nside) else: return SHTFunction.apply(x, self.pct_weights, self.pct, self.W, self.mmax, self.lmax, self.nside) diff --git a/src/harmonic_transform/hpx_fft.h b/src/harmonic_transform/hpx_fft.h index d79ff96..4e54ee8 100644 --- a/src/harmonic_transform/hpx_fft.h +++ b/src/harmonic_transform/hpx_fft.h @@ -69,7 +69,7 @@ void rfft_pre_process_x_pad_batch_float4_dispatch(torch::Tensor x_pad, torch::Te class HealpixFFT{ public: // Constructor initializes only the essential variables and FFT plan - HealpixFFT(int ntheta, int n, int padding, int L, torch::Dtype dtype, torch::Device device, at::cuda::CUDAStream stream) + HealpixFFT(int ntheta, int n, int padding, int L, torch::Dtype dtype, torch::Device device, const at::cuda::CUDAStream &stream) : ntheta_(ntheta), n_(n), padding_(padding), L_(L), dtype_(dtype), device_(device), y_pad_initialized_(false), stream_(stream){ if (dtype == torch::kComplexDouble) { diff --git a/tests/test_sht_bluestein.py b/tests/test_sht_bluestein.py index e27c1c0..e0a75b5 100644 --- a/tests/test_sht_bluestein.py +++ b/tests/test_sht_bluestein.py @@ -90,7 +90,7 @@ def test_bluestein_roundtrip(self, device, nside, lmax, mmax, dtype, complex_dty max_l = min(lmax // 2, nside // 2) for l_idx in range(max_l): for m_idx in range(min(l_idx + 1, mmax)): - coeffs[l_idx, m_idx] = torch.randn(1, dtype=dtype).item() + coeffs[l_idx, m_idx] = torch.randn(1, dtype=dtype).item() + 1j * torch.randn(1, dtype=dtype).item() signal = isht_bluestein(coeffs) coeff_back = sht_bluestein(signal)