diff --git a/examples/network_compression/mnist_compression.py b/examples/network_compression/mnist_compression.py index cc984714..551a4f95 100644 --- a/examples/network_compression/mnist_compression.py +++ b/examples/network_compression/mnist_compression.py @@ -287,7 +287,7 @@ def main(): ) if args.compression == "Wavelet": - CustomWavelet = collections.namedtuple( + collections.namedtuple( "Wavelet", ["dec_lo", "dec_hi", "rec_lo", "rec_hi", "name"] ) # init_wavelet = ProductFilter( diff --git a/src/ptwt/_util.py b/src/ptwt/_util.py index 24e1e9b0..124e806a 100644 --- a/src/ptwt/_util.py +++ b/src/ptwt/_util.py @@ -219,7 +219,7 @@ def _get_pad(data_len: int, filt_len: int) -> tuple[int, int]: padr = (2 * filt_len - 3) // 2 padl = (2 * filt_len - 3) // 2 - # pad to even singal length. + # pad to even signal length. padr += data_len % 2 return padl, padr @@ -817,3 +817,51 @@ def _get_padding_n( for i in range(1, n + 1): rv.extend(_get_pad(data.shape[-i], wavelet_length)) return tuple(rv) + + +def fwt_pad_n( + data: torch.Tensor, + wavelet: Union[Wavelet, str], + n: int, + *, + mode: BoundaryMode | None = None, + padding: Optional[tuple[int, ...]] = None, +) -> torch.Tensor: + """Pad data for the n-dimensional FWT. + + This function pads the last n axes. + + Args: + data (torch.Tensor): Input data with N+1 dimensions. + wavelet : A pywt wavelet compatible object or + the name of a pywt wavelet. + Refer to the output from ``pywt.wavelist(kind='discrete')`` + for possible choices. + n : the number of axes to pad + mode: The desired padding mode for extending the signal along the edges. + See :data:`ptwt.constants.BoundaryMode`. + padding : A tuple + with the number of padded values on the respective side of the + last ``n`` axes of `data`. + If None, the padding values are computed based + on the signal shape and the wavelet length. Defaults to None. + + Returns: + The padded output tensor. + + Raises: + ValueError: if the padding is the wrong size + """ + x = len(data.shape) - 2 + if n != x: + raise ValueError + if padding is None: + padding = _get_padding_n(data, wavelet, n) + elif len(padding) != n: + raise ValueError + + match _translate_boundary_strings(mode): + case "symmetric": + return _pad_symmetric(data, _group_for_symmetric(padding)) + case _ as pytorch_mode: + return torch.nn.functional.pad(data, padding, mode=pytorch_mode) diff --git a/src/ptwt/conv_transform.py b/src/ptwt/conv_transform.py index 1ba3cec5..0484dde9 100644 --- a/src/ptwt/conv_transform.py +++ b/src/ptwt/conv_transform.py @@ -6,7 +6,7 @@ from __future__ import annotations -from typing import Optional, Union, cast +from typing import Optional, Union import pywt import torch @@ -15,14 +15,11 @@ _adjust_padding_at_reconstruction, _check_same_device_dtype, _get_filter_tensors, - _get_padding_n, - _group_for_symmetric, - _pad_symmetric, _postprocess_coeffs, _postprocess_tensor, _preprocess_coeffs, _preprocess_tensor, - _translate_boundary_strings, + fwt_pad_n, ) from .constants import BoundaryMode, Wavelet, WaveletCoeff1d @@ -54,15 +51,7 @@ def _fwt_pad( Returns: A PyTorch tensor with the padded input data """ - pytorch_mode = _translate_boundary_strings(mode) - - if padding is None: - padding = cast(tuple[int, int], _get_padding_n(data, wavelet, n=1)) - if pytorch_mode == "symmetric": - data_pad = _pad_symmetric(data, _group_for_symmetric(padding)) - else: - data_pad = torch.nn.functional.pad(data, padding, mode=pytorch_mode) - return data_pad + return fwt_pad_n(data, wavelet, n=1, mode=mode, padding=padding) def wavedec( diff --git a/src/ptwt/conv_transform_2.py b/src/ptwt/conv_transform_2.py index c1c7d270..a767b472 100644 --- a/src/ptwt/conv_transform_2.py +++ b/src/ptwt/conv_transform_2.py @@ -6,7 +6,7 @@ from __future__ import annotations -from typing import Optional, Union, cast +from typing import Optional, Union import pywt import torch @@ -15,15 +15,12 @@ _adjust_padding_at_reconstruction, _check_same_device_dtype, _get_filter_tensors, - _get_padding_n, - _group_for_symmetric, _outer, - _pad_symmetric, _postprocess_coeffs, _postprocess_tensor, _preprocess_coeffs, _preprocess_tensor, - _translate_boundary_strings, + fwt_pad_n, ) from .constants import BoundaryMode, Wavelet, WaveletCoeff2d, WaveletDetailTuple2d @@ -83,15 +80,7 @@ def _fwt_pad2( The padded output tensor. """ - pytorch_mode = _translate_boundary_strings(mode) - - if padding is None: - padding = cast(tuple[int, int, int, int], _get_padding_n(data, wavelet, n=2)) - if pytorch_mode == "symmetric": - data_pad = _pad_symmetric(data, _group_for_symmetric(padding)) - else: - data_pad = torch.nn.functional.pad(data, padding, mode=pytorch_mode) - return data_pad + return fwt_pad_n(data, wavelet, n=2, mode=mode, padding=padding) def wavedec2( diff --git a/src/ptwt/conv_transform_3.py b/src/ptwt/conv_transform_3.py index 97682e62..4808383a 100644 --- a/src/ptwt/conv_transform_3.py +++ b/src/ptwt/conv_transform_3.py @@ -5,7 +5,7 @@ from __future__ import annotations -from typing import Optional, Union, cast +from typing import Optional, Union import pywt import torch @@ -15,15 +15,12 @@ _as_wavelet, _check_same_device_dtype, _get_filter_tensors, - _get_padding_n, - _group_for_symmetric, _outer, - _pad_symmetric, _postprocess_coeffs, _postprocess_tensor, _preprocess_coeffs, _preprocess_tensor, - _translate_boundary_strings, + fwt_pad_n, ) from .constants import BoundaryMode, Wavelet, WaveletCoeffNd, WaveletDetailDict @@ -88,17 +85,7 @@ def _fwt_pad3( Returns: The padded output tensor. """ - pytorch_mode = _translate_boundary_strings(mode) - - if padding is None: - padding = cast( - tuple[int, int, int, int, int, int], _get_padding_n(data, wavelet, n=3) - ) - if pytorch_mode == "symmetric": - data_pad = _pad_symmetric(data, _group_for_symmetric(padding)) - else: - data_pad = torch.nn.functional.pad(data, padding, mode=pytorch_mode) - return data_pad + return fwt_pad_n(data, wavelet, n=3, mode=mode, padding=padding) def wavedec3(