Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions src/data_morph/shapes/bases/point_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

from matplotlib.axes import Axes

from ...bounds.bounding_box import BoundingBox


class PointCollection(Shape):
"""
Expand All @@ -37,6 +39,32 @@ def __init__(self, *points: Iterable[Number]) -> None:
def __repr__(self) -> str:
return f'<{self.__class__.__name__} of {len(self.points)} points>'

@staticmethod
def _center(points: np.ndarray, bounds: BoundingBox) -> np.ndarray:
"""
Center the points within the bounding box.

Parameters
----------
points : np.ndarray
The points to center.
bounds : BoundingBox
The bounding box within which to center the points.

Returns
-------
np.ndarray
The centered points.
"""
maxes = points.max(axis=0)
span = maxes - points.min(axis=0)
gap = (np.array(bounds.range) - span) / 2

(_, xmax), (_, ymax) = bounds
shift = np.array([xmax, ymax]) - maxes - gap

return points + shift

def distance(self, x: Number, y: Number) -> float:
"""
Calculate the minimum distance from the points of this shape
Expand Down
11 changes: 7 additions & 4 deletions src/data_morph/shapes/points/heart.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,17 +35,20 @@ class Heart(PointCollection):
"""

def __init__(self, dataset: Dataset) -> None:
_, xmax = dataset.data_bounds.x_bounds
x_shift, y_shift = dataset.data_bounds.center
data_bounds = dataset.data_bounds
(_, xmax), (_, ymax) = data_bounds
x_shift, y_shift = data_bounds.center

t = np.linspace(-3, 3, num=80)

x = 16 * np.sin(t) ** 3
y = 13 * np.cos(t) - 5 * np.cos(2 * t) - 2 * np.cos(3 * t) - np.cos(4 * t)

# scale by the half the widest width of the heart
scale_factor = (xmax - x_shift) / 16
scale_factor = min((xmax - x_shift), (ymax - y_shift)) / 16

super().__init__(
*np.stack([x * scale_factor + x_shift, y * scale_factor + y_shift], axis=1)
*self._center(
np.stack([x * scale_factor, y * scale_factor], axis=1), data_bounds
)
)
7 changes: 4 additions & 3 deletions src/data_morph/shapes/points/spade.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,9 @@ class Spade(PointCollection):
"""

def __init__(self, dataset: Dataset) -> None:
_, xmax = dataset.data_bounds.x_bounds
x_shift, y_shift = dataset.data_bounds.center
data_bounds = dataset.data_bounds
_, xmax = data_bounds.x_bounds
x_shift, y_shift = data_bounds.center

# upside-down heart
heart_points = self._get_inverted_heart(dataset, y_shift)
Expand All @@ -43,7 +44,7 @@ def __init__(self, dataset: Dataset) -> None:
x = np.concatenate((heart_points[:, 0], base_x), axis=0)
y = np.concatenate((heart_points[:, 1], base_y), axis=0)

super().__init__(*np.stack([x, y], axis=1))
super().__init__(*self._center(np.stack([x, y], axis=1), data_bounds))

@staticmethod
def _get_inverted_heart(dataset: Dataset, y_shift: Number) -> np.ndarray:
Expand Down
15 changes: 15 additions & 0 deletions tests/shapes/bases/test_point_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
import re

import matplotlib.pyplot as plt
import numpy as np
import pytest

from data_morph.bounds.bounding_box import BoundingBox
from data_morph.shapes.bases.point_collection import PointCollection


Expand All @@ -18,6 +20,19 @@ def point_collection(self):
"""An instance of PointCollection."""
return PointCollection([0, 0], [20, 50])

@pytest.mark.parametrize(
'bounding_box',
[BoundingBox([0, 100], [-50, 50]), BoundingBox([0, 20], [0, 50])],
)
def test_center(self, point_collection, bounding_box):
"""Test that points are centered within the bounding box."""
points = point_collection._center(point_collection.points, bounding_box)

(xmin, xmax), (ymin, ymax) = bounding_box
upper = np.array([xmax, ymax]) - points.max(axis=0)
lower = points.min(axis=0) - np.array([xmin, ymin])
assert np.array_equal(upper, lower)

def test_distance_zero(self, point_collection):
"""Test the distance() method on points in the collection."""
for point in point_collection.points:
Expand Down
12 changes: 6 additions & 6 deletions tests/shapes/points/test_heart.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@ class TestHeart(PointsModuleTestBase):

shape_name = 'heart'
distance_test_cases = (
((19.89946048, 54.82281916), 0.0),
((10.84680454, 70.18556376), 0.0),
((29.9971295, 67.66402445), 0.0),
((27.38657942, 62.417184), 0.0),
((20, 50), 4.567369),
((10, 80), 8.564365),
((22.424114, 59.471779), 0.0),
((10.405462, 70.897342), 0.0),
((21.064032, 72.065253), 0.0),
((16.035166, 60.868470), 0.0),
((20, 50), 6.065782511791651),
((10, 80), 7.173013322704914),
)
13 changes: 7 additions & 6 deletions tests/shapes/points/test_spade.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@ class TestSpade(PointsModuleTestBase):

shape_name = 'spade'
distance_test_cases = (
((19.97189615, 75.43271708), 0),
((23.75, 55), 0),
((11.42685318, 59.11304904), 0),
((20, 75), 0.2037185),
((0, 0), 57.350348),
((10, 80), 10.968080),
((19.818701, 60.065370), 0),
((23.750000, 55.532859), 0),
((20.067229, 60.463689), 0),
((18.935968, 58.467606), 0),
((20, 75), 0.5335993101603015),
((0, 0), 57.861566654807596),
((10, 80), 11.404000978114487),
)
Loading