Skip to content

Commit ce659cb

Browse files
committed
update
1 parent fa5ba7b commit ce659cb

1 file changed

Lines changed: 93 additions & 33 deletions

File tree

FiberFusing/geometry.py

Lines changed: 93 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,10 @@
1717
from FiberFusing.coordinate_system import CoordinateSystem
1818
from FiberFusing.utils import config_dict
1919

20+
2021
class DomainAlignment(Enum):
2122
"""Boundary positioning modes."""
23+
2224
AUTO = "auto"
2325
LEFT = "left"
2426
RIGHT = "right"
@@ -28,7 +30,7 @@ class DomainAlignment(Enum):
2830

2931

3032
@dataclass(config=config_dict)
31-
class Geometry():
33+
class Geometry:
3234
"""
3335
Represents the refractive index (RI) geometric profile including background and fiber structures.
3436
@@ -75,28 +77,28 @@ def add_structure(self, *structure: object) -> None:
7577
"""
7678
self.structure_list.extend(structure)
7779

78-
@field_validator('resolution')
80+
@field_validator("resolution")
7981
@classmethod
8082
def validate_resolution(cls, v: int) -> int:
8183
"""Validate resolution is positive."""
8284
if v <= 0:
83-
raise ValueError('Resolution must be positive')
85+
raise ValueError("Resolution must be positive")
8486
return v
8587

86-
@field_validator('boundary_pad_factor')
88+
@field_validator("boundary_pad_factor")
8789
@classmethod
8890
def validate_boundary_pad_factor(cls, v: float) -> float:
8991
"""Validate boundary pad factor is positive."""
9092
if v <= 0:
91-
raise ValueError('Boundary pad factor must be positive')
93+
raise ValueError("Boundary pad factor must be positive")
9294
return v
9395

94-
@field_validator('index_scrambling')
96+
@field_validator("index_scrambling")
9597
@classmethod
9698
def validate_index_scrambling(cls, v: float) -> float:
9799
"""Validate index scrambling is non-negative."""
98100
if v < 0:
99-
raise ValueError('Index scrambling must be non-negative')
101+
raise ValueError("Index scrambling must be non-negative")
100102
return v
101103

102104
def initialize(self):
@@ -107,7 +109,12 @@ def initialize(self):
107109
x_min, y_min, x_max, y_max = self.get_boundaries()
108110

109111
self.coordinate_system = CoordinateSystem(
110-
x_min=x_min, x_max=x_max, y_min=y_min, y_max=y_max, nx=self.resolution, ny=self.resolution
112+
x_min=x_min,
113+
x_max=x_max,
114+
y_min=y_min,
115+
y_max=y_max,
116+
nx=self.resolution,
117+
ny=self.resolution,
111118
)
112119

113120
self.coordinate_system.center(factor=self.boundary_pad_factor)
@@ -126,14 +133,21 @@ def update_coordinate_system(self) -> None:
126133
x_min, y_min, x_max, y_max = self.get_boundaries()
127134

128135
self.coordinate_system = CoordinateSystem(
129-
x_min=x_min, x_max=x_max, y_min=y_min, y_max=y_max, nx=self.resolution, ny=self.resolution
136+
x_min=x_min,
137+
x_max=x_max,
138+
y_min=y_min,
139+
y_max=y_max,
140+
nx=self.resolution,
141+
ny=self.resolution,
130142
)
131143
self.coordinate_system.center(factor=self.boundary_pad_factor)
132144
self.apply_boundary_settings()
133145

134-
@field_validator('x_bounds')
146+
@field_validator("x_bounds")
135147
@classmethod
136-
def validate_x_bounds(cls, v: Union[DomainAlignment, Tuple[float, float]]) -> Union[DomainAlignment, Tuple[float, float]]:
148+
def validate_x_bounds(
149+
cls, v: Union[DomainAlignment, Tuple[float, float]]
150+
) -> Union[DomainAlignment, Tuple[float, float]]:
137151
"""Validate x_bounds parameter."""
138152
if isinstance(v, (list, tuple)):
139153
if len(v) != 2:
@@ -144,9 +158,11 @@ def validate_x_bounds(cls, v: Union[DomainAlignment, Tuple[float, float]]) -> Un
144158
raise ValueError("x_bounds must be a tuple or DomainAlignment")
145159
return v
146160

