Skip to content

Commit f574f2f

Browse files
committed
made all changes requested
1 parent 8431f6d commit f574f2f

2 files changed

Lines changed: 71 additions & 59 deletions

File tree

pynumdiff/basis_fit.py

Lines changed: 70 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -136,18 +136,21 @@ def rbfdiff(x, dt_or_t, sigma=1, lmbd=0.01, axis=0):
136136
return np.moveaxis(x_hat_flattened.reshape(plump), 0, axis), np.moveaxis(dxdt_hat_flattened.reshape(plump), 0, axis)
137137

138138

139-
def waveletdiff(x, dt_or_t, wavelet='db4', level=None, threshold=1.0, axis=0, mode='periodization'):
139+
def waveletdiff(x, dt, wavelet='db4', level=None, threshold=1.0, axis=0, mode='periodization'):
140140
"""Smooth and differentiate noisy data via discrete wavelet denoising.
141141
142142
Decomposes x into wavelet detail and approximation coefficients, soft-thresholds
143143
the detail coefficients to remove noise using the Donoho & Johnstone (1994)
144144
universal threshold estimator, reconstructs a smoothed signal, then
145-
differentiates with finite differences via np.gradient.
145+
differentiates analytically by applying derivative reconstruction filters to
146+
the denoised wavelet coefficients.
146147
147-
:param np.array x: data to differentiate
148-
:param float or array dt_or_t: scalar dt or array of sample times. If an
149-
array is provided it is passed directly to np.gradient, giving correct
150-
results for non-uniformly sampled data.
148+
Because the DWT requires uniform spacing, this method only accepts a scalar
149+
time step dt (not a vector of sample times). For non-uniformly sampled data,
150+
use :func:`rbfdiff` or :func:`splinediff` instead.
151+
152+
:param np.array x: data to differentiate. May be multidimensional; see :code:`axis`.
153+
:param float dt: uniform time step between samples.
151154
:param str wavelet: PyWavelets wavelet name, e.g. 'db4', 'sym4', 'coif2'.
152155
'db4' is a solid general-purpose default. Biorthogonal wavelets such as
153156
'bior2.2' or 'bior4.4' are symmetric and designed for smooth reconstruction
@@ -170,39 +173,32 @@ def waveletdiff(x, dt_or_t, wavelet='db4', level=None, threshold=1.0, axis=0, mo
170173
:return: - **x_hat** (np.array) -- estimated (smoothed) x
171174
- **dxdt_hat** (np.array) -- estimated derivative of x
172175
"""
173-
N = x.shape[axis]
176+
if not np.isscalar(dt):
177+
raise ValueError(
178+
"`dt` must be a scalar. The DWT requires uniformly sampled data. "
179+
"For variable step sizes, use rbfdiff or splinediff instead."
180+
)
174181

175-
# Axis normalisation — bring target axis to front.
176-
# Skip moveaxis when axis is already 0 to avoid an unnecessary allocation.
177-
# When we do move, call ascontiguousarray immediately so the subsequent
178-
# reshape is guaranteed zero-copy.
179-
if axis == 0:
180-
x_work = x if x.flags['C_CONTIGUOUS'] else np.ascontiguousarray(x)
181-
else:
182-
x_work = np.ascontiguousarray(np.moveaxis(x, axis, 0))
182+
N = x.shape[axis]
183183

184+
# Bring axis of differentiation to front so each column of x_flat is one
185+
# signal to differentiate. moveaxis returns a view with updated strides,
186+
# so ascontiguousarray ensures the subsequent reshape is zero-copy.
187+
x_work = np.ascontiguousarray(np.moveaxis(x, axis, 0))
184188
shape = x_work.shape
185-
x_flat = x_work.reshape(N, -1) # (N, M) contiguous, no hidden copy
189+
x_flat = x_work.reshape(N, -1) # (N, M)
186190
M = x_flat.shape[1]
187191

188-
if np.isscalar(dt_or_t):
189-
grad_arg = dt_or_t
190-
else:
191-
if len(dt_or_t) != N:
192-
raise ValueError(
193-
"`dt_or_t` array must have the same length as x along `axis`."
194-
)
195-
grad_arg = dt_or_t # np.gradient accepts a full coordinate array
196-
197-
# Conservative level default avoids over-decomposing short signals
198-
# (pywt default uses the maximum possible level).
192+
# Conservative level cap: pywt's default uses the maximum possible level,
193+
# which can over-decompose short signals and wash out meaningful detail.
194+
# Capping at 5 keeps at least 2^5 = 32 samples in the coarsest subband,
195+
# which is enough to represent a smooth approximation without artefacts.
199196
if level is None:
200197
max_level = pywt.dwt_max_level(N, wavelet)
201198
level = min(max_level, 5)
202199

