Skip to content

Commit ffcf228

Browse files
committed
util: Make time_to_sample_ceil/sample_to_time_floor accept arrays
This reuses the logic from the equivalent C functions to write these Python functions in a way that lets them handle array inputs.
1 parent 7bddb64 commit ffcf228

3 files changed

Lines changed: 535 additions & 504 deletions

File tree

python/digital_rf/util.py

Lines changed: 83 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
# ----------------------------------------------------------------------------
99
"""Utility functions for Digital RF and Digital Metadata."""
1010

11-
from __future__ import absolute_import, division, print_function
11+
from __future__ import annotations
1212

1313
import ast
1414
import datetime
@@ -17,7 +17,6 @@
1717

1818
import dateutil.parser
1919
import numpy as np
20-
2120
import six
2221

2322
__all__ = (
@@ -27,8 +26,8 @@
2726
"get_samplerate_frac",
2827
"parse_identifier_to_sample",
2928
"parse_identifier_to_time",
30-
"sample_to_time_floor",
3129
"sample_to_datetime",
30+
"sample_to_time_floor",
3231
"samples_to_timedelta",
3332
"time_to_sample",
3433
"time_to_sample_ceil",
@@ -88,41 +87,83 @@ def time_to_sample_ceil(timedelta, sample_rate):
8887
Parameters
8988
----------
9089
timedelta : (second, picosecond) tuple | np.timedelta64 | datetime.timedelta | float
91-
Time span to convert to a number of samples. To represent large time spans
92-
with high accuracy, pass a 2-tuple of ints containing the number of whole
93-
seconds and additional picoseconds. Float values are interpreted as a
94-
number of seconds.
90+
Time span to convert to a number of samples, either scalar or array_like.
91+
To represent large time spans with high accuracy, pass a 2-tuple containing
92+
the number of whole seconds and additional picoseconds. Floating point
93+
values are interpreted as a number of seconds.
9594
9695
sample_rate : fractions.Fraction | first argument to ``get_samplerate_frac``
9796
Sample rate in Hz.
9897
9998
10099
Returns
101100
-------
102-
nsamples : int
101+
nsamples : array_like
103102
Number of samples in the `timedelta` time span at a rate of
104103
`sample_rate`, using ceiling rounding (up to the next whole sample).
105104
106105
"""
107106
if isinstance(timedelta, tuple):
108107
t_sec, t_psec = timedelta
109-
elif isinstance(timedelta, np.timedelta64):
110-
onesec = np.timedelta64(1, "s")
111-
t_sec = timedelta // onesec
112-
t_psec = (timedelta % onesec) // np.timedelta64(1, "ps")
108+
elif hasattr(timedelta, "dtype"):
109+
if np.issubdtype(timedelta.dtype, "timedelta64"):
110+
onesec = np.timedelta64(1, "s")
111+
t_sec = timedelta // onesec
112+
t_psec = (timedelta % onesec) // np.timedelta64(1, "ps")
113+
else:
114+
# floating point seconds
115+
t_sec = np.int64(timedelta)
116+
t_psec = np.int64(np.round((timedelta % 1) * 1e12))
113117
elif isinstance(timedelta, datetime.timedelta):
114118
t_sec = int(timedelta.total_seconds())
115119
t_psec = 1000000 * timedelta.microseconds
116120
else:
121+
# float seconds
117122
t_sec = int(timedelta)
118123
t_psec = int(np.round((timedelta % 1) * 1e12))
119124
# ensure that sample_rate is a fractions.Fraction
120125
if not isinstance(sample_rate, fractions.Fraction):
121126
sample_rate = get_samplerate_frac(sample_rate)
122-
# calculate rational values for the second and picosecond parts
123-
s_frac = t_sec * sample_rate + t_psec * sample_rate / 10**12
124-
# get an integer value through ceiling rounding
125-
return int(s_frac) + ((s_frac % 1) != 0)
127+
128+
srn = sample_rate.numerator
129+
srd = sample_rate.denominator
130+
# calculate with divide/modulus split to avoid overflow
131+
# (divide by denominator and track remainder *before* multiplying by numerator)
132+
# sample_idx = t * n / d = (sec + (psec / 1e12)) * n / d
133+
134+
# start with picosecond part
135+
tmp_div = (t_psec // srd) * srn
136+
tmp_mod = (t_psec % srd) * srn
137+
tmp_div += tmp_mod // srd
138+
tmp_mod = tmp_mod % srd
139+
# quotient and remainder are in terms of (samples * 1e12)
140+
quotient = tmp_div
141+
remainder = tmp_mod # remainder w.r.t. sample_rate_denominator
142+
143+
# multiply and divide second part
144+
tmp_div = (t_sec // srd) * srn
145+
tmp_mod = (t_sec % srd) * srn
146+
tmp_div += tmp_mod // srd
147+
tmp_mod = tmp_mod % srd
148+
# consolidate: tmp_div + tmp_mod / d + quotient / 1e12 + remainder / 1e12 / d
149+
# add second remainder to picosecond part in terms of (samples * 1e12)
150+
remainder += tmp_mod * 1_000_000_000_000
151+
quotient += remainder // srd
152+
remainder = remainder % srd
153+
# now have: tmp_div + quotient / 1e12 + remainder / 1e12 / d
154+
# consolidate into single quotient and remainder
155+
tmp_div += quotient // 1_000_000_000_000
156+
quotient = quotient % 1_000_000_000_000
157+
remainder *= quotient * srd
158+
quotient = tmp_div
159+
# now have: quotient + remainder / 1e12 / d
160+
# update remainder to be in terms of samples using ceiling rounding
161+
remainder = remainder // 1_000_000_000_000 + ((remainder % 1_000_000_000_000) != 0)
162+
# now hav in terms of samples: quotient + remainder / d
163+
# finally ceiling round remainder into quotient
164+
quotient += remainder != 0
165+
166+
return quotient
126167

127168

128169
def sample_to_time_floor(nsamples, sample_rate):
@@ -137,7 +178,7 @@ def sample_to_time_floor(nsamples, sample_rate):
137178
138179
Parameters
139180
----------
140-
nsamples : int
181+
nsamples : array_like
141182
Whole number of samples to convert into a span of time.
142183
143184
sample_rate : fractions.Fraction | first argument to ``get_samplerate_frac``
@@ -146,25 +187,36 @@ def sample_to_time_floor(nsamples, sample_rate):
146187
147188
Returns
148189
-------
149-
seconds : int
190+
seconds : array_like
150191
Number of whole seconds in the time span covered by `nsamples` at a rate
151192
of `sample_rate`.
152193
153-
picoseconds : int
194+
picoseconds : array_like
154195
Number of additional picoseconds in the time span covered by `nsamples`,
155196
using floor rounding (down to the previous whole number of picoseconds).
156197
157198
"""
158-
nsamples = int(nsamples)
159199
# ensure that sample_rate is a fractions.Fraction
160200
if not isinstance(sample_rate, fractions.Fraction):
161201
sample_rate = get_samplerate_frac(sample_rate)
162202

163-
# get the timedelta as a Fraction
164-
t_frac = nsamples / sample_rate
165-
166-
seconds = int(t_frac)
167-
picoseconds = int((t_frac % 1) * 10**12)
203+
srn = sample_rate.numerator
204+
srd = sample_rate.denominator
205+
# calculate with divide/modulus split to avoid overflow
206+
# second = s * d // n == ((s // n) * d) + ((si % n) * d) // n
207+
tmp_div = nsamples // srn
208+
tmp_mod = nsamples % srn
209+
seconds = tmp_div * srd
210+
tmp = tmp_mod * srd
211+
tmp_div = tmp // srn
212+
tmp_mod = tmp % srn
213+
seconds += tmp_div
214+
# picoseconds calculated from remainder of division to calculate seconds
215+
# picosecond = rem * 1e12 // n = rem * (1e12 // n) + (rem * (1e12 % n)) // n
216+
tmp = tmp_mod
217+
tmp_div = 1_000_000_000_000 // srn
218+
tmp_mod = 1_000_000_000_000 % srn
219+
picoseconds = (tmp * tmp_div) + ((tmp * tmp_mod) // srn)
168220

169221
return (seconds, picoseconds)
170222

@@ -210,8 +262,7 @@ def time_to_sample(time, samples_per_second, epoch=None):
210262
tfrac = 1e-6 * td.microseconds
211263
tidx = int(np.uint64(tsec * samples_per_second + tfrac * samples_per_second))
212264
return tidx
213-
else:
214-
return int(np.uint64(time * samples_per_second))
265+
return int(np.uint64(time * samples_per_second))
215266

216267

217268
def sample_to_datetime(sample, sample_rate, epoch=None):
@@ -387,7 +438,7 @@ def parse_identifier_to_sample(iden, sample_rate=None, ref_index=None, epoch=Non
387438
is_relative = False
388439
if iden is None or iden == "":
389440
return None
390-
elif isinstance(iden, six.string_types):
441+
if isinstance(iden, six.string_types):
391442
if iden.startswith("+"):
392443
is_relative = True
393444
iden = iden.lstrip("+")
@@ -429,8 +480,7 @@ def parse_identifier_to_sample(iden, sample_rate=None, ref_index=None, epoch=Non
429480
if ref_index is None:
430481
raise ValueError('ref_index required when relative "+" identifier is used.')
431482
return idx + ref_index
432-
else:
433-
return idx
483+
return idx
434484

435485

436486
def parse_identifier_to_time(iden, sample_rate=None, ref_datetime=None, epoch=None):
@@ -478,7 +528,7 @@ def parse_identifier_to_time(iden, sample_rate=None, ref_datetime=None, epoch=No
478528
is_relative = False
479529
if iden is None or iden == "":
480530
return None
481-
elif isinstance(iden, six.string_types):
531+
if isinstance(iden, six.string_types):
482532
if iden.startswith("+"):
483533
is_relative = True
484534
iden = iden.lstrip("+")
@@ -519,11 +569,10 @@ def parse_identifier_to_time(iden, sample_rate=None, ref_datetime=None, epoch=No
519569
raise ValueError(
520570
'ref_datetime required when relative "+" identifier is used.'
521571
)
522-
elif (
572+
if (
523573
not isinstance(ref_datetime, datetime.datetime)
524574
or ref_datetime.tzinfo is None
525575
):
526576
raise ValueError("ref_datetime must be a timezone-aware datetime.")
527577
return td + ref_datetime
528-
else:
529-
return td + epoch
578+
return td + epoch

0 commit comments

Comments
 (0)