Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "ect"
version = "1.1.1"
version = "1.1.2"
authors = [
{ name="Liz Munch", email="muncheli@msu.edu" },
]
Expand Down
68 changes: 66 additions & 2 deletions src/ect/results.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import matplotlib.pyplot as plt
import numpy as np
from ect.directions import Sampling
from scipy.spatial.distance import cdist
from typing import Union, List, Callable


class ECTResult(np.ndarray):
Expand Down Expand Up @@ -47,8 +49,7 @@ def plot(self, ax=None):
self.directions.sampling == Sampling.UNIFORM
and not self.directions.endpoint
):
plot_thetas = np.concatenate(
[self.directions.thetas, [2 * np.pi]])
plot_thetas = np.concatenate([self.directions.thetas, [2 * np.pi]])
ect_data = np.hstack([self.T, self.T[:, [0]]])
else:
plot_thetas = self.directions.thetas
Expand Down Expand Up @@ -113,3 +114,66 @@ def _plot_ecc(self, theta):
plt.title(r"ECC for $\omega = " + theta_round + "$")
plt.xlabel("$a$")
plt.ylabel(r"$\chi(K_a)$")

def dist(
self,
other: Union["ECTResult", List["ECTResult"]],
metric: Union[str, Callable] = "cityblock",
**kwargs,
):
"""
Compute distance to another ECTResult or list of ECTResults.

Args:
other: Another ECTResult object or list of ECTResult objects
metric: Distance metric to use. Can be:
- String: any metric supported by scipy.spatial.distance
(e.g., 'euclidean', 'cityblock', 'chebyshev', 'cosine', etc.)
- Callable: a custom distance function that takes two 1D arrays
and returns a scalar distance. The function should have signature:
func(u, v) -> float
**kwargs: Additional keyword arguments passed to the metric function
(e.g., p=3 for minkowski distance, w=weights for weighted metrics)

Returns:
float or np.ndarray: Single distance if other is an ECTResult,
array of distances if other is a list

Raises:
ValueError: If the shapes of the ECTResults don't match

Examples:
>>> # Built-in metrics
>>> dist1 = ect1.dist(ect2, metric='euclidean')
>>> dist2 = ect1.dist(ect2, metric='minkowski', p=3)
>>>
>>> # Custom distance function
>>> def my_distance(u, v):
... return np.sum(np.abs(u - v) ** 0.5)
>>> dist3 = ect1.dist(ect2, metric=my_distance)
>>>
>>> # Batch distances with custom function
>>> dists = ect1.dist([ect2, ect3, ect4], metric=my_distance)
"""
# normalize input to list
single = isinstance(other, ECTResult)
others = [other] if single else other

if not others:
return np.array([])

for i, ect in enumerate(others):
if ect.shape != self.shape:
raise ValueError(
f"Shape mismatch at index {i}: {self.shape} vs {ect.shape}"
)

# use ravel to avoid copying the data and compute distances
distances = cdist(
self.ravel()[np.newaxis, :],
np.vstack([ect.ravel() for ect in others]),
metric=metric,
**kwargs,
)[0]

return distances[0] if single else distances
129 changes: 129 additions & 0 deletions tests/test_ect_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,135 @@ def test_array_finalize(self):
self.assertEqual(sliced.directions, self.result.directions)
self.assertTrue(np.array_equal(sliced.thresholds, self.result.thresholds))

def test_dist_single_ectresult(self):
"""Test distance computation between two ECTResults"""
# Create a second ECTResult with same shape
result2 = self.ect.calculate(self.graph)
# Modify it slightly
result2_modified = result2 + 1
result2_modified.directions = result2.directions
result2_modified.thresholds = result2.thresholds

# Test L1 distance (default)
dist_l1 = self.result.dist(result2_modified)
expected_l1 = np.abs(self.result - result2_modified).sum()
self.assertAlmostEqual(dist_l1, expected_l1)
self.assertIsInstance(dist_l1, (float, np.floating))

# Test L2 distance
dist_l2 = self.result.dist(result2_modified, metric='euclidean')
expected_l2 = np.sqrt(((self.result - result2_modified) ** 2).sum())
self.assertAlmostEqual(dist_l2, expected_l2)

