Skip to content
2 changes: 1 addition & 1 deletion examples/network_compression/mnist_compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ def main():
)

if args.compression == "Wavelet":
CustomWavelet = collections.namedtuple(
collections.namedtuple(
"Wavelet", ["dec_lo", "dec_hi", "rec_lo", "rec_hi", "name"]
)
# init_wavelet = ProductFilter(
Expand Down
50 changes: 49 additions & 1 deletion src/ptwt/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def _get_pad(data_len: int, filt_len: int) -> tuple[int, int]:
padr = (2 * filt_len - 3) // 2
padl = (2 * filt_len - 3) // 2

# pad to even singal length.
# pad to even signal length.
padr += data_len % 2

return padl, padr
Expand Down Expand Up @@ -817,3 +817,51 @@ def _get_padding_n(
for i in range(1, n + 1):
rv.extend(_get_pad(data.shape[-i], wavelet_length))
return tuple(rv)


def fwt_pad_n(
data: torch.Tensor,
wavelet: Union[Wavelet, str],
n: int,
*,
mode: BoundaryMode | None = None,
padding: Optional[tuple[int, ...]] = None,
) -> torch.Tensor:
"""Pad data for the n-dimensional FWT.

This function pads the last n axes.

Args:
data (torch.Tensor): Input data with N+1 dimensions.
wavelet : A pywt wavelet compatible object or
the name of a pywt wavelet.
Refer to the output from ``pywt.wavelist(kind='discrete')``
for possible choices.
n : the number of axes to pad
mode: The desired padding mode for extending the signal along the edges.
See :data:`ptwt.constants.BoundaryMode`.
padding : A tuple
with the number of padded values on the respective side of the
last ``n`` axes of `data`.
If None, the padding values are computed based
on the signal shape and the wavelet length. Defaults to None.

Returns:
The padded output tensor.

Raises:
ValueError: if the padding is the wrong size
"""
x = len(data.shape) - 2
if n != x:
raise ValueError
if padding is None:
padding = _get_padding_n(data, wavelet, n)
elif len(padding) != n:
raise ValueError

match _translate_boundary_strings(mode):
case "symmetric":
return _pad_symmetric(data, _group_for_symmetric(padding))
case _ as pytorch_mode:
return torch.nn.functional.pad(data, padding, mode=pytorch_mode)
17 changes: 3 additions & 14 deletions src/ptwt/conv_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from __future__ import annotations

from typing import Optional, Union, cast
from typing import Optional, Union

import pywt
import torch
Expand All @@ -15,14 +15,11 @@
_adjust_padding_at_reconstruction,
_check_same_device_dtype,
_get_filter_tensors,
_get_padding_n,
_group_for_symmetric,
_pad_symmetric,
_postprocess_coeffs,
_postprocess_tensor,
_preprocess_coeffs,
_preprocess_tensor,
_translate_boundary_strings,
fwt_pad_n,
)
from .constants import BoundaryMode, Wavelet, WaveletCoeff1d

Expand Down Expand Up @@ -54,15 +51,7 @@ def _fwt_pad(
Returns:
A PyTorch tensor with the padded input data
"""
pytorch_mode = _translate_boundary_strings(mode)

if padding is None:
padding = cast(tuple[int, int], _get_padding_n(data, wavelet, n=1))
if pytorch_mode == "symmetric":
data_pad = _pad_symmetric(data, _group_for_symmetric(padding))
else:
data_pad = torch.nn.functional.pad(data, padding, mode=pytorch_mode)
return data_pad
return fwt_pad_n(data, wavelet, n=1, mode=mode, padding=padding)


def wavedec(
Expand Down
17 changes: 3 additions & 14 deletions src/ptwt/conv_transform_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from __future__ import annotations

from typing import Optional, Union, cast
from typing import Optional, Union

import pywt
import torch
Expand All @@ -15,15 +15,12 @@
_adjust_padding_at_reconstruction,
_check_same_device_dtype,
_get_filter_tensors,
_get_padding_n,
_group_for_symmetric,
_outer,
_pad_symmetric,
_postprocess_coeffs,
_postprocess_tensor,
_preprocess_coeffs,
_preprocess_tensor,
_translate_boundary_strings,
fwt_pad_n,
)
from .constants import BoundaryMode, Wavelet, WaveletCoeff2d, WaveletDetailTuple2d

Expand Down Expand Up @@ -83,15 +80,7 @@ def _fwt_pad2(
The padded output tensor.

"""
pytorch_mode = _translate_boundary_strings(mode)

if padding is None:
padding = cast(tuple[int, int, int, int], _get_padding_n(data, wavelet, n=2))
if pytorch_mode == "symmetric":
data_pad = _pad_symmetric(data, _group_for_symmetric(padding))
else:
data_pad = torch.nn.functional.pad(data, padding, mode=pytorch_mode)
return data_pad
return fwt_pad_n(data, wavelet, n=2, mode=mode, padding=padding)


def wavedec2(
Expand Down
19 changes: 3 additions & 16 deletions src/ptwt/conv_transform_3.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from __future__ import annotations

from typing import Optional, Union, cast
from typing import Optional, Union

import pywt
import torch
Expand All @@ -15,15 +15,12 @@
_as_wavelet,
_check_same_device_dtype,
_get_filter_tensors,
_get_padding_n,
_group_for_symmetric,
_outer,
_pad_symmetric,
_postprocess_coeffs,
_postprocess_tensor,
_preprocess_coeffs,
_preprocess_tensor,
_translate_boundary_strings,
fwt_pad_n,
)
from .constants import BoundaryMode, Wavelet, WaveletCoeffNd, WaveletDetailDict

Expand Down Expand Up @@ -88,17 +85,7 @@ def _fwt_pad3(
Returns:
The padded output tensor.
"""
pytorch_mode = _translate_boundary_strings(mode)

if padding is None:
padding = cast(
tuple[int, int, int, int, int, int], _get_padding_n(data, wavelet, n=3)
)
if pytorch_mode == "symmetric":
data_pad = _pad_symmetric(data, _group_for_symmetric(padding))
else:
data_pad = torch.nn.functional.pad(data, padding, mode=pytorch_mode)
return data_pad
return fwt_pad_n(data, wavelet, n=3, mode=mode, padding=padding)


def wavedec3(
Expand Down
Loading