1717from FiberFusing .coordinate_system import CoordinateSystem
1818from FiberFusing .utils import config_dict
1919
20+
2021class DomainAlignment (Enum ):
2122 """Boundary positioning modes."""
23+
2224 AUTO = "auto"
2325 LEFT = "left"
2426 RIGHT = "right"
@@ -28,7 +30,7 @@ class DomainAlignment(Enum):
2830
2931
3032@dataclass (config = config_dict )
31- class Geometry () :
33+ class Geometry :
3234 """
3335 Represents the refractive index (RI) geometric profile including background and fiber structures.
3436
@@ -75,28 +77,28 @@ def add_structure(self, *structure: object) -> None:
7577 """
7678 self .structure_list .extend (structure )
7779
78- @field_validator (' resolution' )
80+ @field_validator (" resolution" )
7981 @classmethod
8082 def validate_resolution (cls , v : int ) -> int :
8183 """Validate resolution is positive."""
8284 if v <= 0 :
83- raise ValueError (' Resolution must be positive' )
85+ raise ValueError (" Resolution must be positive" )
8486 return v
8587
86- @field_validator (' boundary_pad_factor' )
88+ @field_validator (" boundary_pad_factor" )
8789 @classmethod
8890 def validate_boundary_pad_factor (cls , v : float ) -> float :
8991 """Validate boundary pad factor is positive."""
9092 if v <= 0 :
91- raise ValueError (' Boundary pad factor must be positive' )
93+ raise ValueError (" Boundary pad factor must be positive" )
9294 return v
9395
94- @field_validator (' index_scrambling' )
96+ @field_validator (" index_scrambling" )
9597 @classmethod
9698 def validate_index_scrambling (cls , v : float ) -> float :
9799 """Validate index scrambling is non-negative."""
98100 if v < 0 :
99- raise ValueError (' Index scrambling must be non-negative' )
101+ raise ValueError (" Index scrambling must be non-negative" )
100102 return v
101103
102104 def initialize (self ):
@@ -107,7 +109,12 @@ def initialize(self):
107109 x_min , y_min , x_max , y_max = self .get_boundaries ()
108110
109111 self .coordinate_system = CoordinateSystem (
110- x_min = x_min , x_max = x_max , y_min = y_min , y_max = y_max , nx = self .resolution , ny = self .resolution
112+ x_min = x_min ,
113+ x_max = x_max ,
114+ y_min = y_min ,
115+ y_max = y_max ,
116+ nx = self .resolution ,
117+ ny = self .resolution ,
111118 )
112119
113120 self .coordinate_system .center (factor = self .boundary_pad_factor )
@@ -126,14 +133,21 @@ def update_coordinate_system(self) -> None:
126133 x_min , y_min , x_max , y_max = self .get_boundaries ()
127134
128135 self .coordinate_system = CoordinateSystem (
129- x_min = x_min , x_max = x_max , y_min = y_min , y_max = y_max , nx = self .resolution , ny = self .resolution
136+ x_min = x_min ,
137+ x_max = x_max ,
138+ y_min = y_min ,
139+ y_max = y_max ,
140+ nx = self .resolution ,
141+ ny = self .resolution ,
130142 )
131143 self .coordinate_system .center (factor = self .boundary_pad_factor )
132144 self .apply_boundary_settings ()
133145
134- @field_validator (' x_bounds' )
146+ @field_validator (" x_bounds" )
135147 @classmethod
136- def validate_x_bounds (cls , v : Union [DomainAlignment , Tuple [float , float ]]) -> Union [DomainAlignment , Tuple [float , float ]]:
148+ def validate_x_bounds (
149+ cls , v : Union [DomainAlignment , Tuple [float , float ]]
150+ ) -> Union [DomainAlignment , Tuple [float , float ]]:
137151 """Validate x_bounds parameter."""
138152 if isinstance (v , (list , tuple )):
139153 if len (v ) != 2 :
@@ -144,9 +158,11 @@ def validate_x_bounds(cls, v: Union[DomainAlignment, Tuple[float, float]]) -> Un
144158 raise ValueError ("x_bounds must be a tuple or DomainAlignment" )
145159 return v
146160
147- @field_validator (' y_bounds' )
161+ @field_validator (" y_bounds" )
148162 @classmethod
149- def validate_y_bounds (cls , v : Union [DomainAlignment , Tuple [float , float ]]) -> Union [DomainAlignment , Tuple [float , float ]]:
163+ def validate_y_bounds (
164+ cls , v : Union [DomainAlignment , Tuple [float , float ]]
165+ ) -> Union [DomainAlignment , Tuple [float , float ]]:
150166 """Validate y_bounds parameter."""
151167 if isinstance (v , (list , tuple )):
152168 if len (v ) != 2 :
@@ -159,14 +175,18 @@ def validate_y_bounds(cls, v: Union[DomainAlignment, Tuple[float, float]]) -> Un
159175
160176 def apply_boundary_settings (self ) -> None :
161177 """Apply boundary settings to coordinate system."""
162- if hasattr (self , ' coordinate_system' ) and self .coordinate_system is not None :
178+ if hasattr (self , " coordinate_system" ) and self .coordinate_system is not None :
163179 if isinstance (self .x_bounds , (list , tuple )):
164- self .coordinate_system .x_min , self .coordinate_system .x_max = self .x_bounds
180+ self .coordinate_system .x_min , self .coordinate_system .x_max = (
181+ self .x_bounds
182+ )
165183 elif isinstance (self .x_bounds , DomainAlignment ):
166184 self ._apply_x_boundary_mode (self .x_bounds )
167185
168186 if isinstance (self .y_bounds , (list , tuple )):
169- self .coordinate_system .y_min , self .coordinate_system .y_max = self .y_bounds
187+ self .coordinate_system .y_min , self .coordinate_system .y_max = (
188+ self .y_bounds
189+ )
170190 elif isinstance (self .y_bounds , DomainAlignment ):
171191 self ._apply_y_boundary_mode (self .y_bounds )
172192
@@ -208,10 +228,16 @@ def get_boundaries(self) -> Tuple[float, float, float, float]:
208228 The boundaries as (x_min, y_min, x_max, y_max).
209229
210230 """
211- filtered_structures = [obj for obj in self .structure_list if not isinstance (obj , FiberFusing .background .BackGround )]
231+ filtered_structures = [
232+ obj
233+ for obj in self .structure_list
234+ if not isinstance (obj , FiberFusing .background .BackGround )
235+ ]
212236
213237 if len (filtered_structures ) == 0 :
214- raise ValueError ('No structures provided (other than background) for computing the mesh.' )
238+ raise ValueError (
239+ "No structures provided (other than background) for computing the mesh."
240+ )
215241
216242 x_min , y_min , x_max , y_max = zip (
217243 * (obj .get_structure_max_min_boundaries () for obj in filtered_structures )
@@ -228,7 +254,11 @@ def refractive_index_maximum(self) -> float:
228254 float
229255 Maximum refractive index.
230256 """
231- return max (refractive_index for obj in self .structure_list for refractive_index in obj .refractive_index_list )
257+ return max (
258+ refractive_index
259+ for obj in self .structure_list
260+ for refractive_index in obj .refractive_index_list
261+ )
232262
233263 @property
234264 def refractive_index_minimum (self ) -> float :
@@ -240,7 +270,12 @@ def refractive_index_minimum(self) -> float:
240270 float
241271 Minimum refractive index.
242272 """
243- return min (refractive_index for obj in self .structure_list if not isinstance (obj , FiberFusing .background .BackGround ) for refractive_index in obj .refractive_index_list )
273+ return min (
274+ refractive_index
275+ for obj in self .structure_list
276+ if not isinstance (obj , FiberFusing .background .BackGround )
277+ for refractive_index in obj .refractive_index_list
278+ )
244279
245280 def get_index_range (self ) -> List [float ]:
246281 """
@@ -278,7 +313,12 @@ def randomize_fiber_structures_index(self, random_factor: float) -> None:
278313 """
279314 for fiber in self .fiber_list :
280315 for structure in fiber .inner_structure :
281- adjustment = structure .refractive_index * self .refractive_index_scrambling * numpy .random .rand () * random_factor
316+ adjustment = (
317+ structure .refractive_index
318+ * self .refractive_index_scrambling
319+ * numpy .random .rand ()
320+ * random_factor
321+ )
282322 structure .refractive_index += adjustment
283323
284324 self .mesh = self .generate_mesh ()
@@ -300,7 +340,9 @@ def rasterize_polygons(self) -> numpy.ndarray:
300340 mesh = numpy .zeros (self .coordinate_system .shape )
301341
302342 for structure in self .structure_list :
303- structure .overlay_structures_on_mesh (mesh = mesh , coordinate_system = self .coordinate_system )
343+ structure .overlay_structures_on_mesh (
344+ mesh = mesh , coordinate_system = self .coordinate_system
345+ )
304346
305347 return mesh
306348
@@ -318,8 +360,10 @@ def generate_mesh(self) -> numpy.ndarray:
318360 AttributeError
319361 If the coordinate system has not been generated before calling this method.
320362 """
321- if not hasattr (self , 'coordinate_system' ):
322- raise AttributeError ("Coordinate system has not been generated. Call generate_coordinate_system() first." )
363+ if not hasattr (self , "coordinate_system" ):
364+ raise AttributeError (
365+ "Coordinate system has not been generated. Call generate_coordinate_system() first."
366+ )
323367
324368 mesh = self .rasterize_polygons ()
325369
@@ -349,12 +393,21 @@ def plot_patch(self, axes: plt.Axes) -> None:
349393 continue
350394
351395 if isinstance (structure , FiberFusing .profile .Profile ):
352- structure .plot (axes = axes , show = False , show_added = False , show_removed = False , show_centers = False , show_fibers = True )
396+ structure .plot (
397+ axes = axes ,
398+ show = False ,
399+ show_added = False ,
400+ show_removed = False ,
401+ show_centers = False ,
402+ show_fibers = True ,
403+ )
353404 continue
354405
355406 structure .plot (axes = axes , show = False )
356407
357- axes .set (title = 'Fiber structure' , xlabel = r'x-distance [m]' , ylabel = r'y-distance [m]' )
408+ axes .set (
409+ title = "Fiber structure" , xlabel = r"x-distance [m]" , ylabel = r"y-distance [m]"
410+ )
358411
359412 @helper .pre_plot (nrows = 1 , ncols = 1 )
360413 def plot_raster (self , axes : plt .Axes , gamma : float = 5 ) -> None :
@@ -378,17 +431,25 @@ def plot_raster(self, axes: plt.Axes, gamma: float = 5) -> None:
378431 self .coordinate_system .x_vector ,
379432 self .coordinate_system .y_vector ,
380433 self .mesh ,
381- cmap = ' Blues' ,
382- norm = colors .PowerNorm (gamma = gamma )
434+ cmap = " Blues" ,
435+ norm = colors .PowerNorm (gamma = gamma ),
383436 )
384437
385438 divider = make_axes_locatable (axes )
386- cax = divider .append_axes (' right' , size = '5%' , pad = 0.05 )
387- axes .get_figure ().colorbar (image , cax = cax , orientation = ' vertical' )
439+ cax = divider .append_axes (" right" , size = "5%" , pad = 0.05 )
440+ axes .get_figure ().colorbar (image , cax = cax , orientation = " vertical" )
388441
389- axes .set (title = 'Fiber structure' , xlabel = r'x-distance [m]' , ylabel = r'y-distance [m]' )
442+ axes .set (
443+ title = "Fiber structure" , xlabel = r"x-distance [m]" , ylabel = r"y-distance [m]"
444+ )
390445
391- @helper .pre_plot (nrows = 1 , ncols = 2 , subplot_kw = dict (aspect = 'equal' , xlabel = 'x-distance [m]' , ylabel = 'y-distance [m]' ))
446+ @helper .pre_plot (
447+ nrows = 1 ,
448+ ncols = 2 ,
449+ subplot_kw = dict (
450+ aspect = "equal" , xlabel = "x-distance [m]" , ylabel = "y-distance [m]"
451+ ),
452+ )
392453 def plot (self , axes , gamma : float = 5 ) -> plt .Figure :
393454 """
394455 Plot the different representations (patch and mesh) of the geometry.
@@ -408,7 +469,6 @@ def plot(self, axes, gamma: float = 5) -> plt.Figure:
408469 axes [0 ].sharex (axes [1 ])
409470 axes [0 ].sharey (axes [1 ])
410471
411-
412472 self .plot_patch (axes = axes [0 ], show = False )
413473
414474 self .plot_raster (axes = axes [1 ], show = False , gamma = gamma )
0 commit comments