Skip to content

Commit 236d70b

Browse files
committed
Add mollifier function and update version to 2.2.0
- Introduced a new mollifier function for smooth transitions in waveforms. - Added support for the mollifier in the Waveform class and its associated formatting. - Updated version number to 2.2.0 to reflect the new feature addition.
1 parent 1a36e76 commit 236d70b

6 files changed

Lines changed: 159 additions & 58 deletions

File tree

waveforms/__init__.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
from .version import __version__
55
from .waveform import (D, Waveform, WaveVStack, chirp, const, cos, cosh,
66
coshPulse, cosPulse, cut, drag, exp, function, gaussian,
7-
general_cosine, hanning, interp, mixing, one, poly,
8-
registerBaseFunc, registerDerivative, samplingPoints,
9-
sign, sin, sinc, sinh, square, step, t, zero)
7+
general_cosine, hanning, interp, mixing, mollifier, one,
8+
poly, registerBaseFunc, registerDerivative,
9+
samplingPoints, sign, sin, sinc, sinh, square, step, t,
10+
zero)
1011
from .waveform_parser import wave_eval

waveforms/_waveform.pyi

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def is_const(x: tuple[tuple, tuple]) -> bool:
2828
pass
2929

3030

31-
def basic_wave(Type, *args, shift=0) -> tuple[tuple, tuple]:
31+
def basic_wave(Type: int, *args, shift: float = 0) -> tuple[tuple, tuple]:
3232
pass
3333

3434

@@ -53,7 +53,7 @@ def calc_parts(bounds: tuple,
5353
x: np.ndarray,
5454
function_lib: dict,
5555
min=-inf,
56-
max=inf) -> tuple[list[np.ndarray], type]:
56+
max=inf) -> tuple[list[tuple[int, int, np.ndarray]], type]:
5757
pass
5858

5959

@@ -103,6 +103,7 @@ HYPERBOLICCHIRP: int = ...
103103
COSH: int = ...
104104
SINH: int = ...
105105
DRAG: int = ...
106+
MOLLIFIER: int = ...
106107

107108

108109
def simplify(expr: tuple[tuple, tuple], eps: float) -> tuple[tuple, tuple]:

waveforms/_waveform.pyx

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,18 @@ def _drag(t: np.ndarray, t0: float, freq: float, width: float, delta: float,
351351
return Omega_x * np.cos(wt) + Omega_y * np.sin(wt)
352352

353353

354+
def _mollifier(t: np.ndarray, r: float, d: int):
355+
x = t / r
356+
if d == 0:
357+
return np.exp(1 / (np.abs(x)**2 - 1) + 1)
358+
else:
359+
p = np.poly1d([-2, 0])
360+
for n in range(1, d):
361+
p = np.poly1d([1, 0, -2, 0, 1]) * p.deriv() + np.poly1d(
362+
[-4 * n, 0, 4 * n - 2, 0]) * p
363+
return np.exp(1 / (np.abs(x)**2 - 1) + 1) * p(x) / (1 - x**2)**(2 * d) / r**d
364+
365+
354366
LINEAR = registerBaseFunc(_LINEAR)
355367
GAUSSIAN = registerBaseFunc(_GAUSSIAN)
356368
ERF = registerBaseFunc(_ERF)
@@ -364,6 +376,7 @@ HYPERBOLICCHIRP = registerBaseFunc(_HYPERBOLICCHIRP)
364376
COSH = registerBaseFunc(_COSH)
365377
SINH = registerBaseFunc(_SINH)
366378
DRAG = registerBaseFunc(_drag)
379+
MOLLIFIER = registerBaseFunc(_mollifier)
367380

368381

369382
def _d_LINEAR(shift, *args):
@@ -433,6 +446,10 @@ def _d_HYPERBOLICCHIRP(shift, f0, k, phi0):
433446
shift)), (-1, 1)), ), (2 * pi * f0, ))
434447

435448

449+
def _d_MOLLIFIER(shift, r, d):
450+
return (((((MOLLIFIER, r, d+1, shift), ), (1, )), ), (1, ))
451+
452+
436453
# register derivative
437454
registerDerivative(LINEAR, _d_LINEAR)
438455
registerDerivative(GAUSSIAN, _d_GAUSSIAN)
@@ -446,6 +463,7 @@ registerDerivative(SINH, _d_SINH)
446463
registerDerivative(LINEARCHIRP, _d_LINEARCHIRP)
447464
registerDerivative(EXPONENTIALCHIRP, _d_EXPONENTIALCHIRP)
448465
registerDerivative(HYPERBOLICCHIRP, _d_HYPERBOLICCHIRP)
466+
registerDerivative(MOLLIFIER, _d_MOLLIFIER)
449467

