Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions doc_source/directions.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Directions

```{eval-rst}
.. automodule:: ect.directions
:members:
```
5 changes: 5 additions & 0 deletions doc_source/ect_on_graphs.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,8 @@
.. automodule:: ect.ect_graph
:members:
```

```{eval-rst}
.. automodule:: ect.sect
:members:
```
3 changes: 2 additions & 1 deletion doc_source/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ Table of Contents

Embedded graphs <embed_graph.md>
Embedded CW complex <embed_cw.md>
ECT on graphs <ect_on_graphs.md>
ECT on graphs <ect_on_graphs.md>
Directions <directions.md>
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "ect"
version = "1.0.1"
version = "1.0.2"
authors = [
{ name="Liz Munch", email="muncheli@msu.edu" },
]
Expand Down
12 changes: 7 additions & 5 deletions src/ect/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@
from .embed_graph import EmbeddedGraph
from .embed_cw import EmbeddedCW
from .directions import Directions
from .sect import SECT
from .utils import examples

__all__ = [
'ECT',
'EmbeddedGraph',
'EmbeddedCW',
'Directions',
'examples',
"ECT",
"SECT",
"EmbeddedGraph",
"EmbeddedCW",
"Directions",
"examples",
]
59 changes: 59 additions & 0 deletions src/ect/sect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from ect import ECT
from .embed_graph import EmbeddedGraph
from .embed_cw import EmbeddedCW
from .directions import Directions
from .results import ECTResult
from typing import Optional, Union
import numpy as np


class SECT(ECT):
"""
A class to calculate the Smooth Euler Characteristic Transform (SECT).
Inherits from ECT and applies smoothing to the final result.
"""

def __init__(
self,
directions: Optional[Directions] = None,
num_dirs: Optional[int] = None,
num_thresh: Optional[int] = None,
bound_radius: Optional[float] = None,
thresholds: Optional[np.ndarray] = None,
dtype=np.float32,
):
"""Initialize SECT calculator with smoothing parameter

Args:
directions: Optional pre-configured Directions object
num_dirs: Number of directions to sample (ignored if directions provided)
num_thresh: Number of threshold values (required if directions not provided)
bound_radius: Optional radius for bounding circle
thresholds: Optional array of thresholds
dtype: Data type for output array
"""
super().__init__(
directions, num_dirs, num_thresh, bound_radius, thresholds, dtype
)

def calculate(
self,
graph: Union[EmbeddedGraph, EmbeddedCW],
theta: Optional[float] = None,
override_bound_radius: Optional[float] = None,
) -> ECTResult:
"""Calculate Smooth Euler Characteristic Transform (SECT)

Args:
graph: The input graph to calculate the SECT for
theta: The angle in [0,2π] for the direction to calculate the SECT
override_bound_radius: Optional override for bounding radius

Returns:
ECTResult: The smoothed transform result containing the matrix,
directions, and thresholds
"""
ect_result = super().calculate(graph, theta, override_bound_radius)
return ECTResult(
ect_result, ect_result.directions, ect_result.thresholds
).smooth()
76 changes: 76 additions & 0 deletions tests/test_sect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import unittest
import numpy as np
from ect import SECT, ECT
from ect.utils.examples import create_example_graph
from ect.directions import Directions


class TestSECT(unittest.TestCase):
def setUp(self):
"""Set up test fixtures"""
self.graph = create_example_graph()
self.num_dirs = 8
self.num_thresh = 10
self.sect = SECT(num_dirs=self.num_dirs, num_thresh=self.num_thresh)

def test_inheritance(self):
"""Test that SECT properly inherits from ECT"""
self.assertIsInstance(self.sect, ECT)
self.assertTrue(hasattr(self.sect, "calculate"))

def test_calculate_output_shape(self):
"""Test that SECT calculation returns correct shape"""
result = self.sect.calculate(self.graph)

self.assertEqual(result.shape[0], self.num_dirs)
self.assertEqual(result.shape[1], self.num_thresh)
self.assertEqual(len(result.thresholds), self.num_thresh)
self.assertEqual(len(result.directions), self.num_dirs)

def test_smoothing_effect(self):
"""Test that smoothing is actually applied"""
# Calculate both ECT and SECT
ect = ECT(num_dirs=self.num_dirs, num_thresh=self.num_thresh)
sect = SECT(num_dirs=self.num_dirs, num_thresh=self.num_thresh)

ect_result = ect.calculate(self.graph)
sect_result = sect.calculate(self.graph)

# Verify results are different due to smoothing
self.assertFalse(np.allclose(ect_result, sect_result))

# Verify smoothing preserves direction count
self.assertEqual(
np.sum(ect_result, axis=1).shape,
np.sum(sect_result, axis=1).shape,
)

def test_with_theta(self):
"""Test SECT calculation with specific theta value"""
theta = np.pi / 4
result = self.sect.calculate(self.graph, theta=theta)

# Should only have one direction when theta is specified
self.assertEqual(result.shape[0], 1)
self.assertEqual(result.shape[1], self.num_thresh)

def test_with_override_radius(self):
"""Test SECT calculation with override_bound_radius"""
override_radius = 2.0
result = self.sect.calculate(self.graph, override_bound_radius=override_radius)

# Check that thresholds are within the override radius
self.assertLessEqual(np.max(np.abs(result.thresholds)), override_radius)

def test_smooth_matrix_properties(self):
"""Test properties of the smoothed matrix"""
result = self.sect.calculate(self.graph)

# Smoothed values should be finite
self.assertTrue(np.all(np.isfinite(result)))

# Shape should be preserved after smoothing
self.assertEqual(result.shape, (self.num_dirs, self.num_thresh))

# Verify result is float type after smoothing
self.assertTrue(np.issubdtype(result.dtype, np.floating))