Skip to content

Commit 636f5c9

Browse files
committed
Add option to drop filtered modes
1 parent 603c8a8 commit 636f5c9

File tree

3 files changed

+462
-27
lines changed

3 files changed

+462
-27
lines changed

tests/test_plugins/test_mode_solver.py

Lines changed: 197 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,12 @@
1010

1111
import tidy3d as td
1212
import tidy3d.plugins.mode.web as msweb
13-
from tidy3d import ScalarFieldDataArray
13+
from tidy3d import Coords, Grid, ModeIndexDataArray, ScalarFieldDataArray, ScalarModeFieldDataArray
1414
from tidy3d.components.data.monitor_data import ModeSolverData
1515
from tidy3d.components.mode.derivatives import create_sfactor_b, create_sfactor_f
1616
from tidy3d.components.mode.solver import compute_modes
1717
from tidy3d.components.mode_spec import MODE_DATA_KEYS
18-
from tidy3d.exceptions import DataError, SetupError
18+
from tidy3d.exceptions import DataError, SetupError, ValidationError
1919
from tidy3d.plugins.mode import ModeSolver
2020
from tidy3d.plugins.mode.mode_solver import MODE_MONITOR_NAME
2121
from tidy3d.web.core.environment import Env
@@ -38,6 +38,69 @@
3838
SOLVER_ID = "Solver-ID"
3939

4040

41+
def make_fill_fraction_mode_data():
42+
freq = np.array([2e14])
43+
mode_spec = td.ModeSpec(num_modes=2)
44+
monitor = td.ModeSolverMonitor(
45+
size=(3.0, 0.0, 3.0),
46+
center=(0.0, 0.0, 0.0),
47+
freqs=freq,
48+
mode_spec=mode_spec,
49+
name="fill_fraction",
50+
)
51+
52+
grid = Grid(
53+
boundaries=Coords(
54+
x=np.array([-1.5, -0.5, 0.5, 1.5]),
55+
y=np.array([-0.5, 0.5]),
56+
z=np.array([-1.5, -0.5, 0.5, 1.5]),
57+
)
58+
)
59+
60+
coords = {
61+
"x": np.array([-1.0, 0.0, 1.0]),
62+
"y": np.array([0.0]),
63+
"z": np.array([-1.0, 0.0, 1.0]),
64+
"f": freq,
65+
"mode_index": np.arange(2),
66+
}
67+
shape = (3, 1, 3, 1, 2)
68+
69+
ex_data = np.zeros(shape, dtype=complex)
70+
ex_data[1, 0, 1, 0, 0] = 2.0
71+
for ix in (0, 2):
72+
for iz in (0, 2):
73+
ex_data[ix, 0, iz, 0, 1] = 1.0
74+
75+
zero_data = np.zeros(shape, dtype=complex)
76+
77+
fields = {
78+
"Ex": ScalarModeFieldDataArray(ex_data, coords=coords),
79+
"Ey": ScalarModeFieldDataArray(np.copy(zero_data), coords=coords),
80+
"Ez": ScalarModeFieldDataArray(np.copy(zero_data), coords=coords),
81+
"Hx": ScalarModeFieldDataArray(np.copy(zero_data), coords=coords),
82+
"Hy": ScalarModeFieldDataArray(np.copy(zero_data), coords=coords),
83+
"Hz": ScalarModeFieldDataArray(np.copy(zero_data), coords=coords),
84+
}
85+
86+
n_complex = ModeIndexDataArray(
87+
np.array([[1.6 + 0.0j, 1.3 + 0.0j]]),
88+
coords={"f": freq, "mode_index": np.arange(2)},
89+
)
90+
91+
data = ModeSolverData(
92+
monitor=monitor,
93+
symmetry=(0, 0, 0),
94+
symmetry_center=(0.0, 0.0, 0.0),
95+
grid_expanded=grid,
96+
n_complex=n_complex,
97+
**fields,
98+
)
99+
100+
bounding_box = td.Box(center=(0.0, 0.0, 0.0), size=(1.0, 2.0, 1.0))
101+
return data, bounding_box
102+
103+
41104
@pytest.fixture
42105
def mock_remote_api(monkeypatch):
43106
def void(*args, **kwargs):
@@ -1343,21 +1406,41 @@ def test_modes_filter_sort():
13431406
for key in get_args(MODE_DATA_KEYS):
13441407
print(key)
13451408
# Test ascending
1346-
sort_spec = td.ModeSortSpec(sort_key=key, sort_order="ascending", track_freq=None)
1409+
sort_kwargs = {
1410+
"sort_key": key,
1411+
"sort_order": "ascending",
1412+
"track_freq": None,
1413+
}
1414+
if key == "fill_fraction_box":
1415+
sort_kwargs["bounding_box"] = td.Box(center=PLANE.center, size=(5.0, 4.0, 5.0))
1416+
sort_spec = td.ModeSortSpec(**sort_kwargs)
13471417
modes = modes.sort_modes(sort_spec)
13481418
metric = getattr(modes, key)
13491419
assert np.all(metric.diff(dim="mode_index") >= 0)
13501420