# Test L-inf distance
dist_linf = self.result.dist(result2_modified, metric='chebyshev')
expected_linf = np.abs(self.result - result2_modified).max()
self.assertAlmostEqual(dist_linf, expected_linf)

def test_dist_list_of_ectresults(self):
"""Test batch distance computation with list of ECTResults"""
# Create multiple ECTResults
result2 = self.result + 1
result3 = self.result + 2
result4 = self.result + 3

# Preserve metadata
for r, val in [(result2, 1), (result3, 2), (result4, 3)]:
r.directions = self.result.directions
r.thresholds = self.result.thresholds

# Test batch distances
distances = self.result.dist([result2, result3, result4])

# Check return type is array
self.assertIsInstance(distances, np.ndarray)
self.assertEqual(len(distances), 3)

# Check individual distances are correct
expected_dists = [
np.abs(self.result - result2).sum(),
np.abs(self.result - result3).sum(),
np.abs(self.result - result4).sum()
]
np.testing.assert_array_almost_equal(distances, expected_dists)

def test_dist_custom_metric(self):
"""Test distance with custom metric function"""
result2 = self.result + 1
result2.directions = self.result.directions
result2.thresholds = self.result.thresholds

# Define custom metric - L0.5 norm
def custom_metric(u, v):
return np.sum(np.abs(u - v) ** 0.5)

# Test with custom metric
dist_custom = self.result.dist(result2, metric=custom_metric)
expected = custom_metric(self.result.ravel(), result2.ravel())
self.assertAlmostEqual(dist_custom, expected)

# Test custom metric with batch
result3 = self.result + 2
result3.directions = self.result.directions
result3.thresholds = self.result.thresholds

distances = self.result.dist([result2, result3], metric=custom_metric)
expected_batch = [
custom_metric(self.result.ravel(), result2.ravel()),
custom_metric(self.result.ravel(), result3.ravel())
]
np.testing.assert_array_almost_equal(distances, expected_batch)

def test_dist_additional_kwargs(self):
"""Test passing additional kwargs to metric functions"""
result2 = self.result + 1
result2.directions = self.result.directions
result2.thresholds = self.result.thresholds

# Test minkowski with different p values
dist_p3 = self.result.dist(result2, metric='minkowski', p=3)
expected_p3 = np.sum(np.abs(self.result - result2) ** 3) ** (1/3)
self.assertAlmostEqual(dist_p3, expected_p3, places=5)

dist_p5 = self.result.dist(result2, metric='minkowski', p=5)
expected_p5 = np.sum(np.abs(self.result - result2) ** 5) ** (1/5)
self.assertAlmostEqual(dist_p5, expected_p5, places=5)

def test_dist_empty_list(self):
"""Test that empty list returns empty array"""
distances = self.result.dist([])
self.assertIsInstance(distances, np.ndarray)
self.assertEqual(len(distances), 0)

def test_dist_shape_mismatch(self):
"""Test that shape mismatch raises ValueError"""
# Create ECTResult with different shape
ect_different = ECT(num_dirs=5, num_thresh=7)
result_different = ect_different.calculate(self.graph)

# Single ECTResult with wrong shape
with self.assertRaises(ValueError) as cm:
self.result.dist(result_different)
self.assertIn("Shape mismatch", str(cm.exception))

# List with one wrong shape
result2 = self.result + 1
result2.directions = self.result.directions
result2.thresholds = self.result.thresholds

with self.assertRaises(ValueError) as cm:
self.result.dist([result2, result_different])
self.assertIn("Shape mismatch at index 1", str(cm.exception))

def test_dist_self(self):
"""Test distance to self is zero"""
dist_self = self.result.dist(self.result)
self.assertEqual(dist_self, 0.0)

# Also test with different metrics
self.assertEqual(self.result.dist(self.result, metric='euclidean'), 0.0)
self.assertEqual(self.result.dist(self.result, metric='chebyshev'), 0.0)


if __name__ == '__main__':
unittest.main()
Loading