This repository implements DWT 3DGS, an enhanced version of 3D Gaussian Splatting that incorporates Discrete Wavelet Transform (DWT) loss functions to improve high-frequency detail preservation and reconstruction quality.
DWT 3DGS extends the standard 3D Gaussian Splatting pipeline by incorporating wavelet-domain loss functions. The method applies a 2-level Haar wavelet decomposition to both predicted and ground truth images, computing Charbonnier losses on selected subbands. This enables better preservation of fine details and improved reconstruction of high-frequency content.
The wavelet decomposition separates images into multiple frequency subbands:
- LL (Low-Low): Low-frequency approximation containing the main structure
- LH (Low-High): Horizontal high-frequency details
- HL (High-Low): Vertical high-frequency details
- HH (High-High): Diagonal high-frequency details
By weighting different subbands, the method can emphasize low-frequency structure while still capturing important high-frequency details.
DWT 3DGS enhances 3D Gaussian Splatting by:
- Wavelet Decomposition: Applying 2-level Haar wavelet transform to decompose images into frequency subbands
- Multi-scale Loss: Computing Charbonnier losses on selected wavelet subbands (LL1, LL2, and optionally high-frequency bands)
- Adaptive Scaling: Using running-mean ratio scaling to balance DWT loss with the base L1 + SSIM loss
- GPU Optimization: Fast GPU-accelerated wavelet decomposition using pure PyTorch operations
The method is particularly effective at preserving high-frequency details that are often lost in standard reconstruction approaches, while maintaining the real-time rendering capabilities of 3D Gaussian Splatting.
This codebase supports training on multispectral datasets, which capture information across multiple spectral bands beyond the visible RGB spectrum. Multispectral imaging enables enhanced analysis and reconstruction of scenes with rich spectral information, making it valuable for applications in agriculture, remote sensing, and scientific imaging.
Important: For multispectral datasets, you should run the multispectral DWT 3DGS variant. The multispectral version extends the standard DWT loss computation to work across all spectral bands, ensuring consistent quality and detail preservation across the full spectrum.
- Python 3.8+
- PyTorch (CUDA-enabled recommended)
- CUDA SDK 11+
- Conda (recommended)
- Clone the repository with submodules:
git clone <repository-url> --recursive
cd gaussian-splatting-highfrequncy-in-low-frequncy-3- Create and activate the conda environment:
conda env create --file environment.yml
conda activate gaussian_splatting- Install the CUDA extensions:
pip install submodules/diff-gaussian-rasterization
pip install submodules/simple-knnTo train a model with DWT loss enabled (default):
python train.py -s <path to COLMAP or NeRF Synthetic dataset>python train.py -s <path to dataset> \
--dwt_enable True \
--dwt_weight 0.5 \
--dwt_ll1_weight 1.0 \
--dwt_ll2_weight 0.5 \
--dwt_lh1_weight 0.0 \
--dwt_hl1_weight 0.0 \
--dwt_hh1_weight 0.0 \
--dwt_lh2_weight 0.0 \
--dwt_hl2_weight 0.0 \
--dwt_hh2_weight 0.0DWT Parameters:
--dwt_enable: Enable or disable DWT loss (default: True)--dwt_weight: Global weight for DWT loss (default: 0.5)--dwt_ll1_weight: Weight for Level 1 LL subband (default: 1.0)--dwt_ll2_weight: Weight for Level 2 LL subband (default: 0.5)--dwt_lh1_weight,--dwt_hl1_weight,--dwt_hh1_weight: Weights for Level 1 high-frequency subbands (default: 0.0)--dwt_lh2_weight,--dwt_hl2_weight,--dwt_hh2_weight: Weights for Level 2 high-frequency subbands (default: 0.0)
The default configuration emphasizes low-frequency components (LL1 and LL2) which typically contain the most important structural information. High-frequency subbands can be enabled for enhanced detail preservation.
After training, render the model:
python render.py -m <path to trained model>Compute metrics on rendered images:
python metrics.py -m <path to trained model>To train with a train or test split for evaluation:
python train.py -s <path to dataset> --eval
python render.py -m <path to trained model>
python metrics.py -m <path to trained model>This codebase includes several custom packages and utilities specifically created for DWT-based training.
The main DWT functionality is implemented in utils/loss_utils.py:
-
get_dwt_subbands(x)- Fast GPU-accelerated 2-level Haar wavelet decomposition
- Input: PyTorch tensor of shape
(N, C, H, W) - Returns: Dictionary with 8 subbands:
{"LL1", "LH1", "HL1", "HH1", "LL2", "LH2", "HL2", "HH2"} - Optimized for GPU computation using depthwise convolutions
- No external dependencies beyond PyTorch
-
charbonnier_loss(pred, target, epsilon=1e-3)- Robust loss function for subband comparison
- More stable than L2 loss for high-frequency content
- Includes epsilon parameter for numerical stability
- Formula:
sqrt((pred - target)^2 + epsilon^2)
-
Wavelet Error Field (WEF) utilities
compute_wef_maps(): Compute error maps in wavelet spacemake_heatmap_rgb(): Visualize error maps as RGB heatmapscompute_wef_all_subbands(): Compute errors for all subbandsmake_wef_grid_image(): Create grid visualizations of wavelet errors
The DWT loss is seamlessly integrated into the training loop:
- Automatic Scaling: Running-mean ratio scaling balances DWT loss with base L1 + SSIM loss
- TensorBoard Logging: All DWT subband losses are logged for monitoring
- Efficient Computation: GPU-accelerated wavelet decomposition during training
- Flexible Weighting: Per-subband weights allow fine-grained control
-
test_pytorch_wavelets.py- Validation script for DWT subband computation
- Tests wavelet decomposition correctness
- Validates subband shapes and properties
- Includes fallback implementation testing
-
DWT_Scaling_Test.ipynb- Jupyter notebook for interactive testing
- Test DWT loss scaling on real images
- Visualize wavelet subbands
- Experiment with different weight configurations
Extended dataset readers support:
- Standard COLMAP datasets
- NeRF Synthetic datasets
- Multispectral datasets (with proper channel handling)
The Gaussian model implementation supports:
- Standard 3DGS optimization
- DWT-enhanced loss computation
- Exposure compensation (optional)
- Depth regularization (optional)
The code expects COLMAP datasets in the following structure:
<dataset_path>/
├── images/
│ ├── image1.jpg
│ ├── image2.jpg
│ └── ...
└── sparse/
└── 0/
├── cameras.bin
├── images.bin
└── points3D.bin
Use the provided converter script:
python convert.py -s <location> [--resize]This will:
- Run COLMAP to extract camera poses
- Undistort images
- Optionally resize images (creates 1/2, 1/4, 1/8 resolution versions)
The method has been evaluated on standard 3DGS benchmarks. The DWT loss improves reconstruction quality, particularly for high-frequency details, while maintaining real-time rendering performance.
If you use this code, please cite the original 3D Gaussian Splatting paper and our DWT extension:
@Article{kerbl3Dgaussians,
author = {Kerbl, Bernhard and Kopanas, Georgios and Leimk{\"u}hler, Thomas and Drettakis, George},
title = {3D Gaussian Splatting for Real-Time Radiance Field Rendering},
journal = {ACM Transactions on Graphics},
number = {4},
volume = {42},
month = {July},
year = {2023},
url = {https://repo-sam.inria.fr/fungraph/3d-gaussian-splatting/}
}This project is licensed under the same license as the original 3D Gaussian Splatting codebase. See LICENSE.md for details.
This work extends the original 3D Gaussian Splatting implementation. We thank the original authors for their excellent work and open-source release.







