Skip to content

Commit f5c0f6d

Browse files
committed
Add more tests
...in particular for `run_iter()` and more components. Test descan error implementation for correctly handling the names, not only indices. Drive-by flake8 fixes
1 parent 3e43e8c commit f5c0f6d

3 files changed

Lines changed: 199 additions & 44 deletions

File tree

tests/test_component.py

Lines changed: 88 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
11
import pytest
2+
from numpy.testing import assert_allclose
3+
24
import numpy as np
35
from jax import jacobian
46
import jax.numpy as jnp
57
import jax_dataclasses as jdc
68

7-
from temgym_core.components import ScanGrid, Detector, Descanner, DescanError, Component
9+
from temgym_core.components import (
10+
ScanGrid, Detector, Descanner, Scanner, DescanError, Component, Plane
11+
)
812
from temgym_core.ray import Ray
913
from temgym_core.utils import custom_jacobian_matrix
1014

@@ -86,8 +90,8 @@ def test_scan_grid_metres_to_pixels(xy, rotation, expected_pixel_coords):
8690
shape=(11, 11),
8791
)
8892
pixel_coords_y, pixel_coords_x = scan_grid.metres_to_pixels(xy)
89-
np.testing.assert_allclose(pixel_coords_y, expected_pixel_coords[0], atol=1e-6)
90-
np.testing.assert_allclose(pixel_coords_x, expected_pixel_coords[1], atol=1e-6)
93+
assert_allclose(pixel_coords_y, expected_pixel_coords[0], atol=1e-6)
94+
assert_allclose(pixel_coords_x, expected_pixel_coords[1], atol=1e-6)
9195

9296

9397
@pytest.mark.parametrize(
@@ -115,8 +119,8 @@ def test_scan_grid_pixels_to_metres(pixel_coords, rotation, expected_xy):
115119
shape=(11, 11),
116120
)
117121
metres_coords_x, metres_coords_y = scan_grid.pixels_to_metres(pixel_coords)
118-
np.testing.assert_allclose(metres_coords_x, expected_xy[0], atol=1e-6)
119-
np.testing.assert_allclose(metres_coords_y, expected_xy[1], atol=1e-6)
122+
assert_allclose(metres_coords_x, expected_xy[0], atol=1e-6)
123+
assert_allclose(metres_coords_y, expected_xy[1], atol=1e-6)
120124

121125

122126
@pytest.mark.parametrize(
@@ -137,8 +141,8 @@ def test_detector_metres_to_pixels(xy, expected_pixel_coords):
137141
flip_y=False,
138142
)
139143
pixel_coords_y, pixel_coords_x = detector.metres_to_pixels(xy)
140-
np.testing.assert_allclose(pixel_coords_y, expected_pixel_coords[0], atol=1e-6)
141-
np.testing.assert_allclose(pixel_coords_x, expected_pixel_coords[1], atol=1e-6)
144+
assert_allclose(pixel_coords_y, expected_pixel_coords[0], atol=1e-6)
145+
assert_allclose(pixel_coords_x, expected_pixel_coords[1], atol=1e-6)
142146

143147

144148
# Test cases for Detector:
@@ -161,46 +165,93 @@ def test_detector_pixels_to_metres(pixel_coords, expected_xy):
161165
flip_y=False,
162166
)
163167
metres_coords_x, metres_coords_y = detector.pixels_to_metres(pixel_coords)
164-
np.testing.assert_allclose(metres_coords_x, expected_xy[0], atol=1e-6)
165-
np.testing.assert_allclose(metres_coords_y, expected_xy[1], atol=1e-6)
168+
assert_allclose(metres_coords_x, expected_xy[0], atol=1e-6)
169+
assert_allclose(metres_coords_y, expected_xy[1], atol=1e-6)
170+
171+
172+
def test_plane():
173+
x, y, dx, dy, z, pathlength = np.random.uniform(-5.0, 5.0, size=6)
174+
ray = Ray(x=x, y=y, dx=dx, dy=dy, _one=1.0, z=0.0, pathlength=0.0)
175+
comp = Plane(z=23)
176+
out = comp(ray)
177+
for attr in ('x', 'y', 'dx', 'dy', '_one', 'z', 'pathlength'):
178+
assert getattr(ray, attr) == getattr(out, attr)
179+
180+
181+
def test_scanner_random():
182+
# Randomly chosen scan position and ray parameters
183+
sp_x, sp_y = np.random.uniform(-5.0, 5.0), np.random.uniform(-5.0, 5.0)
184+
st_x, st_y = np.random.uniform(-0.5, 0.5), np.random.uniform(0.5, 0.5)
185+
x, y, dx, dy, z, pathlength = np.random.uniform(-5.0, 5.0, size=6)
186+
187+
sc = Scanner(
188+
z=23,
189+
scan_pos_x=sp_x, scan_pos_y=sp_y,
190+
scan_tilt_y=st_y, scan_tilt_x=st_x,
191+
)
192+
ray = Ray(x=x, y=y, dx=dx, dy=dy, _one=1.0, z=z, pathlength=pathlength)
193+
out = sc(ray)
194+
195+
# Expected values computed using the same formula as in the implementation
196+
exp_x = x + sp_x
197+
exp_y = y + sp_y
198+
exp_dx = dx + st_x
199+
exp_dy = dy + st_y
200+
201+
assert_allclose(out.x, exp_x, atol=1e-6)
202+
assert_allclose(out.y, exp_y, atol=1e-6)
203+
assert_allclose(out.dx, exp_dx, atol=1e-6)
204+
assert_allclose(out.dy, exp_dy, atol=1e-6)
205+
for attr in ('_one', 'z', 'pathlength'):
206+
assert getattr(ray, attr) == getattr(out, attr)
166207

