@@ -136,18 +136,21 @@ def rbfdiff(x, dt_or_t, sigma=1, lmbd=0.01, axis=0):
136136 return np .moveaxis (x_hat_flattened .reshape (plump ), 0 , axis ), np .moveaxis (dxdt_hat_flattened .reshape (plump ), 0 , axis )
137137
138138
139- def waveletdiff (x , dt_or_t , wavelet = 'db4' , level = None , threshold = 1.0 , axis = 0 , mode = 'periodization' ):
139+ def waveletdiff (x , dt , wavelet = 'db4' , level = None , threshold = 1.0 , axis = 0 , mode = 'periodization' ):
140140 """Smooth and differentiate noisy data via discrete wavelet denoising.
141141
142142 Decomposes x into wavelet detail and approximation coefficients, soft-thresholds
143143 the detail coefficients to remove noise using the Donoho & Johnstone (1994)
144144 universal threshold estimator, reconstructs a smoothed signal, then
145- differentiates with finite differences via np.gradient.
145+ differentiates analytically by applying derivative reconstruction filters to
146+ the denoised wavelet coefficients.
146147
147- :param np.array x: data to differentiate
148- :param float or array dt_or_t: scalar dt or array of sample times. If an
149- array is provided it is passed directly to np.gradient, giving correct
150- results for non-uniformly sampled data.
148+ Because the DWT requires uniform spacing, this method only accepts a scalar
149+ time step dt (not a vector of sample times). For non-uniformly sampled data,
150+ use :func:`rbfdiff` or :func:`splinediff` instead.
151+
152+ :param np.array x: data to differentiate. May be multidimensional; see :code:`axis`.
153+ :param float dt: uniform time step between samples.
151154 :param str wavelet: PyWavelets wavelet name, e.g. 'db4', 'sym4', 'coif2'.
152155 'db4' is a solid general-purpose default. Biorthogonal wavelets such as
153156 'bior2.2' or 'bior4.4' are symmetric and designed for smooth reconstruction
@@ -170,39 +173,32 @@ def waveletdiff(x, dt_or_t, wavelet='db4', level=None, threshold=1.0, axis=0, mo
170173 :return: - **x_hat** (np.array) -- estimated (smoothed) x
171174 - **dxdt_hat** (np.array) -- estimated derivative of x
172175 """
173- N = x .shape [axis ]
176+ if not np .isscalar (dt ):
177+ raise ValueError (
178+ "`dt` must be a scalar. The DWT requires uniformly sampled data. "
179+ "For variable step sizes, use rbfdiff or splinediff instead."
180+ )
174181
175- # Axis normalisation — bring target axis to front.
176- # Skip moveaxis when axis is already 0 to avoid an unnecessary allocation.
177- # When we do move, call ascontiguousarray immediately so the subsequent
178- # reshape is guaranteed zero-copy.
179- if axis == 0 :
180- x_work = x if x .flags ['C_CONTIGUOUS' ] else np .ascontiguousarray (x )
181- else :
182- x_work = np .ascontiguousarray (np .moveaxis (x , axis , 0 ))
182+ N = x .shape [axis ]
183183
184+ # Bring axis of differentiation to front so each column of x_flat is one
185+ # signal to differentiate. moveaxis returns a view with updated strides,
186+ # so ascontiguousarray ensures the subsequent reshape is zero-copy.
187+ x_work = np .ascontiguousarray (np .moveaxis (x , axis , 0 ))
184188 shape = x_work .shape
185- x_flat = x_work .reshape (N , - 1 ) # (N, M) contiguous, no hidden copy
189+ x_flat = x_work .reshape (N , - 1 ) # (N, M)
186190 M = x_flat .shape [1 ]
187191
188- if np .isscalar (dt_or_t ):
189- grad_arg = dt_or_t
190- else :
191- if len (dt_or_t ) != N :
192- raise ValueError (
193- "`dt_or_t` array must have the same length as x along `axis`."
194- )
195- grad_arg = dt_or_t # np.gradient accepts a full coordinate array
196-
197- # Conservative level default avoids over-decomposing short signals
198- # (pywt default uses the maximum possible level).
192+ # Conservative level cap: pywt's default uses the maximum possible level,
193+ # which can over-decompose short signals and wash out meaningful detail.
194+ # Capping at 5 keeps at least 2^5 = 32 samples in the coarsest subband,
195+ # which is enough to represent a smooth approximation without artefacts.
199196 if level is None :
200197 max_level = pywt .dwt_max_level (N , wavelet )
201198 level = min (max_level , 5 )
202199
203- # Decompose all columns and stack coefficients into 2-D arrays of shape
204- # (coeff_len_i, M). Probing column 0 first lets us pre-allocate correctly;
205- # the probe result is reused for col 0 so we pay N+1 wavedec calls total.
200+ # Decompose all columns; probe column 0 first to learn coefficient lengths
201+ # and pre-allocate, reusing that result so we only pay N+1 wavedec calls.
206202 _probe = pywt .wavedec (x_flat [:, 0 ], wavelet , level = level , mode = mode )
207203 coeff_lengths = [len (c ) for c in _probe ]
208204 n_levels = len (_probe )
@@ -221,10 +217,19 @@ def waveletdiff(x, dt_or_t, wavelet='db4', level=None, threshold=1.0, axis=0, mo
221217 coeffs_all [i ][:, col ] = c
222218
223219 # Vectorised noise estimation and soft-thresholding over all columns at once.
224- # sigma: robust MAD estimator from finest detail level, shape (M,).
225- # thresh: per-column universal threshold, shape (M,).
226- # Approximation coefficients (index 0) are left untouched; only detail
227- # levels (indices 1..n_levels-1) are thresholded.
220+ #
221+ # Soft-thresholding achieves smoothing by shrinking wavelet detail
222+ # coefficients toward zero: coefficients whose magnitude is below the
223+ # threshold (mostly noise) are zeroed out, while large coefficients (true
224+ # signal features) are kept but reduced by the threshold amount. Only detail
225+ # levels (indices 1..n_levels-1) are thresholded; the coarse approximation
226+ # coefficients (index 0) are left untouched.
227+ #
228+ # sigma: robust noise-level estimate via the median absolute deviation of
229+ # the finest detail level. Dividing by 0.6745 converts MAD to an
230+ # estimate of the Gaussian standard deviation.
231+ # thresh: per-column Donoho & Johnstone (1994) universal threshold,
232+ # sigma * sqrt(2 * log(N)), scaled by the user-supplied `threshold`.
228233 sigma = np .median (np .abs (coeffs_all [- 1 ]), axis = 0 ) / 0.6745
229234 np .maximum (sigma , 1e-10 , out = sigma ) # floor avoids zero threshold on clean signals
230235
@@ -235,24 +240,43 @@ def waveletdiff(x, dt_or_t, wavelet='db4', level=None, threshold=1.0, axis=0, mo
235240 for c in coeffs_all [1 :]
236241 ]
237242
238- # Reconstruct and differentiate — pywt.waverec is 1-D only so a column
239- # loop remains, but all Python-level arithmetic has been moved out above.
243+ # Build derivative reconstruction filters from the wavelet's reconstruction
244+ # filters. Because the DWT reconstructs a signal as a linear combination of
245+ # shifted scaling/wavelet functions, the derivative of the reconstruction is
246+ # the same linear combination of the *derivatives* of those basis functions.
247+ # We obtain derivative filters by finite-differencing the reconstruction
248+ # lowpass filter (rec_lo), then scaling by 1/dt to convert discrete
249+ # differences to continuous-time derivatives.
250+ w = pywt .Wavelet (wavelet )
251+ rec_lo = np .array (w .rec_lo )
252+ # First-order finite difference of the filter gives the derivative filter.
253+ # np.diff shortens by 1; padding with a leading zero keeps the filter length
254+ # and phase consistent with the original so waverec alignment is preserved.
255+ d_rec_lo = np .concatenate (([0.0 ], np .diff (rec_lo ))) / dt
256+ d_rec_hi = np .concatenate (([0.0 ], np .diff (np .array (w .rec_hi )))) / dt
257+
258+ # Reconstruct x_hat and dxdt_hat column by column.
259+ # pywt.waverec is 1-D only, so the column loop is unavoidable here;
260+ # the vectorised operations above have already moved all Python-level
261+ # arithmetic outside this loop.
240262 x_hat_flat = np .empty_like (x_flat )
241263 dxdt_hat_flat = np .empty_like (x_flat )
242264
243265 for col in range (M ):
244266 col_coeffs = [coeffs_denoised [i ][:, col ] for i in range (n_levels )]
245- x_hat_col = pywt .waverec (col_coeffs , wavelet , mode = mode )[:N ]
246- x_hat_flat [:, col ] = x_hat_col
247- dxdt_hat_flat [:, col ] = np .gradient (x_hat_col , grad_arg )
248267
249- # Restore original shape and axis order.
250- # moveaxis on the way out is only needed when we moved on the way in.
251- x_hat = x_hat_flat .reshape (shape )
252- dxdt_hat = dxdt_hat_flat .reshape (shape )
268+ # Standard reconstruction for the smoothed signal.
269+ x_hat_flat [:, col ] = pywt .waverec (col_coeffs , wavelet , mode = mode )[:N ]
253270
254- if axis != 0 :
255- x_hat = np .moveaxis (x_hat , 0 , axis )
256- dxdt_hat = np .moveaxis (dxdt_hat , 0 , axis )
271+ # Derivative reconstruction: replace the wavelet's reconstruction
272+ # filters with their finite-difference derivatives and run waverec.
273+ d_wavelet = pywt .Wavelet (
274+ filter_bank = (w .dec_lo , w .dec_hi , d_rec_lo , d_rec_hi )
275+ )
276+ dxdt_hat_flat [:, col ] = pywt .waverec (col_coeffs , d_wavelet , mode = mode )[:N ]
277+
278+ # Restore original shape and axis order.
279+ x_hat = np .moveaxis (x_hat_flat .reshape (shape ), 0 , axis )
280+ dxdt_hat = np .moveaxis (dxdt_hat_flat .reshape (shape ), 0 , axis )
257281
258282 return x_hat , dxdt_hat
0 commit comments