11import numpy as np
22from numba import prange , njit
3+ from numba .typed import List
34from typing import Optional , Union
45
56from .embed_cw import EmbeddedCW
@@ -55,8 +56,7 @@ def _ensure_directions(self, graph_dim, theta=None):
5556 """Ensures directions is a valid Directions object of correct dimension"""
5657 if self .directions is None :
5758 if self .num_dirs is None :
58- raise ValueError (
59- "Either 'directions' or 'num_dirs' must be provided." )
59+ raise ValueError ("Either 'directions' or 'num_dirs' must be provided." )
6060 self .directions = Directions .uniform (self .num_dirs , dim = graph_dim )
6161 elif isinstance (self .directions , list ):
6262 # if list of vectors, convert to Directions object
@@ -127,12 +127,10 @@ def calculate(
127127
128128 # override with theta if provided
129129 directions = (
130- self .directions if theta is None else Directions .from_angles ([
131- theta ])
130+ self .directions if theta is None else Directions .from_angles ([theta ])
132131 )
133132
134- simplex_projections = self ._compute_simplex_projections (
135- graph , directions )
133+ simplex_projections = self ._compute_simplex_projections (graph , directions )
136134
137135 ect_matrix = self ._compute_directional_transform (
138136 simplex_projections , self .thresholds , self .shape_descriptor , self .dtype
@@ -148,7 +146,7 @@ def _compute_simplex_projections(
148146 self , graph : Union [EmbeddedGraph , EmbeddedCW ], directions
149147 ):
150148 """Compute projections of each simplex (vertices, edges, faces)"""
151- simplex_projections = []
149+ simplex_projections = List ()
152150 node_projections = self ._compute_node_projections (
153151 graph .coord_matrix , directions
154152 )
@@ -162,11 +160,9 @@ def _compute_simplex_projections(
162160
163161 if isinstance (graph , EmbeddedCW ) and len (graph .faces ) > 0 :
164162 node_to_index = {n : i for i , n in enumerate (graph .node_list )}
165- face_indices = [[node_to_index [v ] for v in face ]
166- for face in graph .faces ]
163+ face_indices = [[node_to_index [v ] for v in face ] for face in graph .faces ]
167164 face_maxes = np .array (
168- [np .max (node_projections [face , :], axis = 0 )
169- for face in face_indices ]
165+ [np .max (node_projections [face , :], axis = 0 ) for face in face_indices ]
170166 )
171167 simplex_projections .append (face_maxes )
172168
@@ -192,7 +188,7 @@ def _compute_directional_transform(
192188 num_thresh = thresholds .shape [0 ]
193189 result = np .empty ((num_dir , num_thresh ), dtype = dtype )
194190
195- sorted_projections = []
191+ sorted_projections = List ()
196192 for proj in simplex_projections_list :
197193 sorted_proj = np .empty_like (proj )
198194 for i in prange (num_dir ):
@@ -202,7 +198,7 @@ def _compute_directional_transform(
202198 for j in prange (num_thresh ):
203199 thresh = thresholds [j ]
204200 for i in range (num_dir ):
205- simplex_counts_list = []
201+ simplex_counts_list = List ()
206202 for k in range (len (sorted_projections )):
207203 projs = sorted_projections [k ][:, i ]
208204 simplex_counts_list .append (
@@ -212,7 +208,7 @@ def _compute_directional_transform(
212208 return result
213209
214210 @staticmethod
215- @njit (parallel = True , fastmath = True )
211+ @njit (fastmath = True )
216212 def shape_descriptor (simplex_counts_list ):
217213 """Calculate shape descriptor from simplex counts (Euler characteristic)"""
218214 chi = 0
0 commit comments