Skip to content

Commit 87f4819

Browse files
committed
Add option to drop filtered modes
1 parent 603c8a8 commit 87f4819

File tree

3 files changed

+198
-8
lines changed

3 files changed

+198
-8
lines changed

tests/test_plugins/test_mode_solver.py

Lines changed: 77 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,12 @@
1010

1111
import tidy3d as td
1212
import tidy3d.plugins.mode.web as msweb
13-
from tidy3d import ScalarFieldDataArray
13+
from tidy3d import ModeIndexDataArray, ScalarFieldDataArray
1414
from tidy3d.components.data.monitor_data import ModeSolverData
1515
from tidy3d.components.mode.derivatives import create_sfactor_b, create_sfactor_f
1616
from tidy3d.components.mode.solver import compute_modes
1717
from tidy3d.components.mode_spec import MODE_DATA_KEYS
18-
from tidy3d.exceptions import DataError, SetupError
18+
from tidy3d.exceptions import DataError, SetupError, ValidationError
1919
from tidy3d.plugins.mode import ModeSolver
2020
from tidy3d.plugins.mode.mode_solver import MODE_MONITOR_NAME
2121
from tidy3d.web.core.environment import Env
@@ -1503,3 +1503,78 @@ def test_sort_spec_track_freq():
15031503
assert np.allclose(modes_lowest.Ex.abs, modes_lowest_retracked.Ex.abs)
15041504
assert np.all(modes_lowest.n_eff == modes_lowest_retracked.n_eff)
15051505
assert np.all(modes_lowest.n_group == modes_lowest_retracked.n_group)
1506+
1507+
1508+
def test_mode_sort_spec_drop_modes_reduces_modes():
1509+
freqs = np.array([2e14, 4e14])
1510+
mode_spec = td.ModeSpec(num_modes=3)
1511+
monitor = td.ModeSolverMonitor(
1512+
size=(1.0, 0.0, 1.0),
1513+
center=(0.0, 0.0, 0.0),
1514+
freqs=freqs,
1515+
mode_spec=mode_spec,
1516+
name="drop_modes",
1517+
)
1518+
n_complex = ModeIndexDataArray(
1519+
np.array(
1520+
[
1521+
[1.6 + 0.6j, 1.5 + 0.2j, 1.1 + 0.5j],
1522+
[1.7 + 0.4j, 1.4 + 0.3j, 1.0 + 0.1j],
1523+
]
1524+
),
1525+
coords={"f": freqs, "mode_index": np.arange(3)},
1526+
)
1527+
data = ModeSolverData(monitor=monitor, n_complex=n_complex)
1528+
1529+
sort_spec = td.ModeSortSpec(
1530+
filter_key="n_eff",
1531+
filter_reference=1.3,
1532+
filter_order="over",
1533+
sort_key="k_eff",
1534+
sort_order="ascending",
1535+
drop_modes=True,
1536+
)
1537+
1538+
sorted_data = data.sort_modes(sort_spec)
1539+
1540+
assert sorted_data.n_eff.sizes["mode_index"] == 2
1541+
assert np.allclose(sorted_data.n_eff.isel(f=0).values, [1.5, 1.6])
1542+
assert np.allclose(sorted_data.n_eff.isel(f=1).values, [1.4, 1.7])
1543+
assert sorted_data.monitor.mode_spec.num_modes == 2
1544+
assert sorted_data.monitor.mode_spec.sort_spec.drop_modes is True
1545+
1546+
1547+
def test_mode_sort_spec_drop_modes_all_filtered():
1548+
freqs = np.array([2e14, 4e14])
1549+
mode_spec = td.ModeSpec(num_modes=3)
1550+
monitor = td.ModeSolverMonitor(
1551+
size=(1.0, 0.0, 1.0),
1552+
center=(0.0, 0.0, 0.0),
1553+
freqs=freqs,
1554+
mode_spec=mode_spec,
1555+
name="drop_all",
1556+
)
1557+
n_complex = ModeIndexDataArray(
1558+
np.array(
1559+
[
1560+
[1.1 + 0.1j, 1.05 + 0.05j, 1.0 + 0.01j],
1561+
[1.1 + 0.1j, 1.05 + 0.05j, 1.0 + 0.01j],
1562+
]
1563+
),
1564+
coords={"f": freqs, "mode_index": np.arange(3)},
1565+
)
1566+
data = ModeSolverData(monitor=monitor, n_complex=n_complex)
1567+
1568+
sort_spec = td.ModeSortSpec(
1569+
filter_key="n_eff",
1570+
filter_reference=2.0,
1571+
drop_modes=True,
1572+
)
1573+
1574+
with pytest.raises(ValidationError):
1575+
_ = data.sort_modes(sort_spec)
1576+
1577+
1578+
def test_mode_sort_spec_drop_modes_requires_filter():
1579+
with pytest.raises(pydantic.ValidationError):
1580+
td.ModeSortSpec(drop_modes=True)