167208

168209
def test_descanner_random_descan_error():
169210
# Randomly chosen scan position and ray parameters
170211
sp_x, sp_y = np.random.uniform(-5.0, 5.0), np.random.uniform(-5.0, 5.0)
171-
x, y, dx, dy = np.random.uniform(-5.0, 5.0, size=4)
212+
st_x, st_y = np.random.uniform(-0.5, 0.5), np.random.uniform(0.5, 0.5)
213+
x, y, dx, dy, z, pathlength = np.random.uniform(-5.0, 5.0, size=6)
172214

173215
# Randomly chosen non-zero descan error (length 12)
174-
err = np.random.rand(12)
216+
(pxo_pxi, pxo_pyi, pyo_pxi, pyo_pyi,
217+
sxo_pxi, sxo_pyi, syo_pxi, syo_pyi,
218+
offpxi, offpyi, offsxi, offsyi) = np.random.rand(12)
175219

176220
err = DescanError(
177-
pxo_pxi=err[0],
178-
pxo_pyi=err[1],
179-
pyo_pxi=err[2],
180-
pyo_pyi=err[3],
181-
sxo_pxi=err[4],
182-
sxo_pyi=err[5],
183-
syo_pxi=err[6],
184-
syo_pyi=err[7],
185-
offpxi=err[8],
186-
offpyi=err[9],
187-
offsxi=err[10],
188-
offsyi=err[11],
221+
pxo_pxi=pxo_pxi,
222+
pxo_pyi=pxo_pyi,
223+
pyo_pxi=pyo_pxi,
224+
pyo_pyi=pyo_pyi,
225+
sxo_pxi=sxo_pxi,
226+
sxo_pyi=sxo_pyi,
227+
syo_pxi=syo_pxi,
228+
syo_pyi=syo_pyi,
229+
offpxi=offpxi,
230+
offpyi=offpyi,
231+
offsxi=offsxi,
232+
offsyi=offsyi,
189233
)
190-
desc = Descanner(z=0.0, scan_pos_x=sp_x, scan_pos_y=sp_y, descan_error=err)
191-
ray = Ray(x=x, y=y, dx=dx, dy=dy, _one=1.0, z=0.0, pathlength=0.0)
234+
desc = Descanner(
235+
z=23,
236+
scan_pos_x=sp_x, scan_pos_y=sp_y,
237+
scan_tilt_y=st_y, scan_tilt_x=st_x,
238+
descan_error=err
239+
)
240+
ray = Ray(x=x, y=y, dx=dx, dy=dy, _one=1.0, z=z, pathlength=pathlength)
192241
out = desc(ray)
193242

194243
# Expected values computed using the same formula as in the implementation
195-
exp_x = x + sp_x * err[0] + sp_y * err[1] + err[8] - sp_x
196-
exp_y = y + sp_x * err[2] + sp_y * err[3] + err[9] - sp_y
197-
exp_dx = dx + sp_x * err[4] + sp_y * err[5] + err[10]
198-
exp_dy = dy + sp_x * err[6] + sp_y * err[7] + err[11]
244+
exp_x = x + sp_x * pxo_pxi + sp_y * pxo_pyi + offpxi - sp_x
245+
exp_y = y + sp_x * pyo_pxi + sp_y * pyo_pyi + offpyi - sp_y
246+
exp_dx = dx + sp_x * sxo_pxi + sp_y * sxo_pyi + offsxi - st_x
247+
exp_dy = dy + sp_x * syo_pxi + sp_y * syo_pyi + offsyi - st_y
199248

