Skip to content

Commit e58a259

Browse files
yaugenst-flexdaquinteroflex
authored andcommitted
fix(autograd): clip rescale output to prevent numerical precision errors
1 parent b3206f1 commit e58a259

File tree

2 files changed

+27
-1
lines changed

2 files changed

+27
-1
lines changed

tests/test_plugins/autograd/test_functions.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -448,6 +448,30 @@ def test_rescale_exceptions(array, out_min, out_max, in_min, in_max, expected_me
448448
rescale(array, out_min, out_max, in_min, in_max)
449449

450450

451+
def test_rescale_clips_output_to_bounds():
452+
"""Test that rescale clips output to [out_min, out_max] even when input is slightly outside [in_min, in_max].
453+
454+
This is a regression test for a numerical precision issue where filter_project + tanh_projection
455+
could produce values slightly outside [0, 1] (e.g., -1e-15), causing rescale to produce
456+
permittivity values slightly below 1.0, which would fail CustomMedium validation.
457+
"""
458+
# Simulate input slightly outside the expected [0, 1] range due to numerical precision
459+
array_with_numerical_error = np.array([-1e-15, 0.5, 1.0 + 1e-15])
460+
461+
out_min, out_max = 1.0, 2.75
462+
in_min, in_max = 0.0, 1.0
463+
464+
result = rescale(array_with_numerical_error, out_min, out_max, in_min, in_max)
465+
466+
# Without clipping, result[0] would be slightly below 1.0 (e.g., 0.999999999999998)
467+
# and result[2] would be slightly above 2.75
468+
assert result.min() >= out_min, f"Output {result.min()} is below out_min={out_min}"
469+
assert result.max() <= out_max, f"Output {result.max()} is above out_max={out_max}"
470+
471+
npt.assert_equal(result[0], out_min)
472+
npt.assert_equal(result[2], out_max)
473+
474+
451475
@pytest.mark.parametrize(
452476
"ary, vmin, vmax, level, expected",
453477
[

tidy3d/plugins/autograd/functions.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -709,7 +709,9 @@ def rescale(
709709
raise ValueError(f"'in_min' ({in_min}) must be less than 'in_max' ({in_max}).")
710710

711711
scaled = (array - in_min) / (in_max - in_min)
712-
return scaled * (out_max - out_min) + out_min
712+
result = scaled * (out_max - out_min) + out_min
713+
714+
return np.clip(result, out_min, out_max)
713715

714716

715717
def threshold(

0 commit comments

Comments
 (0)