tidy3d/components/data/monitor_data.py

Lines changed: 104 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2322,6 +2322,72 @@ def _apply_mode_reorder(self, sort_inds_2d):
23222322

23232323
return self.updated_copy(**modify_data)
23242324

2325+
def _apply_mode_subset(self, subset_inds_2d: np.ndarray) -> ModeSolverData:
2326+
"""Return copy of self containing only the selected modes.
2327+
2328+
Parameters
2329+
----------
2330+
subset_inds_2d : np.ndarray
2331+
Array of shape ``(num_freqs, num_modes_keep)`` containing the indices of the original
2332+
modes to retain at each frequency.
2333+
2334+
Returns
2335+
-------
2336+
:class:`.ModeSolverData`
2337+
Copy of self with only the retained modes.
2338+
"""
2339+
2340+
subset_inds_2d = np.asarray(subset_inds_2d, dtype=int)
2341+
if subset_inds_2d.ndim != 2:
2342+
raise DataError(
2343+
"subset_inds_2d must be a 2D array of shape (num_freqs, num_modes_keep)."
2344+
)
2345+
2346+
num_freqs, num_keep = subset_inds_2d.shape
2347+
if num_keep == 0:
2348+
raise DataError("Cannot create a mode subset with zero modes.")
2349+
2350+
num_modes_full = self.n_eff["mode_index"].size
2351+
2352+
modify_data = {}
2353+
new_mode_index_coord = np.arange(num_keep)
2354+
2355+
for key, data in self.data_arrs.items():
2356+
if "mode_index" not in data.dims or "f" not in data.dims:
2357+
continue
2358+
2359+
dims_orig = tuple(data.dims)
2360+
coords_out = {
2361+
k: (v.values if hasattr(v, "values") else np.asarray(v))
2362+
for k, v in data.coords.items()
2363+
}
2364+
2365+
f_axis = data.get_axis_num("f")
2366+
m_axis = data.get_axis_num("mode_index")
2367+
src_order = (
2368+
[f_axis] + [ax for ax in range(data.ndim) if ax not in (f_axis, m_axis)] + [m_axis]
2369+
)
2370+
2371+
arr = np.moveaxis(data.data, src_order, range(data.ndim))
2372+
nf, nm = arr.shape[0], arr.shape[-1]
2373+
if nf != num_freqs or nm != num_modes_full:
2374+
raise DataError(
2375+
"subset_inds_2d shape does not match array shape in _apply_mode_subset."
2376+
)
2377+
2378+
arr2 = arr.reshape(nf, -1, nm)
2379+
inds = subset_inds_2d[:, None, :]
2380+
arr2_subset = np.take_along_axis(arr2, inds, axis=2)
2381+
arr_subset = arr2_subset.reshape(arr.shape[:-1] + (num_keep,))
2382+
arr_subset = np.moveaxis(arr_subset, range(data.ndim), src_order)
2383+
2384+
coords_out["mode_index"] = new_mode_index_coord
2385+
coords_out["f"] = data.coords["f"].values
2386+
2387+
modify_data[key] = DataArray(arr_subset, coords=coords_out, dims=dims_orig)
2388+
2389+
return self.updated_copy(**modify_data)
2390+
23252391
def sort_modes(
23262392
self, sort_spec: Optional[ModeSortSpec] = None, track_freq: Optional[TrackFreq] = None
23272393
) -> ModeSolverData:
@@ -2356,8 +2422,9 @@ def sort_modes(
23562422
num_freqs = self.n_eff["f"].size
23572423
num_modes = self.n_eff["mode_index"].size
23582424
all_inds = np.arange(num_modes)
2359-
identity = np.arange(num_modes)
2360-
sort_inds_2d = np.tile(identity, (num_freqs, 1))
2425+
drop_modes = getattr(sort_spec, "drop_modes", False)
2426+
if drop_modes and sort_spec.filter_key is None:
2427+
raise ValidationError("ModeSortSpec.drop_modes requires 'filter_key' to be set.")
23612428

23622429
# Helper to compute ordered indices within a subset
23632430
def _order_indices(indices, vals_all):
@@ -2376,12 +2443,13 @@ def _order_indices(indices, vals_all):
23762443
filter_metric = getattr(self, sort_spec.filter_key)
23772444
if sort_spec.sort_key is not None:
23782445
sort_metric = getattr(self, sort_spec.sort_key)
2446+
identity = np.arange(num_modes)
2447+
sort_inds_2d = np.tile(identity, (num_freqs, 1))
23792448

23802449
for ifreq in range(num_freqs):
23812450
# Build groups according to filter if requested
23822451
if filter_metric is not None:
23832452
vals_filt = filter_metric.isel(f=ifreq).values
2384-
# Boolean mask for modes in the first group
23852453
if sort_spec.filter_order == "over":
23862454
mask_first = vals_filt >= sort_spec.filter_reference
23872455
else:
@@ -2406,11 +2474,10 @@ def _order_indices(indices, vals_all):
24062474

24072475
sort_inds_2d[ifreq, : len(sort_inds)] = sort_inds
24082476

2409-
# If all rows are identity, skip
24102477
if np.all(sort_inds_2d == np.tile(identity, (num_freqs, 1))):
24112478
data_sorted = self
24122479
else:
2413-
data_sorted = self._apply_mode_reorder(sort_inds_2d) # this creates a copy
2480+
data_sorted = self._apply_mode_reorder(sort_inds_2d)
24142481
data_sorted = data_sorted.updated_copy(
24152482
path="monitor/mode_spec", sort_spec=sort_spec, deep=False, validate=False
24162483
)
@@ -2422,6 +2489,38 @@ def _order_indices(indices, vals_all):
24222489
if track_freq and num_freqs > 1:
24232490
data_sorted = data_sorted.overlap_sort(track_freq)
24242491

2492+
if drop_modes:
2493+
# Re-evaluate the filter after sorting/tracking so modes are dropped consistently.
2494+
filter_metric_sorted = getattr(data_sorted, sort_spec.filter_key)
2495+
masks_after = []
2496+
for ifreq in range(num_freqs):
2497+
vals = filter_metric_sorted.isel(f=ifreq).values
2498+
if sort_spec.filter_order == "over":
2499+
mask = vals >= sort_spec.filter_reference
2500+
else:
2501+
mask = vals <= sort_spec.filter_reference
2502+
masks_after.append(mask)
2503+
2504+
keep_mask = np.all(np.stack(masks_after, axis=0), axis=0)
2505+
if not np.any(keep_mask):
2506+
raise ValidationError(
2507+
"Filtering removes all modes; relax the filter threshold or disable drop_modes."
2508+
)
2509+
2510+
num_modes_sorted = filter_metric_sorted.sizes["mode_index"]
2511+
if keep_mask.sum() < num_modes_sorted:
2512+
keep_inds = np.where(keep_mask)[0]
2513+
subset_inds_2d = np.tile(keep_inds, (num_freqs, 1))
2514+
data_subset = data_sorted._apply_mode_subset(subset_inds_2d)
2515+
mspec = data_subset.monitor.mode_spec
2516+
mspec_updated = mspec.updated_copy(num_modes=keep_inds.size, validate=False)
2517+
monitor_updated = data_subset.monitor.updated_copy(
2518+
mode_spec=mspec_updated, validate=False
2519+
)
2520+
data_sorted = data_subset.updated_copy(
2521+
monitor=monitor_updated, deep=False, validate=False
2522+
)
2523+
24252524
return data_sorted
24262525

24272526

tidy3d/components/mode_spec.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@ class ModeSortSpec(Tidy3dBaseModel):
3535
applied to ``filter_key``: modes "over" or "under" ``filter_reference`` are placed first,
3636
with the remaining modes placed next. Second, an optional sorting step orders modes within
3737
each group according to ``sort_key``, optionally with respect to ``sort_reference`` and in
38-
the specified ``sort_order``.
38+
the specified ``sort_order``. If ``drop_modes`` is set to ``True``, the modes that do not meet
39+
the filter criterion are removed instead of being appended as a second group.
3940
"""
4041

4142
# Filtering stage
@@ -54,6 +55,15 @@ class ModeSortSpec(Tidy3dBaseModel):
5455
title="Filtering order",
5556
description="Select whether the first group contains values over or under the reference.",
5657
)
58+
drop_modes: bool = pd.Field(
59+
False,
60+
title="Drop filtered modes",
61+
description=(
62+
"If ``True``, modes that do not satisfy the filter criterion are removed entirely "
63+
"instead of being appended after the filtered group. Only modes passing the filter "
64+
"at every tracked frequency are kept."
65+
),
66+
)
5767

5868
# Sorting stage
5969
sort_key: Optional[MODE_DATA_KEYS] = pd.Field(
@@ -85,6 +95,12 @@ class ModeSortSpec(Tidy3dBaseModel):
8595
"while at other frequencies it can change depending on the mode tracking.",
8696
)
8797

98+
@pd.validator("drop_modes", always=True)
99+
def _drop_requires_filter(cls, val, values):
100+
if val and values.get("filter_key") is None:
101+
raise ValidationError("ModeSortSpec.drop_modes requires 'filter_key' to be set.")
102+
return val
103+
88104

89105
class FrequencySamplingSpec(Tidy3dBaseModel, ABC):
90106
"""Abstract base class for frequency sampling specifications."""

0 commit comments

Comments
 (0)