147-
@field_validator('y_bounds')
161+
@field_validator("y_bounds")
148162
@classmethod
149-
def validate_y_bounds(cls, v: Union[DomainAlignment, Tuple[float, float]]) -> Union[DomainAlignment, Tuple[float, float]]:
163+
def validate_y_bounds(
164+
cls, v: Union[DomainAlignment, Tuple[float, float]]
165+
) -> Union[DomainAlignment, Tuple[float, float]]:
150166
"""Validate y_bounds parameter."""
151167
if isinstance(v, (list, tuple)):
152168
if len(v) != 2:
@@ -159,14 +175,18 @@ def validate_y_bounds(cls, v: Union[DomainAlignment, Tuple[float, float]]) -> Un
159175

160176
def apply_boundary_settings(self) -> None:
161177
"""Apply boundary settings to coordinate system."""
162-
if hasattr(self, 'coordinate_system') and self.coordinate_system is not None:
178+
if hasattr(self, "coordinate_system") and self.coordinate_system is not None:
163179
if isinstance(self.x_bounds, (list, tuple)):
164-
self.coordinate_system.x_min, self.coordinate_system.x_max = self.x_bounds
180+
self.coordinate_system.x_min, self.coordinate_system.x_max = (
181+
self.x_bounds
182+
)
165183
elif isinstance(self.x_bounds, DomainAlignment):
166184
self._apply_x_boundary_mode(self.x_bounds)
167185

168186
if isinstance(self.y_bounds, (list, tuple)):
169-
self.coordinate_system.y_min, self.coordinate_system.y_max = self.y_bounds
187+
self.coordinate_system.y_min, self.coordinate_system.y_max = (
188+
self.y_bounds
189+
)
170190
elif isinstance(self.y_bounds, DomainAlignment):
171191
self._apply_y_boundary_mode(self.y_bounds)
172192

@@ -208,10 +228,16 @@ def get_boundaries(self) -> Tuple[float, float, float, float]:
208228
The boundaries as (x_min, y_min, x_max, y_max).
209229
210230
"""
211-
filtered_structures = [obj for obj in self.structure_list if not isinstance(obj, FiberFusing.background.BackGround)]
231+
filtered_structures = [
232+
obj
233+
for obj in self.structure_list
234+
if not isinstance(obj, FiberFusing.background.BackGround)
235+
]
212236

213237
if len(filtered_structures) == 0:
214-
raise ValueError('No structures provided (other than background) for computing the mesh.')
238+
raise ValueError(
239+
"No structures provided (other than background) for computing the mesh."
240+
)
215241

216242
x_min, y_min, x_max, y_max = zip(
217243
*(obj.get_structure_max_min_boundaries() for obj in filtered_structures)
@@ -228,7 +254,11 @@ def refractive_index_maximum(self) -> float:
228254
float
229255
Maximum refractive index.
230256
"""
231-
return max(refractive_index for obj in self.structure_list for refractive_index in obj.refractive_index_list)
257+
return max(
258+
refractive_index
259+
for obj in self.structure_list
260+
for refractive_index in obj.refractive_index_list
261+
)
232262

