diff --git a/src/data_morph/shapes/bases/point_collection.py b/src/data_morph/shapes/bases/point_collection.py index a71a1ddb..0b2c773b 100644 --- a/src/data_morph/shapes/bases/point_collection.py +++ b/src/data_morph/shapes/bases/point_collection.py @@ -16,6 +16,8 @@ from matplotlib.axes import Axes + from ...bounds.bounding_box import BoundingBox + class PointCollection(Shape): """ @@ -37,6 +39,32 @@ def __init__(self, *points: Iterable[Number]) -> None: def __repr__(self) -> str: return f'<{self.__class__.__name__} of {len(self.points)} points>' + @staticmethod + def _center(points: np.ndarray, bounds: BoundingBox) -> np.ndarray: + """ + Center the points within the bounding box. + + Parameters + ---------- + points : np.ndarray + The points to center. + bounds : BoundingBox + The bounding box within which to center the points. + + Returns + ------- + np.ndarray + The centered points. + """ + maxes = points.max(axis=0) + span = maxes - points.min(axis=0) + gap = (np.array(bounds.range) - span) / 2 + + (_, xmax), (_, ymax) = bounds + shift = np.array([xmax, ymax]) - maxes - gap + + return points + shift + def distance(self, x: Number, y: Number) -> float: """ Calculate the minimum distance from the points of this shape diff --git a/src/data_morph/shapes/points/heart.py b/src/data_morph/shapes/points/heart.py index 14580a38..9570e91b 100644 --- a/src/data_morph/shapes/points/heart.py +++ b/src/data_morph/shapes/points/heart.py @@ -35,8 +35,9 @@ class Heart(PointCollection): """ def __init__(self, dataset: Dataset) -> None: - _, xmax = dataset.data_bounds.x_bounds - x_shift, y_shift = dataset.data_bounds.center + data_bounds = dataset.data_bounds + (_, xmax), (_, ymax) = data_bounds + x_shift, y_shift = data_bounds.center t = np.linspace(-3, 3, num=80) @@ -44,8 +45,10 @@ def __init__(self, dataset: Dataset) -> None: y = 13 * np.cos(t) - 5 * np.cos(2 * t) - 2 * np.cos(3 * t) - np.cos(4 * t) # scale by the half the widest width of the heart - scale_factor = (xmax - x_shift) / 16 + scale_factor = min((xmax - x_shift), (ymax - y_shift)) / 16 super().__init__( - *np.stack([x * scale_factor + x_shift, y * scale_factor + y_shift], axis=1) + *self._center( + np.stack([x * scale_factor, y * scale_factor], axis=1), data_bounds + ) ) diff --git a/src/data_morph/shapes/points/spade.py b/src/data_morph/shapes/points/spade.py index 654c6e3a..2fd59a71 100644 --- a/src/data_morph/shapes/points/spade.py +++ b/src/data_morph/shapes/points/spade.py @@ -30,8 +30,9 @@ class Spade(PointCollection): """ def __init__(self, dataset: Dataset) -> None: - _, xmax = dataset.data_bounds.x_bounds - x_shift, y_shift = dataset.data_bounds.center + data_bounds = dataset.data_bounds + _, xmax = data_bounds.x_bounds + x_shift, y_shift = data_bounds.center # upside-down heart heart_points = self._get_inverted_heart(dataset, y_shift) @@ -43,7 +44,7 @@ def __init__(self, dataset: Dataset) -> None: x = np.concatenate((heart_points[:, 0], base_x), axis=0) y = np.concatenate((heart_points[:, 1], base_y), axis=0) - super().__init__(*np.stack([x, y], axis=1)) + super().__init__(*self._center(np.stack([x, y], axis=1), data_bounds)) @staticmethod def _get_inverted_heart(dataset: Dataset, y_shift: Number) -> np.ndarray: diff --git a/tests/shapes/bases/test_point_collection.py b/tests/shapes/bases/test_point_collection.py index 5f3068ac..3a2c2b26 100644 --- a/tests/shapes/bases/test_point_collection.py +++ b/tests/shapes/bases/test_point_collection.py @@ -3,8 +3,10 @@ import re import matplotlib.pyplot as plt +import numpy as np import pytest +from data_morph.bounds.bounding_box import BoundingBox from data_morph.shapes.bases.point_collection import PointCollection @@ -18,6 +20,19 @@ def point_collection(self): """An instance of PointCollection.""" return PointCollection([0, 0], [20, 50]) + @pytest.mark.parametrize( + 'bounding_box', + [BoundingBox([0, 100], [-50, 50]), BoundingBox([0, 20], [0, 50])], + ) + def test_center(self, point_collection, bounding_box): + """Test that points are centered within the bounding box.""" + points = point_collection._center(point_collection.points, bounding_box) + + (xmin, xmax), (ymin, ymax) = bounding_box + upper = np.array([xmax, ymax]) - points.max(axis=0) + lower = points.min(axis=0) - np.array([xmin, ymin]) + assert np.array_equal(upper, lower) + def test_distance_zero(self, point_collection): """Test the distance() method on points in the collection.""" for point in point_collection.points: diff --git a/tests/shapes/points/test_heart.py b/tests/shapes/points/test_heart.py index 59d16167..604a7e73 100644 --- a/tests/shapes/points/test_heart.py +++ b/tests/shapes/points/test_heart.py @@ -12,10 +12,10 @@ class TestHeart(PointsModuleTestBase): shape_name = 'heart' distance_test_cases = ( - ((19.89946048, 54.82281916), 0.0), - ((10.84680454, 70.18556376), 0.0), - ((29.9971295, 67.66402445), 0.0), - ((27.38657942, 62.417184), 0.0), - ((20, 50), 4.567369), - ((10, 80), 8.564365), + ((22.424114, 59.471779), 0.0), + ((10.405462, 70.897342), 0.0), + ((21.064032, 72.065253), 0.0), + ((16.035166, 60.868470), 0.0), + ((20, 50), 6.065782511791651), + ((10, 80), 7.173013322704914), ) diff --git a/tests/shapes/points/test_spade.py b/tests/shapes/points/test_spade.py index de385eb9..95b4c395 100644 --- a/tests/shapes/points/test_spade.py +++ b/tests/shapes/points/test_spade.py @@ -12,10 +12,11 @@ class TestSpade(PointsModuleTestBase): shape_name = 'spade' distance_test_cases = ( - ((19.97189615, 75.43271708), 0), - ((23.75, 55), 0), - ((11.42685318, 59.11304904), 0), - ((20, 75), 0.2037185), - ((0, 0), 57.350348), - ((10, 80), 10.968080), + ((19.818701, 60.065370), 0), + ((23.750000, 55.532859), 0), + ((20.067229, 60.463689), 0), + ((18.935968, 58.467606), 0), + ((20, 75), 0.5335993101603015), + ((0, 0), 57.861566654807596), + ((10, 80), 11.404000978114487), )