From 405f4ba5b6a43754c33b8ce8035fbaf6cf909798 Mon Sep 17 00:00:00 2001 From: Charles Tapley Hoyt Date: Tue, 16 Dec 2025 14:26:44 +0100 Subject: [PATCH 01/12] Generalize FWT padding to N dimensions --- .../network_compression/mnist_compression.py | 4 +- .../network_compression/wavelet_linear.py | 3 +- src/ptwt/_util.py | 86 ++++++++++++++++--- src/ptwt/conv_transform.py | 19 +--- src/ptwt/conv_transform_2.py | 22 +---- src/ptwt/conv_transform_3.py | 21 +---- src/ptwt/matmul_transform.py | 4 +- src/ptwt/matmul_transform_2.py | 4 +- src/ptwt/matmul_transform_3.py | 4 +- tests/test_util.py | 2 +- 10 files changed, 91 insertions(+), 78 deletions(-) diff --git a/examples/network_compression/mnist_compression.py b/examples/network_compression/mnist_compression.py index 5c40a602..551a4f95 100644 --- a/examples/network_compression/mnist_compression.py +++ b/examples/network_compression/mnist_compression.py @@ -20,9 +20,9 @@ import argparse import collections +from pathlib import Path from typing import Literal -from pathlib import Path import matplotlib.pyplot as plt import numpy as np import pystow @@ -287,7 +287,7 @@ def main(): ) if args.compression == "Wavelet": - CustomWavelet = collections.namedtuple( + collections.namedtuple( "Wavelet", ["dec_lo", "dec_hi", "rec_lo", "rec_hi", "name"] ) # init_wavelet = ProductFilter( diff --git a/examples/network_compression/wavelet_linear.py b/examples/network_compression/wavelet_linear.py index f2f3573c..94031185 100644 --- a/examples/network_compression/wavelet_linear.py +++ b/examples/network_compression/wavelet_linear.py @@ -4,8 +4,9 @@ import numpy as np import pywt import torch -from torch.nn.parameter import Parameter import torch.nn +from torch.nn.parameter import Parameter + from ptwt import wavedec, waverec from ptwt.wavelets_learnable import WaveletFilter diff --git a/src/ptwt/_util.py b/src/ptwt/_util.py index f6be783c..20aa9d4e 100644 --- a/src/ptwt/_util.py +++ b/src/ptwt/_util.py @@ -26,8 +26,23 @@ WaveletDetailTuple2d, ) - -def _translate_boundary_strings(pywt_mode: BoundaryMode) -> str: +PyTorchBoundaryMode = Literal["replicate", "constant", "reflect", "circular"] +ExtendedPyTorchBoundaryMode = PyTorchBoundaryMode | Literal["symmetric"] + +translation_dict: dict[BoundaryMode, ExtendedPyTorchBoundaryMode] = { + "constant": "replicate", + "zero": "constant", + "reflect": "reflect", + "periodic": "circular", + # pytorch does not support symmetric mode, + # we have our own implementation. + "symmetric": "symmetric", +} + + +def _translate_boundary_strings( + pywt_mode: BoundaryMode | None, +) -> ExtendedPyTorchBoundaryMode: """Translate pywt mode strings to PyTorch mode strings. We support ``constant``, ``zero``, ``reflect``, @@ -38,15 +53,8 @@ def _translate_boundary_strings(pywt_mode: BoundaryMode) -> str: Raises: ValueError: If the padding mode is not supported. """ - translation_dict = { - "constant": "replicate", - "zero": "constant", - "reflect": "reflect", - "periodic": "circular", - # pytorch does not support symmetric mode, - # we have our own implementation. - "symmetric": "symmetric", - } + if pywt_mode is None: + return translation_dict["reflect"] if pywt_mode in translation_dict: return translation_dict[pywt_mode] else: @@ -207,7 +215,7 @@ def _get_pad(data_len: int, filt_len: int) -> tuple[int, int]: padr = (2 * filt_len - 3) // 2 padl = (2 * filt_len - 3) // 2 - # pad to even singal length. + # pad to even signal length. padr += data_len % 2 return padl, padr @@ -795,3 +803,57 @@ def wrapper(*args: Param.args, **kwargs: Param.kwargs) -> RetType: def _group_for_symmetric(padding: tuple[int, ...]) -> list[tuple[int, int]]: """Repack the padding tuple for symmetric padding.""" return list(reversed(list(grouper(padding, 2)))) # type:ignore[arg-type] + + +def fwt_pad_n( + data: torch.Tensor, + wavelet: Union[Wavelet, str], + n: int, + *, + mode: BoundaryMode | None = None, + padding: Optional[tuple[int, ...]] = None, +) -> torch.Tensor: + """Pad data for the n-dimensional FWT. + + This function pads the last n axes. + + Args: + data (torch.Tensor): Input data with N+1 dimensions. + wavelet : A pywt wavelet compatible object or + the name of a pywt wavelet. + Refer to the output from ``pywt.wavelist(kind='discrete')`` + for possible choices. + n : the number of axes to pad + mode: The desired padding mode for extending the signal along the edges. + See :data:`ptwt.constants.BoundaryMode`. + padding : A tuple + with the number of padded values on the respective side of the + last ``n`` axes of `data`. + If None, the padding values are computed based + on the signal shape and the wavelet length. Defaults to None. + + Returns: + The padded output tensor. + """ + if padding is None: + padding = _unpack_padding(data, wavelet, n) + elif len(padding) != n: + raise ValueError + + # TODO check dimensions of data? + + match _translate_boundary_strings(mode): + case "symmetric": + return _pad_symmetric(data, _group_for_symmetric(padding)) + case _ as pytorch_mode: + return torch.nn.functional.pad(data, padding, mode=pytorch_mode) + + +def _unpack_padding(data, wavelet, d: int) -> tuple[int, ...]: + wl = _get_len(wavelet) + rv = [] + for i in range(1, d + 1): + a, b = _get_pad(data.shape[-i], wl) + rv.append(b) + rv.append(a) + return tuple(rv) diff --git a/src/ptwt/conv_transform.py b/src/ptwt/conv_transform.py index 39002a11..36dd98d5 100644 --- a/src/ptwt/conv_transform.py +++ b/src/ptwt/conv_transform.py @@ -15,15 +15,11 @@ _adjust_padding_at_reconstruction, _check_same_device_dtype, _get_filter_tensors, - _get_len, - _get_pad, - _pad_symmetric, _postprocess_coeffs, _postprocess_tensor, _preprocess_coeffs, _preprocess_tensor, - _translate_boundary_strings, - _group_for_symmetric, + fwt_pad_n, ) from .constants import BoundaryMode, Wavelet, WaveletCoeff1d @@ -55,18 +51,7 @@ def _fwt_pad( Returns: A PyTorch tensor with the padded input data """ - # convert pywt to pytorch convention. - if mode is None: - mode = "reflect" - pytorch_mode = _translate_boundary_strings(mode) - - if padding is None: - padding = _get_pad(data.shape[-1], _get_len(wavelet)) - if pytorch_mode == "symmetric": - data_pad = _pad_symmetric(data, _group_for_symmetric(padding)) - else: - data_pad = torch.nn.functional.pad(data, padding, mode=pytorch_mode) - return data_pad + return fwt_pad_n(data, wavelet, n=2, mode=mode, padding=padding) def wavedec( diff --git a/src/ptwt/conv_transform_2.py b/src/ptwt/conv_transform_2.py index f0e99472..a767b472 100644 --- a/src/ptwt/conv_transform_2.py +++ b/src/ptwt/conv_transform_2.py @@ -15,16 +15,12 @@ _adjust_padding_at_reconstruction, _check_same_device_dtype, _get_filter_tensors, - _get_len, - _get_pad, _outer, - _pad_symmetric, _postprocess_coeffs, _postprocess_tensor, _preprocess_coeffs, _preprocess_tensor, - _translate_boundary_strings, - _group_for_symmetric, + fwt_pad_n, ) from .constants import BoundaryMode, Wavelet, WaveletCoeff2d, WaveletDetailTuple2d @@ -84,21 +80,7 @@ def _fwt_pad2( The padded output tensor. """ - if mode is None: - mode = "reflect" - pytorch_mode = _translate_boundary_strings(mode) - - if padding is None: - _len_wavelet = _get_len(wavelet) - padding = ( - *_get_pad(data.shape[-1], _len_wavelet), - *_get_pad(data.shape[-2], _len_wavelet), - ) - if pytorch_mode == "symmetric": - data_pad = _pad_symmetric(data, _group_for_symmetric(padding)) - else: - data_pad = torch.nn.functional.pad(data, padding, mode=pytorch_mode) - return data_pad + return fwt_pad_n(data, wavelet, n=2, mode=mode, padding=padding) def wavedec2( diff --git a/src/ptwt/conv_transform_3.py b/src/ptwt/conv_transform_3.py index b73a4332..4808383a 100644 --- a/src/ptwt/conv_transform_3.py +++ b/src/ptwt/conv_transform_3.py @@ -15,16 +15,12 @@ _as_wavelet, _check_same_device_dtype, _get_filter_tensors, - _get_len, - _get_pad, _outer, - _pad_symmetric, _postprocess_coeffs, _postprocess_tensor, _preprocess_coeffs, _preprocess_tensor, - _translate_boundary_strings, - _group_for_symmetric, + fwt_pad_n, ) from .constants import BoundaryMode, Wavelet, WaveletCoeffNd, WaveletDetailDict @@ -89,20 +85,7 @@ def _fwt_pad3( Returns: The padded output tensor. """ - pytorch_mode = _translate_boundary_strings(mode) - - if padding is None: - _len_wavelet = _get_len(wavelet) - padding = ( - *_get_pad(data.shape[-1], _len_wavelet), - *_get_pad(data.shape[-2], _len_wavelet), - *_get_pad(data.shape[-3], _len_wavelet), - ) - if pytorch_mode == "symmetric": - data_pad = _pad_symmetric(data, _group_for_symmetric(padding)) - else: - data_pad = torch.nn.functional.pad(data, padding, mode=pytorch_mode) - return data_pad + return fwt_pad_n(data, wavelet, n=3, mode=mode, padding=padding) def wavedec3( diff --git a/src/ptwt/matmul_transform.py b/src/ptwt/matmul_transform.py index 1087c59c..fe17a990 100644 --- a/src/ptwt/matmul_transform.py +++ b/src/ptwt/matmul_transform.py @@ -329,7 +329,7 @@ def _construct_analysis_matrices( f"level {curr_level}, the current signal length {curr_length} is " f"smaller than the filter length {filt_len}. Therefore, the " "transformation is only computed up to the decomposition level " - f"{curr_level-1}.\n" + f"{curr_level - 1}.\n" ) break @@ -626,7 +626,7 @@ def _construct_synthesis_matrices( f"level {curr_level}, the current signal length {curr_length} is " f"smaller than the filter length {filt_len}. Therefore, the " "transformation is only computed up to the decomposition level " - f"{curr_level-1}.\n" + f"{curr_level - 1}.\n" ) break diff --git a/src/ptwt/matmul_transform_2.py b/src/ptwt/matmul_transform_2.py index 6ca64a9a..3fcc4d0c 100644 --- a/src/ptwt/matmul_transform_2.py +++ b/src/ptwt/matmul_transform_2.py @@ -408,7 +408,7 @@ def _construct_analysis_matrices( f". At level {curr_level}, at least one of the current signal " f"height and width ({current_height}, {current_width}) is smaller " f"then the filter length {filt_len}. Therefore, the transformation " - f"is only computed up to the decomposition level {curr_level-1}.\n" + f"is only computed up to the decomposition level {curr_level - 1}.\n" ) break # the conv matrices require even length inputs. @@ -704,7 +704,7 @@ def _construct_synthesis_matrices( f". At level {curr_level}, at least one of the current signal " f"height and width ({current_height}, {current_width}) is smaller " f"then the filter length {filt_len}. Therefore, the transformation " - f"is only computed up to the decomposition level {curr_level-1}.\n" + f"is only computed up to the decomposition level {curr_level - 1}.\n" ) break current_height, current_width, pad_tuple = _matrix_pad_2( diff --git a/src/ptwt/matmul_transform_3.py b/src/ptwt/matmul_transform_3.py index 376ed4c2..17a5729f 100644 --- a/src/ptwt/matmul_transform_3.py +++ b/src/ptwt/matmul_transform_3.py @@ -162,7 +162,7 @@ def _construct_analysis_matrices( f"depth, height, and width ({current_depth}, {current_height}," f"{current_width}) is smaller " f"then the filter length {filt_len}. Therefore, the transformation " - f"is only computed up to the decomposition level {curr_level-1}.\n" + f"is only computed up to the decomposition level {curr_level - 1}.\n" ) break # the conv matrices require even length inputs. @@ -369,7 +369,7 @@ def _construct_synthesis_matrices( f" depth, height and width ({current_depth}, {current_height}, " f"{current_width}) is smaller than the filter length {filt_len}." f" Therefore, the transformation " - f"is only computed up to the decomposition level {curr_level-1}.\n" + f"is only computed up to the decomposition level {curr_level - 1}.\n" ) break # the conv matrices require even length inputs. diff --git a/tests/test_util.py b/tests/test_util.py index f6f2c0f9..bf7b6302 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -8,9 +8,9 @@ from ptwt._util import ( _as_wavelet, _fold_axes, + _group_for_symmetric, _pad_symmetric, _pad_symmetric_1d, - _group_for_symmetric, _unfold_axes, ) From cb6f1641622e247a364afc10c6756f249bf1553f Mon Sep 17 00:00:00 2001 From: Charles Tapley Hoyt Date: Tue, 16 Dec 2025 14:46:00 +0100 Subject: [PATCH 02/12] Fix and test symmetric padding repack function --- src/ptwt/_util.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/src/ptwt/_util.py b/src/ptwt/_util.py index 20aa9d4e..c5600827 100644 --- a/src/ptwt/_util.py +++ b/src/ptwt/_util.py @@ -836,7 +836,7 @@ def fwt_pad_n( The padded output tensor. """ if padding is None: - padding = _unpack_padding(data, wavelet, n) + padding = _get_padding_n(data, wavelet, n) elif len(padding) != n: raise ValueError @@ -849,11 +849,16 @@ def fwt_pad_n( return torch.nn.functional.pad(data, padding, mode=pytorch_mode) -def _unpack_padding(data, wavelet, d: int) -> tuple[int, ...]: - wl = _get_len(wavelet) +def _repack_symmetric(padding: tuple[int, ...]) -> list[tuple[int, int]]: + """Repack the padding tuple for symmetric padding.""" + return list(reversed(list(grouper(padding, 2)))) + + +def _get_padding_n(data, wavelet, n: int) -> tuple[int, ...]: + wavelet_length = _get_len(wavelet) rv = [] - for i in range(1, d + 1): - a, b = _get_pad(data.shape[-i], wl) - rv.append(b) - rv.append(a) + for i in range(1, n + 1): + right, left = _get_pad(data.shape[-i], wavelet_length) + rv.append(left) + rv.append(right) return tuple(rv) From 1f661ed099a30a68302b65654d1cd91cddabb90a Mon Sep 17 00:00:00 2001 From: Charles Tapley Hoyt Date: Tue, 16 Dec 2025 14:53:27 +0100 Subject: [PATCH 03/12] Update --- src/ptwt/_util.py | 3 +++ tests/test_util.py | 13 +++++++++++++ 2 files changed, 16 insertions(+) diff --git a/src/ptwt/_util.py b/src/ptwt/_util.py index c5600827..c1f2a359 100644 --- a/src/ptwt/_util.py +++ b/src/ptwt/_util.py @@ -834,6 +834,9 @@ def fwt_pad_n( Returns: The padded output tensor. + + Raises: + ValueError: if the padding is the wrong size """ if padding is None: padding = _get_padding_n(data, wavelet, n) diff --git a/tests/test_util.py b/tests/test_util.py index bf7b6302..193ebce9 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -12,6 +12,9 @@ _pad_symmetric, _pad_symmetric_1d, _unfold_axes, + _get_pad, + _get_len, + _get_padding_n, ) @@ -97,3 +100,13 @@ def test_repack_symmetric() -> None: (pad_top, pad_bottom), (pad_left, pad_right), ] + + +def test_get_padding_n() -> None: + """Ensure padding works as expected.""" + wavelet = pywt.Wavelet("sym4") + data = torch.randn(3, 3, 3) + padb, padt = _get_pad(data.shape[-2], _get_len(wavelet)) + padr, padl = _get_pad(data.shape[-1], _get_len(wavelet)) + padding = _get_padding_n(data, wavelet, 2) + assert padding == (padl, padr, padt, padb) From 8c713dc85ea49dc2cbce5e2414f28352da80c238 Mon Sep 17 00:00:00 2001 From: Charles Tapley Hoyt Date: Tue, 16 Dec 2025 15:56:45 +0100 Subject: [PATCH 04/12] Sort --- tests/test_util.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_util.py b/tests/test_util.py index 193ebce9..78415214 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -8,13 +8,13 @@ from ptwt._util import ( _as_wavelet, _fold_axes, + _get_len, + _get_pad, + _get_padding_n, _group_for_symmetric, _pad_symmetric, _pad_symmetric_1d, _unfold_axes, - _get_pad, - _get_len, - _get_padding_n, ) From 6cf670cc6df8aeaed679a700a418c438250f20b3 Mon Sep 17 00:00:00 2001 From: Charles Tapley Hoyt Date: Tue, 16 Dec 2025 15:57:24 +0100 Subject: [PATCH 05/12] Update _util.py --- src/ptwt/_util.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/ptwt/_util.py b/src/ptwt/_util.py index c1f2a359..c0ed7f30 100644 --- a/src/ptwt/_util.py +++ b/src/ptwt/_util.py @@ -861,7 +861,5 @@ def _get_padding_n(data, wavelet, n: int) -> tuple[int, ...]: wavelet_length = _get_len(wavelet) rv = [] for i in range(1, n + 1): - right, left = _get_pad(data.shape[-i], wavelet_length) - rv.append(left) - rv.append(right) + rv.extend(_get_pad(data.shape[-i], wavelet_length)) return tuple(rv) From e706b5aa7af6871cf88bb32fc24160ffbb7421c8 Mon Sep 17 00:00:00 2001 From: Charles Tapley Hoyt Date: Tue, 16 Dec 2025 15:58:18 +0100 Subject: [PATCH 06/12] Update test_util.py --- tests/test_util.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_util.py b/tests/test_util.py index 78415214..c0e70c72 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -106,7 +106,7 @@ def test_get_padding_n() -> None: """Ensure padding works as expected.""" wavelet = pywt.Wavelet("sym4") data = torch.randn(3, 3, 3) - padb, padt = _get_pad(data.shape[-2], _get_len(wavelet)) - padr, padl = _get_pad(data.shape[-1], _get_len(wavelet)) + padt, padb = _get_pad(data.shape[-2], _get_len(wavelet)) + padl, padr = _get_pad(data.shape[-1], _get_len(wavelet)) padding = _get_padding_n(data, wavelet, 2) assert padding == (padl, padr, padt, padb) From 2408f0addfe2b600a1a2b9059cd3fc9e9a4679ba Mon Sep 17 00:00:00 2001 From: Charles Tapley Hoyt Date: Tue, 16 Dec 2025 16:10:56 +0100 Subject: [PATCH 07/12] Update _util.py --- src/ptwt/_util.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/ptwt/_util.py b/src/ptwt/_util.py index c0ed7f30..4f133894 100644 --- a/src/ptwt/_util.py +++ b/src/ptwt/_util.py @@ -857,9 +857,11 @@ def _repack_symmetric(padding: tuple[int, ...]) -> list[tuple[int, int]]: return list(reversed(list(grouper(padding, 2)))) -def _get_padding_n(data, wavelet, n: int) -> tuple[int, ...]: +def _get_padding_n( + data: torch.Tensor, wavelet: Union[Wavelet, str], n: int +) -> tuple[int, ...]: wavelet_length = _get_len(wavelet) - rv = [] + rv: list[int] = [] for i in range(1, n + 1): rv.extend(_get_pad(data.shape[-i], wavelet_length)) return tuple(rv) From ed361c8df6e54f3a4963b8b1993d388fd640786a Mon Sep 17 00:00:00 2001 From: Charles Tapley Hoyt Date: Tue, 16 Dec 2025 23:54:15 +0100 Subject: [PATCH 08/12] Remove unused --- src/ptwt/conv_transform.py | 6 +----- src/ptwt/conv_transform_2.py | 2 +- src/ptwt/conv_transform_3.py | 5 +---- 3 files changed, 3 insertions(+), 10 deletions(-) diff --git a/src/ptwt/conv_transform.py b/src/ptwt/conv_transform.py index a7dba859..36dd98d5 100644 --- a/src/ptwt/conv_transform.py +++ b/src/ptwt/conv_transform.py @@ -6,7 +6,7 @@ from __future__ import annotations -from typing import Optional, Union, cast +from typing import Optional, Union import pywt import torch @@ -15,14 +15,10 @@ _adjust_padding_at_reconstruction, _check_same_device_dtype, _get_filter_tensors, - _get_padding_n, - _group_for_symmetric, - _pad_symmetric, _postprocess_coeffs, _postprocess_tensor, _preprocess_coeffs, _preprocess_tensor, - _translate_boundary_strings, fwt_pad_n, ) from .constants import BoundaryMode, Wavelet, WaveletCoeff1d diff --git a/src/ptwt/conv_transform_2.py b/src/ptwt/conv_transform_2.py index 28285b00..a767b472 100644 --- a/src/ptwt/conv_transform_2.py +++ b/src/ptwt/conv_transform_2.py @@ -6,7 +6,7 @@ from __future__ import annotations -from typing import Optional, Union, cast +from typing import Optional, Union import pywt import torch diff --git a/src/ptwt/conv_transform_3.py b/src/ptwt/conv_transform_3.py index be17aa68..4808383a 100644 --- a/src/ptwt/conv_transform_3.py +++ b/src/ptwt/conv_transform_3.py @@ -5,7 +5,7 @@ from __future__ import annotations -from typing import Optional, Union, cast +from typing import Optional, Union import pywt import torch @@ -15,14 +15,11 @@ _as_wavelet, _check_same_device_dtype, _get_filter_tensors, - _get_padding_n, - _group_for_symmetric, _outer, _postprocess_coeffs, _postprocess_tensor, _preprocess_coeffs, _preprocess_tensor, - _translate_boundary_strings, fwt_pad_n, ) from .constants import BoundaryMode, Wavelet, WaveletCoeffNd, WaveletDetailDict From d4d57b03d049577cf59e5d1bded5ffd141aa400e Mon Sep 17 00:00:00 2001 From: Charles Tapley Hoyt Date: Tue, 16 Dec 2025 23:56:10 +0100 Subject: [PATCH 09/12] Lint --- src/ptwt/matmul_transform_2.py | 4 ++-- src/ptwt/matmul_transform_3.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/ptwt/matmul_transform_2.py b/src/ptwt/matmul_transform_2.py index 3fcc4d0c..6ca64a9a 100644 --- a/src/ptwt/matmul_transform_2.py +++ b/src/ptwt/matmul_transform_2.py @@ -408,7 +408,7 @@ def _construct_analysis_matrices( f". At level {curr_level}, at least one of the current signal " f"height and width ({current_height}, {current_width}) is smaller " f"then the filter length {filt_len}. Therefore, the transformation " - f"is only computed up to the decomposition level {curr_level - 1}.\n" + f"is only computed up to the decomposition level {curr_level-1}.\n" ) break # the conv matrices require even length inputs. @@ -704,7 +704,7 @@ def _construct_synthesis_matrices( f". At level {curr_level}, at least one of the current signal " f"height and width ({current_height}, {current_width}) is smaller " f"then the filter length {filt_len}. Therefore, the transformation " - f"is only computed up to the decomposition level {curr_level - 1}.\n" + f"is only computed up to the decomposition level {curr_level-1}.\n" ) break current_height, current_width, pad_tuple = _matrix_pad_2( diff --git a/src/ptwt/matmul_transform_3.py b/src/ptwt/matmul_transform_3.py index 17a5729f..376ed4c2 100644 --- a/src/ptwt/matmul_transform_3.py +++ b/src/ptwt/matmul_transform_3.py @@ -162,7 +162,7 @@ def _construct_analysis_matrices( f"depth, height, and width ({current_depth}, {current_height}," f"{current_width}) is smaller " f"then the filter length {filt_len}. Therefore, the transformation " - f"is only computed up to the decomposition level {curr_level - 1}.\n" + f"is only computed up to the decomposition level {curr_level-1}.\n" ) break # the conv matrices require even length inputs. @@ -369,7 +369,7 @@ def _construct_synthesis_matrices( f" depth, height and width ({current_depth}, {current_height}, " f"{current_width}) is smaller than the filter length {filt_len}." f" Therefore, the transformation " - f"is only computed up to the decomposition level {curr_level - 1}.\n" + f"is only computed up to the decomposition level {curr_level-1}.\n" ) break # the conv matrices require even length inputs. From f2465ef9777fdb5c0c417ae1697589fd790a8158 Mon Sep 17 00:00:00 2001 From: Charles Tapley Hoyt Date: Wed, 17 Dec 2025 00:10:49 +0100 Subject: [PATCH 10/12] Update conv_transform.py --- src/ptwt/conv_transform.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ptwt/conv_transform.py b/src/ptwt/conv_transform.py index 36dd98d5..0484dde9 100644 --- a/src/ptwt/conv_transform.py +++ b/src/ptwt/conv_transform.py @@ -51,7 +51,7 @@ def _fwt_pad( Returns: A PyTorch tensor with the padded input data """ - return fwt_pad_n(data, wavelet, n=2, mode=mode, padding=padding) + return fwt_pad_n(data, wavelet, n=1, mode=mode, padding=padding) def wavedec( From 882e6c04fce65f0dcff647e0febe91d85465e0c4 Mon Sep 17 00:00:00 2001 From: Charles Tapley Hoyt Date: Wed, 17 Dec 2025 00:11:43 +0100 Subject: [PATCH 11/12] Cleanup --- src/ptwt/_util.py | 6 ++++-- src/ptwt/matmul_transform.py | 4 ++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/ptwt/_util.py b/src/ptwt/_util.py index 55aed7d3..23faf40c 100644 --- a/src/ptwt/_util.py +++ b/src/ptwt/_util.py @@ -852,13 +852,15 @@ def fwt_pad_n( Raises: ValueError: if the padding is the wrong size """ + x = len(data.shape) - 2 + if n != x: + raise ValueError(f'{n=} but shoulda been shape={x}') + if padding is None: padding = _get_padding_n(data, wavelet, n) elif len(padding) != n: raise ValueError - # TODO check dimensions of data? - match _translate_boundary_strings(mode): case "symmetric": return _pad_symmetric(data, _group_for_symmetric(padding)) diff --git a/src/ptwt/matmul_transform.py b/src/ptwt/matmul_transform.py index fe17a990..1087c59c 100644 --- a/src/ptwt/matmul_transform.py +++ b/src/ptwt/matmul_transform.py @@ -329,7 +329,7 @@ def _construct_analysis_matrices( f"level {curr_level}, the current signal length {curr_length} is " f"smaller than the filter length {filt_len}. Therefore, the " "transformation is only computed up to the decomposition level " - f"{curr_level - 1}.\n" + f"{curr_level-1}.\n" ) break @@ -626,7 +626,7 @@ def _construct_synthesis_matrices( f"level {curr_level}, the current signal length {curr_length} is " f"smaller than the filter length {filt_len}. Therefore, the " "transformation is only computed up to the decomposition level " - f"{curr_level - 1}.\n" + f"{curr_level-1}.\n" ) break From 1599011bfc7c644e91d5d6d84b5448094daf9dd4 Mon Sep 17 00:00:00 2001 From: Charles Tapley Hoyt Date: Wed, 17 Dec 2025 00:14:47 +0100 Subject: [PATCH 12/12] Update _util.py --- src/ptwt/_util.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/ptwt/_util.py b/src/ptwt/_util.py index 23faf40c..124e806a 100644 --- a/src/ptwt/_util.py +++ b/src/ptwt/_util.py @@ -854,8 +854,7 @@ def fwt_pad_n( """ x = len(data.shape) - 2 if n != x: - raise ValueError(f'{n=} but shoulda been shape={x}') - + raise ValueError if padding is None: padding = _get_padding_n(data, wavelet, n) elif len(padding) != n: