11from fractions import Fraction
2+ from typing import Generator , Iterable , cast
23
34import numpy as np
45from numpy import e , inf , pi
6+ from numpy .typing import NDArray
57from scipy .signal import sosfilt
68
7- from ._waveform import (_D , COS , COSH , DRAG , ERF , EXP , EXPONENTIALCHIRP ,
8- GAUSSIAN , HYPERBOLICCHIRP , INTERP , LINEAR , LINEARCHIRP ,
9- NDIGITS , SINC , SINH , _baseFunc , _baseFunc_latex ,
10- _const , _half , _one , _zero , add , basic_wave ,
11- calc_parts , filter , is_const , merge_waveform , mul , pow ,
12- registerBaseFunc , registerBaseFuncLatex ,
13- registerDerivative , shift , simplify , wave_sum )
9+ from ._waveform import (
10+ _D , COS , COSH , DRAG , ERF , EXP , EXPONENTIALCHIRP , GAUSSIAN , HYPERBOLICCHIRP ,
11+ INTERP , LINEAR , LINEARCHIRP , MOLLIFIER , NDIGITS , SINC , SINH , _baseFunc ,
12+ _baseFunc_latex , _const , _half , _one , _zero , add , basic_wave , calc_parts ,
13+ filter , is_const , merge_waveform , mul , pow , registerBaseFunc ,
14+ registerBaseFuncLatex , registerDerivative , shift , simplify , wave_sum )
1415
1516
1617def _test_spec_num (num , spec ):
@@ -124,7 +125,7 @@ def __init__(self, bounds=(+inf, ), seq=(_zero, ), min=-inf, max=inf):
124125 self .start = None
125126 self .stop = None
126127 self .sample_rate = None
127- self .filters = None
128+ self .filters : tuple [ np . ndarray , float ] | None = None
128129 self .label = None
129130
130131 @staticmethod
@@ -160,12 +161,14 @@ def end(self):
160161 else :
161162 return min (self .stop , self ._end (self .bounds , self .seq ))
162163
163- def sample (self ,
164- sample_rate = None ,
165- out = None ,
166- chunk_size = None ,
167- function_lib = None ,
168- filters = None ):
164+ def sample (
165+ self ,
166+ sample_rate = None ,
167+ out : np .ndarray | None = None ,
168+ chunk_size = None ,
169+ function_lib = None ,
170+ filters : tuple [np .ndarray , float ] | None = None
171+ ) -> np .ndarray | Iterable [np .ndarray ]:
169172 if sample_rate is None :
170173 sample_rate = self .sample_rate
171174 if self .start is None or self .stop is None or sample_rate is None :
@@ -184,17 +187,20 @@ def sample(self,
184187 elif not sos .flags .writeable :
185188 sos = sos .copy ()
186189 if initial :
187- sig = sosfilt (sos , sig - initial ) + initial
190+ sig = cast (np .ndarray , sosfilt (sos ,
191+ sig - initial )) + initial
188192 else :
189- sig = sosfilt (sos , sig )
190- return sig
193+ sig = cast ( np . ndarray , sosfilt (sos , sig ) )
194+ return cast ( np . ndarray , sig )
191195 else :
192196 return self ._sample_iter (sample_rate , chunk_size , out ,
193197 function_lib , filters )
194198
195- def _sample_iter (self , sample_rate , chunk_size , out , function_lib ,
196- filters ):
197- start = self .start
199+ def _sample_iter (
200+ self , sample_rate , chunk_size , out : np .ndarray | None , function_lib ,
201+ filters : tuple [np .ndarray , float ] | None
202+ ) -> Generator [np .ndarray , None , None ]:
203+ start = cast (float , self .start )
198204 start_n = 0
199205 if filters is not None :
200206 sos , initial = filters
@@ -205,10 +211,10 @@ def _sample_iter(self, sample_rate, chunk_size, out, function_lib,
205211 # zi = sosfilt_zi(sos)
206212 zi = np .zeros ((sos .shape [0 ], 2 ))
207213 length = chunk_size / sample_rate
208- while start < self .stop :
209- if start + length > self .stop :
210- length = self .stop - start
211- stop = self .stop
214+ while start < cast ( float , self .stop ) :
215+ if start + length > cast ( float , self .stop ) :
216+ length = cast ( float , self .stop ) - start
217+ stop = cast ( float , self .stop )
212218 size = round ((stop - start ) * sample_rate )
213219 else :
214220 stop = start + length
@@ -217,21 +223,25 @@ def _sample_iter(self, sample_rate, chunk_size, out, function_lib,
217223
218224 if filters is None :
219225 if out is not None :
220- yield self .__call__ (x ,
221- out = out [start_n :],
222- function_lib = function_lib )
226+ yield cast (
227+ np .ndarray ,
228+ self .__call__ (x ,
229+ out = out [start_n :],
230+ function_lib = function_lib ))
223231 else :
224- yield self .__call__ (x , function_lib = function_lib )
232+ yield cast (np .ndarray ,
233+ self .__call__ (x , function_lib = function_lib ))
225234 else :
226- sig = self .__call__ (x , function_lib = function_lib )
235+ sig = cast (np .ndarray ,
236+ self .__call__ (x , function_lib = function_lib ))
227237 if initial :
228238 sig -= initial
229239 sig , zi = sosfilt (sos , sig , zi = zi )
230240 if initial :
231241 sig += initial
232242 if out is not None :
233243 out [start_n :start_n + size ] = sig
234- yield sig
244+ yield cast ( np . ndarray , sig )
235245
236246 start = stop
237247 start_n += chunk_size
@@ -506,16 +516,21 @@ def _fill_parts(parts, out):
506516 for start , stop , part in parts :
507517 out [start :stop ] += part
508518
509- def __call__ (self ,
510- x ,
511- frag = False ,
512- out = None ,
513- accumulate = False ,
514- function_lib = None ):
519+ def __call__ (
520+ self ,
521+ x ,
522+ frag = False ,
523+ out : np .ndarray | None = None ,
524+ accumulate = False ,
525+ function_lib = None
526+ ) -> NDArray [np .float64 ] | list [tuple [int , int ,
527+ NDArray [np .float64 ]]] | np .float64 :
515528 if function_lib is None :
516529 function_lib = _baseFunc
517530 if isinstance (x , (int , float , complex )):
518- return self .__call__ (np .array ([x ]), function_lib = function_lib )[0 ]
531+ return cast (
532+ NDArray [np .float64 ],
533+ self .__call__ (np .array ([x ]), function_lib = function_lib ))[0 ]
519534 parts , dtype = calc_parts (self .bounds , self .seq , x , function_lib ,
520535 self .min , self .max )
521536 if not frag :
@@ -965,6 +980,25 @@ def _format_DRAG(shift, *args):
965980 return f"DRAG(...)"
966981
967982
983+ def _format_MOLLIFIER (shift , * args ):
984+ r = _num_latex (args [0 ])
985+ d = _num_latex (args [1 ])
986+ shift_str = _num_latex (- shift )
987+ if shift_str == '0' :
988+ shift_str = ''
989+ elif shift_str [0 ] != '-' :
990+ shift_str = '+' + shift_str
991+
992+ if d == '0' :
993+ return f"\\ mathrm{{Mollifier}}\\ left(t{ shift_str } , r={ r } \\ right)"
994+ elif d == '1' :
995+ return f"\\ mathrm{{Mollifier}}'\\ left(t{ shift_str } , r={ r } \\ right)"
996+ elif d == '2' :
997+ return f"\\ mathrm{{Mollifier}}''\\ left(t{ shift_str } , r={ r } \\ right)"
998+ else :
999+ return f"\\ mathrm{{Mollifier}}^{{({ d } )}}\\ left(t{ shift_str } , r={ r } \\ right)"
1000+
1001+
9681002registerBaseFuncLatex (LINEAR , _format_LINEAR )
9691003registerBaseFuncLatex (GAUSSIAN , _format_GAUSSIAN )
9701004registerBaseFuncLatex (ERF , _format_ERF )
@@ -974,12 +1008,26 @@ def _format_DRAG(shift, *args):
9741008registerBaseFuncLatex (COSH , _format_COSH )
9751009registerBaseFuncLatex (SINH , _format_SINH )
9761010registerBaseFuncLatex (DRAG , _format_DRAG )
1011+ registerBaseFuncLatex (MOLLIFIER , _format_MOLLIFIER )
9771012
9781013
979- def D (wav ) :
1014+ def D (wav : Waveform , d : int = 1 ) -> Waveform :
9801015 """derivative
1016+
1017+ Parameters
1018+ ----------
1019+ wav : Waveform
1020+ The waveform to take the derivative of.
1021+ d : int, optional
1022+ The order of the derivative, by default 1.
9811023 """
982- return Waveform (bounds = wav .bounds , seq = tuple (_D (x ) for x in wav .seq ))
1024+ assert d >= 0 and isinstance (d , int ), "d must be a non-negative integer"
1025+ if d == 0 :
1026+ return wav
1027+ elif d == 1 :
1028+ return Waveform (bounds = wav .bounds , seq = tuple (_D (x ) for x in wav .seq ))
1029+ else :
1030+ return D (D (wav , d - 1 ), 1 )
9831031
9841032
9851033def convolve (a , b ):
@@ -1189,6 +1237,40 @@ def slepian(duration, *arg):
11891237 return wav * square (duration )
11901238
11911239
1240+ def mollifier (width , plateau : float = 0.0 , d : int = 0 ):
1241+ """
1242+ Mollifier function is a smooth function that is 1 at the origin and 0 outside a certain radius.
1243+ It is defined as:
1244+
1245+ f(x) = exp(1 / ((x / r) ^ 2 - 1) + 1) in case |x| < r
1246+ = 0 in case |x| >= r
1247+ where r = width / 2 is the radius of the mollifier.
1248+
1249+ The parameter plateau is the width of the plateau.
1250+ The parameter d is the order of the derivative.
1251+ """
1252+ assert d >= 0 and isinstance (d , int ), "d must be a non-negative integer"
1253+ assert width > 0 , "width must be positive"
1254+
1255+ if plateau <= 0 :
1256+ return Waveform (bounds = (- 0.5 * width , 0.5 * width , inf ),
1257+ seq = (_zero , basic_wave (MOLLIFIER , width / 2 ,
1258+ d ), _zero ))
1259+ else :
1260+ return Waveform (bounds = (- 0.5 * width - 0.5 * plateau , - 0.5 * plateau ,
1261+ 0.5 * plateau , 0.5 * width + 0.5 * plateau ,
1262+ inf ),
1263+ seq = (_zero ,
1264+ basic_wave (MOLLIFIER ,
1265+ width / 2 ,
1266+ d ,
1267+ shift = - 0.5 * plateau ), _one ,
1268+ basic_wave (MOLLIFIER ,
1269+ width / 2 ,
1270+ d ,
1271+ shift = 0.5 * plateau ), _zero ))
1272+
1273+
11921274def _poly (* a ):
11931275 """
11941276 a[0] + a[1] * t + a[2] * t**2 + ...
@@ -1384,7 +1466,7 @@ def mixing(I,
13841466__all__ = [
13851467 'D' , 'Waveform' , 'chirp' , 'const' , 'cos' , 'cosh' , 'coshPulse' , 'cosPulse' ,
13861468 'cut' , 'drag' , 'exp' , 'function' , 'gaussian' , 'general_cosine' , 'hanning' ,
1387- 'interp' , 'mixing' , 'one' , 'poly' , 'registerBaseFunc' ,
1469+ 'interp' , 'mixing' , 'mollifier' , ' one' , 'poly' , 'registerBaseFunc' ,
13881470 'registerDerivative' , 'samplingPoints' , 'sign' , 'sin' , 'sinc' , 'sinh' ,
13891471 'square' , 'step' , 't' , 'zero'
13901472]
0 commit comments