11import pytest
2+ from numpy .testing import assert_allclose
3+
24import numpy as np
35from jax import jacobian
46import jax .numpy as jnp
57import jax_dataclasses as jdc
68
7- from temgym_core .components import ScanGrid , Detector , Descanner , DescanError , Component
9+ from temgym_core .components import (
10+ ScanGrid , Detector , Descanner , Scanner , DescanError , Component , Plane
11+ )
812from temgym_core .ray import Ray
913from temgym_core .utils import custom_jacobian_matrix
1014
@@ -86,8 +90,8 @@ def test_scan_grid_metres_to_pixels(xy, rotation, expected_pixel_coords):
8690 shape = (11 , 11 ),
8791 )
8892 pixel_coords_y , pixel_coords_x = scan_grid .metres_to_pixels (xy )
89- np . testing . assert_allclose (pixel_coords_y , expected_pixel_coords [0 ], atol = 1e-6 )
90- np . testing . assert_allclose (pixel_coords_x , expected_pixel_coords [1 ], atol = 1e-6 )
93+ assert_allclose (pixel_coords_y , expected_pixel_coords [0 ], atol = 1e-6 )
94+ assert_allclose (pixel_coords_x , expected_pixel_coords [1 ], atol = 1e-6 )
9195
9296
9397@pytest .mark .parametrize (
@@ -115,8 +119,8 @@ def test_scan_grid_pixels_to_metres(pixel_coords, rotation, expected_xy):
115119 shape = (11 , 11 ),
116120 )
117121 metres_coords_x , metres_coords_y = scan_grid .pixels_to_metres (pixel_coords )
118- np . testing . assert_allclose (metres_coords_x , expected_xy [0 ], atol = 1e-6 )
119- np . testing . assert_allclose (metres_coords_y , expected_xy [1 ], atol = 1e-6 )
122+ assert_allclose (metres_coords_x , expected_xy [0 ], atol = 1e-6 )
123+ assert_allclose (metres_coords_y , expected_xy [1 ], atol = 1e-6 )
120124
121125
122126@pytest .mark .parametrize (
@@ -137,8 +141,8 @@ def test_detector_metres_to_pixels(xy, expected_pixel_coords):
137141 flip_y = False ,
138142 )
139143 pixel_coords_y , pixel_coords_x = detector .metres_to_pixels (xy )
140- np . testing . assert_allclose (pixel_coords_y , expected_pixel_coords [0 ], atol = 1e-6 )
141- np . testing . assert_allclose (pixel_coords_x , expected_pixel_coords [1 ], atol = 1e-6 )
144+ assert_allclose (pixel_coords_y , expected_pixel_coords [0 ], atol = 1e-6 )
145+ assert_allclose (pixel_coords_x , expected_pixel_coords [1 ], atol = 1e-6 )
142146
143147
144148# Test cases for Detector:
@@ -161,46 +165,93 @@ def test_detector_pixels_to_metres(pixel_coords, expected_xy):
161165 flip_y = False ,
162166 )
163167 metres_coords_x , metres_coords_y = detector .pixels_to_metres (pixel_coords )
164- np .testing .assert_allclose (metres_coords_x , expected_xy [0 ], atol = 1e-6 )
165- np .testing .assert_allclose (metres_coords_y , expected_xy [1 ], atol = 1e-6 )
168+ assert_allclose (metres_coords_x , expected_xy [0 ], atol = 1e-6 )
169+ assert_allclose (metres_coords_y , expected_xy [1 ], atol = 1e-6 )
170+
171+
172+ def test_plane ():
173+ x , y , dx , dy , z , pathlength = np .random .uniform (- 5.0 , 5.0 , size = 6 )
174+ ray = Ray (x = x , y = y , dx = dx , dy = dy , _one = 1.0 , z = 0.0 , pathlength = 0.0 )
175+ comp = Plane (z = 23 )
176+ out = comp (ray )
177+ for attr in ('x' , 'y' , 'dx' , 'dy' , '_one' , 'z' , 'pathlength' ):
178+ assert getattr (ray , attr ) == getattr (out , attr )
179+
180+
181+ def test_scanner_random ():
182+ # Randomly chosen scan position and ray parameters
183+ sp_x , sp_y = np .random .uniform (- 5.0 , 5.0 ), np .random .uniform (- 5.0 , 5.0 )
184+ st_x , st_y = np .random .uniform (- 0.5 , 0.5 ), np .random .uniform (0.5 , 0.5 )
185+ x , y , dx , dy , z , pathlength = np .random .uniform (- 5.0 , 5.0 , size = 6 )
186+
187+ sc = Scanner (
188+ z = 23 ,
189+ scan_pos_x = sp_x , scan_pos_y = sp_y ,
190+ scan_tilt_y = st_y , scan_tilt_x = st_x ,
191+ )
192+ ray = Ray (x = x , y = y , dx = dx , dy = dy , _one = 1.0 , z = z , pathlength = pathlength )
193+ out = sc (ray )
194+
195+ # Expected values computed using the same formula as in the implementation
196+ exp_x = x + sp_x
197+ exp_y = y + sp_y
198+ exp_dx = dx + st_x
199+ exp_dy = dy + st_y
200+
201+ assert_allclose (out .x , exp_x , atol = 1e-6 )
202+ assert_allclose (out .y , exp_y , atol = 1e-6 )
203+ assert_allclose (out .dx , exp_dx , atol = 1e-6 )
204+ assert_allclose (out .dy , exp_dy , atol = 1e-6 )
205+ for attr in ('_one' , 'z' , 'pathlength' ):
206+ assert getattr (ray , attr ) == getattr (out , attr )
166207
167208
168209def test_descanner_random_descan_error ():
169210 # Randomly chosen scan position and ray parameters
170211 sp_x , sp_y = np .random .uniform (- 5.0 , 5.0 ), np .random .uniform (- 5.0 , 5.0 )
171- x , y , dx , dy = np .random .uniform (- 5.0 , 5.0 , size = 4 )
212+ st_x , st_y = np .random .uniform (- 0.5 , 0.5 ), np .random .uniform (0.5 , 0.5 )
213+ x , y , dx , dy , z , pathlength = np .random .uniform (- 5.0 , 5.0 , size = 6 )
172214
173215 # Randomly chosen non-zero descan error (length 12)
174- err = np .random .rand (12 )
216+ (pxo_pxi , pxo_pyi , pyo_pxi , pyo_pyi ,
217+ sxo_pxi , sxo_pyi , syo_pxi , syo_pyi ,
218+ offpxi , offpyi , offsxi , offsyi ) = np .random .rand (12 )
175219
176220 err = DescanError (
177- pxo_pxi = err [ 0 ] ,
178- pxo_pyi = err [ 1 ] ,
179- pyo_pxi = err [ 2 ] ,
180- pyo_pyi = err [ 3 ] ,
181- sxo_pxi = err [ 4 ] ,
182- sxo_pyi = err [ 5 ] ,
183- syo_pxi = err [ 6 ] ,
184- syo_pyi = err [ 7 ] ,
185- offpxi = err [ 8 ] ,
186- offpyi = err [ 9 ] ,
187- offsxi = err [ 10 ] ,
188- offsyi = err [ 11 ] ,
221+ pxo_pxi = pxo_pxi ,
222+ pxo_pyi = pxo_pyi ,
223+ pyo_pxi = pyo_pxi ,
224+ pyo_pyi = pyo_pyi ,
225+ sxo_pxi = sxo_pxi ,
226+ sxo_pyi = sxo_pyi ,
227+ syo_pxi = syo_pxi ,
228+ syo_pyi = syo_pyi ,
229+ offpxi = offpxi ,
230+ offpyi = offpyi ,
231+ offsxi = offsxi ,
232+ offsyi = offsyi ,
189233 )
190- desc = Descanner (z = 0.0 , scan_pos_x = sp_x , scan_pos_y = sp_y , descan_error = err )
191- ray = Ray (x = x , y = y , dx = dx , dy = dy , _one = 1.0 , z = 0.0 , pathlength = 0.0 )
234+ desc = Descanner (
235+ z = 23 ,
236+ scan_pos_x = sp_x , scan_pos_y = sp_y ,
237+ scan_tilt_y = st_y , scan_tilt_x = st_x ,
238+ descan_error = err
239+ )
240+ ray = Ray (x = x , y = y , dx = dx , dy = dy , _one = 1.0 , z = z , pathlength = pathlength )
192241 out = desc (ray )
193242
194243 # Expected values computed using the same formula as in the implementation
195- exp_x = x + sp_x * err [ 0 ] + sp_y * err [ 1 ] + err [ 8 ] - sp_x
196- exp_y = y + sp_x * err [ 2 ] + sp_y * err [ 3 ] + err [ 9 ] - sp_y
197- exp_dx = dx + sp_x * err [ 4 ] + sp_y * err [ 5 ] + err [ 10 ]
198- exp_dy = dy + sp_x * err [ 6 ] + sp_y * err [ 7 ] + err [ 11 ]
244+ exp_x = x + sp_x * pxo_pxi + sp_y * pxo_pyi + offpxi - sp_x
245+ exp_y = y + sp_x * pyo_pxi + sp_y * pyo_pyi + offpyi - sp_y
246+ exp_dx = dx + sp_x * sxo_pxi + sp_y * sxo_pyi + offsxi - st_x
247+ exp_dy = dy + sp_x * syo_pxi + sp_y * syo_pyi + offsyi - st_y
199248
200- np .testing .assert_allclose (out .x , exp_x , atol = 1e-6 )
201- np .testing .assert_allclose (out .y , exp_y , atol = 1e-6 )
202- np .testing .assert_allclose (out .dx , exp_dx , atol = 1e-6 )
203- np .testing .assert_allclose (out .dy , exp_dy , atol = 1e-6 )
249+ assert_allclose (out .x , exp_x , atol = 1e-6 )
250+ assert_allclose (out .y , exp_y , atol = 1e-6 )
251+ assert_allclose (out .dx , exp_dx , atol = 1e-6 )
252+ assert_allclose (out .dy , exp_dy , atol = 1e-6 )
253+ for attr in ('_one' , 'z' , 'pathlength' ):
254+ assert getattr (ray , attr ) == getattr (out , attr )
204255
205256
206257def test_descanner_offset_consistency ():
@@ -223,7 +274,7 @@ def test_descanner_offset_consistency():
223274 offsyi = err [11 ],
224275 )
225276 desc = Descanner (
226- z = 0.0 , scan_pos_x = scan_pos_x , scan_pos_y = scan_pos_y , descan_error = err
277+ z = 11 , scan_pos_x = scan_pos_x , scan_pos_y = scan_pos_y , descan_error = err
227278 )
228279
229280 # generate a batch of random rays
@@ -251,7 +302,7 @@ def test_descanner_offset_consistency():
251302 # assert that all rays have received the same offset
252303 first = offsets [0 ]
253304 for off in offsets :
254- np . testing . assert_allclose (off , first , atol = 1e-6 )
305+ assert_allclose (off , first , atol = 1e-6 )
255306
256307
257308def test_descanner_jacobian_matrix ():
@@ -294,7 +345,7 @@ def test_descanner_jacobian_matrix():
294345 [0.0 , 0.0 , 0.0 , 0.0 , 1.0 ],
295346 ]
296347 )
297- np . testing . assert_allclose (J , T , atol = 1e-6 )
348+ assert_allclose (J , T , atol = 1e-6 )
298349
299350
300351@pytest .mark .parametrize ("repeat" , tuple (range (5 )))
@@ -319,7 +370,7 @@ def test_scan_grid_rotation_random(repeat):
319370 # expected rotated step vector = R(scan_rot) @ [step_x, 0]
320371 theta = np .deg2rad (scan_rot )
321372 exp_scan = np .array ([np .cos (theta ) * step [0 ], - np .sin (theta ) * step [0 ]])
322- np . testing . assert_allclose (vec_scan , exp_scan , atol = 1e-6 )
373+ assert_allclose (vec_scan , exp_scan , atol = 1e-6 )
323374
324375
325376def test_singular_component_jacobian ():
0 commit comments