203-
# Decompose all columns and stack coefficients into 2-D arrays of shape
204-
# (coeff_len_i, M). Probing column 0 first lets us pre-allocate correctly;
205-
# the probe result is reused for col 0 so we pay N+1 wavedec calls total.
200+
# Decompose all columns; probe column 0 first to learn coefficient lengths
201+
# and pre-allocate, reusing that result so we only pay N+1 wavedec calls.
206202
_probe = pywt.wavedec(x_flat[:, 0], wavelet, level=level, mode=mode)
207203
coeff_lengths = [len(c) for c in _probe]
208204
n_levels = len(_probe)
@@ -221,10 +217,19 @@ def waveletdiff(x, dt_or_t, wavelet='db4', level=None, threshold=1.0, axis=0, mo
221217
coeffs_all[i][:, col] = c
222218

223219
# Vectorised noise estimation and soft-thresholding over all columns at once.
224-
# sigma: robust MAD estimator from finest detail level, shape (M,).
225-
# thresh: per-column universal threshold, shape (M,).
226-
# Approximation coefficients (index 0) are left untouched; only detail
227-
# levels (indices 1..n_levels-1) are thresholded.
220+
#
221+
# Soft-thresholding achieves smoothing by shrinking wavelet detail
222+
# coefficients toward zero: coefficients whose magnitude is below the
223+
# threshold (mostly noise) are zeroed out, while large coefficients (true
224+
# signal features) are kept but reduced by the threshold amount. Only detail
225+
# levels (indices 1..n_levels-1) are thresholded; the coarse approximation
226+
# coefficients (index 0) are left untouched.
227+
#
228+
# sigma: robust noise-level estimate via the median absolute deviation of
229+
# the finest detail level. Dividing by 0.6745 converts MAD to an
230+
# estimate of the Gaussian standard deviation.
231+
# thresh: per-column Donoho & Johnstone (1994) universal threshold,
232+
# sigma * sqrt(2 * log(N)), scaled by the user-supplied `threshold`.
228233
sigma = np.median(np.abs(coeffs_all[-1]), axis=0) / 0.6745
229234
np.maximum(sigma, 1e-10, out=sigma) # floor avoids zero threshold on clean signals
230235

@@ -235,24 +240,43 @@ def waveletdiff(x, dt_or_t, wavelet='db4', level=None, threshold=1.0, axis=0, mo
235240
for c in coeffs_all[1:]
236241
]
237242

238-
# Reconstruct and differentiate — pywt.waverec is 1-D only so a column
239-
# loop remains, but all Python-level arithmetic has been moved out above.
243+
# Build derivative reconstruction filters from the wavelet's reconstruction
244+
# filters. Because the DWT reconstructs a signal as a linear combination of
245+
# shifted scaling/wavelet functions, the derivative of the reconstruction is
246+
# the same linear combination of the *derivatives* of those basis functions.
247+
# We obtain derivative filters by finite-differencing the reconstruction
248+
# lowpass filter (rec_lo), then scaling by 1/dt to convert discrete
249+
# differences to continuous-time derivatives.
250+
w = pywt.Wavelet(wavelet)
251+
rec_lo = np.array(w.rec_lo)
252+
# First-order finite difference of the filter gives the derivative filter.
253+
# np.diff shortens by 1; padding with a leading zero keeps the filter length
254+
# and phase consistent with the original so waverec alignment is preserved.
255+
d_rec_lo = np.concatenate(([0.0], np.diff(rec_lo))) / dt
256+
d_rec_hi = np.concatenate(([0.0], np.diff(np.array(w.rec_hi)))) / dt
257+
258+
# Reconstruct x_hat and dxdt_hat column by column.
259+
# pywt.waverec is 1-D only, so the column loop is unavoidable here;
260+
# the vectorised operations above have already moved all Python-level
261+
# arithmetic outside this loop.
240262
x_hat_flat = np.empty_like(x_flat)
241263
dxdt_hat_flat = np.empty_like(x_flat)
242264

243265
for col in range(M):
244266
col_coeffs = [coeffs_denoised[i][:, col] for i in range(n_levels)]
245-
x_hat_col = pywt.waverec(col_coeffs, wavelet, mode=mode)[:N]
246-
x_hat_flat[:, col] = x_hat_col
247-
dxdt_hat_flat[:, col] = np.gradient(x_hat_col, grad_arg)
248267

249-
# Restore original shape and axis order.
250-
# moveaxis on the way out is only needed when we moved on the way in.
251-
x_hat = x_hat_flat.reshape(shape)
252-
dxdt_hat = dxdt_hat_flat.reshape(shape)
268+
# Standard reconstruction for the smoothed signal.
269+
x_hat_flat[:, col] = pywt.waverec(col_coeffs, wavelet, mode=mode)[:N]
253270

254-
if axis != 0:
255-
x_hat = np.moveaxis(x_hat, 0, axis)
256-
dxdt_hat = np.moveaxis(dxdt_hat, 0, axis)
271+
# Derivative reconstruction: replace the wavelet's reconstruction
272+
# filters with their finite-difference derivatives and run waverec.
273+
d_wavelet = pywt.Wavelet(
274+
filter_bank=(w.dec_lo, w.dec_hi, d_rec_lo, d_rec_hi)
275+
)
276+
dxdt_hat_flat[:, col] = pywt.waverec(col_coeffs, d_wavelet, mode=mode)[:N]
277+
278+
# Restore original shape and axis order.
279+
x_hat = np.moveaxis(x_hat_flat.reshape(shape), 0, axis)
280+
dxdt_hat = np.moveaxis(dxdt_hat_flat.reshape(shape), 0, axis)
257281

258282
return x_hat, dxdt_hat

pynumdiff/tests/test_diff_methods.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -332,19 +332,15 @@ def test_diff_method(diff_method_and_params, test_func_and_deriv, request): # re
332332
(kerneldiff, {'kernel': 'gaussian', 'window_size': 5}),
333333
(butterdiff, {'filter_order': 3, 'cutoff_freq': 1 - 1e-6}),
334334
(finitediff, {}),
335-
<<<<<<< HEAD
336-
(savgoldiff, {'degree': 3, 'window_size': 11, 'smoothing_win': 3}),
337-
(waveletdiff, {'wavelet': 'db4', 'threshold': 1.0}),
338-
=======
339335
(polydiff, {'degree': 2, 'window_size': 5}),
340336
(savgoldiff, {'degree': 3, 'window_size': 11, 'smoothing_win': 3}),
337+
(waveletdiff, {'wavelet': 'db4', 'threshold': 1.0}),
341338
(rtsdiff, {'order':2, 'log_qr_ratio':7, 'forwardbackward':True}),
342339
(spectraldiff, {'high_freq_cutoff': 0.25, 'pad_to_zero_dxdt': False}),
343340
(rbfdiff, {'sigma': 0.5, 'lmbd': 1e-6}),
344341
(splinediff, {'degree': 9, 's': 1e-6}),
345342
(robustdiff, {'order':2, 'log_q':7, 'log_r':2}),
346343
(tvrdiff, {'order': 3, 'gamma': 1e-4})
347-
>>>>>>> b38199f982cb4036065f599b3fe00f6076671a6a
348344
]
349345