13511421
# Test descending
1352-
sort_spec = td.ModeSortSpec(sort_key=key, sort_order="descending", track_freq=None)
1422+
sort_kwargs = {
1423+
"sort_key": key,
1424+
"sort_order": "descending",
1425+
"track_freq": None,
1426+
}
1427+
if key == "fill_fraction_box":
1428+
sort_kwargs["bounding_box"] = td.Box(center=PLANE.center, size=(5.0, 4.0, 5.0))
1429+
sort_spec = td.ModeSortSpec(**sort_kwargs)
13531430
modes = modes.sort_modes(sort_spec)
13541431
metric = getattr(modes, key)
13551432
assert np.all(metric.diff(dim="mode_index") <= 0)
13561433

13571434
# Test descending with a large reference value should be the same as ascending
1358-
sort_spec = td.ModeSortSpec(
1359-
sort_key=key, sort_order="descending", sort_reference=100, track_freq=None
1360-
)
1435+
sort_kwargs = {
1436+
"sort_key": key,
1437+
"sort_order": "descending",
1438+
"sort_reference": 100,
1439+
"track_freq": None,
1440+
}
1441+
if key == "fill_fraction_box":
1442+
sort_kwargs["bounding_box"] = td.Box(center=PLANE.center, size=(5.0, 4.0, 5.0))
1443+
sort_spec = td.ModeSortSpec(**sort_kwargs)
13611444
modes = modes.sort_modes(sort_spec)
13621445
metric = getattr(modes, key)
13631446
assert np.all(metric.diff(dim="mode_index") >= 0)
@@ -1503,3 +1586,110 @@ def test_sort_spec_track_freq():
15031586
assert np.allclose(modes_lowest.Ex.abs, modes_lowest_retracked.Ex.abs)
15041587
assert np.all(modes_lowest.n_eff == modes_lowest_retracked.n_eff)
15051588
assert np.all(modes_lowest.n_group == modes_lowest_retracked.n_group)
1589+
1590+
1591+
def test_mode_sort_spec_drop_modes_reduces_modes():
1592+
freqs = np.array([2e14, 4e14])
1593+
mode_spec = td.ModeSpec(num_modes=3)
1594+
monitor = td.ModeSolverMonitor(
1595+
size=(1.0, 0.0, 1.0),
1596+
center=(0.0, 0.0, 0.0),
1597+
freqs=freqs,
1598+
mode_spec=mode_spec,
1599+
name="drop_modes",
1600+
)
1601+
n_complex = ModeIndexDataArray(
1602+
np.array(
1603+
[
1604+
[1.6 + 0.6j, 1.5 + 0.2j, 1.1 + 0.5j],
1605+
[1.7 + 0.4j, 1.4 + 0.3j, 1.0 + 0.1j],
1606+
]
1607+
),
1608+
coords={"f": freqs, "mode_index": np.arange(3)},
1609+
)
1610+
data = ModeSolverData(monitor=monitor, n_complex=n_complex)
1611+
1612+
sort_spec = td.ModeSortSpec(
1613+
filter_key="n_eff",
1614+
filter_reference=1.3,
1615+
filter_order="over",
1616+
sort_key="k_eff",
1617+
sort_order="ascending",
1618+
drop_modes=True,
1619+
)
1620+
1621+
sorted_data = data.sort_modes(sort_spec)
1622+
1623+
assert sorted_data.n_eff.sizes["mode_index"] == 2
1624+
assert np.allclose(sorted_data.n_eff.isel(f=0).values, [1.5, 1.6])
1625+
assert np.allclose(sorted_data.n_eff.isel(f=1).values, [1.4, 1.7])
1626+
assert sorted_data.monitor.mode_spec.num_modes == 2
1627+
assert sorted_data.monitor.mode_spec.sort_spec.drop_modes is True
1628+
1629+
1630+
def test_mode_sort_spec_drop_modes_all_filtered():
1631+
freqs = np.array([2e14, 4e14])
1632+
mode_spec = td.ModeSpec(num_modes=3)
1633+
monitor = td.ModeSolverMonitor(
1634+
size=(1.0, 0.0, 1.0),
1635+
center=(0.0, 0.0, 0.0),
1636+
freqs=freqs,
1637+
mode_spec=mode_spec,
1638+
name="drop_all",
1639+
)
1640+
n_complex = ModeIndexDataArray(
1641+
np.array(
1642+
[
1643+
[1.1 + 0.1j, 1.05 + 0.05j, 1.0 + 0.01j],
1644+
[1.1 + 0.1j, 1.05 + 0.05j, 1.0 + 0.01j],
1645+
]
1646+
),
1647+
coords={"f": freqs, "mode_index": np.arange(3)},
1648+
)
1649+
data = ModeSolverData(monitor=monitor, n_complex=n_complex)
1650+
1651+
sort_spec = td.ModeSortSpec(
1652+
filter_key="n_eff",
1653+
filter_reference=2.0,
1654+
drop_modes=True,
1655+
)
1656+
1657+
with pytest.raises(ValidationError):
1658+
_ = data.sort_modes(sort_spec)
1659+
1660+
1661+
def test_mode_sort_spec_drop_modes_requires_filter():
1662+
with pytest.raises(pydantic.ValidationError):
1663+
td.ModeSortSpec(drop_modes=True)
1664+
1665+
1666+
def test_mode_sort_spec_fill_fraction_box_filter_drops_modes():
1667+
data, bounding_box = make_fill_fraction_mode_data()
1668+
1669+
sort_spec = td.ModeSortSpec(
1670+
filter_key="fill_fraction_box",
1671+
filter_reference=0.5,
1672+
filter_order="over",
1673+
drop_modes=True,
1674+
bounding_box=bounding_box,
1675+
)
1676+
1677+
filtered = data.sort_modes(sort_spec)
1678+
1679+
assert filtered.n_eff.sizes["mode_index"] == 1
1680+
assert filtered.monitor.mode_spec.num_modes == 1
1681+
1682+
fills = data.fill_fraction(bounding_box)
1683+
assert np.isclose(fills.isel(mode_index=0, f=0).item(), 1.0)
1684+
assert np.isclose(fills.isel(mode_index=1, f=0).item(), 0.0)
1685+
1686+
1687+
def test_mode_sort_spec_fill_fraction_box_requires_bounding_box():
1688+
with pytest.raises(pydantic.ValidationError):
1689+
td.ModeSortSpec(filter_key="fill_fraction_box")
1690+
1691+
1692+
def test_mode_data_fill_fraction_box_requires_intersection():
1693+
data, _ = make_fill_fraction_mode_data()
1694+
with pytest.raises(ValidationError):
1695+
data.fill_fraction(td.Box(center=(0.0, 2.0, 0.0), size=(1.0, 1.0, 1.0)))

0 commit comments

Comments
 (0)