200-
np.testing.assert_allclose(out.x, exp_x, atol=1e-6)
201-
np.testing.assert_allclose(out.y, exp_y, atol=1e-6)
202-
np.testing.assert_allclose(out.dx, exp_dx, atol=1e-6)
203-
np.testing.assert_allclose(out.dy, exp_dy, atol=1e-6)
249+
assert_allclose(out.x, exp_x, atol=1e-6)
250+
assert_allclose(out.y, exp_y, atol=1e-6)
251+
assert_allclose(out.dx, exp_dx, atol=1e-6)
252+
assert_allclose(out.dy, exp_dy, atol=1e-6)
253+
for attr in ('_one', 'z', 'pathlength'):
254+
assert getattr(ray, attr) == getattr(out, attr)
204255

205256

206257
def test_descanner_offset_consistency():
@@ -223,7 +274,7 @@ def test_descanner_offset_consistency():
223274
offsyi=err[11],
224275
)
225276
desc = Descanner(
226-
z=0.0, scan_pos_x=scan_pos_x, scan_pos_y=scan_pos_y, descan_error=err
277+
z=11, scan_pos_x=scan_pos_x, scan_pos_y=scan_pos_y, descan_error=err
227278
)
228279

229280
# generate a batch of random rays
@@ -251,7 +302,7 @@ def test_descanner_offset_consistency():
251302
# assert that all rays have received the same offset
252303
first = offsets[0]
253304
for off in offsets:
254-
np.testing.assert_allclose(off, first, atol=1e-6)
305+
assert_allclose(off, first, atol=1e-6)
255306

256307

257308
def test_descanner_jacobian_matrix():
@@ -294,7 +345,7 @@ def test_descanner_jacobian_matrix():
294345
[0.0, 0.0, 0.0, 0.0, 1.0],
295346
]
296347
)
297-
np.testing.assert_allclose(J, T, atol=1e-6)
348+
assert_allclose(J, T, atol=1e-6)
298349

299350

300351
@pytest.mark.parametrize("repeat", tuple(range(5)))
@@ -319,7 +370,7 @@ def test_scan_grid_rotation_random(repeat):
319370
# expected rotated step vector = R(scan_rot) @ [step_x, 0]
320371
theta = np.deg2rad(scan_rot)
321372
exp_scan = np.array([np.cos(theta) * step[0], -np.sin(theta) * step[0]])
322-
np.testing.assert_allclose(vec_scan, exp_scan, atol=1e-6)
373+
assert_allclose(vec_scan, exp_scan, atol=1e-6)
323374

324375

325376
def test_singular_component_jacobian():

tests/test_gaussians.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717
calculate_z1_and_z2_from_M_and_f,
1818
)
1919
from skimage.restoration import unwrap_phase
20-
from temgym_core.utils import make_aperture, zero_phase, FresnelPropagator, fresnel_lens_imaging_solution
20+
from temgym_core.utils import (
21+
make_aperture, zero_phase, FresnelPropagator, fresnel_lens_imaging_solution
22+
)
2123
import numpy as np
2224
import jax.numpy as jnp
2325
import jax
@@ -293,7 +295,9 @@ def test_gaussian_free_space_vs_fresnel():
293295
gauss_input.shape[1] // 2,
294296
)
295297

296-
fresnel_gauss_image = FresnelPropagator(gauss_input, det_edge_x, wavelength, propagation_distance)
298+
fresnel_gauss_image = FresnelPropagator(
299+
gauss_input, det_edge_x, wavelength, propagation_distance
300+
)
297301

298302
# Normalize amplitude so the maximum magnitude is 1
299303
analytic_gauss_image /= np.max(np.abs(analytic_gauss_image))
@@ -402,8 +406,10 @@ def test_gaussian_lens_vs_fresnel():
402406
gauss_input.shape[1] // 2,
403407
)
404408

405-
fresnel_gauss_image = fresnel_lens_imaging_solution(gauss_input, Y, X, pixel_size[0], wavelength,
406-
defocus+np.abs(z1), f, z2)
409+
fresnel_gauss_image = fresnel_lens_imaging_solution(
410+
gauss_input, Y, X, pixel_size[0], wavelength,
411+
defocus+np.abs(z1), f, z2
412+
)
407413

