Skip to content

Commit 2f87e18

Browse files
uelluematbryan52
andauthored
WIP things to make Microscope Calibration work (#4)
* Default with correct type * Add more tests ...in particular for `run_iter()` and more components. Test descan error implementation for correctly handling the names, not only indices. Drive-by flake8 fixes * Introduce type for single coordinates or pixels; allow float pixels The types work well to make clear what x and y are, and to distinguish pixel vs physical coordinates in code. * Introduce types for single values as opposed to arrays, which correspond well to single rays etc * Make Grid center a single type * Allow floats for pixel coordinates since conversion to discrete values should happen as late as possible to avoid rounding issues * Test for conversion helpers that reduce boilerplate * Set float64 support everywhere --------- Co-authored-by: Matthew Bryan <78845903+matbryan52@users.noreply.github.com>
1 parent 39d20b7 commit 2f87e18

19 files changed

Lines changed: 324 additions & 64 deletions

conftest.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
import jax
22

3+
jax.config.update("jax_enable_x64", True) # noqa: E702
34
jax.config.update("jax_platform_name", "cpu")

src/temgym_core/__init__.py

Lines changed: 53 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
from typing_extensions import TypeAlias
2-
from typing import NamedTuple
2+
from typing import NamedTuple, Union
3+
4+
import jax; jax.config.update("jax_enable_x64", True) # noqa: E702
5+
import jax.numpy as jnp
36
import numpy as np
47
from numpy.typing import NDArray
58

@@ -46,6 +49,26 @@ class ScaleYX(NamedTuple):
4649
x: float
4750

4851

52+
class CoordXY(NamedTuple):
53+
"""Continuous coordinates in the optical frame.
54+
55+
Parameters
56+
----------
57+
x : float
58+
X position, metres.
59+
y : float
60+
Y position, metres.
61+
"""
62+
x: float
63+
y: float
64+
65+
def to_coords(self) -> 'CoordsXY':
66+
return CoordsXY(
67+
x=jnp.array((self.x,)),
68+
y=jnp.array((self.y,))
69+
)
70+
71+
4972
class CoordsXY(NamedTuple):
5073
"""Continuous coordinates in the optical frame.
5174
@@ -60,22 +83,46 @@ class CoordsXY(NamedTuple):
6083
y: NDArray[np.floating]
6184

6285

86+
class PixelYX(NamedTuple):
87+
"""Pixel coordinates for images.
88+
89+
Parameters
90+
----------
91+
y : Union[int, float]
92+
Pixel row indices
93+
x : Union[int, float]
94+
Pixel column indices
95+
96+
Notes
97+
-----
98+
Pixel indices are 0-based.
99+
"""
100+
y: Union[int, float]
101+
x: Union[int, float]
102+
103+
def to_pixels(self) -> 'PixelsYX':
104+
return PixelsYX(
105+
x=jnp.array((self.x,)),
106+
y=jnp.array((self.y,))
107+
)
108+
109+
63110
class PixelsYX(NamedTuple):
64-
"""Discrete pixel coordinates for images.
111+
"""Pixel coordinates for images.
65112
66113
Parameters
67114
----------
68115
y : numpy.ndarray
69-
Pixel row indices. Integer dtype.
116+
Pixel row indices. Integer or floating dtype.
70117
x : numpy.ndarray
71-
Pixel column indices. Integer dtype.
118+
Pixel column indices. Integer or floating dtype.
72119
73120
Notes
74121
-----
75122
Pixel indices are 0-based.
76123
"""
77-
y: NDArray[np.integer]
78-
x: NDArray[np.integer]
124+
y: NDArray[Union[np.integer, np.floating]]
125+
x: NDArray[Union[np.integer, np.floating]]
79126

80127

81128
# Convenience re-exports

src/temgym_core/components.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from typing import NamedTuple, Dict
2+
import jax; jax.config.update("jax_enable_x64", True) # noqa: E702
23
import jax_dataclasses as jdc
34
import jax.numpy as jnp
45

src/temgym_core/coordinate_transforms.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import jax; jax.config.update("jax_enable_x64", True) # noqa: E702
12
import jax.numpy as jnp
23
import jax.lax as lax
34
from . import Degrees, Radians, ShapeYX, CoordsXY, ScaleYX

src/temgym_core/gaussian.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
1+
import jax; jax.config.update("jax_enable_x64", True) # noqa: E702
12
import jax.numpy as jnp
2-
import jax
3+
import jax_dataclasses as jdc
4+
from jax._src.lax.control_flow.loops import _batch_and_remainder
5+
from jax import lax
6+
37
from .grid import Grid
48
from .run import run_to_end
59
from .utils import custom_jacobian_matrix
610
from .ray import Ray
7-
import jax_dataclasses as jdc
8-
from jax._src.lax.control_flow.loops import _batch_and_remainder
9-
from jax import lax
1011

1112

1213
def w_z(w0, z, z_r):

src/temgym_core/grid.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
from typing import Union
22
import numpy as np
3+
import jax; jax.config.update("jax_enable_x64", True) # noqa: E702
34
import jax.numpy as jnp
45

5-
from . import Degrees, ShapeYX, CoordsXY, ScaleYX, PixelsYX
6+
from . import Degrees, ShapeYX, CoordXY, CoordsXY, ScaleYX, PixelsYX
67
from .ray import Ray
78
from .utils import inplace_sum, try_ravel, try_reshape
89
from .coordinate_transforms import pixels_to_metres_transform, apply_transformation
@@ -15,7 +16,7 @@ class Grid:
1516
----------
1617
z : float
1718
Axial position in metres.
18-
centre : CoordsXY
19+
centre : CoordXY
1920
Grid centre in metres (x, y).
2021
shape : ShapeYX
2122
Grid shape (y, x) in pixels.
@@ -27,7 +28,7 @@ class Grid:
2728
If True, apply an additional vertical flip.
2829
"""
2930
z: float
30-
centre: CoordsXY
31+
centre: CoordXY
3132
shape: ShapeYX
3233
pixel_size: ScaleYX
3334
rotation: Degrees
@@ -146,7 +147,10 @@ def metres_to_pixels(self, coords: CoordsXY, cast: bool = True) -> PixelsYX:
146147
if cast:
147148
pixels_y = jnp.round(pixels_y).astype(jnp.int32)
148149
pixels_x = jnp.round(pixels_x).astype(jnp.int32)
149-
return try_reshape(pixels_y, coords_y), try_reshape(pixels_x, coords_x)
150+
return PixelsYX(
151+
y=try_reshape(pixels_y, coords_y),
152+
x=try_reshape(pixels_x, coords_x)
153+
)
150154

151155
def pixels_to_metres(self, pixels: PixelsYX) -> CoordsXY:
152156
"""Convert pixel indices to metric coordinates.
@@ -172,7 +176,10 @@ def pixels_to_metres(self, pixels: PixelsYX) -> CoordsXY:
172176
metres_y, metres_x = apply_transformation(
173177
try_ravel(pixels_y), try_ravel(pixels_x), pixels_to_metres_mat
174178
)
175-
return try_reshape(metres_x, pixels_x), try_reshape(metres_y, pixels_y)
179+
return CoordsXY(
180+
x=try_reshape(metres_x, pixels_x),
181+
y=try_reshape(metres_y, pixels_y)
182+
)
176183

177184
def ray_at_grid(
178185
self, px_y: float, px_x: float, dx: float = 0., dy: float = 0., z: float | None = None

src/temgym_core/ray.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import dataclasses
2+
import jax; jax.config.update("jax_enable_x64", True) # noqa: E702
23
import jax_dataclasses as jdc
34
import jax.numpy as jnp
45
from .tree_utils import HasParamsMixin

src/temgym_core/run.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import dataclasses
33
from typing import TYPE_CHECKING, Sequence, Union, Any, Callable, Generator
44

5-
import jax
5+
import jax; jax.config.update("jax_enable_x64", True) # noqa: E702
66
import jax.numpy as jnp
77
from .utils import custom_jacobian_matrix
88
from .propagator import FreeSpaceParaxial, BasePropagator, Propagator

src/temgym_core/source.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import numpy as np
2+
import jax; jax.config.update("jax_enable_x64", True) # noqa: E702
23
import jax_dataclasses as jdc
34

45
from .tree_utils import HasParamsMixin
@@ -98,7 +99,7 @@ class PointSource(Source):
9899
"""
99100
z: float
100101
semi_conv: float
101-
offset_xy: CoordsXY = (0.0, 0.0)
102+
offset_xy: CoordsXY = CoordsXY(x=0.0, y=0.0)
102103

103104
def generate_array(self, num: int, random: bool = False) -> np.ndarray:
104105
"""Generate rays with varying slopes within a cone of semi-convergence.

src/temgym_core/transfer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import jax; jax.config.update("jax_enable_x64", True) # noqa: E702
12
import jax.numpy as jnp
23
import numpy as np
34

0 commit comments

Comments
 (0)