450468

451469
def _cos_power_n(x, n):

waveforms/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
"""Define version number here and read it from setup.py automatically"""
2-
__version__ = "2.1.1"
2+
__version__ = "2.2.0"

waveforms/waveform.py

Lines changed: 122 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,17 @@
11
from fractions import Fraction
2+
from typing import Generator, Iterable, cast
23

34
import numpy as np
45
from numpy import e, inf, pi
6+
from numpy.typing import NDArray
57
from scipy.signal import sosfilt
68

7-
from ._waveform import (_D, COS, COSH, DRAG, ERF, EXP, EXPONENTIALCHIRP,
8-
GAUSSIAN, HYPERBOLICCHIRP, INTERP, LINEAR, LINEARCHIRP,
9-
NDIGITS, SINC, SINH, _baseFunc, _baseFunc_latex,
10-
_const, _half, _one, _zero, add, basic_wave,
11-
calc_parts, filter, is_const, merge_waveform, mul, pow,
12-
registerBaseFunc, registerBaseFuncLatex,
13-
registerDerivative, shift, simplify, wave_sum)
9+
from ._waveform import (
10+
_D, COS, COSH, DRAG, ERF, EXP, EXPONENTIALCHIRP, GAUSSIAN, HYPERBOLICCHIRP,
11+
INTERP, LINEAR, LINEARCHIRP, MOLLIFIER, NDIGITS, SINC, SINH, _baseFunc,
12+
_baseFunc_latex, _const, _half, _one, _zero, add, basic_wave, calc_parts,
13+
filter, is_const, merge_waveform, mul, pow, registerBaseFunc,
14+
registerBaseFuncLatex, registerDerivative, shift, simplify, wave_sum)
1415

1516

1617
def _test_spec_num(num, spec):
@@ -124,7 +125,7 @@ def __init__(self, bounds=(+inf, ), seq=(_zero, ), min=-inf, max=inf):
124125
self.start = None
125126
self.stop = None
126127
self.sample_rate = None
127-
self.filters = None
128+
self.filters: tuple[np.ndarray, float] | None = None
128129
self.label = None
129130

130131
@staticmethod
@@ -160,12 +161,14 @@ def end(self):
160161
else:
161162
return min(self.stop, self._end(self.bounds, self.seq))
162163

