diff --git a/pyproject.toml b/pyproject.toml index aaf5587a..dbb5550b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,6 +48,7 @@ dependencies = [ "matplotlib>=3.3", "numpy>=1.20", "pandas>=1.2", + "sortedcountercpp", "tqdm>=4.64.1", ] optional-dependencies.dev = [ diff --git a/src/data_morph/data/dataset.py b/src/data_morph/data/dataset.py index 4a0ac297..c8079482 100644 --- a/src/data_morph/data/dataset.py +++ b/src/data_morph/data/dataset.py @@ -50,6 +50,9 @@ def __init__( self.df: pd.DataFrame = self._validate_data(df).pipe(self._scale_data, scale) """pandas.DataFrame: DataFrame containing columns x and y.""" + self._x = self.df['x'].to_numpy() + self._y = self.df['y'].to_numpy() + self.name: str = name """str: The name to use for the dataset.""" diff --git a/src/data_morph/data/stats.py b/src/data_morph/data/stats.py index d3c52669..38e44758 100644 --- a/src/data_morph/data/stats.py +++ b/src/data_morph/data/stats.py @@ -1,25 +1,538 @@ """Utility functions for calculating summary statistics.""" from collections import namedtuple +from collections.abc import Iterable, Sequence +from numbers import Number -import pandas as pd +import numpy as np +from sortedcounter import SortedCounter SummaryStatistics = namedtuple( - 'SummaryStatistics', ['x_mean', 'y_mean', 'x_stdev', 'y_stdev', 'correlation'] + 'SummaryStatistics', + ['x_mean', 'y_mean', 'x_med', 'y_med', 'x_stdev', 'y_stdev', 'correlation'], ) SummaryStatistics.__doc__ = ( 'Named tuple containing the summary statistics for plotting/analysis.' ) -def get_values(df: pd.DataFrame) -> SummaryStatistics: +def create_median_tree(data: Sequence, /) -> tuple[SortedCounter, SortedCounter]: + """ + Return a tuple of low and high ``SortedCounter``s from input data. + + Parameters + ---------- + data : Sequence + The input data as an iterable. + + Returns + ------- + tuple[SortedCounter, SortedCounter] + The low and high ``SortedCounter``s. + + Notes + ----- + The time complexity of the execution of the function is O(n log n) due to + the sorting operation done beforehand. + """ + # make sure the data is sorted + data = sorted(data) + half = len(data) // 2 + return SortedCounter(data[:half]), SortedCounter(data[half:]) + + +def shifted_mean( + mean_old: float, + value_old: float, + value_new: float, + size: int, +) -> float: + """ + Return the shifted mean by perturbing one point. + + Parameters + ---------- + mean_old : float + The old value of the mean of the data. + value_old : float + The old value of the point (before perturbation). + value_new : float + The new value of the point (after perturbation). + size : int + The size of the dataset. + + Returns + ------- + float + The new value of the mean of the data. + """ + return (mean_old - value_old / size) + value_new / size + + +def shifted_var( + mean_old: float, + var_old: float, + value_old: float, + value_new: float, + size: int, + *, + ddof: float = 0, +) -> float: + """ + Compute the shifted variance by perturbing one point. + + Parameters + ---------- + mean_old : float + The old value of the mean of the data. + var_old : float + The old value of the variance of the data. + value_old : float + The old value of the point (before perturbation). + value_new : float + The new value of the point (after perturbation). + size : int + The size of the dataset. + ddof : float, optional + “Delta Degrees of Freedom”: the divisor used in the calculation is ``N + - ddof``, where ``N`` represents the number of elements. By default + ``ddof`` is zero. + + Returns + ------- + float + The new value of the covariance of the data. + """ + return ( + var_old + + 2 * (value_new - value_old) * (value_old - mean_old) / (size - ddof) + + (value_new - value_old) ** 2 * (1 / (size - ddof) - 1 / (size - ddof) / size) + ) + + +def shifted_stdev(*args: float, **kwargs: int) -> float: + """ + Compute the shifted standard deviation by perturbing one point. + + Parameters + ---------- + *args + The positional arguments passed to :attr:``shifted_cov``. + **kwargs + The keyword arguments passed to :attr:``shifted_cov``. + + Returns + ------- + float + The new value of the standard deviation of the data. + """ + return np.sqrt(shifted_var(*args, **kwargs)) + + +def shifted_corrcoef( + x_old: float, + y_old: float, + x_new: float, + y_new: float, + meanx_old: float, + meany_old: float, + xy_old: float, + varx_old: float, + vary_old: float, + size: int, +) -> float: + """ + Compute the shifted correlation of the data by perturbing one point. + + Parameters + ---------- + x_old : float + The old value of the point ``x`` (before perturbation). + y_old : float + The old value of the point ``y`` (before perturbation). + x_new : float + The new value of the point ``x`` (after perturbation). + y_new : float + The new value of the point ``y`` (after perturbation). + meanx_old : float + The old value of the mean of the data ``x``. + meany_old : float + The old value of the mean of the data ``y``. + xy_old : float + The old value of the mean of ``x * y``. + varx_old : float + The old value of the variance of ``x``. + vary_old : float + The old value of the variance of ``y``. + size : int + The size of the dataset. + + Returns + ------- + float + The new correlation coefficient of the data. + """ + deltax = x_new - x_old + deltay = y_new - y_old + + numerator = ( + xy_old + + (deltax * y_old + deltay * x_old + deltax * deltay) / size + - shifted_mean(mean_old=meanx_old, value_old=x_old, value_new=x_new, size=size) + * shifted_mean(mean_old=meany_old, value_old=y_old, value_new=y_new, size=size) + ) + + denominator = np.sqrt( + shifted_var( + mean_old=meanx_old, + var_old=varx_old, + value_old=x_old, + value_new=x_new, + size=size, + ) + * shifted_var( + mean_old=meany_old, + var_old=vary_old, + value_old=y_old, + value_new=y_new, + size=size, + ) + ) + + return numerator / denominator + + +def shifted_median( + xlow: SortedCounter, + xhigh: SortedCounter, + value_old: float, + value_new: float, +) -> float: + """ + Compute the shifted median using two ``SortedCounter``s. + + Parameters + ---------- + xlow : SortedCounter + The lower part of the data (below the median). + + xhigh : SortedCounter + The higher part of the data (above the median). + + value_old : float + The old value of the point (before perturbation). + + value_new : float + The new value of the point (after perturbation). + + Returns + ------- + float + The new value of the median of the data. + + Notes + ----- + Modifies ``xlow`` and ``xhigh`` in-place. + """ + + # notation: + # S1 = lower half of data values + # S2 = upper half of data values + # G = range of values . Note that it can be empty in case of duplicates + # L = range <-inf, min(S1)> + # H = range + # xi = old value of the data point + # xi'= new value of the data point + # + # constraints: + # at the end of the computation, we need abs(len(S2) - len(S1)) = 0 + # if len(S1) + len(S2) is even, 1 if odd + + low_max = xlow.maximum() + high_min = xhigh.minimum() + + # xi is guaranteed to be in either S1 or S2 + if value_old <= low_max: + xlow.remove(value_old) + elif value_old >= high_min: + xhigh.remove(value_old) + + # it doesn't really matter where we insert it since we are rebalancing + # later anyway + if value_new <= xlow.maximum(): + xlow.add(value_new) + else: + xhigh.add(value_new) + + # Rebalance the two SortedCounters if their sizes differ by more than 1 + # NOTE: this operation is O(log n) since we are always doing it only a + # handful (fixed number) of times + + # remove items from xlow and add them in xhigh + while len(xlow) > len(xhigh): + low_max = xlow.maximum() + xlow.remove(low_max) + xhigh.add(low_max) + + # remove items from xhigh and add them in xlow + while len(xhigh) > len(xlow): + high_min = xhigh.minimum() + xhigh.remove(high_min) + xlow.add(high_min) + + # Compute the median based on the sizes of xlow and xhigh + if len(xlow) == len(xhigh): + return (xlow.maximum() + xhigh.minimum()) / 2 + if len(xlow) > len(xhigh): + return xlow.maximum() + + return xhigh.minimum() + + +class Statistics: + """ + Container for computing various statistics of the data. + + Parameters + ---------- + x : iterable of float + The ``x`` value of the data as an iterable. + y : iterable of float + The ``y`` value of the data as an iterable. + """ + + def __init__(self, x: Iterable[Number], y: Iterable[Number]) -> None: + if len(x) != len(y): + raise ValueError('The two datasets should have the same size') + + self._x = np.copy(x) + self._y = np.copy(y) + self._size = len(self._x) + self._x_mean = np.mean(self._x) + self._y_mean = np.mean(self._y) + self._x_median = np.median(self._x) + self._y_median = np.median(self._y) + self._x_low, self._x_high = create_median_tree(self._x) + self._y_low, self._y_high = create_median_tree(self._y) + self._x_var = np.var(self._x, ddof=0) + self._x_stdev = np.sqrt(self._x_var) + self._y_var = np.var(self._y, ddof=0) + self._y_stdev = np.sqrt(self._y_var) + self._corrcoef = np.corrcoef(self._x, self._y)[0, 1] + self._xy_mean = np.mean(self._x * self._y) + + @property + def x_mean(self) -> float: + """ + Return the mean of the ``x`` data. + + Returns + ------- + float + The mean of the ``x`` data. + """ + return self._x_mean + + @property + def y_mean(self) -> float: + """ + Return the mean of the ``y`` data. + + Returns + ------- + float + The mean of the ``y`` data. + """ + return self._y_mean + + @property + def x_stdev(self) -> float: + """ + Return the std of the ``x`` data. + + Returns + ------- + float + The standard deviation of the ``x`` data. + """ + return self._x_stdev + + @property + def y_stdev(self) -> float: + """ + Return the std of the ``y`` data. + + Returns + ------- + float + The standard deviation of the ``y`` data. + """ + return self._y_stdev + + @property + def corrcoef(self) -> float: + """ + Return the correlation coefficient of the ``x`` and ``y`` data. + + Returns + ------- + float + The correlation coefficient between ``x`` and ``y`` data. + """ + return self._corrcoef + + def __len__(self) -> int: + """ + Return the size of the dataset. + + Returns + ------- + int + The size of the dataset. + """ + return len(self._x) + + def perturb( + self, + index: int, + deltax: float, + deltay: float, + *, + update: bool = False, + ) -> SummaryStatistics: + """ + Perturb a single point and return the new ``SummaryStatistics``. + + Parameters + ---------- + index : int + The index of the point we wish to perturb. + + deltax : float + The amount by which to perturb the ``x`` point. + + deltay : float + The amount by which to perturb the ``y`` point. + + update : bool, optional + Whether to actually update the data (default: False). + + Returns + ------- + SummaryStatistics + The new summary statistics. + """ + x_mean = shifted_mean( + mean_old=self.x_mean, + value_old=self._x[index], + value_new=self._x[index] + deltax, + size=len(self), + ) + y_mean = shifted_mean( + mean_old=self.y_mean, + value_old=self._y[index], + value_new=self._y[index] + deltay, + size=len(self), + ) + + x_var = shifted_var( + mean_old=self.x_mean, + var_old=self._x_var, + value_old=self._x[index], + value_new=self._x[index] + deltax, + size=len(self), + ) + + y_var = shifted_var( + mean_old=self.y_mean, + var_old=self._y_var, + value_old=self._y[index], + value_new=self._y[index] + deltay, + size=len(self), + ) + + corrcoef = shifted_corrcoef( + x_old=self._x[index], + y_old=self._y[index], + x_new=self._x[index] + deltax, + y_new=self._y[index] + deltay, + meanx_old=self.x_mean, + meany_old=self.y_mean, + xy_old=self._xy_mean, + varx_old=self._x_var, + vary_old=self._y_var, + size=len(self), + ) + + # `shifted_median` updates the containers in-place; in case + # `update=False`, we put it back + x_median = shifted_median( + xlow=self._x_low, + xhigh=self._x_high, + value_old=self._x[index], + value_new=self._x[index] + deltax, + ) + + y_median = shifted_median( + xlow=self._y_low, + xhigh=self._y_high, + value_old=self._y[index], + value_new=self._y[index] + deltay, + ) + + if update: + self._x_mean = x_mean + self._y_mean = y_mean + self._x_median = x_median + self._y_median = y_median + self._x_var = x_var + self._y_var = y_var + self._x_stdev = np.sqrt(x_var) + self._y_stdev = np.sqrt(y_var) + self._corrcoef = corrcoef + self._xy_mean += ( + deltax * self._y[index] + deltay * self._x[index] + deltax * deltay + ) / len(self) + + self._x[index] += deltax + self._y[index] += deltay + else: + shifted_median( + xlow=self._x_low, + xhigh=self._x_high, + value_old=self._x[index] + deltax, + value_new=self._x[index], + ) + + shifted_median( + xlow=self._y_low, + xhigh=self._y_high, + value_old=self._y[index] + deltay, + value_new=self._y[index], + ) + + return SummaryStatistics( + x_mean, + y_mean, + x_median, + y_median, + np.sqrt(x_var), + np.sqrt(y_var), + corrcoef, + ) + + +def get_values(x: Iterable[Number], y: Iterable[Number]) -> SummaryStatistics: """ Calculate the summary statistics for the given set of points. Parameters ---------- - df : pandas.DataFrame - A dataset with columns x and y. + x : Iterable[Number] + The ``x`` value of the dataset. + + y : Iterable[Number] + The ``y`` value of the dataset. Returns ------- @@ -28,9 +541,11 @@ def get_values(df: pd.DataFrame) -> SummaryStatistics: along with the Pearson correlation coefficient between the two. """ return SummaryStatistics( - df.x.mean(), - df.y.mean(), - df.x.std(), - df.y.std(), - df.corr().x.y, + np.mean(x), + np.mean(y), + np.median(x), + np.median(y), + np.std(x, ddof=1), + np.std(y, ddof=1), + np.corrcoef(x, y)[0, 1], ) diff --git a/src/data_morph/morpher.py b/src/data_morph/morpher.py index 5164786b..48c019fb 100644 --- a/src/data_morph/morpher.py +++ b/src/data_morph/morpher.py @@ -2,6 +2,7 @@ from __future__ import annotations +from collections.abc import MutableSequence from functools import partial from numbers import Number from pathlib import Path @@ -12,7 +13,7 @@ from .bounds.bounding_box import BoundingBox from .data.dataset import Dataset -from .data.stats import get_values +from .data.stats import Statistics, SummaryStatistics from .plotting.animation import ( ease_in_out_quadratic, ease_in_out_sine, @@ -240,16 +241,22 @@ def _record_frames( frame_number += 1 return frame_number - def _is_close_enough(self, df1: pd.DataFrame, df2: pd.DataFrame) -> bool: + def _is_close_enough( + self, + item1: SummaryStatistics, + item2: SummaryStatistics, + /, + ) -> bool: """ Check whether the statistics are within the acceptable bounds. Parameters ---------- - df1 : pandas.DataFrame - The original DataFrame. - df2 : pandas.DataFrame - The DataFrame after the latest perturbation. + item1 : SummaryStatistics + The first summary statistic. + + item2 : SummaryStatistics + The second summary statistic. Returns ------- @@ -259,10 +266,8 @@ def _is_close_enough(self, df1: pd.DataFrame, df2: pd.DataFrame) -> bool: return np.all( np.abs( np.subtract( - *( - np.floor(np.array(get_values(data)) * 10**self.decimals) - for data in [df1, df2] - ) + np.floor(np.array(item1) * 10**self.decimals), + np.floor(np.array(item2) * 10**self.decimals), ) ) == 0 @@ -270,21 +275,24 @@ def _is_close_enough(self, df1: pd.DataFrame, df2: pd.DataFrame) -> bool: def _perturb( self, - df: pd.DataFrame, + x: MutableSequence[Number], + y: MutableSequence[Number], target_shape: Shape, *, shake: Number, allowed_dist: Number, temp: Number, bounds: BoundingBox, - ) -> pd.DataFrame: + ) -> tuple[int, MutableSequence[Number], MutableSequence[Number]]: """ Perform one round of perturbation. Parameters ---------- - df : pandas.DataFrame - The data to perturb. + x : MutableSequence[Number] + The ``x`` part of the dataset. + y : MutableSequence[Number] + The ``y`` part of the dataset. target_shape : Shape The shape to morph the data into. shake : numbers.Number @@ -301,12 +309,12 @@ def _perturb( Returns ------- - pandas.DataFrame - The input dataset with one point perturbed. + tuple[int, MutableSequence[Number], MutableSequence[Number]] + The index and input dataset with one point perturbed. """ - row = self._rng.integers(0, len(df)) - initial_x = df.at[row, 'x'] - initial_y = df.at[row, 'y'] + row = self._rng.integers(0, len(x)) + initial_x = x[row] + initial_y = y[row] # this is the simulated annealing step, if "do_bad", then we are willing to # accept a new state which is worse than the current one @@ -325,10 +333,10 @@ def _perturb( within_bounds = [new_x, new_y] in bounds done = close_enough and within_bounds - df.loc[row, 'x'] = new_x - df.loc[row, 'y'] = new_y + x[row] = new_x + y[row] = new_y - return df + return row, x, y def morph( self, @@ -471,11 +479,23 @@ def _tweening( max_value=max_shake, ) + x, y = ( + start_shape.df['x'].to_numpy(copy=True), + start_shape.df['y'].to_numpy(copy=True), + ) + + # the starting dataset statistics + stats = Statistics(x, y) + + # the summary statistics of the above + summary_stats = stats.perturb(0, 0, 0) + for i in self._looper( iterations, leave=True, ascii=True, desc=f'{target_shape} pattern' ): - perturbed_data = self._perturb( - morphed_data.copy(), + index, *perturbed_data = self._perturb( + np.copy(x), + np.copy(y), target_shape=target_shape, shake=get_current_shake(i), allowed_dist=allowed_dist, @@ -483,8 +503,21 @@ def _tweening( bounds=start_shape.morph_bounds, ) - if self._is_close_enough(start_shape.df, perturbed_data): - morphed_data = perturbed_data + new_summary_stats = stats.perturb( + index, + perturbed_data[0][index] - x[index], + perturbed_data[1][index] - y[index], + ) + + if self._is_close_enough(summary_stats, new_summary_stats): + summary_stats = stats.perturb( + index, + perturbed_data[0][index] - x[index], + perturbed_data[1][index] - y[index], + update=True, + ) + x, y = perturbed_data + morphed_data = pd.DataFrame({'x': x, 'y': y}) frame_number = record_frames( data=morphed_data, diff --git a/src/data_morph/plotting/static.py b/src/data_morph/plotting/static.py index 39deea14..17a5dc36 100644 --- a/src/data_morph/plotting/static.py +++ b/src/data_morph/plotting/static.py @@ -61,11 +61,11 @@ def plot( ax.xaxis.set_major_formatter(tick_formatter) ax.yaxis.set_major_formatter(tick_formatter) - res = get_values(df) + res = get_values(df['x'].to_numpy(), df['y'].to_numpy()) labels = ('X Mean', 'Y Mean', 'X SD', 'Y SD', 'Corr.') locs = np.linspace(0.8, 0.2, num=len(labels)) - max_label_length = max([len(label) for label in labels]) + max_label_length = max(len(label) for label in labels) max_stat = int(np.log10(np.max(np.abs(res)))) + 1 mean_x_digits, mean_y_digits = ( int(x) + 1 for x in np.log10(np.abs([res.x_mean, res.y_mean])) diff --git a/tests/data/test_stats.py b/tests/data/test_stats.py index c99134ed..bbad1c56 100644 --- a/tests/data/test_stats.py +++ b/tests/data/test_stats.py @@ -1,7 +1,18 @@ """Test the stats module.""" +import numpy as np +import pytest +from numpy.testing import assert_allclose, assert_equal + from data_morph.data.loader import DataLoader -from data_morph.data.stats import get_values +from data_morph.data.stats import ( + create_median_tree, + get_values, + shifted_corrcoef, + shifted_mean, + shifted_median, + shifted_var, +) def test_stats(): @@ -9,10 +20,224 @@ def test_stats(): data = DataLoader.load_dataset('dino').df - stats = get_values(data) + stats = get_values(data['x'], data['y']) assert stats.x_mean == data.x.mean() assert stats.y_mean == data.y.mean() assert stats.x_stdev == data.x.std() assert stats.y_stdev == data.y.std() - assert stats.correlation == data.corr().x.y + np.allclose(stats.correlation, data.corr().x.y) + + +def test_new_mean(): + data = DataLoader.load_dataset('dino').df + + # make sure if we don't do anything to the data that we retrieve the same results + x = data['x'].to_numpy() + + assert_equal(np.mean(x), shifted_mean(np.mean(x), x[0], x[0], len(x))) + + # we want to test both very large and very small displacements + for scale in [0.1, 10]: + x_old = data['x'].to_numpy() + y_old = data['y'].to_numpy() + + # scaling the data + x_old /= np.max(np.abs(x_old)) + y_old /= np.max(np.abs(y_old)) + + x_new = np.copy(x_old) + y_new = np.copy(y_old) + + rng = np.random.default_rng(42) + + for _ in range(100_000): + row = rng.integers(0, len(x_old)) + jitter_x, jitter_y = rng.normal(loc=0, scale=scale, size=2) + + x_new[row] += jitter_x + y_new[row] += jitter_y + + meanx = np.mean(x_new) + new_meanx = shifted_mean(np.mean(x_old), x_old[row], x_new[row], len(x_old)) + + meany = np.mean(y_new) + new_meany = shifted_mean(np.mean(y_old), y_old[row], y_new[row], len(y_old)) + + assert_allclose(meanx, new_meanx) + assert_allclose(meany, new_meany) + + x_old = np.copy(x_new) + y_old = np.copy(y_new) + + +def test_new_var(): + data = DataLoader.load_dataset('dino').df + + # make sure if we don't do anything to the data that we retrieve the same results + x = data['x'].to_numpy() + + assert_equal( + np.var(x, ddof=0), shifted_var(np.mean(x), np.var(x), x[0], x[0], len(x)) + ) + + # we want to test both very large and very small displacements + for scale in [0.1, 10]: + x_old = data['x'].to_numpy() + y_old = data['y'].to_numpy() + + # scaling the data + x_old /= np.max(np.abs(x_old)) + y_old /= np.max(np.abs(y_old)) + + x_new = np.copy(x_old) + y_new = np.copy(y_old) + + rng = np.random.default_rng(42) + + for _ in range(100_000): + row = rng.integers(0, len(x_old)) + jitter_x, jitter_y = rng.normal(loc=0, scale=scale, size=2) + + x_new[row] += jitter_x + y_new[row] += jitter_y + + varx = np.var(x_new) + new_varx = shifted_var( + np.mean(x_old), np.var(x_old), x_old[row], x_new[row], len(x_old) + ) + + vary = np.var(y_new) + new_vary = shifted_var( + np.mean(y_old), np.var(y_old), y_old[row], y_new[row], len(y_old) + ) + + assert_allclose(varx, new_varx) + assert_allclose(vary, new_vary) + + x_old = np.copy(x_new) + y_old = np.copy(y_new) + + +def test_new_corrcoef(): + data = DataLoader.load_dataset('dino').df + + # make sure if we don't do anything to the data that we retrieve the same results + x = data['x'].to_numpy() + y = data['y'].to_numpy() + corrcoef = np.corrcoef(x, y)[0, 1] + row = 0 + new_corrcoef = shifted_corrcoef( + x_old=x[row], + y_old=y[row], + x_new=x[row], + y_new=y[row], + meanx_old=np.mean(x), + meany_old=np.mean(y), + xy_old=np.mean(x * y), + varx_old=np.var(x), + vary_old=np.var(y), + size=len(x), + ) + + corrcoef_by_hand = np.cov(x, y, ddof=0) / np.sqrt(np.var(x) * np.var(y)) + + assert_allclose(corrcoef, corrcoef_by_hand[0, 1]) + assert_allclose(corrcoef_by_hand[0, 1], new_corrcoef) + + # we want to test both very large and very small displacements + for scale in [0.1, 10]: + x_old = data['x'].to_numpy() + y_old = data['y'].to_numpy() + + # scaling the data + x_old /= np.max(np.abs(x_old)) + y_old /= np.max(np.abs(y_old)) + + x_new = np.copy(x_old) + y_new = np.copy(y_old) + + rng = np.random.default_rng(42) + + for _ in range(100_000): + row = rng.integers(0, len(x_old)) + jitter_x, jitter_y = rng.normal(loc=0, scale=scale, size=2) + + x_new[row] += jitter_x + y_new[row] += jitter_y + + corrcoef = ( + np.cov(x_new, y_new, ddof=0) / np.sqrt(np.var(x_new) * np.var(y_new)) + )[0, 1] + new_corrcoef = shifted_corrcoef( + x_old=x_old[row], + y_old=y_old[row], + x_new=x_new[row], + y_new=y_new[row], + meanx_old=np.mean(x_old), + meany_old=np.mean(y_old), + xy_old=np.mean(x_old * y_old), + varx_old=np.var(x_old), + vary_old=np.var(y_old), + size=len(x_old), + ) + + assert_allclose(corrcoef, new_corrcoef) + + x_old = np.copy(x_new) + y_old = np.copy(y_new) + + +@pytest.mark.parametrize('data', [list(range(5)), list(range(10))]) +def test_shifted_median(data): + """ + Check that the ``shifted_median`` function works properly. + """ + data = np.sort(np.array(data, dtype=float)) + size = len(data) + xlow, xhigh = create_median_tree(data) + # make sure it works if we don't do anything + ref = np.median(data) + actual = shifted_median(xlow, xhigh, data[0], data[0]) + assert_equal(len(data), len(xlow) + len(xhigh)) + assert_equal(ref, actual) + + # move lower part of data into itself + x = np.copy(data) + xlow, xhigh = create_median_tree(data) + index = size // 2 - 1 + x[index] -= 0.3 + ref = np.median(x) + actual = shifted_median(xlow, xhigh, data[index], x[index]) + assert_equal(len(x), len(xlow) + len(xhigh)) + assert_equal(ref, actual) + + # move higher part of data into itself + x = np.copy(data) + xlow, xhigh = create_median_tree(data) + index = size // 2 + 1 + x[index] += 0.3 + ref = np.median(x) + actual = shifted_median(xlow, xhigh, data[index], x[index]) + assert_equal(len(x), len(xlow) + len(xhigh)) + assert_equal(ref, actual) + + # move higher part of data into lower part + x = np.copy(data) + xlow, xhigh = create_median_tree(data) + index = size // 2 + 1 + x[index] -= 5.2 + ref = np.median(x) + actual = shifted_median(xlow, xhigh, data[index], x[index]) + assert_equal(len(x), len(xlow) + len(xhigh)) + assert_equal(ref, actual) + + # move lower part of data into higher part + x = np.copy(data) + xlow, xhigh = create_median_tree(data) + index = size // 2 - 1 + x[index] += 5.2 + ref = np.median(x) + actual = shifted_median(xlow, xhigh, data[index], x[index]) + assert_equal(len(x), len(xlow) + len(xhigh)) + assert_equal(ref, actual) diff --git a/tests/test_morpher.py b/tests/test_morpher.py index 01ad373b..350811ce 100644 --- a/tests/test_morpher.py +++ b/tests/test_morpher.py @@ -10,6 +10,7 @@ from pandas.testing import assert_frame_equal from data_morph.data.loader import DataLoader +from data_morph.data.stats import Statistics from data_morph.morpher import DataMorpher from data_morph.shapes.factory import ShapeFactory @@ -169,9 +170,17 @@ def test_no_writing(self, capsys): freeze_for=0, ) + stats = Statistics(dataset.df['x'], dataset.df['y']).perturb(0, 0, 0) + perturbed_stats = Statistics(morphed_data['x'], morphed_data['y']).perturb( + 0, 0, 0 + ) + with pytest.raises(AssertionError): assert_frame_equal(morphed_data, dataset.df) - assert morpher._is_close_enough(dataset.df, morphed_data) + assert morpher._is_close_enough( + stats, + perturbed_stats, + ) _, err = capsys.readouterr() assert f'{target_shape} pattern: 100%' in err