233263
@property
234264
def refractive_index_minimum(self) -> float:
@@ -240,7 +270,12 @@ def refractive_index_minimum(self) -> float:
240270
float
241271
Minimum refractive index.
242272
"""
243-
return min(refractive_index for obj in self.structure_list if not isinstance(obj, FiberFusing.background.BackGround) for refractive_index in obj.refractive_index_list)
273+
return min(
274+
refractive_index
275+
for obj in self.structure_list
276+
if not isinstance(obj, FiberFusing.background.BackGround)
277+
for refractive_index in obj.refractive_index_list
278+
)
244279

245280
def get_index_range(self) -> List[float]:
246281
"""
@@ -278,7 +313,12 @@ def randomize_fiber_structures_index(self, random_factor: float) -> None:
278313
"""
279314
for fiber in self.fiber_list:
280315
for structure in fiber.inner_structure:
281-
adjustment = structure.refractive_index * self.refractive_index_scrambling * numpy.random.rand() * random_factor
316+
adjustment = (
317+
structure.refractive_index
318+
* self.refractive_index_scrambling
319+
* numpy.random.rand()
320+
* random_factor
321+
)
282322
structure.refractive_index += adjustment
283323

284324
self.mesh = self.generate_mesh()
@@ -300,7 +340,9 @@ def rasterize_polygons(self) -> numpy.ndarray:
300340
mesh = numpy.zeros(self.coordinate_system.shape)
301341

302342
for structure in self.structure_list:
303-
structure.overlay_structures_on_mesh(mesh=mesh, coordinate_system=self.coordinate_system)
343+
structure.overlay_structures_on_mesh(
344+
mesh=mesh, coordinate_system=self.coordinate_system
345+
)
304346

305347
return mesh
306348

@@ -318,8 +360,10 @@ def generate_mesh(self) -> numpy.ndarray:
318360
AttributeError
319361
If the coordinate system has not been generated before calling this method.
320362
"""
321-
if not hasattr(self, 'coordinate_system'):
322-
raise AttributeError("Coordinate system has not been generated. Call generate_coordinate_system() first.")
363+
if not hasattr(self, "coordinate_system"):
364+
raise AttributeError(
365+
"Coordinate system has not been generated. Call generate_coordinate_system() first."
366+
)
323367

324368
mesh = self.rasterize_polygons()
325369

@@ -349,12 +393,21 @@ def plot_patch(self, axes: plt.Axes) -> None:
349393
continue
350394

351395
if isinstance(structure, FiberFusing.profile.Profile):
352-
structure.plot(axes=axes, show=False, show_added=False, show_removed=False, show_centers=False, show_fibers=True)
396+
structure.plot(
397+
axes=axes,
398+
show=False,
399+
show_added=False,
400+
show_removed=False,
401+
show_centers=False,
402+
show_fibers=True,
403+
)
353404
continue
354405

355406
structure.plot(axes=axes, show=False)
356407

357-
axes.set(title='Fiber structure', xlabel=r'x-distance [m]', ylabel=r'y-distance [m]')
408+
axes.set(
409+
title="Fiber structure", xlabel=r"x-distance [m]", ylabel=r"y-distance [m]"
410+
)
358411

359412
@helper.pre_plot(nrows=1, ncols=1)
360413
def plot_raster(self, axes: plt.Axes, gamma: float = 5) -> None:
@@ -378,17 +431,25 @@ def plot_raster(self, axes: plt.Axes, gamma: float = 5) -> None:
378431
self.coordinate_system.x_vector,
379432
self.coordinate_system.y_vector,
380433
self.mesh,
381-
cmap='Blues',
382-
norm=colors.PowerNorm(gamma=gamma)
434+
cmap="Blues",
435+
norm=colors.PowerNorm(gamma=gamma),
383436
)
384437

385438
divider = make_axes_locatable(axes)
386-
cax = divider.append_axes('right', size='5%', pad=0.05)
387-
axes.get_figure().colorbar(image, cax=cax, orientation='vertical')
439+
cax = divider.append_axes("right", size="5%", pad=0.05)
440+
axes.get_figure().colorbar(image, cax=cax, orientation="vertical")
388441

389-
axes.set(title='Fiber structure', xlabel=r'x-distance [m]', ylabel=r'y-distance [m]')
442+
axes.set(
443+
title="Fiber structure", xlabel=r"x-distance [m]", ylabel=r"y-distance [m]"
444+
)
390445

391-
@helper.pre_plot(nrows=1, ncols=2, subplot_kw=dict(aspect='equal', xlabel='x-distance [m]', ylabel='y-distance [m]'))
446+
@helper.pre_plot(
447+
nrows=1,
448+
ncols=2,
449+
subplot_kw=dict(
450+
aspect="equal", xlabel="x-distance [m]", ylabel="y-distance [m]"
451+
),
452+
)
392453
def plot(self, axes, gamma: float = 5) -> plt.Figure:
393454
"""
394455
Plot the different representations (patch and mesh) of the geometry.
@@ -408,7 +469,6 @@ def plot(self, axes, gamma: float = 5) -> plt.Figure:
408469
axes[0].sharex(axes[1])
409470
axes[0].sharey(axes[1])
410471

411-
412472
self.plot_patch(axes=axes[0], show=False)
413473

414474
self.plot_raster(axes=axes[1], show=False, gamma=gamma)

0 commit comments

Comments
 (0)