diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index fb4763cd..56a6137a 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -44,7 +44,7 @@ jobs: fail-fast: false matrix: os: [macos-latest, ubuntu-latest, windows-latest] - python-version: ["3.9", "3.10", "3.11", "3.12", "3.13", "3.14"] + python-version: ["3.10", "3.11", "3.12", "3.13", "3.14"] env: MPLBACKEND: Agg # non-interactive backend for matplotlib diff --git a/docs/_static/panda_to_star.gif b/docs/_static/panda-to-star-classic.gif similarity index 100% rename from docs/_static/panda_to_star.gif rename to docs/_static/panda-to-star-classic.gif diff --git a/docs/_static/panda-to-star-eased.gif b/docs/_static/panda-to-star-eased.gif index 5659a7fd..84c53fae 100644 Binary files a/docs/_static/panda-to-star-eased.gif and b/docs/_static/panda-to-star-eased.gif differ diff --git a/docs/_static/panda-to-star.gif b/docs/_static/panda-to-star.gif index c25d7dae..db19f587 100644 Binary files a/docs/_static/panda-to-star.gif and b/docs/_static/panda-to-star.gif differ diff --git a/docs/quickstart.rst b/docs/quickstart.rst index 355e8076..c7776667 100644 --- a/docs/quickstart.rst +++ b/docs/quickstart.rst @@ -44,10 +44,24 @@ This produces the following animation in the newly-created ``morphed_data`` dire within your current working directory: .. figure:: _static/panda-to-star.gif - :alt: Morphing the panda dataset into the star shape. + :alt: Morphing the panda dataset into the star shape with marginal plots. :align: center - Morphing the panda :class:`.Dataset` into the star :class:`.Shape`. + Morphing the panda :class:`.Dataset` into the star :class:`.Shape` with marginal plots. + +If you don't want the marginal plots (the histograms on the sides), you can use classic mode: + +.. code:: console + + $ data-morph --start-shape panda --target-shape star --classic + +Animations generated in classic mode include only the scatter plot and the summary statistics: + +.. figure:: _static/panda-to-star-classic.gif + :alt: Morphing the panda dataset into the star shape using classic mode. + :align: center + + Morphing the panda :class:`.Dataset` into the star :class:`.Shape` using classic mode. You can smooth the transition with the ``--ease`` or ``--ease-in`` and ``--ease-out`` flags. The ``--freeze`` flag allows you to start the animation with the specified number of frames diff --git a/docs/tutorials/custom-datasets.rst b/docs/tutorials/custom-datasets.rst index e317e010..a6a08bfa 100644 --- a/docs/tutorials/custom-datasets.rst +++ b/docs/tutorials/custom-datasets.rst @@ -99,7 +99,7 @@ Pass the path to the CSV file to use those points as the starting shape: .. code:: console - $ data-morph --start-shape path/to/points.csv --target-shape wide_lines + $ data-morph --start-shape path/to/points.csv --target-shape wide_lines --classic Here is an example animation generated from a custom dataset: diff --git a/pyproject.toml b/pyproject.toml index 40221912..80bee93c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,14 +24,13 @@ authors = [ { name = "Aaron Stevens", email = "bheklilr2@gmail.com" }, { name = "Justin Matejka", email = "Justin.Matejka@Autodesk.com" }, ] -requires-python = ">=3.9" +requires-python = ">=3.10" classifiers = [ "Development Status :: 4 - Beta", "Framework :: Matplotlib", "Intended Audience :: Education", "Operating System :: OS Independent", "Programming Language :: Python :: 3 :: Only", - "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", @@ -45,9 +44,9 @@ dynamic = [ ] dependencies = [ - "matplotlib>=3.7", - "numpy>=1.20", - "pandas>=1.2", + "matplotlib>=3.10", + "numpy>=1.23.0", + "pandas>=2.1", "rich>=13.9.4", ] diff --git a/src/data_morph/bounds/_utils.py b/src/data_morph/bounds/_utils.py index df17efe7..1f850aee 100644 --- a/src/data_morph/bounds/_utils.py +++ b/src/data_morph/bounds/_utils.py @@ -21,7 +21,7 @@ def _validate_2d(data: Iterable[Number], name: str) -> Iterable[Number]: The validated data. """ if not ( - isinstance(data, (tuple, list)) + isinstance(data, tuple | list) and len(data) == 2 and all(isinstance(x, Number) and not isinstance(x, bool) for x in data) ): diff --git a/src/data_morph/bounds/bounding_box.py b/src/data_morph/bounds/bounding_box.py index 38804601..f85aab2e 100644 --- a/src/data_morph/bounds/bounding_box.py +++ b/src/data_morph/bounds/bounding_box.py @@ -38,7 +38,7 @@ def __init__( if isinstance(inclusive, bool): inclusive = [inclusive] * 2 if not ( - isinstance(inclusive, (tuple, list)) + isinstance(inclusive, tuple | list) and len(inclusive) == 2 and all(isinstance(x, bool) for x in inclusive) ): @@ -47,19 +47,19 @@ def __init__( ' or a single Boolean value' ) - self.x_bounds = ( + self.x_bounds: Interval = ( x_bounds.clone() if isinstance(x_bounds, Interval) else Interval(x_bounds, inclusive[0]) ) - """Interval: The bounds for the x direction.""" + """The bounds for the x direction.""" - self.y_bounds = ( + self.y_bounds: Interval = ( y_bounds.clone() if isinstance(y_bounds, Interval) else Interval(y_bounds, inclusive[1]) ) - """Interval: The bounds for the y direction.""" + """The bounds for the y direction.""" self._bounds = (self.x_bounds, self.y_bounds) diff --git a/src/data_morph/cli.py b/src/data_morph/cli.py index 8f7cc817..229a9548 100644 --- a/src/data_morph/cli.py +++ b/src/data_morph/cli.py @@ -198,6 +198,16 @@ def generate_parser() -> argparse.ArgumentParser: frame_group = parser.add_argument_group( 'Animation Configuration', description='Customize aspects of the animation.' ) + frame_group.add_argument( + '--classic', + default=False, + action='store_true', + help=( + 'Whether to plot the original visualization, which consists of a scatter plot ' + 'and the summary statistics. Otherwise, marginal plots will be included in ' + 'addition to the classic plot.' + ), + ) frame_group.add_argument( '--ease', default=False, @@ -294,6 +304,7 @@ def _morph( forward_only_animation=args.forward_only, num_frames=100, in_notebook=False, + classic=args.classic, with_median=args.with_median, ) @@ -409,6 +420,7 @@ def _serialize(args: argparse.Namespace, target_shapes: Sequence[str]) -> None: forward_only_animation=args.forward_only, num_frames=100, in_notebook=False, + classic=args.classic, with_median=args.with_median, ) diff --git a/src/data_morph/data/dataset.py b/src/data_morph/data/dataset.py index 20d50173..48a31702 100644 --- a/src/data_morph/data/dataset.py +++ b/src/data_morph/data/dataset.py @@ -6,6 +6,7 @@ from typing import TYPE_CHECKING import matplotlib.pyplot as plt +import numpy as np from ..bounds.bounding_box import BoundingBox from ..bounds.interval import Interval @@ -55,19 +56,27 @@ def __init__( self.data: pd.DataFrame = self._validate_data(data).pipe( self._scale_data, scale ) - """pandas.DataFrame: DataFrame containing columns x and y.""" + """DataFrame containing columns x and y.""" self.name: str = name - """str: The name to use for the dataset.""" + """The name to use for the dataset.""" self.data_bounds: BoundingBox = self._derive_data_bounds() - """BoundingBox: The bounds of the data.""" + """The bounds of the data.""" self.morph_bounds: BoundingBox = self._derive_morphing_bounds() - """BoundingBox: The limits for the morphing process.""" + """The limits for the morphing process.""" self.plot_bounds: BoundingBox = self._derive_plotting_bounds() - """BoundingBox: The bounds to use when plotting the morphed data.""" + """The bounds to use when plotting the morphed data.""" + + self.marginals: tuple[ + tuple[np.ndarray, np.ndarray], tuple[np.ndarray, np.ndarray] + ] = ( + np.histogram(self.data.x, bins=30, range=self.plot_bounds.x_bounds), + np.histogram(self.data.y, bins=30, range=self.plot_bounds.y_bounds), + ) + """The counts per bin and bin boundaries for generating marginal plots.""" def __repr__(self) -> str: return f'<{self.__class__.__name__} name={self.name} scaled={self._scaled}>' diff --git a/src/data_morph/morpher.py b/src/data_morph/morpher.py index a9509f5a..8867f9b6 100644 --- a/src/data_morph/morpher.py +++ b/src/data_morph/morpher.py @@ -60,6 +60,10 @@ class DataMorpher: forward_only_animation : bool, default ``False`` Whether to generate the animation in the forward direction only. By default, the animation will play forward and then reverse. + classic : bool, default ``False`` + Whether to plot the original visualization, which consists of a scatter plot + and the summary statistics. When this is ``False``, marginal plots will be + included in addition to the classic plot. with_median : bool, default ``False`` Whether to preserve the median in addition to the other summary statistics. Note that this will be a little slower. @@ -77,6 +81,7 @@ def __init__( num_frames: int = 100, keep_frames: bool = False, forward_only_animation: bool = False, + classic: bool = False, with_median: bool = False, ) -> None: self._rng = np.random.default_rng(seed) @@ -133,6 +138,7 @@ def __init__( self._ProgressTracker = partial(DataMorphProgress, not self._in_notebook) + self._classic = classic self._with_median = with_median def _select_frames( @@ -204,6 +210,8 @@ def _record_frames( self, data: pd.DataFrame, bounds: BoundingBox, + marginals: tuple[tuple[np.ndarray, np.ndarray], tuple[np.ndarray, np.ndarray]] + | None, base_file_name: str, frame_number: str, ) -> None: @@ -216,6 +224,8 @@ def _record_frames( The DataFrame of the data for morphing. bounds : BoundingBox The plotting limits. + marginals : tuple[tuple[np.ndarray, np.ndarray], tuple[np.ndarray, np.ndarray]] | None + The counts per bin and bin boundaries for generating marginal plots. base_file_name : str The prefix to the file names for both the PNG and GIF files. frame_number : str @@ -228,6 +238,7 @@ def _record_frames( decimals=self.decimals, x_bounds=bounds.x_bounds, y_bounds=bounds.y_bounds, + marginals=marginals, with_median=self._with_median, dpi=150, ) @@ -456,6 +467,7 @@ def morph( self._record_frames, base_file_name=base_file_name, bounds=start_shape.plot_bounds, + marginals=None if self._classic else start_shape.marginals, ) frame_number_format = f'{{:0{len(str(iterations))}d}}'.format diff --git a/src/data_morph/plotting/animation.py b/src/data_morph/plotting/animation.py index f3d12d80..dc59500a 100644 --- a/src/data_morph/plotting/animation.py +++ b/src/data_morph/plotting/animation.py @@ -6,11 +6,13 @@ import re from functools import wraps from pathlib import Path -from typing import TYPE_CHECKING, Callable +from typing import TYPE_CHECKING from PIL import Image if TYPE_CHECKING: + from collections.abc import Callable + from ..shapes.bases.shape import Shape @@ -115,7 +117,7 @@ def wrapper(step: int | float) -> int | float: int or float The eased value at the current step, from 0.0 to 1.0. """ - if not (isinstance(step, (int, float)) and 0 <= step <= 1): + if not (isinstance(step, int | float) and 0 <= step <= 1): raise ValueError('Step must be an integer or float, between 0 and 1.') return easing_function(step) diff --git a/src/data_morph/plotting/static.py b/src/data_morph/plotting/static.py index f937367f..42b12bca 100644 --- a/src/data_morph/plotting/static.py +++ b/src/data_morph/plotting/static.py @@ -8,7 +8,7 @@ import matplotlib.pyplot as plt import numpy as np -from matplotlib.ticker import EngFormatter +from matplotlib.ticker import EngFormatter, MaxNLocator from ..data.stats import get_summary_statistics from .style import plot_with_custom_style @@ -36,6 +36,8 @@ def plot( data: pd.DataFrame, x_bounds: Iterable[Number], y_bounds: Iterable[Number], + marginals: tuple[tuple[np.ndarray, np.ndarray], tuple[np.ndarray, np.ndarray]] + | None, save_to: str | Path, decimals: int, with_median: bool, @@ -50,6 +52,8 @@ def plot( The dataset to plot. x_bounds, y_bounds : Iterable[numbers.Number] The plotting limits. + marginals : tuple[tuple[np.ndarray, np.ndarray], tuple[np.ndarray, np.ndarray]] | None + The counts per bin and bin boundaries for generating marginal plots. save_to : str or pathlib.Path Path to save the plot frame to. decimals : int @@ -65,8 +69,12 @@ def plot( matplotlib.axes.Axes or None When ``save_to`` is falsey, an :class:`~matplotlib.axes.Axes` object is returned. """ + add_marginals = marginals is not None + fig, ax = plt.subplots( - figsize=(7, 3), layout='constrained', subplot_kw={'aspect': 'equal'} + figsize=(9 if add_marginals else 7, 3), + layout='constrained', + subplot_kw={'aspect': 'equal'}, ) fig.get_layout_engine().set(w_pad=1.4, h_pad=0.2, wspace=0) @@ -77,6 +85,40 @@ def plot( ax.xaxis.set_major_formatter(tick_formatter) ax.yaxis.set_major_formatter(tick_formatter) + if add_marginals: + ax_histx = ax.inset_axes([0, 1.05, 1, 0.25], sharex=ax) + ax_histy = ax.inset_axes([1.05, 0, 0.25, 1], sharey=ax) + + ax_histy.xaxis.set_major_formatter(tick_formatter) + ax_histx.yaxis.set_major_formatter(tick_formatter) + + (x_marginal_counts, x_marginal_bins), (y_marginal_counts, y_marginal_bins) = ( + marginals + ) + + ax_histx.set(xlim=x_bounds, ylim=(0, np.ceil(x_marginal_counts.max() * 2))) + ax_histy.set(xlim=(0, np.ceil(y_marginal_counts.max() * 2)), ylim=y_bounds) + + # no labels on marginal axis that shares with scatter plot + ax_histx.tick_params(axis='x', labelbottom=False) + ax_histy.tick_params(axis='y', labelleft=False) + + # move marginal axis ticks that are visible to the corner and only show the non-zero label + locator = MaxNLocator(2, integer=True, prune='lower') + ax_histx.tick_params(axis='y', labelleft=False, labelright=True) + ax_histx.yaxis.set_major_locator(locator) + ax_histy.tick_params(axis='x', labelbottom=False, labeltop=True) + ax_histy.xaxis.set_major_locator(locator) + + ax_histx.hist(data.x, bins=x_marginal_bins, color='slategray', ec='black') + ax_histy.hist( + data.y, + bins=y_marginal_bins, + orientation='horizontal', + color='slategray', + ec='black', + ) + res = get_summary_statistics(data, with_median=with_median) if with_median: @@ -89,16 +131,31 @@ def plot( 'y_stdev', 'correlation', ) - locs = [0.9, 0.78, 0.66, 0.5, 0.38, 0.26, 0.1] + locs = ( + [0.94, 0.8, 0.66, 0.49, 0.35, 0.21, 0.04] + if add_marginals + else [0.9, 0.78, 0.66, 0.5, 0.38, 0.26, 0.1] + ) else: fields = ('x_mean', 'y_mean', 'x_stdev', 'y_stdev', 'correlation') - locs = np.linspace(0.8, 0.2, num=len(fields)) + locs = ( + np.linspace(0.85, 0.15, num=len(fields)) + if add_marginals + else np.linspace(0.8, 0.2, num=len(fields)) + ) labels = [_STATISTIC_DISPLAY_NAME_MAPPING[field] for field in fields] max_label_length = max([len(label) for label in labels]) max_stat = int(np.log10(np.max(np.abs(res)))) + 1 mean_x_digits, mean_y_digits = ( - int(x) + 1 for x in np.log10(np.abs([res.x_mean, res.y_mean])) + int(x) + 1 + for x in np.log10( + np.abs( + [max(res.x_mean, res.x_median), max(res.y_mean, res.y_median)] + if with_median + else [res.x_mean, res.y_mean] + ) + ) ) # If `max_label_length = 10`, this string will be "{:<10}: {:0.7f}", then we @@ -117,25 +174,21 @@ def plot( add_stat_text = partial( ax.text, - 1.05, + 1.4 if add_marginals else 1.05, fontsize=15, transform=ax.transAxes, va='center', ) - for loc, field in zip(locs, fields): + for loc, field in zip(locs, fields, strict=False): label = _STATISTIC_DISPLAY_NAME_MAPPING[field] stat = getattr(res, field) if field == 'correlation': - correlation_str = corr_formatter(labels[-1], res.correlation) + correlation_str = corr_formatter(label, res.correlation) for alpha, text in zip( - [0.3, 1], [correlation_str, correlation_str[:-stat_clip]] + [0.3, 1], [correlation_str, correlation_str[:-stat_clip]], strict=False ): - add_stat_text( - locs[-1], - text, - alpha=alpha, - ) + add_stat_text(loc, text, alpha=alpha) else: add_stat_text(loc, formatter(label, stat), alpha=0.3) add_stat_text(loc, formatter(label, stat)[:-stat_clip]) diff --git a/src/data_morph/plotting/style.py b/src/data_morph/plotting/style.py index 42b7db55..e46c415b 100644 --- a/src/data_morph/plotting/style.py +++ b/src/data_morph/plotting/style.py @@ -1,11 +1,11 @@ """Utility functions for styling Matplotlib plots.""" -from collections.abc import Generator +from collections.abc import Callable, Generator from contextlib import contextmanager from functools import wraps from importlib.resources import as_file, files from pathlib import Path -from typing import Any, Callable +from typing import Any import matplotlib.pyplot as plt diff --git a/src/data_morph/shapes/bases/line_collection.py b/src/data_morph/shapes/bases/line_collection.py index f19e6950..4e1469fe 100644 --- a/src/data_morph/shapes/bases/line_collection.py +++ b/src/data_morph/shapes/bases/line_collection.py @@ -121,5 +121,5 @@ def plot(self, ax: Axes | None = None) -> Axes: fig.get_layout_engine().set(w_pad=0.2, h_pad=0.2) _ = ax.axis('equal') for start, end in self.lines: - ax.plot(*list(zip(start, end)), 'k-') + ax.plot(*list(zip(start, end, strict=True)), 'k-') return ax diff --git a/src/data_morph/shapes/lines/star.py b/src/data_morph/shapes/lines/star.py index b5b339d1..16ae1f19 100644 --- a/src/data_morph/shapes/lines/star.py +++ b/src/data_morph/shapes/lines/star.py @@ -1,5 +1,7 @@ """Star shape.""" +import itertools + from ...data.dataset import Dataset from ..bases.line_collection import LineCollection @@ -46,4 +48,4 @@ def __init__(self, dataset: Dataset) -> None: [xmin, ymin + y_range * 0.625], ] - super().__init__(*list(zip(pts[:-1], pts[1:]))) + super().__init__(*list(itertools.pairwise(pts))) diff --git a/tests/bounds/test_interval.py b/tests/bounds/test_interval.py index 4c744dcb..43eca8e4 100644 --- a/tests/bounds/test_interval.py +++ b/tests/bounds/test_interval.py @@ -91,7 +91,7 @@ def test_getitem(self): def test_iter(self): """Test that the __iter__() method is working.""" limits = [0, 1] - for bound, limit in zip(Interval(limits), limits): + for bound, limit in zip(Interval(limits), limits, strict=True): assert bound == limit @pytest.mark.parametrize( diff --git a/tests/data/test_loader.py b/tests/data/test_loader.py index a1a5d1bb..25925a2d 100644 --- a/tests/data/test_loader.py +++ b/tests/data/test_loader.py @@ -67,7 +67,9 @@ def test_plot_available_datasets(self, monkeypatch, subset): assert len(populated_axs) == len(DataLoader.AVAILABLE_DATASETS) assert all(ax.get_xlabel() == ax.get_ylabel() == '' for ax in populated_axs) - for dataset, ax in zip(DataLoader.AVAILABLE_DATASETS, populated_axs): + for dataset, ax in zip( + DataLoader.AVAILABLE_DATASETS, populated_axs, strict=True + ): subplot_title = ax.get_title() assert subplot_title.startswith(dataset) assert subplot_title.endswith(' points)') diff --git a/tests/plotting/test_animation.py b/tests/plotting/test_animation.py index 4fa67a6d..8def9dbc 100644 --- a/tests/plotting/test_animation.py +++ b/tests/plotting/test_animation.py @@ -30,6 +30,7 @@ def test_frame_stitching(sample_data, tmp_path, forward_only): save_to=(tmp_path / f'{start_shape}-to-{target_shape}-{frame}.png'), decimals=2, with_median=False, + marginals=None, ) duration_multipliers = [0, 0, 0, 0, 1, 1, *frame_numbers[2:], frame_numbers[-1]] diff --git a/tests/plotting/test_static.py b/tests/plotting/test_static.py index 9b1064e5..d0f7b024 100644 --- a/tests/plotting/test_static.py +++ b/tests/plotting/test_static.py @@ -1,5 +1,7 @@ """Test the static module.""" +import matplotlib.pyplot as plt +import numpy as np import pytest from data_morph.plotting.static import plot @@ -8,16 +10,23 @@ @pytest.mark.parametrize( - ('file_path', 'with_median'), + ('file_path', 'with_median', 'classic'), [ - ('test_plot.png', False), - (None, True), - (None, False), + ('test_plot.png', False, True), + (None, True, True), + (None, False, True), + (None, True, False), + (None, False, False), ], ) -def test_plot(sample_data, tmp_path, file_path, with_median): +def test_plot(sample_data, tmp_path, file_path, with_median, classic): """Test static plot creation.""" bounds = (-5.0, 105.0) + + marginals = ( + None if classic else (np.histogram(sample_data.x), np.histogram(sample_data.y)) + ) + if file_path: save_to = tmp_path / 'another-level' / file_path @@ -28,6 +37,7 @@ def test_plot(sample_data, tmp_path, file_path, with_median): save_to=save_to, decimals=2, with_median=with_median, + marginals=marginals, ) assert save_to.is_file() @@ -39,6 +49,7 @@ def test_plot(sample_data, tmp_path, file_path, with_median): save_to=None, decimals=2, with_median=with_median, + marginals=marginals, ) # confirm that the stylesheet was used @@ -52,3 +63,10 @@ def test_plot(sample_data, tmp_path, file_path, with_median): expected_stats = 7 if with_median else 5 expected_texts = 2 * expected_stats # label and the number assert len(ax.texts) == expected_texts + + # if marginals should be there, check for two inset Axes + expected_insets = 0 if classic else 2 + inset_axes = [ + child for child in ax.get_children() if isinstance(child, plt.Axes) + ] + assert len(inset_axes) == expected_insets diff --git a/tests/shapes/circles/test_circle.py b/tests/shapes/circles/test_circle.py index 2daa251d..626d92a3 100644 --- a/tests/shapes/circles/test_circle.py +++ b/tests/shapes/circles/test_circle.py @@ -32,5 +32,6 @@ def test_is_circle(self, shape): for x, y in zip( cx + shape.radius * np.cos(angles), cy + shape.radius * np.sin(angles), + strict=True, ): assert pytest.approx(shape.distance(x, y)) == 0 diff --git a/tests/test_morpher.py b/tests/test_morpher.py index b4892191..2b5b09d4 100644 --- a/tests/test_morpher.py +++ b/tests/test_morpher.py @@ -263,8 +263,9 @@ def test_record_frames(self, write_images, start_frame, tmp_path): morpher._record_frames( dataset.data, dataset.plot_bounds, - base_path, - frame_number, + marginals=None, + base_file_name=base_path, + frame_number=frame_number, ) images = list(tmp_path.glob(f'{base_path}*.png'))