Skip to content

Commit b3206f1

Browse files
marcorudolphflexdaquinteroflex
authored andcommitted
feat(tidy3d): FXC-4260-allow-users-to-swap-custom-cmaps-in-field-visualization
1 parent dc4f57c commit b3206f1

File tree

8 files changed

+88
-12
lines changed

8 files changed

+88
-12
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
4242
- Added `smoothed_projection` for topology optimization of completely binarized designs.
4343
- Added more RF-specific mode characteristics to `MicrowaveModeData`, including propagation constants (alpha, beta, gamma), phase/group velocities, wave impedance, and automatic mode classification with configurable polarization thresholds in `MicrowaveModeSpec`.
4444
- Introduce `tidy3d.rf` namespace to consolidate all RF classes.
45+
- Added support for custom colormaps in `plot_field`.
4546

4647
### Breaking Changes
4748
- Edge singularity correction at PEC and lossy metal edges defaults to `True`.

tests/test_components/test_eme.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1093,6 +1093,19 @@ def test_eme_sim_data():
10931093
_ = sim_data.plot_field(
10941094
"field", "E", eme_port_index=0, val="abs^2", f=td.C_0, mode_index=0, ax=AX
10951095
)
1096+
_ = sim_data.plot_field(
1097+
"field", "Ex", eme_port_index=0, val="real", f=td.C_0, mode_index=0, cmap="plasma", ax=AX
1098+
)
1099+
_ = sim_data.plot_field(
1100+
"field",
1101+
"Ex",
1102+
eme_port_index=0,
1103+
val="real",
1104+
f=td.C_0,
1105+
mode_index=0,
1106+
cmap=plt.get_cmap("cividis"),
1107+
ax=AX,
1108+
)
10961109

10971110
# test smatrix in basis with sweep
10981111
smatrix = _get_eme_smatrix_dataset(num_modes_1=5, num_modes_2=5, num_sweep=10)

tests/test_data/test_sim_data.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import numpy as np
77
import pydantic.v1 as pydantic
88
import pytest
9+
from matplotlib import colors as mcolors
910

1011
import tidy3d as td
1112
from tidy3d.components.data.data_array import ScalarFieldTimeDataArray
@@ -194,6 +195,22 @@ def test_plot(phase):
194195
plt.close()
195196

196197

198+
def test_plot_field_custom_cmap():
199+
sim_data = make_sim_data()
200+
_ = sim_data.plot_field("field", "Ex", val="real", f=2e14, z=0.10, cmap="viridis")
201+
plt.close()
202+
custom_cmap = mcolors.LinearSegmentedColormap.from_list("two", ["black", "white"])
203+
_ = sim_data.plot_field(
204+
"field",
205+
"Ez",
206+
val="imag",
207+
f=2e14,
208+
z=0.10,
209+
cmap=custom_cmap,
210+
)
211+
plt.close()
212+
213+
197214
def test_plot_field_missing_derived_data():
198215
sim_data = make_sim_data()
199216
with pytest.raises(Tidy3dKeyError):

tidy3d/components/data/sim_data.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from abc import ABC
99
from collections import defaultdict
1010
from os import PathLike
11-
from typing import Any, Callable, Optional, Union
11+
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
1212

1313
import h5py
1414
import numpy as np
@@ -34,6 +34,9 @@
3434
from .data_array import FreqDataArray, TimeDataArray
3535
from .monitor_data import AbstractFieldData, FieldTimeData
3636

37+
if TYPE_CHECKING:
38+
from matplotlib.colors import Colormap
39+
3740
DATA_TYPE_MAP = {data.__fields__["monitor"].type_: data for data in MonitorDataTypes}
3841

