From f618f6bf8814624449aa86a833635aea1488d00d Mon Sep 17 00:00:00 2001 From: Momchil Minkov Date: Thu, 4 Dec 2025 12:43:58 +0100 Subject: [PATCH] Add option to drop filtered modes --- tests/test_plugins/test_mode_solver.py | 259 ++++++++++++++++++++++++- tidy3d/components/data/monitor_data.py | 258 ++++++++++++++++++++++-- tidy3d/components/mode_spec.py | 55 +++++- 3 files changed, 545 insertions(+), 27 deletions(-) diff --git a/tests/test_plugins/test_mode_solver.py b/tests/test_plugins/test_mode_solver.py index 97155fdd56..af7a47d38a 100644 --- a/tests/test_plugins/test_mode_solver.py +++ b/tests/test_plugins/test_mode_solver.py @@ -10,12 +10,12 @@ import tidy3d as td import tidy3d.plugins.mode.web as msweb -from tidy3d import ScalarFieldDataArray +from tidy3d import Coords, Grid, ModeIndexDataArray, ScalarFieldDataArray, ScalarModeFieldDataArray from tidy3d.components.data.monitor_data import ModeSolverData from tidy3d.components.mode.derivatives import create_sfactor_b, create_sfactor_f from tidy3d.components.mode.solver import compute_modes from tidy3d.components.mode_spec import MODE_DATA_KEYS -from tidy3d.exceptions import DataError, SetupError +from tidy3d.exceptions import DataError, SetupError, ValidationError from tidy3d.plugins.mode import ModeSolver from tidy3d.plugins.mode.mode_solver import MODE_MONITOR_NAME from tidy3d.web.core.environment import Env @@ -38,6 +38,69 @@ SOLVER_ID = "Solver-ID" +def make_fill_fraction_mode_data(): + freq = np.array([2e14]) + mode_spec = td.ModeSpec(num_modes=2) + monitor = td.ModeSolverMonitor( + size=(3.0, 0.0, 3.0), + center=(0.0, 0.0, 0.0), + freqs=freq, + mode_spec=mode_spec, + name="fill_fraction", + ) + + grid = Grid( + boundaries=Coords( + x=np.array([-1.5, -0.5, 0.5, 1.5]), + y=np.array([-0.5, 0.5]), + z=np.array([-1.5, -0.5, 0.5, 1.5]), + ) + ) + + coords = { + "x": np.array([-1.0, 0.0, 1.0]), + "y": np.array([0.0]), + "z": np.array([-1.0, 0.0, 1.0]), + "f": freq, + "mode_index": np.arange(2), + } + shape = (3, 1, 3, 1, 2) + + ex_data = np.zeros(shape, dtype=complex) + ex_data[1, 0, 1, 0, 0] = 2.0 + for ix in (0, 2): + for iz in (0, 2): + ex_data[ix, 0, iz, 0, 1] = 1.0 + + zero_data = np.zeros(shape, dtype=complex) + + fields = { + "Ex": ScalarModeFieldDataArray(ex_data, coords=coords), + "Ey": ScalarModeFieldDataArray(np.copy(zero_data), coords=coords), + "Ez": ScalarModeFieldDataArray(np.copy(zero_data), coords=coords), + "Hx": ScalarModeFieldDataArray(np.copy(zero_data), coords=coords), + "Hy": ScalarModeFieldDataArray(np.copy(zero_data), coords=coords), + "Hz": ScalarModeFieldDataArray(np.copy(zero_data), coords=coords), + } + + n_complex = ModeIndexDataArray( + np.array([[1.6 + 0.0j, 1.3 + 0.0j]]), + coords={"f": freq, "mode_index": np.arange(2)}, + ) + + data = ModeSolverData( + monitor=monitor, + symmetry=(0, 0, 0), + symmetry_center=(0.0, 0.0, 0.0), + grid_expanded=grid, + n_complex=n_complex, + **fields, + ) + + bounding_box = td.Box(center=(0.0, 0.0, 0.0), size=(1.0, 2.0, 1.0)) + return data, bounding_box + + @pytest.fixture def mock_remote_api(monkeypatch): def void(*args, **kwargs): @@ -1343,21 +1406,41 @@ def test_modes_filter_sort(): for key in get_args(MODE_DATA_KEYS): print(key) # Test ascending - sort_spec = td.ModeSortSpec(sort_key=key, sort_order="ascending", track_freq=None) + sort_kwargs = { + "sort_key": key, + "sort_order": "ascending", + "track_freq": None, + } + if key == "fill_fraction_box": + sort_kwargs["bounding_box"] = td.Box(center=PLANE.center, size=(5.0, 4.0, 5.0)) + sort_spec = td.ModeSortSpec(**sort_kwargs) modes = modes.sort_modes(sort_spec) metric = getattr(modes, key) assert np.all(metric.diff(dim="mode_index") >= 0) # Test descending - sort_spec = td.ModeSortSpec(sort_key=key, sort_order="descending", track_freq=None) + sort_kwargs = { + "sort_key": key, + "sort_order": "descending", + "track_freq": None, + } + if key == "fill_fraction_box": + sort_kwargs["bounding_box"] = td.Box(center=PLANE.center, size=(5.0, 4.0, 5.0)) + sort_spec = td.ModeSortSpec(**sort_kwargs) modes = modes.sort_modes(sort_spec) metric = getattr(modes, key) assert np.all(metric.diff(dim="mode_index") <= 0) # Test descending with a large reference value should be the same as ascending - sort_spec = td.ModeSortSpec( - sort_key=key, sort_order="descending", sort_reference=100, track_freq=None - ) + sort_kwargs = { + "sort_key": key, + "sort_order": "descending", + "sort_reference": 100, + "track_freq": None, + } + if key == "fill_fraction_box": + sort_kwargs["bounding_box"] = td.Box(center=PLANE.center, size=(5.0, 4.0, 5.0)) + sort_spec = td.ModeSortSpec(**sort_kwargs) modes = modes.sort_modes(sort_spec) metric = getattr(modes, key) assert np.all(metric.diff(dim="mode_index") >= 0) @@ -1503,3 +1586,165 @@ def test_sort_spec_track_freq(): assert np.allclose(modes_lowest.Ex.abs, modes_lowest_retracked.Ex.abs) assert np.all(modes_lowest.n_eff == modes_lowest_retracked.n_eff) assert np.all(modes_lowest.n_group == modes_lowest_retracked.n_group) + + +def test_mode_sort_spec_drop_modes_reduces_modes(): + freqs = np.array([2e14, 4e14]) + mode_spec = td.ModeSpec(num_modes=3) + monitor = td.ModeSolverMonitor( + size=(1.0, 0.0, 1.0), + center=(0.0, 0.0, 0.0), + freqs=freqs, + mode_spec=mode_spec, + name="drop_modes", + ) + n_complex = ModeIndexDataArray( + np.array( + [ + [1.6 + 0.6j, 1.5 + 0.2j, 1.1 + 0.5j], + [1.7 + 0.4j, 1.4 + 0.3j, 1.0 + 0.1j], + ] + ), + coords={"f": freqs, "mode_index": np.arange(3)}, + ) + data = ModeSolverData(monitor=monitor, n_complex=n_complex) + + sort_spec = td.ModeSortSpec( + filter_key="n_eff", + filter_reference=1.3, + filter_order="over", + sort_key="k_eff", + sort_order="ascending", + keep_modes="filtered", + ) + + sorted_data = data.sort_modes(sort_spec) + + assert sorted_data.n_eff.sizes["mode_index"] == 2 + assert np.allclose(sorted_data.n_eff.isel(f=0).values, [1.5, 1.6]) + assert np.allclose(sorted_data.n_eff.isel(f=1).values, [1.4, 1.7]) + assert sorted_data.monitor.mode_spec.num_modes == 2 + assert sorted_data.monitor.mode_spec.sort_spec.keep_modes == "filtered" + + +@pytest.mark.parametrize("keep_modes", (1, 3)) +def test_mode_sort_spec_keep_modes_integer(keep_modes): + freqs = np.array([2e14, 4e14]) + mode_spec = td.ModeSpec(num_modes=4) + monitor = td.ModeSolverMonitor( + size=(1.0, 0.0, 1.0), + center=(0.0, 0.0, 0.0), + freqs=freqs, + mode_spec=mode_spec, + name="drop_modes", + ) + n_complex = ModeIndexDataArray( + np.array( + [ + [1.6 + 0.6j, 1.5 + 0.2j, 1.1 + 0.5j, 1.05 + 0.2j], + [1.7 + 0.4j, 1.4 + 0.3j, 1.07 + 0.1j, 1.02 + 0.3j], + ] + ), + coords={"f": freqs, "mode_index": np.arange(4)}, + ) + data = ModeSolverData(monitor=monitor, n_complex=n_complex) + + sort_spec = td.ModeSortSpec( + filter_key="n_eff", + filter_reference=1.3, + filter_order="over", + sort_key="k_eff", + sort_order="ascending", + keep_modes=keep_modes, + ) + + if keep_modes == 1: + with AssertLogLevel(None): + sorted_data = data.sort_modes(sort_spec) + else: + with AssertLogLevel("WARNING", contains_str="Filtering"): + sorted_data = data.sort_modes(sort_spec) + + assert sorted_data.n_eff.sizes["mode_index"] == keep_modes + if keep_modes == 1: + assert np.allclose(sorted_data.n_eff.isel(f=0).values, [1.5]) + assert np.allclose(sorted_data.n_eff.isel(f=1).values, [1.4]) + else: + assert np.allclose(sorted_data.n_eff.isel(f=0).values, [1.5, 1.6, 1.05]) + assert np.allclose(sorted_data.n_eff.isel(f=1).values, [1.4, 1.7, 1.07]) + assert sorted_data.monitor.mode_spec.num_modes == keep_modes + assert sorted_data.monitor.mode_spec.sort_spec.keep_modes == keep_modes + + +def test_mode_sort_spec_drop_modes_all_filtered(): + freqs = np.array([2e14, 4e14]) + mode_spec = td.ModeSpec(num_modes=3) + monitor = td.ModeSolverMonitor( + size=(1.0, 0.0, 1.0), + center=(0.0, 0.0, 0.0), + freqs=freqs, + mode_spec=mode_spec, + name="drop_all", + ) + n_complex = ModeIndexDataArray( + np.array( + [ + [1.1 + 0.1j, 1.05 + 0.05j, 1.0 + 0.01j], + [1.1 + 0.1j, 1.05 + 0.05j, 1.0 + 0.01j], + ] + ), + coords={"f": freqs, "mode_index": np.arange(3)}, + ) + data = ModeSolverData(monitor=monitor, n_complex=n_complex) + + sort_spec = td.ModeSortSpec( + filter_key="n_eff", + filter_reference=2.0, + keep_modes="filtered", + ) + + with pytest.raises(ValidationError): + _ = data.sort_modes(sort_spec) + + +def test_mode_sort_spec_drop_modes_requires_filter(): + with pytest.raises(pydantic.ValidationError): + td.ModeSortSpec(keep_modes="filtered") + + +def test_mode_sort_spec_keep_modes_at_most_num_modes(): + sort_spec = td.ModeSortSpec(keep_modes=4) + with pytest.raises(pydantic.ValidationError): + _ = td.ModeSpec(num_modes=2, sort_spec=sort_spec) + + +def test_mode_sort_spec_fill_fraction_box_filter_drops_modes(): + data, bounding_box = make_fill_fraction_mode_data() + + sort_spec = td.ModeSortSpec( + filter_key="fill_fraction_box", + filter_reference=0.5, + filter_order="over", + keep_modes="filtered", + bounding_box=bounding_box, + ) + + filtered = data.sort_modes(sort_spec) + + assert filtered.n_eff.sizes["mode_index"] == 1 + assert filtered.monitor.mode_spec.num_modes == 1 + + fills = data.fill_fraction(bounding_box) + assert np.isclose(fills.isel(mode_index=0, f=0).item(), 1.0) + assert np.isclose(fills.isel(mode_index=1, f=0).item(), 0.0) + + +def test_mode_sort_spec_fill_fraction_box_requires_bounding_box(): + with pytest.raises(pydantic.ValidationError): + td.ModeSortSpec(filter_key="fill_fraction_box") + + +def test_mode_data_fill_fraction_box_requires_intersection(): + data, _ = make_fill_fraction_mode_data() + with pytest.raises(ValidationError): + data.fill_fraction(td.Box(center=(0.0, 2.0, 0.0), size=(1.0, 1.0, 1.0))) diff --git a/tidy3d/components/data/monitor_data.py b/tidy3d/components/data/monitor_data.py index b2bda4e653..45b3092f58 100644 --- a/tidy3d/components/data/monitor_data.py +++ b/tidy3d/components/data/monitor_data.py @@ -16,6 +16,7 @@ from tidy3d.components.base import cached_property, skip_if_fields_missing from tidy3d.components.base_sim.data.monitor_data import AbstractMonitorData +from tidy3d.components.geometry.base import Box from tidy3d.components.grid.grid import Coords, Grid from tidy3d.components.medium import Medium, MediumType from tidy3d.components.mode_spec import ModeSortSpec, ModeSpec @@ -731,6 +732,87 @@ def mode_area(self) -> FreqModeDataArray: return FreqModeDataArray(area) + def _bounding_box_mask(self, bounding_box: Box) -> DataArray: + """Create a mask selecting cells whose centers lie within ``bounding_box``.""" + + tan_dims = self._tangential_dims + intensity = self.intensity + coords = {dim: intensity.coords[dim].values for dim in tan_dims} + + lower, upper = bounding_box.bounds + axis_indices = ["xyz".index(dim) for dim in tan_dims] + + masks_1d = [] + for dim, axis_idx in zip(tan_dims, axis_indices): + coord_vals = coords[dim] + lower_bound = lower[axis_idx] + upper_bound = upper[axis_idx] + masks_1d.append((coord_vals >= lower_bound) & (coord_vals <= upper_bound)) + + if len(masks_1d) != 2: + raise DataError("Bounding box masking currently supports planar monitors only.") + + mask_values = (masks_1d[0][:, None] & masks_1d[1][None, :]).astype(float) + mask = DataArray(mask_values, coords={dim: coords[dim] for dim in tan_dims}, dims=tan_dims) + return mask + + def _validate_bounding_box_intersection(self, bounding_box: Box) -> None: + """Ensure bounding box intersects the monitor plane.""" + + zero_dims = self.monitor.zero_dims + if len(zero_dims) != 1: + raise DataError("Bounding box fill fraction requires a planar monitor.") + + normal_axis = zero_dims[0] + plane_coord = self.monitor.center[normal_axis] + lower, upper = bounding_box.bounds + lower_bound = lower[normal_axis] + upper_bound = upper[normal_axis] + tol = 1e-9 + if plane_coord < lower_bound - tol or plane_coord > upper_bound + tol: + raise ValidationError( + "Bounding box must intersect the monitor plane when using 'fill_fraction_box'." + ) + + def fill_fraction(self, bounding_box: Box) -> FreqModeDataArray: + """Return the field-energy fill fraction within ``bounding_box``. + + The fill fraction is defined as the ratio between the integrated field intensity inside + the bounding box and the total integrated intensity over the monitor plane. + """ + + self._check_fields_stored(["Ex", "Ey", "Ez"]) + self._validate_bounding_box_intersection(bounding_box) + + intensity = self.intensity + area = self._diff_area + mask = self._bounding_box_mask(bounding_box) + + weighted_total = (intensity * area).sum(dim=area.dims) + weighted_box = (intensity * mask * area).sum(dim=area.dims) + + fill_values = weighted_box / weighted_total + fill_values = fill_values.where(weighted_total != 0, 0.0) + fill_values = fill_values.fillna(0.0) + + return FreqModeDataArray(fill_values) + + @property + def fill_fraction_box(self) -> FreqModeDataArray: + """Convenience accessor using the :class:`Box` defined on ``sort_spec``. + + The component of the box along the propagation axis does not influence the fill fraction, + but the box must intersect the monitor plane. + """ + + sort_spec = getattr(self.monitor.mode_spec, "sort_spec", None) + bounding_box = None if sort_spec is None else sort_spec.bounding_box + if bounding_box is None: + raise DataError( + "ModeSortSpec.bounding_box must be set to access 'fill_fraction_box' metric." + ) + return self.fill_fraction(bounding_box) + def dot( self, field_data: Union[FieldData, ModeData, ModeSolverData], conjugate: bool = True ) -> ModeAmpsDataArray: @@ -2329,6 +2411,72 @@ def _apply_mode_reorder(self, sort_inds_2d): return self.updated_copy(**modify_data) + def _apply_mode_subset(self, subset_inds_2d: np.ndarray) -> ModeSolverData: + """Return copy of self containing only the selected modes. + + Parameters + ---------- + subset_inds_2d : np.ndarray + Array of shape ``(num_freqs, num_modes_keep)`` containing the indices of the original + modes to retain at each frequency. + + Returns + ------- + :class:`.ModeSolverData` + Copy of self with only the retained modes. + """ + + subset_inds_2d = np.asarray(subset_inds_2d, dtype=int) + if subset_inds_2d.ndim != 2: + raise DataError( + "subset_inds_2d must be a 2D array of shape (num_freqs, num_modes_keep)." + ) + + num_freqs, num_keep = subset_inds_2d.shape + if num_keep == 0: + raise DataError("Cannot create a mode subset with zero modes.") + + num_modes_full = self.n_eff["mode_index"].size + + modify_data = {} + new_mode_index_coord = np.arange(num_keep) + + for key, data in self.data_arrs.items(): + if "mode_index" not in data.dims or "f" not in data.dims: + continue + + dims_orig = tuple(data.dims) + coords_out = { + k: (v.values if hasattr(v, "values") else np.asarray(v)) + for k, v in data.coords.items() + } + + f_axis = data.get_axis_num("f") + m_axis = data.get_axis_num("mode_index") + src_order = ( + [f_axis] + [ax for ax in range(data.ndim) if ax not in (f_axis, m_axis)] + [m_axis] + ) + + arr = np.moveaxis(data.data, src_order, range(data.ndim)) + nf, nm = arr.shape[0], arr.shape[-1] + if nf != num_freqs or nm != num_modes_full: + raise DataError( + "subset_inds_2d shape does not match array shape in _apply_mode_subset." + ) + + arr2 = arr.reshape(nf, -1, nm) + inds = subset_inds_2d[:, None, :] + arr2_subset = np.take_along_axis(arr2, inds, axis=2) + arr_subset = arr2_subset.reshape(arr.shape[:-1] + (num_keep,)) + arr_subset = np.moveaxis(arr_subset, range(data.ndim), src_order) + + coords_out["mode_index"] = new_mode_index_coord + coords_out["f"] = data.coords["f"].values + + modify_data[key] = DataArray(arr_subset, coords=coords_out, dims=dims_orig) + + return self.updated_copy(**modify_data) + def sort_modes( self, sort_spec: Optional[ModeSortSpec] = None, track_freq: Optional[TrackFreq] = None ) -> ModeSolverData: @@ -2360,35 +2508,55 @@ def sort_modes( if track_freq is None and sort_spec is None: return self - num_freqs = self.n_eff["f"].size - num_modes = self.n_eff["mode_index"].size + data = self + if sort_spec is not None and sort_spec != self.monitor.mode_spec.sort_spec: + data = self.updated_copy( + path="monitor/mode_spec", sort_spec=sort_spec, deep=False, validate=False + ) + + num_freqs = data.n_eff["f"].size + num_modes = data.n_eff["mode_index"].size all_inds = np.arange(num_modes) - identity = np.arange(num_modes) - sort_inds_2d = np.tile(identity, (num_freqs, 1)) + keep_modes = getattr(sort_spec, "keep_modes", False) + if keep_modes == "filtered" and sort_spec.filter_key is None: + raise ValidationError("ModeSortSpec.keep_modes requires 'filter_key' to be set.") # Helper to compute ordered indices within a subset - def _order_indices(indices, vals_all): + def _order_indices(indices, vals_all, sort_order): if indices.size == 0: return indices vals = vals_all.isel(mode_index=indices) order = np.argsort(vals) - if sort_spec.sort_order == "descending": + if sort_order == "descending": order = order[::-1] return indices[order] # Precompute metrics if provided - filter_metric = None - sort_metric = None - if sort_spec.filter_key is not None: - filter_metric = getattr(self, sort_spec.filter_key) - if sort_spec.sort_key is not None: - sort_metric = getattr(self, sort_spec.sort_key) + fill_fraction_metric = None + + def _metric_for_key(key: Optional[str]): + nonlocal fill_fraction_metric + if key is None: + return None + if key == "fill_fraction_box": + if sort_spec is None or sort_spec.bounding_box is None: + raise ValidationError( + "ModeSortSpec.bounding_box must be defined when using 'fill_fraction_box'." + ) + if fill_fraction_metric is None: + fill_fraction_metric = data.fill_fraction_box + return fill_fraction_metric + return getattr(data, key) + + filter_metric = _metric_for_key(sort_spec.filter_key) if sort_spec else None + sort_metric = _metric_for_key(sort_spec.sort_key) if sort_spec else None + identity = np.arange(num_modes) + sort_inds_2d = np.tile(identity, (num_freqs, 1)) for ifreq in range(num_freqs): # Build groups according to filter if requested if filter_metric is not None: vals_filt = filter_metric.isel(f=ifreq).values - # Boolean mask for modes in the first group if sort_spec.filter_order == "over": mask_first = vals_filt >= sort_spec.filter_reference else: @@ -2404,8 +2572,8 @@ def _order_indices(indices, vals_all): vals_sort = sort_metric.isel(f=ifreq) if sort_spec.sort_reference is not None: vals_sort = np.abs(vals_sort - sort_spec.sort_reference) - g1 = _order_indices(group1, vals_sort) - g2 = _order_indices(group2, vals_sort) + g1 = _order_indices(group1, vals_sort, sort_spec.sort_order) + g2 = _order_indices(group2, vals_sort, sort_spec.sort_order) sort_inds = np.concatenate([g1, g2]) else: # only filtering applied, keep original ordering within groups @@ -2413,11 +2581,15 @@ def _order_indices(indices, vals_all): sort_inds_2d[ifreq, : len(sort_inds)] = sort_inds - # If all rows are identity, skip if np.all(sort_inds_2d == np.tile(identity, (num_freqs, 1))): - data_sorted = self + if sort_spec is not None: + data_sorted = data.updated_copy( + path="monitor/mode_spec", sort_spec=sort_spec, deep=False, validate=False + ) + else: + data_sorted = data else: - data_sorted = self._apply_mode_reorder(sort_inds_2d) # this creates a copy + data_sorted = data._apply_mode_reorder(sort_inds_2d) data_sorted = data_sorted.updated_copy( path="monitor/mode_spec", sort_spec=sort_spec, deep=False, validate=False ) @@ -2425,10 +2597,58 @@ def _order_indices(indices, vals_all): # Sort modes across frequencies if requested. # Note: after sorting, ``track_freq`` is set in ``sort_spec`` regardless of how it was # provided. The deprecated ``mode_spec.track_freq`` is cleared. - track_freq = track_freq or sort_spec.track_freq + sort_spec_track_freq = sort_spec.track_freq if sort_spec is not None else None + track_freq = track_freq or sort_spec_track_freq if track_freq and num_freqs > 1: data_sorted = data_sorted.overlap_sort(track_freq) + keep_inds = None + if keep_modes == "filtered" or isinstance(keep_modes, int): + # Re-evaluate the filter after sorting/tracking so modes are dropped consistently. + if sort_spec.filter_key == "fill_fraction_box": + filter_metric_sorted = data_sorted.fill_fraction_box + else: + filter_metric_sorted = getattr(data_sorted, sort_spec.filter_key) + masks_after = [] + for ifreq in range(num_freqs): + vals = filter_metric_sorted.isel(f=ifreq).values + if sort_spec.filter_order == "over": + mask = vals >= sort_spec.filter_reference + else: + mask = vals <= sort_spec.filter_reference + masks_after.append(mask) + + keep_mask = np.all(np.stack(masks_after, axis=0), axis=0) + if keep_modes == "filtered": + if not np.any(keep_mask): + raise ValidationError( + "Filtering removes all modes; relax the filter threshold or change 'keep_modes'." + ) + + num_modes_sorted = filter_metric_sorted.sizes["mode_index"] + if keep_mask.sum() < num_modes_sorted: + keep_inds = np.where(keep_mask)[0] + elif isinstance(keep_modes, int): + if keep_mask.sum() < keep_modes: + log.warning( + f"Filtering with 'keep_modes={keep_modes}' keeps " + f"more than the '{keep_mask.sum()}' modes which pass the filter. Consider " + "relaxing the filter threshold or changing 'keep_modes'." + ) + keep_inds = np.arange(keep_modes) + + if keep_inds is not None: + subset_inds_2d = np.tile(keep_inds, (num_freqs, 1)) + data_subset = data_sorted._apply_mode_subset(subset_inds_2d) + mspec = data_subset.monitor.mode_spec + mspec_updated = mspec.updated_copy(num_modes=keep_inds.size, validate=False) + monitor_updated = data_subset.monitor.updated_copy( + mode_spec=mspec_updated, validate=False + ) + data_sorted = data_subset.updated_copy( + monitor=monitor_updated, deep=False, validate=False + ) + return data_sorted diff --git a/tidy3d/components/mode_spec.py b/tidy3d/components/mode_spec.py index 0dfb12b531..0f20b31b0b 100644 --- a/tidy3d/components/mode_spec.py +++ b/tidy3d/components/mode_spec.py @@ -9,6 +9,7 @@ import numpy as np import pydantic.v1 as pd +from tidy3d.components.geometry.base import Box from tidy3d.constants import GLANCING_CUTOFF, MICROMETER, RADIAN, fp_eps from tidy3d.exceptions import SetupError, ValidationError from tidy3d.log import log @@ -25,6 +26,7 @@ "wg_TE_fraction", "wg_TM_fraction", "mode_area", + "fill_fraction_box", ] @@ -35,7 +37,9 @@ class ModeSortSpec(Tidy3dBaseModel): applied to ``filter_key``: modes "over" or "under" ``filter_reference`` are placed first, with the remaining modes placed next. Second, an optional sorting step orders modes within each group according to ``sort_key``, optionally with respect to ``sort_reference`` and in - the specified ``sort_order``. + the specified ``sort_order``. If ``keep_modes`` is set to "filtered", the modes that do not meet + the filter criterion are removed instead of being appended as a second group. + If ``keep_modes`` is set to an integer, that is the number of modes that will be kept. """ # Filtering stage @@ -54,6 +58,25 @@ class ModeSortSpec(Tidy3dBaseModel): title="Filtering order", description="Select whether the first group contains values over or under the reference.", ) + bounding_box: Optional[Box] = pd.Field( + None, + title="Bounding box", + description=( + "Regular 3D tidy3d :class:`Box` used by metrics such as ``'fill_fraction_box'``. " + "The extent along the propagation axis is ignored for the metric, but the box must " + "still intersect the monitor plane. Required when filtering or sorting with that key." + ), + ) + keep_modes: Union[Literal["all"], Literal["filtered"], pd.PositiveInt] = pd.Field( + "all", + title="Keep Modes", + description=( + "If ``filtered``, modes that do not satisfy the filter criterion are removed entirely " + "instead of being appended after the filtered group. Only modes passing the filter " + "at every tracked frequency are kept. " + "If a positive integer is given, that is the number of modes which will be kept." + ), + ) # Sorting stage sort_key: Optional[MODE_DATA_KEYS] = pd.Field( @@ -85,6 +108,24 @@ class ModeSortSpec(Tidy3dBaseModel): "while at other frequencies it can change depending on the mode tracking.", ) + @pd.validator("keep_modes", always=True) + def _drop_requires_filter(cls, val, values): + if val == "filtered" and values.get("filter_key") is None: + raise ValidationError( + "ModeSortSpec.keep_modes 'filtered' requires 'filter_key' to be set." + ) + return val + + @pd.root_validator(skip_on_failure=True) + def _bounding_box_required_for_fill_fraction(cls, values): + bbox = values.get("bounding_box") + keys = (values.get("filter_key"), values.get("sort_key")) + if any(key == "fill_fraction_box" for key in keys) and bbox is None: + raise ValidationError( + "ModeSortSpec.bounding_box must be set when using 'fill_fraction_box'." + ) + return values + class FrequencySamplingSpec(Tidy3dBaseModel, ABC): """Abstract base class for frequency sampling specifications.""" @@ -561,6 +602,18 @@ class AbstractModeSpec(Tidy3dBaseModel, ABC): "not be ``None``) to ensure consistent mode ordering across frequencies.", ) + @pd.validator("sort_spec", always=True) + def _keep_modes_at_most_num_modes(cls, val, values): + if val is not None: + if isinstance(val.keep_modes, int): + num_modes = values.get("num_modes") + if val.keep_modes > num_modes: + raise ValidationError( + "ModeSortSpec.keep_modes cannot be larger than 'num_modes' ." + f"Currently these are {val.keep_modes} and {num_modes}." + ) + return val + @pd.validator("bend_axis", always=True) @skip_if_fields_missing(["bend_radius"]) def bend_axis_given(cls, val, values):