diff --git a/src/stratify/_vinterp.pyx b/src/stratify/_vinterp.pyx index f580487..5c45a59 100644 --- a/src/stratify/_vinterp.pyx +++ b/src/stratify/_vinterp.pyx @@ -53,7 +53,8 @@ cdef inline int relative_sign(double z, double z_base) nogil: @cython.boundscheck(False) @cython.wraparound(False) cdef long gridwise_interpolation(double[:] z_target, double[:] z_src, - double[:, :] fz_src, bint increasing, + double[:, :] fz_src, bint rising, + bint aligned, Interpolator interpolation, Extrapolator extrapolation, double [:, :] fz_target) nogil except -1: @@ -65,7 +66,8 @@ cdef long gridwise_interpolation(double[:] z_target, double[:] z_src, z_target - the levels to interpolate the source data ``fz_src`` to. z_src - the levels that the source data ``fz_src`` is interpolated from. fz_src - the source data to be interpolated. - increasing - true when increasing Z index generally implies increasing Z values + rising - true when rising Z index generally implies rising Z values + aligned - true when both src and tgt increase/decrease in the same direction interpolation - the inner interpolation functionality. See the definition of Interpolator. extrapolation - the inner extrapolation functionality. See the definition of @@ -91,7 +93,7 @@ cdef long gridwise_interpolation(double[:] z_target, double[:] z_src, cdef unsigned int i_src, i_target, n_src, n_target, i, m cdef bint all_nans = True cdef double z_before, z_current, z_after, z_last - cdef int sign_after, sign_before, extrapolating + cdef int sign_after, sign_before, extrapolating, z_final n_src = z_src.shape[0] n_target = z_target.shape[0] @@ -110,13 +112,12 @@ cdef long gridwise_interpolation(double[:] z_target, double[:] z_src, fz_target[i, i_target] = NAN return 0 - interpolation.prepare_column(z_target, z_src, fz_src, increasing) - extrapolation.prepare_column(z_target, z_src, fz_src, increasing) + interpolation.prepare_column(z_target, z_src, fz_src, rising) + extrapolation.prepare_column(z_target, z_src, fz_src, rising) + with gil: + z_src = np.asarray(z_src) - if increasing: - z_before = -INFINITY - else: - z_before = INFINITY + z_before = -INFINITY if rising else INFINITY z_last = -z_before @@ -125,7 +126,11 @@ cdef long gridwise_interpolation(double[:] z_target, double[:] z_src, # first window value (typically -inf, but may be +inf) and the first z_src. # This search window will be moved along until a crossing is detected, at # which point we will do an interpolation. - z_after = z_src[0] + with gil: + z_final = z_src.size - 1 + + + z_after = z_src[0] if aligned else z_src[z_final] # We start in extrapolation mode. This will be turned off as soon as we # start increasing i_src. @@ -151,7 +156,12 @@ cdef long gridwise_interpolation(double[:] z_target, double[:] z_src, i_src += 1 if i_src < n_src: extrapolating = 0 - z_after = z_src[i_src] + with gil: + if aligned: + z_after = z_src[i_src] + else: + dummy = z_src.size - (i_src + 1) + z_after = z_src[dummy] if isnan(z_after): with gil: raise ValueError('The source coordinate may not contain NaN values.') @@ -201,7 +211,7 @@ cdef class Interpolator(object): 'the kernel function.') cdef bint prepare_column(self, double[:] z_target, double[:] z_src, - double[:, :] fz_src, bint increasing) nogil except -1: + double[:, :] fz_src, bint rising) nogil except -1: # Called before all levels are interpolated. pass @@ -262,7 +272,7 @@ cdef class PyFuncInterpolator(Interpolator): def __init__(self, use_column_prep=True): self.use_column_prep = use_column_prep - def column_prep(self, z_target, z_src, fz_src, increasing): + def column_prep(self, z_target, z_src, fz_src, rising): """ Called each time this interpolator sees a new data array. This method may be used for validation of a column, or for column @@ -274,10 +284,10 @@ cdef class PyFuncInterpolator(Interpolator): pass cdef bint prepare_column(self, double[:] z_target, double[:] z_src, - double[:, :] fz_src, bint increasing) nogil except -1: + double[:, :] fz_src, bint rising) nogil except -1: if self.use_column_prep: with gil: - self.column_prep(z_target, z_src, fz_src, increasing) + self.column_prep(z_target, z_src, fz_src, rising) def interp_kernel(self, index, z_src, fz_src, level, output_array): # Fill the output array with the fz_src data at the given index. @@ -319,7 +329,7 @@ cdef class Extrapolator(object): 'the kernel function.') cdef bint prepare_column(self, double[:] z_target, double[:] z_src, - double[:, :] fz_src, bint increasing) nogil except -1: + double[:, :] fz_src, bint rising) nogil except -1: pass @@ -359,7 +369,7 @@ cdef class NearestNExtrapolator(Extrapolator): cdef class LinearExtrapolator(Extrapolator): cdef bint prepare_column(self, double[:] z_target, double[:] z_src, - double[:, :] fz_src, bint increasing) nogil except -1: + double[:, :] fz_src, bint rising) nogil except -1: cdef unsigned int n_src_pts = z_src.shape[0] if n_src_pts < 2: @@ -402,7 +412,7 @@ cdef class PyFuncExtrapolator(Extrapolator): def __init__(self, use_column_prep=True): self.use_column_prep = use_column_prep - def column_prep(self, z_target, z_src, fz_src, increasing): + def column_prep(self, z_target, z_src, fz_src, rising): """ Called each time this extrapolator sees a new data array. This method may be used for validation of a column, or for column @@ -414,10 +424,10 @@ cdef class PyFuncExtrapolator(Extrapolator): pass cdef bint prepare_column(self, double[:] z_target, double[:] z_src, - double[:, :] fz_src, bint increasing) nogil except -1: + double[:, :] fz_src, bint rising) nogil except -1: if self.use_column_prep: with gil: - self.column_prep(z_target, z_src, fz_src, increasing) + self.column_prep(z_target, z_src, fz_src, rising) def extrap_kernel(self, direction, z_src, fz_src, level, output_array): # Fill the output array with nans. @@ -449,7 +459,7 @@ EXTRAPOLATE_NEAREST = extrap_schemes['nearest']() EXTRAPOLATE_LINEAR = extrap_schemes['linear']() -def interpolate(z_target, z_src, fz_src, axis=-1, rising=None, +def interpolate(z_target, z_src, fz_src, rising=None, axis=-1, interpolation='linear', extrapolation='nan'): """ Interface for optimised 1d interpolation across multiple dimensions. @@ -486,16 +496,6 @@ def interpolate(z_target, z_src, fz_src, axis=-1, rising=None, the same as the shape of ``z_src``. axis: int (default -1) The ``fz_src`` axis to perform the interpolation over. - rising: bool (default None) - Whether the values of the source's interpolation coordinate values - are generally rising or generally falling. For example, values of - pressure levels will be generally falling as the z coordinate - increases. - This will determine whether extrapolation needs to occur for - ``z_target`` below the first and above the last ``z_src``. - If rising is None, the first two interpolation coordinate values - will be used to determine the general direction. In most cases, - this is a good option. interpolation: :class:`.Interpolator` instance or valid scheme name The core interpolation operation to use. :attr:`.INTERPOLATE_LINEAR` and :attr:`_INTERPOLATE_NEAREST` are provided for convenient @@ -509,7 +509,6 @@ def interpolate(z_target, z_src, fz_src, axis=-1, rising=None, func = functools.partial( _interpolate, axis=axis, - rising=rising, interpolation=interpolation, extrapolation=extrapolation ) @@ -564,14 +563,14 @@ def interpolate(z_target, z_src, fz_src, axis=-1, rising=None, meta=np.array((), dtype=fz_src.dtype)) -def _interpolate(z_target, z_src, fz_src, axis=-1, rising=None, +def _interpolate(z_target, z_src, fz_src, axis=-1, interpolation='linear', extrapolation='nan'): if interpolation in interp_schemes: interpolation = interp_schemes[interpolation]() if extrapolation in extrap_schemes: extrapolation = extrap_schemes[extrapolation]() - interp = _Interpolation(z_target, z_src, fz_src, rising=rising, axis=axis, + interp = _Interpolation(z_target, z_src, fz_src, axis=axis, interpolation=interpolation, extrapolation=extrapolation) if interp.z_target.ndim == 1: @@ -583,16 +582,14 @@ def _interpolate(z_target, z_src, fz_src, axis=-1, rising=None, cdef class _Interpolation(object): """ Where the magic happens for gridwise_interp. The work of this __init__ is - mostly for putting the input nd arrays into a 3 and 4 dimensional form for - convenient (read: efficient) Cython form. Inline comments should help with - understanding. + mostly for putting the input nd. """ cdef Interpolator interpolation cdef Extrapolator extrapolation cdef public np.dtype _target_dtype - cdef int rising + cdef rising, aligned cdef public z_target, orig_shape, axis, _zp_reshaped, _fp_reshaped cdef public _result_working_shape, result_shape @@ -692,17 +689,27 @@ cdef class _Interpolation(object): #: The shape of the interpolated data. self.result_shape = tuple(result_shape) - if rising is None: - if z_src.shape[zp_axis] < 2: - raise ValueError('The rising keyword must be defined when ' - 'the size of the source array is <2 in ' - 'the interpolation axis.') - z_src_indexer = [0] * z_src.ndim - z_src_indexer[zp_axis] = slice(0, 2) - first_two = z_src[tuple(z_src_indexer)] - rising = first_two[0] <= first_two[1] + if z_src.shape[zp_axis] < 2: + raise ValueError('The rising keyword must be defined when ' + 'the size of the source array is <2 in ' + 'the interpolation axis.') + - self.rising = bool(rising) + z_src_indexer = [0] * z_src.ndim + z_src_indexer[zp_axis] = slice(0, 2) + src_first_two = z_src[tuple(z_src_indexer)] + src_rising = src_first_two[0] <= src_first_two[1] + src_rise = bool(src_rising) + + z_tgt_indexer = [0] * z_target.ndim + z_tgt_indexer[zp_axis] = slice(0, 2) + tgt_first_two = z_target[tuple(z_tgt_indexer)] + tgt_rising = tgt_first_two[0] <= tgt_first_two[1] + tgt_rise = bool(tgt_rising) + + + self.rising = bool(tgt_rising) + self.aligned = src_rise == tgt_rise # Sometimes we want to add additional constraints on our interpolation # and extrapolation - for example, linear extrapolation requires there @@ -733,13 +740,17 @@ cdef class _Interpolation(object): # Construct a memory view of the fz_target array. cdef double[:, :, :, :] fz_target_view = fz_target + cdef int rising = self.rising + cdef int aligned = self.aligned + # Release the GIL and do the for loop over the left-hand, and # right-hand dimensions. The loop optimised for row-major data (C). with nogil: for j in range(nj): for i in range(ni): gridwise_interpolation(z_target, z_src[i, :, j], fz_src[:, i, :, j], - self.rising, + rising, + aligned, self.interpolation, self.extrapolation, fz_target_view[:, i, :, j]) @@ -755,6 +766,8 @@ cdef class _Interpolation(object): fz_target = np.empty(self._result_working_shape, dtype=np.float64) cdef unsigned int i, j, ni, nj + cdef int rising = self.rising + cdef int aligned = self.aligned ni = fz_target.shape[1] nj = fz_target.shape[3] @@ -775,7 +788,8 @@ cdef class _Interpolation(object): for j in range(nj): for i in range(ni): gridwise_interpolation(z_target[i, :, j], z_src[i, :, j], fz_src[:, i, :, j], - self.rising, + rising, + aligned, self.interpolation, self.extrapolation, fz_target_view[:, i, :, j]) diff --git a/src/stratify/tests/test_vinterp.py b/src/stratify/tests/test_vinterp.py index 002bbc4..4ebd48b 100644 --- a/src/stratify/tests/test_vinterp.py +++ b/src/stratify/tests/test_vinterp.py @@ -19,7 +19,7 @@ def extrap_kernel(self, direction, z_src, fz_src, level, output_array): class TestColumnInterpolation(unittest.TestCase): - def interpolate(self, x_target, x_src, rising=None): + def interpolate(self, x_target, x_src): x_target = np.array(x_target) x_src = np.array(x_src) fx_src = np.empty(x_src.shape) @@ -31,28 +31,28 @@ def interpolate(self, x_target, x_src, rising=None): x_target, x_src, fx_src, - rising=rising, + # rising=rising, interpolation=index_interp, extrapolation=extrap_direct, ) - if rising is not None: - r2 = stratify.interpolate( - -1 * x_target, - -1 * x_src, - fx_src, - rising=not rising, - interpolation=index_interp, - extrapolation=extrap_direct, - ) - assert_array_equal(r1, r2) + # if rising is not None: + # r2 = stratify.interpolate( + # -1 * x_target, + # -1 * x_src, + # fx_src, + # rising=not rising, + # interpolation=index_interp, + # extrapolation=extrap_direct, + # ) + # assert_array_equal(r1, r2) lazy_fx_src = da.asarray(fx_src, chunks=tuple(range(1, x_src.ndim + 1))) r3 = stratify.interpolate( x_target, x_src, lazy_fx_src, - rising=rising, + # rising=rising, interpolation=index_interp, extrapolation=extrap_direct, ) @@ -77,7 +77,7 @@ def test_lower_extrap_only(self): assert_array_equal(r, [-np.inf, -np.inf, -np.inf]) def test_upper_extrap_only(self): - r = self.interpolate([1, 2, 3], [-4, -5], rising=True) + r = self.interpolate([1, 2, 3], [-4, -5]) assert_array_equal(r, [np.inf, np.inf, np.inf]) def test_extrap_on_both_sides_only(self): @@ -96,7 +96,7 @@ def test_nan_in_target(self): def test_nan_in_src(self): msg = "The source coordinate .* NaN" with self.assertRaisesRegex(ValueError, msg): - self.interpolate([1], [0, np.nan], rising=True) + self.interpolate([1], [0, np.nan]) def test_all_nan_in_src(self): r = self.interpolate([1, 2, 3, 4], [np.nan, np.nan, np.nan]) @@ -117,13 +117,13 @@ def test_wrong_rising_target(self): assert_array_equal(r, [1, np.inf]) def test_wrong_rising_source(self): - r = self.interpolate([1, 2], [2, 1], rising=True) + r = self.interpolate([1, 2], [2, 1]) assert_array_equal(r, [-np.inf, 0]) def test_wrong_rising_source_and_target(self): # If we overshoot the first level, there is no hope, # so we end up extrapolating. - r = self.interpolate([3, 2, 1, 0], [2, 1], rising=True) + r = self.interpolate([3, 2, 1, 0], [2, 1]) assert_array_equal(r, [np.inf, np.inf, np.inf, np.inf]) def test_non_monotonic_coordinate_interp(self): @@ -135,7 +135,7 @@ def test_non_monotonic_coordinate_extrap(self): assert_array_equal(result, [-np.inf, 1, 1, 1, 2, 3, np.inf]) def test_length_one_interp(self): - r = self.interpolate([1], [2], rising=True) + r = self.interpolate([1], [2]) assert_array_equal(r, [-np.inf]) def test_auto_rising_not_enough_values(self): @@ -200,7 +200,7 @@ def test_single_point(self): [20], interpolation=interpolation, extrapolation=extrapolation, - rising=True, + # rising=True, ) self.assertEqual(r, 20) @@ -324,7 +324,7 @@ def test_npts(self): [20], interpolation=interpolation, extrapolation=extrapolation, - rising=True, + # rising=True, ) @@ -346,7 +346,7 @@ def test(self): [10, 20], interpolation=interpolation, extrapolation=extrapolation, - rising=True, + # rising=True, ) assert_array_equal(r, [0, -10])