408414
fresnel_gauss_image = zero_phase(
409415
fresnel_gauss_image,
@@ -563,11 +569,17 @@ def test_gaussian_two_beam_interference_vs_fresnel():
563569
tilted_shifted_plane_wave1 = np.exp(1j * k * dot[0])
564570
tilted_shifted_plane_wave2 = np.exp(1j * k * dot[1])
565571

566-
gaussian_misaligned1 = (gaussian_shifted_1 * tilted_shifted_plane_wave1).reshape(shape[0], shape[1])
567-
gaussian_misaligned2 = (gaussian_shifted_2 * tilted_shifted_plane_wave2).reshape(shape[0], shape[1])
572+
gaussian_misaligned1 = (
573+
gaussian_shifted_1 * tilted_shifted_plane_wave1
574+
).reshape(shape[0], shape[1])
575+
gaussian_misaligned2 = (
576+
gaussian_shifted_2 * tilted_shifted_plane_wave2
577+
).reshape(shape[0], shape[1])
568578
gaussian_misaligned = gaussian_misaligned1 + gaussian_misaligned2
569579

570-
fresnel_gauss_image = fresnel_lens_imaging_solution(gaussian_misaligned, Y, X, pixel_size[0], wavelength, 0.0, f, z2)
580+
fresnel_gauss_image = fresnel_lens_imaging_solution(
581+
gaussian_misaligned, Y, X, pixel_size[0], wavelength, 0.0, f, z2
582+
)
571583
fresnel_gauss_image = zero_phase(fresnel_gauss_image, shape[0]//2, shape[1]//2)
572584

573585
# Normalize amplitude so the maximum magnitude is 1

tests/test_run.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
from numpy.testing import assert_allclose
2+
3+
from temgym_core.components import Scanner, Plane, Descanner
4+
from temgym_core.source import PointSource
5+
from temgym_core.ray import Ray
6+
from temgym_core.run import run_iter
7+
from temgym_core.propagator import Propagator, FreeSpaceParaxial
8+
9+
10+
def test_run_iter():
11+
# These components shouldn't change the ray as it passes through
12+
components = (
13+
PointSource(z=0., semi_conv=0.023),
14+
Scanner(z=1.2, scan_pos_x=0., scan_pos_y=0.),
15+
Plane(z=1.2),
16+
Descanner(z=1.2, scan_pos_x=0., scan_pos_y=0.),
17+
Plane(z=3.1)
18+
)
19+
ray = Ray(
20+
x=0.12,
21+
y=0.23,
22+
dx=0.34,
23+
dy=0.45,
24+
z=3.14,
25+
pathlength=0.34
26+
)
27+
res = list(run_iter(ray=ray, components=components))
28+
29+
prev_ray = ray
30+
31+
for i, component in enumerate(components):
32+
prop_index = 2*i
33+
comp_index = 2*i + 1
34+
prop, prop_r = res[prop_index]
35+
comp, comp_r = res[comp_index]
36+
assert isinstance(prop, Propagator)
37+
assert isinstance(prop.propagator, FreeSpaceParaxial)
38+
assert prop.distance == component.z - prev_ray.z
39+
assert_allclose(prop_r.z, comp.z)
40+
assert_allclose(comp_r.z, comp.z)
41+
assert prev_ray.dx == prop_r.dx
42+
assert prev_ray.dy == prop_r.dy
43+
assert_allclose(prop_r.x, prev_ray.x + prev_ray.dx*prop.distance)
44+
assert_allclose(prop_r.y, prev_ray.y + prev_ray.dy*prop.distance)
45+
# FIXME add test for correct path length
46+
47+
prev_ray = comp_r
48+
49+
50+
def test_run_iter_noprop():
51+
# everything at the same z level
52+
z = 1.2
53+
# These components do change the ray, i.e. we
54+
# test that run_iter() actually passes the ray through the components
55+
components = (
56+
PointSource(z=z, semi_conv=0.023),
57+
Scanner(z=z, scan_pos_x=23., scan_pos_y=42.),
58+
Plane(z=z),
59+
Descanner(z=z, scan_pos_x=13., scan_pos_y=11.),
60+
Plane(z=z)
61+
)
62+
ray = Ray(
63+
x=0.12,
64+
y=0.23,
65+
dx=0.34,
66+
dy=0.45,
67+
z=z,
68+
pathlength=0.34
69+
)
70+
res = list(run_iter(ray=ray, components=components))
71+
72+
# Reference result: Compose the components without propagation
73+
res_ref = ray
74+
for comp in components:
75+
res_ref = comp(res_ref)
76+
77+
prev_ray = ray
78+
for i, component in enumerate(components):
79+
prop_index = 2*i
80+
comp_index = 2*i + 1
81+
prop, prop_r = res[prop_index]
82+
comp, comp_r = res[comp_index]
83+
assert isinstance(prop, Propagator)
84+
assert isinstance(prop.propagator, FreeSpaceParaxial)
85+
assert prop.distance == component.z - prev_ray.z
86+
assert_allclose(prop_r.z, comp.z)
87+
assert_allclose(comp_r.z, comp.z)
88+
prev_ray = comp_r
89+
90+
final_ray = res[-1][1]
91+
for attr in ('x', 'y', 'dx', 'dy', '_one', 'z', 'pathlength'):
92+
assert_allclose(getattr(final_ray, attr), getattr(res_ref, attr))

0 commit comments

Comments
 (0)