From 45b8bf4c3d1e7ac3f0933fba43617b6f172d1e6b Mon Sep 17 00:00:00 2001 From: yemeen Date: Mon, 15 Sep 2025 16:10:11 -0400 Subject: [PATCH 1/2] add distance to ectresult and test cases --- src/ect/results.py | 68 ++++++++++++++++++++- tests/test_ect_result.py | 129 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 195 insertions(+), 2 deletions(-) diff --git a/src/ect/results.py b/src/ect/results.py index 095e7c7..9ed0821 100644 --- a/src/ect/results.py +++ b/src/ect/results.py @@ -1,6 +1,8 @@ import matplotlib.pyplot as plt import numpy as np from ect.directions import Sampling +from scipy.spatial.distance import cdist +from typing import Union, List, Callable class ECTResult(np.ndarray): @@ -47,8 +49,7 @@ def plot(self, ax=None): self.directions.sampling == Sampling.UNIFORM and not self.directions.endpoint ): - plot_thetas = np.concatenate( - [self.directions.thetas, [2 * np.pi]]) + plot_thetas = np.concatenate([self.directions.thetas, [2 * np.pi]]) ect_data = np.hstack([self.T, self.T[:, [0]]]) else: plot_thetas = self.directions.thetas @@ -113,3 +114,66 @@ def _plot_ecc(self, theta): plt.title(r"ECC for $\omega = " + theta_round + "$") plt.xlabel("$a$") plt.ylabel(r"$\chi(K_a)$") + + def dist( + self, + other: Union["ECTResult", List["ECTResult"]], + metric: Union[str, Callable] = "cityblock", + **kwargs, + ): + """ + Compute distance to another ECTResult or list of ECTResults. + + Args: + other: Another ECTResult object or list of ECTResult objects + metric: Distance metric to use. Can be: + - String: any metric supported by scipy.spatial.distance + (e.g., 'euclidean', 'cityblock', 'chebyshev', 'cosine', etc.) + - Callable: a custom distance function that takes two 1D arrays + and returns a scalar distance. The function should have signature: + func(u, v) -> float + **kwargs: Additional keyword arguments passed to the metric function + (e.g., p=3 for minkowski distance, w=weights for weighted metrics) + + Returns: + float or np.ndarray: Single distance if other is an ECTResult, + array of distances if other is a list + + Raises: + ValueError: If the shapes of the ECTResults don't match + + Examples: + >>> # Built-in metrics + >>> dist1 = ect1.dist(ect2, metric='euclidean') + >>> dist2 = ect1.dist(ect2, metric='minkowski', p=3) + >>> + >>> # Custom distance function + >>> def my_distance(u, v): + ... return np.sum(np.abs(u - v) ** 0.5) + >>> dist3 = ect1.dist(ect2, metric=my_distance) + >>> + >>> # Batch distances with custom function + >>> dists = ect1.dist([ect2, ect3, ect4], metric=my_distance) + """ + # normalize input to list + single = isinstance(other, ECTResult) + others = [other] if single else other + + if not others: + return np.array([]) + + for i, ect in enumerate(others): + if ect.shape != self.shape: + raise ValueError( + f"Shape mismatch at index {i}: {self.shape} vs {ect.shape}" + ) + + # use ravel to avoid copying the data and compute distances + distances = cdist( + self.ravel()[np.newaxis, :], + np.vstack([ect.ravel() for ect in others]), + metric=metric, + **kwargs, + )[0] + + return distances[0] if single else distances diff --git a/tests/test_ect_result.py b/tests/test_ect_result.py index 15212c1..eb90b0f 100644 --- a/tests/test_ect_result.py +++ b/tests/test_ect_result.py @@ -78,6 +78,135 @@ def test_array_finalize(self): self.assertEqual(sliced.directions, self.result.directions) self.assertTrue(np.array_equal(sliced.thresholds, self.result.thresholds)) + def test_dist_single_ectresult(self): + """Test distance computation between two ECTResults""" + # Create a second ECTResult with same shape + result2 = self.ect.calculate(self.graph) + # Modify it slightly + result2_modified = result2 + 1 + result2_modified.directions = result2.directions + result2_modified.thresholds = result2.thresholds + + # Test L1 distance (default) + dist_l1 = self.result.dist(result2_modified) + expected_l1 = np.abs(self.result - result2_modified).sum() + self.assertAlmostEqual(dist_l1, expected_l1) + self.assertIsInstance(dist_l1, (float, np.floating)) + + # Test L2 distance + dist_l2 = self.result.dist(result2_modified, metric='euclidean') + expected_l2 = np.sqrt(((self.result - result2_modified) ** 2).sum()) + self.assertAlmostEqual(dist_l2, expected_l2) + + # Test L-inf distance + dist_linf = self.result.dist(result2_modified, metric='chebyshev') + expected_linf = np.abs(self.result - result2_modified).max() + self.assertAlmostEqual(dist_linf, expected_linf) + + def test_dist_list_of_ectresults(self): + """Test batch distance computation with list of ECTResults""" + # Create multiple ECTResults + result2 = self.result + 1 + result3 = self.result + 2 + result4 = self.result + 3 + + # Preserve metadata + for r, val in [(result2, 1), (result3, 2), (result4, 3)]: + r.directions = self.result.directions + r.thresholds = self.result.thresholds + + # Test batch distances + distances = self.result.dist([result2, result3, result4]) + + # Check return type is array + self.assertIsInstance(distances, np.ndarray) + self.assertEqual(len(distances), 3) + + # Check individual distances are correct + expected_dists = [ + np.abs(self.result - result2).sum(), + np.abs(self.result - result3).sum(), + np.abs(self.result - result4).sum() + ] + np.testing.assert_array_almost_equal(distances, expected_dists) + + def test_dist_custom_metric(self): + """Test distance with custom metric function""" + result2 = self.result + 1 + result2.directions = self.result.directions + result2.thresholds = self.result.thresholds + + # Define custom metric - L0.5 norm + def custom_metric(u, v): + return np.sum(np.abs(u - v) ** 0.5) + + # Test with custom metric + dist_custom = self.result.dist(result2, metric=custom_metric) + expected = custom_metric(self.result.ravel(), result2.ravel()) + self.assertAlmostEqual(dist_custom, expected) + + # Test custom metric with batch + result3 = self.result + 2 + result3.directions = self.result.directions + result3.thresholds = self.result.thresholds + + distances = self.result.dist([result2, result3], metric=custom_metric) + expected_batch = [ + custom_metric(self.result.ravel(), result2.ravel()), + custom_metric(self.result.ravel(), result3.ravel()) + ] + np.testing.assert_array_almost_equal(distances, expected_batch) + + def test_dist_additional_kwargs(self): + """Test passing additional kwargs to metric functions""" + result2 = self.result + 1 + result2.directions = self.result.directions + result2.thresholds = self.result.thresholds + + # Test minkowski with different p values + dist_p3 = self.result.dist(result2, metric='minkowski', p=3) + expected_p3 = np.sum(np.abs(self.result - result2) ** 3) ** (1/3) + self.assertAlmostEqual(dist_p3, expected_p3, places=5) + + dist_p5 = self.result.dist(result2, metric='minkowski', p=5) + expected_p5 = np.sum(np.abs(self.result - result2) ** 5) ** (1/5) + self.assertAlmostEqual(dist_p5, expected_p5, places=5) + + def test_dist_empty_list(self): + """Test that empty list returns empty array""" + distances = self.result.dist([]) + self.assertIsInstance(distances, np.ndarray) + self.assertEqual(len(distances), 0) + + def test_dist_shape_mismatch(self): + """Test that shape mismatch raises ValueError""" + # Create ECTResult with different shape + ect_different = ECT(num_dirs=5, num_thresh=7) + result_different = ect_different.calculate(self.graph) + + # Single ECTResult with wrong shape + with self.assertRaises(ValueError) as cm: + self.result.dist(result_different) + self.assertIn("Shape mismatch", str(cm.exception)) + + # List with one wrong shape + result2 = self.result + 1 + result2.directions = self.result.directions + result2.thresholds = self.result.thresholds + + with self.assertRaises(ValueError) as cm: + self.result.dist([result2, result_different]) + self.assertIn("Shape mismatch at index 1", str(cm.exception)) + + def test_dist_self(self): + """Test distance to self is zero""" + dist_self = self.result.dist(self.result) + self.assertEqual(dist_self, 0.0) + + # Also test with different metrics + self.assertEqual(self.result.dist(self.result, metric='euclidean'), 0.0) + self.assertEqual(self.result.dist(self.result, metric='chebyshev'), 0.0) + if __name__ == '__main__': unittest.main() \ No newline at end of file From b8354ef4b347c9683a5c5ff48ed864acbe679b5f Mon Sep 17 00:00:00 2001 From: yemeen Date: Mon, 15 Sep 2025 16:14:41 -0400 Subject: [PATCH 2/2] Bump version to 1.1.2 in pyproject.toml --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 05b497a..f68f6cb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ect" -version = "1.1.1" +version = "1.1.2" authors = [ { name="Liz Munch", email="muncheli@msu.edu" }, ]