163-
def sample(self,
164-
sample_rate=None,
165-
out=None,
166-
chunk_size=None,
167-
function_lib=None,
168-
filters=None):
164+
def sample(
165+
self,
166+
sample_rate=None,
167+
out: np.ndarray | None = None,
168+
chunk_size=None,
169+
function_lib=None,
170+
filters: tuple[np.ndarray, float] | None = None
171+
) -> np.ndarray | Iterable[np.ndarray]:
169172
if sample_rate is None:
170173
sample_rate = self.sample_rate
171174
if self.start is None or self.stop is None or sample_rate is None:
@@ -184,17 +187,20 @@ def sample(self,
184187
elif not sos.flags.writeable:
185188
sos = sos.copy()
186189
if initial:
187-
sig = sosfilt(sos, sig - initial) + initial
190+
sig = cast(np.ndarray, sosfilt(sos,
191+
sig - initial)) + initial
188192
else:
189-
sig = sosfilt(sos, sig)
190-
return sig
193+
sig = cast(np.ndarray, sosfilt(sos, sig))
194+
return cast(np.ndarray, sig)
191195
else:
192196
return self._sample_iter(sample_rate, chunk_size, out,
193197
function_lib, filters)
194198

195-
def _sample_iter(self, sample_rate, chunk_size, out, function_lib,
196-
filters):
197-
start = self.start
199+
def _sample_iter(
200+
self, sample_rate, chunk_size, out: np.ndarray | None, function_lib,
201+
filters: tuple[np.ndarray, float] | None
202+
) -> Generator[np.ndarray, None, None]:
203+
start = cast(float, self.start)
198204
start_n = 0
199205
if filters is not None:
200206
sos, initial = filters
@@ -205,10 +211,10 @@ def _sample_iter(self, sample_rate, chunk_size, out, function_lib,
205211
# zi = sosfilt_zi(sos)
206212
zi = np.zeros((sos.shape[0], 2))
207213
length = chunk_size / sample_rate
208-
while start < self.stop:
209-
if start + length > self.stop:
210-
length = self.stop - start
211-
stop = self.stop
214+
while start < cast(float, self.stop):
215+
if start + length > cast(float, self.stop):
216+
length = cast(float, self.stop) - start
217+
stop = cast(float, self.stop)
212218
size = round((stop - start) * sample_rate)
213219
else:
214220
stop = start + length
@@ -217,21 +223,25 @@ def _sample_iter(self, sample_rate, chunk_size, out, function_lib,
217223

218224
if filters is None:
219225
if out is not None:
220-
yield self.__call__(x,
221-
out=out[start_n:],
222-
function_lib=function_lib)
226+
yield cast(
227+
np.ndarray,
228+
self.__call__(x,
229+
out=out[start_n:],
230+
function_lib=function_lib))
223231
else:
224-
yield self.__call__(x, function_lib=function_lib)
232+
yield cast(np.ndarray,
233+
self.__call__(x, function_lib=function_lib))
225234
else:
226-
sig = self.__call__(x, function_lib=function_lib)
235+
sig = cast(np.ndarray,
236+
self.__call__(x, function_lib=function_lib))
227237
if initial:
228238
sig -= initial
229239
sig, zi = sosfilt(sos, sig, zi=zi)
230240
if initial:
231241
sig += initial
232242
if out is not None:
233243
out[start_n:start_n + size] = sig
234-
yield sig
244+
yield cast(np.ndarray, sig)
235245

236246
start = stop
237247
start_n += chunk_size
@@ -506,16 +516,21 @@ def _fill_parts(parts, out):
506516
for start, stop, part in parts:
507517
out[start:stop] += part
508518

509-
def __call__(self,
510-
x,
511-
frag=False,
512-
out=None,
513-
accumulate=False,
514-
function_lib=None):
519+
def __call__(
520+
self,
521+
x,
522+
frag=False,
523+
out: np.ndarray | None = None,
524+
accumulate=False,
525+
function_lib=None
526+
) -> NDArray[np.float64] | list[tuple[int, int,
527+
NDArray[np.float64]]] | np.float64:
515528
if function_lib is None:
516529
function_lib = _baseFunc
517530
if isinstance(x, (int, float, complex)):
518-
return self.__call__(np.array([x]), function_lib=function_lib)[0]
531+
return cast(
532+
NDArray[np.float64],
533+
self.__call__(np.array([x]), function_lib=function_lib))[0]
519534
parts, dtype = calc_parts(self.bounds, self.seq, x, function_lib,
520535
self.min, self.max)
521536
if not frag:
@@ -965,6 +980,25 @@ def _format_DRAG(shift, *args):
965980
return f"DRAG(...)"
966981

967982

983+
def _format_MOLLIFIER(shift, *args):
984+
r = _num_latex(args[0])
985+
d = _num_latex(args[1])
986+
shift_str = _num_latex(-shift)
987+
if shift_str == '0':
988+
shift_str = ''
989+
elif shift_str[0] != '-':
990+
shift_str = '+' + shift_str
991+
992+
if d == '0':
993+
return f"\\mathrm{{Mollifier}}\\left(t{shift_str}, r={r}\\right)"
994+
elif d == '1':
995+
return f"\\mathrm{{Mollifier}}'\\left(t{shift_str}, r={r}\\right)"
996+
elif d == '2':
997+
return f"\\mathrm{{Mollifier}}''\\left(t{shift_str}, r={r}\\right)"
998+
else:
999+
return f"\\mathrm{{Mollifier}}^{{({d})}}\\left(t{shift_str}, r={r}\\right)"
1000+
1001+
9681002
registerBaseFuncLatex(LINEAR, _format_LINEAR)
9691003
registerBaseFuncLatex(GAUSSIAN, _format_GAUSSIAN)
9701004
registerBaseFuncLatex(ERF, _format_ERF)
@@ -974,12 +1008,26 @@ def _format_DRAG(shift, *args):
9741008
registerBaseFuncLatex(COSH, _format_COSH)
9751009
registerBaseFuncLatex(SINH, _format_SINH)
9761010
registerBaseFuncLatex(DRAG, _format_DRAG)
1011+
registerBaseFuncLatex(MOLLIFIER, _format_MOLLIFIER)
9771012

9781013

979-
def D(wav):
1014+
def D(wav: Waveform, d: int = 1) -> Waveform:
9801015
"""derivative
1016+
1017+
Parameters
1018+
----------
1019+
wav : Waveform
1020+
The waveform to take the derivative of.
1021+
d : int, optional
1022+
The order of the derivative, by default 1.
9811023
"""
982-
return Waveform(bounds=wav.bounds, seq=tuple(_D(x) for x in wav.seq))
1024+
assert d >= 0 and isinstance(d, int), "d must be a non-negative integer"
1025+
if d == 0:
1026+
return wav
1027+
elif d == 1:
1028+
return Waveform(bounds=wav.bounds, seq=tuple(_D(x) for x in wav.seq))
1029+
else:
1030+
return D(D(wav, d - 1), 1)
9831031

9841032

9851033
def convolve(a, b):
@@ -1189,6 +1237,40 @@ def slepian(duration, *arg):
11891237
return wav * square(duration)
11901238

11911239

1240+
def mollifier(width, plateau: float = 0.0, d: int = 0):
1241+
"""
1242+
Mollifier function is a smooth function that is 1 at the origin and 0 outside a certain radius.
1243+
It is defined as:
1244+
1245+
f(x) = exp(1 / ((x / r) ^ 2 - 1) + 1) in case |x| < r
1246+
= 0 in case |x| >= r
1247+
where r = width / 2 is the radius of the mollifier.
1248+
1249+
The parameter plateau is the width of the plateau.
1250+
The parameter d is the order of the derivative.
1251+
"""
1252+
assert d >= 0 and isinstance(d, int), "d must be a non-negative integer"
1253+
assert width > 0, "width must be positive"
1254+
1255+
if plateau <= 0:
1256+
return Waveform(bounds=(-0.5 * width, 0.5 * width, inf),
1257+
seq=(_zero, basic_wave(MOLLIFIER, width / 2,
1258+
d), _zero))
1259+
else:
1260+
return Waveform(bounds=(-0.5 * width - 0.5 * plateau, -0.5 * plateau,
1261+
0.5 * plateau, 0.5 * width + 0.5 * plateau,
1262+
inf),
1263+
seq=(_zero,
1264+
basic_wave(MOLLIFIER,
1265+
width / 2,
1266+
d,
1267+
shift=-0.5 * plateau), _one,
1268+
basic_wave(MOLLIFIER,
1269+
width / 2,
1270+
d,
1271+
shift=0.5 * plateau), _zero))
1272+
1273+
11921274
def _poly(*a):
11931275
"""
11941276
a[0] + a[1] * t + a[2] * t**2 + ...
@@ -1384,7 +1466,7 @@ def mixing(I,
13841466
__all__ = [
13851467
'D', 'Waveform', 'chirp', 'const', 'cos', 'cosh', 'coshPulse', 'cosPulse',
13861468
'cut', 'drag', 'exp', 'function', 'gaussian', 'general_cosine', 'hanning',
1387-
'interp', 'mixing', 'one', 'poly', 'registerBaseFunc',
1469+
'interp', 'mixing', 'mollifier', 'one', 'poly', 'registerBaseFunc',
13881470
'registerDerivative', 'samplingPoints', 'sign', 'sin', 'sinc', 'sinh',
13891471
'square', 'step', 't', 'zero'
13901472
]

waveforms/waveform_parser.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,9 @@ def __init__(self):
3030
self.functions = [
3131
'D', 'chirp', 'const', 'cos', 'cosh', 'coshPulse', 'cosPulse',
3232
'cut', 'drag', 'drag_sin', 'drag_sinx', 'exp', 'gaussian',
33-
'general_cosine', 'hanning', 'interp', 'mixing', 'one', 'poly',
34-
'samplingPoints', 'sign', 'sin', 'sinc', 'sinh', 'square', 'step',
35-
't', 'zero'
33+
'general_cosine', 'hanning', 'interp', 'mixing', 'mollifier',
34+
'one', 'poly', 'samplingPoints', 'sign', 'sin', 'sinc', 'sinh',
35+
'square', 'step', 't', 'zero'
3636
]
3737
self.constants = {
3838
'pi': waveform.pi,
@@ -226,14 +226,13 @@ def _generate_antlr_parser():
226226

227227
# Generate ANTLR files
228228
try:
229-
result = subprocess.run([
230-
"antlr4", "-Dlanguage=Python3",
231-
str(grammar_file)
232-
],
233-
cwd=str(current_dir),
234-
capture_output=True,
235-
text=True,
236-
check=True)
229+
result = subprocess.run(
230+
["antlr4", "-Dlanguage=Python3",
231+
str(grammar_file)],
232+
cwd=str(current_dir),
233+
capture_output=True,
234+
text=True,
235+
check=True)
237236
except (subprocess.CalledProcessError, FileNotFoundError) as e:
238237
# Fall back to java command if antlr4 command is not available
239238
try:
@@ -258,7 +257,7 @@ def parse_waveform_expression(expr: str) -> waveform.Waveform:
258257
try:
259258
# Generate parser files if they don't exist
260259
# _generate_antlr_parser()
261-
260+
262261
# Import generated ANTLR classes
263262
from .WaveformLexer import WaveformLexer
264263
from .WaveformParser import WaveformParser

0 commit comments

Comments
 (0)