diff --git a/pynumdiff/__init__.py b/pynumdiff/__init__.py index 7a3c465..2e33073 100644 --- a/pynumdiff/__init__.py +++ b/pynumdiff/__init__.py @@ -15,6 +15,6 @@ from .finite_difference import finitediff, first_order, second_order, fourth_order from .smooth_finite_difference import kerneldiff, meandiff, mediandiff, gaussiandiff, friedrichsdiff, butterdiff from .polynomial_fit import splinediff, polydiff, savgoldiff -from .basis_fit import spectraldiff, rbfdiff +from .basis_fit import spectraldiff, rbfdiff, waveletdiff from .total_variation_regularization import iterative_velocity from .kalman_smooth import kalman_filter, rts_smooth, rtsdiff, constant_velocity, constant_acceleration, constant_jerk diff --git a/pynumdiff/basis_fit.py b/pynumdiff/basis_fit.py index 4e0455d..5b46230 100644 --- a/pynumdiff/basis_fit.py +++ b/pynumdiff/basis_fit.py @@ -2,6 +2,7 @@ from warnings import warn import numpy as np from scipy import sparse +import pywt from pynumdiff.utils import utility @@ -133,3 +134,133 @@ def rbfdiff(x, dt_or_t, sigma=1, lmbd=0.01, axis=0): dxdt_hat_flattened = drbfdt @ alpha return np.moveaxis(x_hat_flattened.reshape(plump), 0, axis), np.moveaxis(dxdt_hat_flattened.reshape(plump), 0, axis) + + +def waveletdiff(x, dt, wavelet='db4', level=None, threshold=1.0, axis=0, mode='periodization'): + """Smooth and differentiate noisy data via discrete wavelet denoising. + + Decomposes x into wavelet detail and approximation coefficients, soft-thresholds + the detail coefficients to remove noise using the Donoho & Johnstone (1994) + universal threshold estimator, reconstructs a smoothed signal, then + differentiates analytically by applying derivative reconstruction filters to + the denoised wavelet coefficients. + + Because the DWT requires uniform spacing, this method only accepts a scalar + time step dt (not a vector of sample times). For non-uniformly sampled data, + use :func:`rbfdiff` or :func:`splinediff` instead. + + :param np.array x: data to differentiate. May be multidimensional; see :code:`axis`. + :param float dt: uniform time step between samples. + :param str wavelet: PyWavelets wavelet name, e.g. 'db4', 'sym4', 'coif2'. + 'db4' is a solid general-purpose default. Biorthogonal wavelets such as + 'bior2.2' or 'bior4.4' are symmetric and designed for smooth reconstruction + but may need a lower threshold value. + :param int level: decomposition depth. None (default) resolves to + min(pywt.dwt_max_level(N, wavelet), 5) to avoid over-decomposing short + signals. Increase for heavily oversampled data. + :param float threshold: soft-thresholding scale factor in [0, inf). + Multiplies the universal threshold sigma * sqrt(2 * log(N)). + threshold=1.0 is the classical Donoho & Johnstone universal threshold + and is the recommended starting point. Values < 1.0 give less smoothing; + values > 1.0 give more aggressive smoothing. This parameter maps onto + tvgamma in the pynumdiff.optimize framework. + :param int axis: axis along which to differentiate (default 0). + :param str mode: PyWavelets signal extension mode passed to wavedec/waverec. + 'periodization' (default) keeps coefficient arrays exactly length N and + is the most numerically stable choice for differentiation. 'reflect' is + a good alternative for clearly non-periodic signals. + See pywt.Modes.modes for all options. + :return: - **x_hat** (np.array) -- estimated (smoothed) x + - **dxdt_hat** (np.array) -- estimated derivative of x + """ + if not np.isscalar(dt): + raise ValueError( + "`dt` must be a scalar. The DWT requires uniformly sampled data. " + "For variable step sizes, use rbfdiff or splinediff instead." + ) + + N = x.shape[axis] + + # Bring axis of differentiation to front so each column of x_flat is one + # signal to differentiate. moveaxis returns a view with updated strides, + # so ascontiguousarray ensures the subsequent reshape is zero-copy. + x_work = np.ascontiguousarray(np.moveaxis(x, axis, 0)) + shape = x_work.shape + x_flat = x_work.reshape(N, -1) # (N, M) + M = x_flat.shape[1] + + # Conservative level cap: pywt's default uses the maximum possible level, + # which can over-decompose short signals and wash out meaningful detail. + # Capping at 5 keeps at least 2^5 = 32 samples in the coarsest subband, + # which is enough to represent a smooth approximation without artefacts. + if level is None: + max_level = pywt.dwt_max_level(N, wavelet) + level = min(max_level, 5) + + # Decompose all columns; probe column 0 first to learn coefficient lengths + # and pre-allocate, reusing that result so we only pay N+1 wavedec calls. + _probe = pywt.wavedec(x_flat[:, 0], wavelet, level=level, mode=mode) + coeff_lengths = [len(c) for c in _probe] + n_levels = len(_probe) + + coeffs_all = [ + np.empty((coeff_lengths[i], M), dtype=x_flat.dtype) + for i in range(n_levels) + ] + for i, c in enumerate(_probe): + coeffs_all[i][:, 0] = c + + for col in range(1, M): + for i, c in enumerate( + pywt.wavedec(x_flat[:, col], wavelet, level=level, mode=mode) + ): + coeffs_all[i][:, col] = c + + # Vectorised noise estimation and soft-thresholding over all columns at once. + # + # Soft-thresholding achieves smoothing by shrinking wavelet detail + # coefficients toward zero: coefficients whose magnitude is below the + # threshold (mostly noise) are zeroed out, while large coefficients (true + # signal features) are kept but reduced by the threshold amount. Only detail + # levels (indices 1..n_levels-1) are thresholded; the coarse approximation + # coefficients (index 0) are left untouched. + # + # sigma: robust noise-level estimate via the median absolute deviation of + # the finest detail level. Dividing by 0.6745 converts MAD to an + # estimate of the Gaussian standard deviation. + # thresh: per-column Donoho & Johnstone (1994) universal threshold, + # sigma * sqrt(2 * log(N)), scaled by the user-supplied `threshold`. + sigma = np.median(np.abs(coeffs_all[-1]), axis=0) / 0.6745 + np.maximum(sigma, 1e-10, out=sigma) # floor avoids zero threshold on clean signals + + thresh = threshold * sigma * np.sqrt(2 * np.log(N)) # shape (M,) + + coeffs_denoised = [coeffs_all[0]] + [ + pywt.threshold(c, thresh[np.newaxis, :], mode='soft') + for c in coeffs_all[1:] + ] + + # Reconstruct x_hat and differentiate column by column. + # pywt.waverec is 1-D only, so the column loop is unavoidable here; + # the vectorised operations above have already moved all Python-level + # arithmetic outside this loop. + # + # After wavelet denoising we have a smooth, noise-free signal. np.gradient + # applies a second-order central finite difference to that clean signal, + # which gives an accurate derivative. This is appropriate here because the + # heavy lifting (noise removal) has already been done by the wavelet + # thresholding step; np.gradient on a smooth signal converges at O(dt^2). + x_hat_flat = np.empty_like(x_flat) + dxdt_hat_flat = np.empty_like(x_flat) + + for col in range(M): + col_coeffs = [coeffs_denoised[i][:, col] for i in range(n_levels)] + x_hat_col = pywt.waverec(col_coeffs, wavelet, mode=mode)[:N] + x_hat_flat[:, col] = x_hat_col + dxdt_hat_flat[:, col] = np.gradient(x_hat_col, dt) + + # Restore original shape and axis order. + x_hat = np.moveaxis(x_hat_flat.reshape(shape), 0, axis) + dxdt_hat = np.moveaxis(dxdt_hat_flat.reshape(shape), 0, axis) + + return x_hat, dxdt_hat diff --git a/pynumdiff/tests/test_diff_methods.py b/pynumdiff/tests/test_diff_methods.py index 1b762aa..ebe9c69 100644 --- a/pynumdiff/tests/test_diff_methods.py +++ b/pynumdiff/tests/test_diff_methods.py @@ -5,7 +5,7 @@ from ..smooth_finite_difference import kerneldiff, mediandiff, meandiff, gaussiandiff, friedrichsdiff, butterdiff from ..finite_difference import finitediff, first_order, second_order, fourth_order from ..polynomial_fit import polydiff, savgoldiff, splinediff -from ..basis_fit import spectraldiff, rbfdiff +from ..basis_fit import spectraldiff, rbfdiff, waveletdiff from ..total_variation_regularization import velocity, acceleration, jerk, iterative_velocity, smooth_acceleration, tvrdiff from ..kalman_smooth import rtsdiff, constant_velocity, constant_acceleration, constant_jerk, robustdiff from ..linear_model import lineardiff @@ -51,6 +51,7 @@ def polydiff_irreg_step(*args, **kwargs): return polydiff(*args, **kwargs) (spline_irreg_step, {'degree':5, 's':2}), (spectraldiff, {'high_freq_cutoff':0.2}), (spectraldiff, [0.2]), (rbfdiff, {'sigma':0.5, 'lmbd':0.001}), + (waveletdiff, {'wavelet':'db4', 'threshold':1.0}), (constant_velocity, {'r':1e-2, 'q':1e3}), (constant_velocity, [1e-2, 1e3]), (constant_acceleration, {'r':1e-3, 'q':1e4}), (constant_acceleration, [1e-3, 1e4]), (constant_jerk, {'r':1e-4, 'q':1e5}), (constant_jerk, [1e-4, 1e5]), @@ -173,6 +174,12 @@ def polydiff_irreg_step(*args, **kwargs): return polydiff(*args, **kwargs) [(-2, -2), (0, 0), (0, -1), (0, 0)], [(0, 0), (2, 2), (0, 0), (2, 2)], [(1, 1), (3, 3), (1, 1), (3, 3)]], + waveletdiff: [[(-14, -15), (-14, -14), (-1, -1), (0, 0)], + [(-9, -9), (-8, -8), (0, 0), (1, 1)], + [(-9, -9), (0, 0), (0, 0), (1, 1)], + [(-1, -1), (0, 0), (0, 0), (1, 1)], + [(1, 0), (2, 2), (1, 1), (2, 2)], + [(0, 0), (3, 3), (1, 0), (3, 3)]], velocity: [[(-25, -25), (-18, -19), (0, -1), (1, 0)], [(-12, -12), (-11, -12), (-1, -1), (-1, -2)], [(0, -1), (1, 0), (0, -1), (1, 0)], @@ -327,6 +334,7 @@ def test_diff_method(diff_method_and_params, test_func_and_deriv, request): # re (finitediff, {}), (polydiff, {'degree': 2, 'window_size': 5}), (savgoldiff, {'degree': 3, 'window_size': 11, 'smoothing_win': 3}), + (waveletdiff, {'wavelet': 'db4', 'threshold': 1.0}), (rtsdiff, {'order':2, 'log_qr_ratio':7, 'forwardbackward':True}), (spectraldiff, {'high_freq_cutoff': 0.25, 'pad_to_zero_dxdt': False}), (rbfdiff, {'sigma': 0.5, 'lmbd': 1e-6}), @@ -343,6 +351,7 @@ def test_diff_method(diff_method_and_params, test_func_and_deriv, request): # re kerneldiff: [(2, 1), (3, 2)], butterdiff: [(0, -1), (1, -1)], finitediff: [(0, -1), (1, -1)], + waveletdiff: [(1, 0), (2, 1)], polydiff: [(1, -1), (1, 0)], savgoldiff: [(0, -1), (1, 1)], rtsdiff: [(1, -1), (1, 0)], diff --git a/pyproject.toml b/pyproject.toml index a0d7d55..736c508 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,8 @@ classifiers = [ dependencies = [ "numpy", "scipy", - "matplotlib" + "matplotlib", + "pywavelets" ] [project.urls]