3942
# maps monitor type (string) to the class of the corresponding data
@@ -456,6 +459,7 @@ def plot_field_monitor_data(
456459
vmax: Optional[float] = None,
457460
ax: Ax = None,
458461
shading: str = "flat",
462+
cmap: Optional[Union[str, Colormap]] = None,
459463
**sel_kwargs: Any,
460464
) -> Ax:
461465
"""Plot the field data for a monitor with simulation plot overlaid.
@@ -492,6 +496,8 @@ def plot_field_monitor_data(
492496
matplotlib axes to plot on, if not specified, one is created.
493497
shading: str = 'flat'
494498
Shading argument for Xarray plot method ('flat','nearest','goraud')
499+
cmap : Optional[Union[str, Colormap]] = None
500+
Colormap for visualizing the field values. ``None`` uses the default which infers it from the data.
495501
sel_kwargs : keyword arguments used to perform ``.sel()`` selection in the monitor data.
496502
These kwargs can select over the spatial dimensions (``x``, ``y``, ``z``),
497503
frequency or time dimensions (``f``, ``t``) or ``mode_index``, if applicable.
@@ -656,6 +662,7 @@ def plot_field_monitor_data(
656662
cmap_type=cmap_type,
657663
ax=ax,
658664
shading=shading,
665+
cmap=cmap,
659666
infer_intervals=True if shading == "flat" else False,
660667
)
661668

@@ -672,6 +679,7 @@ def plot_field(
672679
vmax: Optional[float] = None,
673680
ax: Ax = None,
674681
shading: str = "flat",
682+
cmap: Optional[Union[str, Colormap]] = None,
675683
**sel_kwargs: Any,
676684
) -> Ax:
677685
"""Plot the field data for a monitor with simulation plot overlaid.
@@ -709,6 +717,8 @@ def plot_field(
709717
matplotlib axes to plot on, if not specified, one is created.
710718
shading: str = 'flat'
711719
Shading argument for Xarray plot method ('flat','nearest','goraud')
720+
cmap : Optional[Union[str, Colormap]] = None
721+
Colormap for visualizing the field values. ``None`` uses the default which infers it from the data.
712722
sel_kwargs : keyword arguments used to perform ``.sel()`` selection in the monitor data.
713723
These kwargs can select over the spatial dimensions (``x``, ``y``, ``z``),
714724
frequency or time dimensions (``f``, ``t``) or ``mode_index``, if applicable.
@@ -736,6 +746,7 @@ def plot_field(
736746
vmax=vmax,
737747
ax=ax,
738748
shading=shading,
749+
cmap=cmap,
739750
**sel_kwargs,
740751
)
741752

@@ -752,6 +763,7 @@ def plot_scalar_array(
752763
vmin: Optional[float] = None,
753764
vmax: Optional[float] = None,
754765
cmap_type: ColormapType = "divergent",
766+
cmap: Optional[Union[str, Colormap]] = None,
755767
ax: Ax = None,
756768
**kwargs: Any,
757769
) -> Ax:
@@ -784,6 +796,8 @@ def plot_scalar_array(
784796
inferred from the data and other keyword arguments.
785797
cmap_type : Literal["divergent", "sequential", "cyclic"] = "divergent"
786798
Type of color map to use for plotting.
799+
cmap : Optional[Union[str, Colormap]] = None
800+
Colormap for visualizing the field values. ``None`` uses the default which infers it from the data. Overrides inferred colormap from `cmap_type`.
787801
ax : matplotlib.axes._subplots.Axes = None
788802
matplotlib axes to plot on, if not specified, one is created.
789803
**kwargs : Extra arguments to ``DataArray.plot``.
@@ -798,19 +812,23 @@ def plot_scalar_array(
798812
interp_kwarg = {"xyz"[axis]: position}
799813

800814
if cmap_type == "divergent":
801-
cmap = "RdBu"
815+
default_cmap = "RdBu"
802816
center = 0.0
803817
eps_reverse = False
804818
elif cmap_type == "sequential":
805-
cmap = "magma"
819+
default_cmap = "magma"
806820
center = False
807821
eps_reverse = True
808822
elif cmap_type == "cyclic":
809-
cmap = "twilight"
823+
default_cmap = "twilight"
810824
vmin = -np.pi
811825
vmax = np.pi
812826
center = False
813827
eps_reverse = False
828+
else:
829+
default_cmap = None
830+
831+
cmap_to_use = default_cmap if cmap is None else cmap
814832

815833
# plot the field
816834
xy_coord_labels = list("xyz")
@@ -820,7 +838,7 @@ def plot_scalar_array(
820838
ax=ax,
821839
x=x_coord_label,
822840
y=y_coord_label,
823-
cmap=cmap,
841+
cmap=cmap_to_use,
824842
vmin=vmin,
825843
vmax=vmax,
826844
robust=robust,

tidy3d/components/mode/data/sim_data.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from __future__ import annotations
44

5-
from typing import Any, Literal, Optional, Union
5+
from typing import TYPE_CHECKING, Any, Literal, Optional, Union
66

77
import pydantic.v1 as pd
88

@@ -16,6 +16,9 @@
1616

1717
ModeSimulationMonitorDataType = Union[PermittivityData, MediumData]
1818

19+
if TYPE_CHECKING:
20+
from matplotlib.colors import Colormap
21+
1922

2023
class ModeSimulationData(AbstractYeeGridSimulationData):
2124
"""Data associated with a mode solver simulation."""
@@ -53,6 +56,7 @@ def plot_field(
5356
vmin: Optional[float] = None,
5457
vmax: Optional[float] = None,
5558
ax: Ax = None,
59+
cmap: Optional[Union[str, Colormap]] = None,
5660
**sel_kwargs: Any,
5761
) -> Ax:
5862
"""Plot the field for a :class:`.ModeSolverData` with :class:`.Simulation` plot overlaid.
@@ -80,6 +84,8 @@ def plot_field(
8084
inferred from the data and other keyword arguments.
8185
ax : matplotlib.axes._subplots.Axes = None
8286
matplotlib axes to plot on, if not specified, one is created.
87+
cmap : Optional[Union[str, Colormap]] = None
88+
Colormap for visualizing the field values. ``None`` uses the default which infers it from the data.
8389
sel_kwargs : keyword arguments used to perform ``.sel()`` selection in the monitor data.
8490
These kwargs can select over the spatial dimensions (``x``, ``y``, ``z``),
8591
frequency or time dimensions (``f``, ``t``) or `mode_index`, if applicable.
@@ -102,6 +108,7 @@ def plot_field(
102108
vmin=vmin,
103109
vmax=vmax,
104110
ax=ax,
111+
cmap=cmap,
105112
**sel_kwargs,
106113
)
107114

tidy3d/components/mode/mode_solver.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from functools import wraps
88
from math import isclose
9-
from typing import Any, Literal, Optional, Union, get_args
9+
from typing import TYPE_CHECKING, Any, Literal, Optional, Union, get_args
1010

1111
import numpy as np
1212
import pydantic.v1 as pydantic
@@ -82,6 +82,9 @@
8282
from tidy3d.constants import C_0, fp_eps
8383
from tidy3d.exceptions import SetupError, ValidationError
8484
from tidy3d.log import log
85+
86+
if TYPE_CHECKING:
87+
from matplotlib.colors import Colormap
8588
from tidy3d.packaging import supports_local_subpixel, tidy3d_extras
8689

8790
# Importing the local solver may not work if e.g. scipy is not installed
@@ -2249,6 +2252,7 @@ def plot_field(
22492252
vmin: Optional[float] = None,
22502253
vmax: Optional[float] = None,
22512254
ax: Ax = None,
2255+
cmap: Optional[Union[str, Colormap]] = None,
22522256
**sel_kwargs: Any,
22532257
) -> Ax:
22542258
"""Plot the field for a :class:`.ModeSolverData` with :class:`.Simulation` plot overlaid.
@@ -2276,6 +2280,8 @@ def plot_field(
22762280
inferred from the data and other keyword arguments.
22772281
ax : matplotlib.axes._subplots.Axes = None
22782282
matplotlib axes to plot on, if not specified, one is created.
2283+
cmap : Optional[Union[str, Colormap]] = None
2284+
Colormap for visualizing the field values. ``None`` uses the default which infers it from the data.
22792285
sel_kwargs : keyword arguments used to perform ``.sel()`` selection in the monitor data.
22802286
These kwargs can select over the spatial dimensions (``x``, ``y``, ``z``),
22812287
frequency or time dimensions (``f``, ``t``) or `mode_index`, if applicable.
@@ -2300,6 +2306,7 @@ def plot_field(
23002306
vmin=vmin,
23012307
vmax=vmax,
23022308
ax=ax,
2309+
cmap=cmap,
23032310
**sel_kwargs,
23042311
)
23052312

tidy3d/components/tcad/data/sim_data.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from __future__ import annotations
44

55
from abc import ABC
6-
from typing import Any, Literal, Optional
6+
from typing import TYPE_CHECKING, Any, Literal, Optional, Union
77

88
import numpy as np
99
import pydantic.v1 as pd
@@ -35,6 +35,9 @@
3535
from tidy3d.exceptions import DataError, Tidy3dKeyError
3636
from tidy3d.log import log
3737

38+
if TYPE_CHECKING:
39+
from matplotlib.colors import Colormap
40+
3841

3942
class DeviceCharacteristics(Tidy3dBaseModel):
4043
"""Stores device characteristics. For example, in steady-state it stores
@@ -281,6 +284,7 @@ def plot_field(
281284
vmin: Optional[float] = None,
282285
vmax: Optional[float] = None,
283286
ax: Ax = None,
287+
cmap: Optional[Union[str, Colormap]] = None,
284288
**sel_kwargs: Any,
285289
) -> Ax:
286290
"""Plot the data for a monitor with simulation structures overlaid.
@@ -310,6 +314,8 @@ def plot_field(
310314
inferred from the data and other keyword arguments.
311315
ax : matplotlib.axes._subplots.Axes = None
312316
matplotlib axes to plot on, if not specified, one is created.
317+
cmap : Optional[Union[str, Colormap]] = None
318+
Colormap for visualizing the field values. ``None`` uses the default which infers it from the data.
313319
sel_kwargs : keyword arguments used to perform ``.sel()`` selection in the monitor data.
314320
These kwargs can select over the spatial dimensions (``x``, ``y``, ``z``),
315321
or time dimension (``t``) if applicable.
@@ -345,7 +351,7 @@ def plot_field(
345351
if scale == "log":
346352
field_data = np.log10(np.abs(field_data))
347353

348-
cmap = "coolwarm"
354+
cmap_to_use = "coolwarm" if cmap is None else cmap
349355

350356
# do sel on unstructured data
351357
# it could produce either SpatialDataArray or UnstructuredGridDatasetType
@@ -361,7 +367,7 @@ def plot_field(
361367
if isinstance(field_data, TriangularGridDataset):
362368
field_data.plot(
363369
ax=ax,
364-
cmap=cmap,
370+
cmap=cmap_to_use,
365371
vmin=vmin,
366372
vmax=vmax,
367373
cbar_kwargs={"label": field_name},
@@ -436,7 +442,7 @@ def plot_field(
436442
ax=ax,
437443
x=x_coord_label,
438444
y=y_coord_label,
439-
cmap=cmap,
445+
cmap=cmap_to_use,
440446
vmin=vmin,
441447
vmax=vmax,
442448
robust=robust,

tidy3d/plugins/waveguide/rectangular_dielectric.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from __future__ import annotations
44

5-
from typing import Annotated, Any, Literal, Optional, Union
5+
from typing import TYPE_CHECKING, Annotated, Any, Literal, Optional, Union
66

77
import numpy
88
import pydantic.v1 as pydantic
@@ -27,6 +27,9 @@
2727
from tidy3d.log import log
2828
from tidy3d.plugins.mode.mode_solver import ModeSolver
2929

30+
if TYPE_CHECKING:
31+
from matplotlib.colors import Colormap
32+
3033
AnnotatedMedium = Annotated[MediumType, pydantic.Field(discriminator=TYPE_TAG_STR)]
3134

3235

@@ -1099,6 +1102,7 @@ def plot_field(
10991102
vmax: Optional[float] = None,
11001103
ax: Ax = None,
11011104
geometry_edges: Optional[str] = None,
1105+
cmap: Optional[Union[str, Colormap]] = None,
11021106
**sel_kwargs: Any,
11031107
) -> Ax:
11041108
"""Plot the field for a :class:`.ModeSolverData` with :class:`.Simulation` plot overlaid.
@@ -1127,6 +1131,8 @@ def plot_field(
11271131
ax : matplotlib.axes._subplots.Axes = None
11281132
matplotlib axes to plot on, if not specified, one is created.
11291133
geometry_edges : Optional color to use for the geometry edges overlaid on the fields.
1134+
cmap : Optional[Union[str, Colormap]] = None
1135+
Colormap for visualizing the field values. ``None`` uses the default which infers it from the data.
11301136
sel_kwargs : keyword arguments used to perform ``.sel()`` selection in the monitor data.
11311137
These kwargs can select over the spatial dimensions (``x``, ``y``, ``z``),
11321138
frequency or time dimensions (``f``, ``t``) or `mode_index`, if applicable.
@@ -1147,6 +1153,7 @@ def plot_field(
11471153
vmin=vmin,
11481154
vmax=vmax,
11491155
ax=ax,
1156+
cmap=cmap,
11501157
**sel_kwargs,
11511158
)
11521159
if geometry_edges is not None:

0 commit comments

Comments
 (0)