350346
# Similar to the error_bounds table, index by method first. But then we test against only one 2D function,
@@ -355,10 +351,7 @@ def test_diff_method(diff_method_and_params, test_func_and_deriv, request): # re
355351
kerneldiff: [(2, 1), (3, 2)],
356352
butterdiff: [(0, -1), (1, -1)],
357353
finitediff: [(0, -1), (1, -1)],
358-
<<<<<<< HEAD
359-
savgoldiff: [(0, -1), (1, 1)],
360354
waveletdiff: [(1, 0), (2, 1)],
361-
=======
362355
polydiff: [(1, -1), (1, 0)],
363356
savgoldiff: [(0, -1), (1, 1)],
364357
rtsdiff: [(1, -1), (1, 0)],
@@ -367,7 +360,6 @@ def test_diff_method(diff_method_and_params, test_func_and_deriv, request): # re
367360
splinediff: [(0, -1), (1, 0)],
368361
robustdiff: [(-2, -3), (0, -1)],
369362
tvrdiff: [(0, -1), (1, 0)]
370-
>>>>>>> b38199f982cb4036065f599b3fe00f6076671a6a
371363
}
372364

373365
@mark.parametrize("multidim_method_and_params", multidim_methods_and_params)
@@ -420,9 +412,6 @@ def test_multidimensionality(multidim_method_and_params, request):
420412
ax2.plot_wireframe(T1, T2, computed_d2)
421413
ax3.plot_wireframe(T1, T2, computed_laplacian, label='computed')
422414
legend = ax3.legend(bbox_to_anchor=(0.7, 0.8)); legend.legend_handles[0].set_facecolor(pyplot.cm.viridis(0.6))
423-
<<<<<<< HEAD
424-
fig.suptitle(f'{diff_method.__name__}', fontsize=16)
425-
=======
426415
fig.suptitle(f'{diff_method.__name__}', fontsize=16)
427416

428417

@@ -475,4 +464,3 @@ def test_missing_data(diff_method_and_params):
475464

476465
assert np.all(np.isfinite(x_hat))
477466
assert np.all(np.isfinite(dxdt_hat))
478-
>>>>>>> b38199f982cb4036065f599b3fe00f6076671a6a

0 commit comments

Comments
 (0)