Skip to content

Commit e73e481

Browse files
authored
Several fixes for torch & MPS device (#45)
* Some fixes for torch[mps] * Change dtype of nu in adam solver * Fix underflow NaNs on torch[mps] * Allow det_rotation to be None * Fix plan validation w/ complex numbers * layers regularization support on pytorch * Fix warnings w/ torch * Fix remove_linear_ramp on torch[mps] * Fix tests on pytorch
1 parent 3cb9aa1 commit e73e481

19 files changed

Lines changed: 308 additions & 91 deletions

File tree

phaser/engines/common/regularizers.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
get_array_module, get_scipy_module, Float, unstack,
1010
jit, fft2, ifft2, abs2, xp_is_jax, to_real_dtype, to_numpy
1111
)
12+
from phaser.utils.image import convolve1d
1213
from phaser.state import ReconsState
1314
from phaser.hooks.regularization import (
1415
ClampObjectAmplitudeProps, LimitProbeSupportProps, NonNegObjectPhaseProps,
@@ -142,7 +143,6 @@ def init_state(self, sim: ReconsState) -> None:
142143

143144
def apply_iter(self, sim: ReconsState, state: None) -> t.Tuple[ReconsState, None]:
144145
xp = get_array_module(sim.object.data)
145-
scipy = get_scipy_module(sim.object.data)
146146
dtype = to_real_dtype(sim.object.data)
147147

148148
if len(sim.object.thicknesses) < 2:
@@ -161,17 +161,9 @@ def apply_iter(self, sim: ReconsState, state: None) -> t.Tuple[ReconsState, None
161161

162162
# we convolve the log of object, because the transmission
163163
# function is multiplicative, not additive
164-
165-
if xp_is_jax(xp):
166-
new_obj = xp.exp(scipy.signal.convolve(
167-
xp.pad(xp.log(sim.object.data), ((r, r), (0, 0), (0, 0)), mode='edge'),
168-
kernel[:, None, None],
169-
mode="valid"
170-
))
171-
else:
172-
new_obj = xp.exp(scipy.ndimage.convolve1d(xp.log(
173-
sim.object.data
174-
), kernel, axis=0, mode='nearest'))
164+
new_obj = xp.exp(convolve1d(xp.log(
165+
sim.object.data
166+
), kernel, axis=0, mode='nearest'))
175167

176168
assert new_obj.shape == sim.object.data.shape
177169
assert new_obj.dtype == sim.object.data.dtype

phaser/engines/gradient/solvers.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
import numpy
2727
from numpy.typing import NDArray
2828

29-
from phaser.utils.num import get_array_module
29+
from phaser.utils.num import get_array_module, to_real_dtype, xp_is_torch
3030
import phaser.utils.tree as tree
3131
from phaser.hooks.solver import GradientSolver, GradientSolverArgs
3232
from phaser.hooks.schedule import ScheduleLike, Schedule
@@ -229,7 +229,7 @@ def scale_by_adam(
229229
def init_fn(params: Params) -> ScaleByAdamState:
230230
xp = get_array_module(params)
231231
mu = tree.zeros_like(params, dtype=mu_dtype) # First moment
232-
nu = tree.zeros_like(params) # Second moment
232+
nu = tree.map(lambda x: xp.zeros_like(x, dtype=to_real_dtype(x.dtype)), params) # Second moment
233233
return ScaleByAdamState(n=xp.zeros((), dtype=xp.int32), mu=mu, nu=nu)
234234

235235
def update_fn(
@@ -241,6 +241,15 @@ def update_fn(
241241
nu = tree.update_moment_per_elem_norm(updates, state.nu, b2, 2)
242242
n_inc = safe_increment(state.n)
243243

244+
# HACK: on mps we need to prevent small mu values from returning nan
245+
if xp_is_torch(xp) and any(
246+
leaf.device.type == 'mps' for leaf in tree.leaves(updates)
247+
):
248+
mu = tree.map(
249+
lambda arr: xp.nan_to_num(arr, nan=0.),
250+
mu, is_leaf=lambda x: x is None
251+
)
252+
244253
if nesterov:
245254
mu_hat = tree.map(
246255
lambda m, g: b1 * m + (1 - b1) * g,

phaser/hooks/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
if t.TYPE_CHECKING:
1414
from phaser.utils.num import Sampling
1515
from phaser.utils.object import ObjectSampling
16-
from ..state import ObjectState, ProbeState, ReconsState, Patterns
16+
from ..state import ObjectState, ProbeState, ReconsState, Patterns # noqa: F401
1717
from ..execute import Observer
1818

1919

phaser/hooks/io/empad.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def load_empad(args: None, props: LoadEmpadProps) -> RawData:
4646
'shape': scan_shape,
4747
'step_size': tuple(s*1e10 for s in reversed(meta.scan_step)), # m to A
4848
'affine': meta.scan_correction[::-1, ::-1] if meta.scan_correction is not None else None,
49-
'rotation': meta.det_rotation - meta.scan_rotation,
49+
'rotation': (meta.det_rotation or 0.0) - meta.scan_rotation,
5050
}
5151

5252
#TODO: add tilt to metafile
@@ -95,4 +95,4 @@ def load_empad(args: None, props: LoadEmpadProps) -> RawData:
9595
'scan_hook': scan_hook,
9696
'tilt_hook': tilt_hook,
9797
'seed': None,
98-
}
98+
}

phaser/hooks/preprocessing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def offset_patterns(raw_data: RawData, props: OffsetProps) -> RawData:
4343
return raw_data
4444

4545
def bin_patterns(raw_data: RawData, props: BinProps) -> RawData:
46-
xp = get_array_module(raw_data['patterns'])
46+
#xp = get_array_module(raw_data['patterns'])
4747
bin_factor = props.bin
4848
patterns = raw_data['patterns']
4949
Ny, Nx = patterns.shape[-2:]

phaser/hooks/solver.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
if t.TYPE_CHECKING:
1414
from phaser.engines.common.simulation import SimulationState
1515
from phaser.execute import Observer
16-
from phaser.plan import ConventionalEnginePlan, GradientEnginePlan
16+
from phaser.plan import ConventionalEnginePlan, GradientEnginePlan # noqa: F401
1717
from phaser.state import ReconsState
1818

1919

phaser/io/empad.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def __post_init__(self):
5858
Flips to apply to the raw diffraction patterns, (flip_y, flip_x, transpose).
5959
Defaults to `(True, False, False)` (appears to be the most common orientation).
6060
"""
61-
det_rotation: float = 0.0
61+
det_rotation: t.Optional[float] = None
6262
"""Detector rotation (degrees)."""
6363

6464
orig_path: t.Optional[Path] = None

phaser/main.py

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,15 @@ def cli():
1111

1212
@cli.command('run')
1313
@click.argument('path', type=click.Path(exists=True, dir_okay=False))
14-
def run(path: t.Union[str, Path]):
14+
@click.option('--raise-on-warn/--no-raise-on-warn')
15+
def run(path: t.Union[str, Path], *, raise_on_warn: bool = False):
1516
from .plan import ReconsPlan
1617
from .execute import execute_plan
18+
19+
if raise_on_warn:
20+
import warnings
21+
warnings.simplefilter('error')
22+
1723
plans = ReconsPlan.from_yaml_all(path)
1824

1925
for plan in plans:
@@ -65,19 +71,33 @@ def validate(path: t.Union[str, Path], json: bool = False):
6571

6672
sys.exit(1)
6773

74+
if json:
75+
from json import dump, dumps
76+
77+
def _serialize_complex(val: t.Any) -> t.Any:
78+
if isinstance(val, complex):
79+
return {'re': val.real, 'im': val.imag}
80+
raise TypeError()
81+
82+
try:
83+
s = dumps({
84+
'result': 'success',
85+
'plans': [(plan.name, plan.into_data()) for plan in plans],
86+
}, default=_serialize_complex)
87+
except Exception as e:
88+
print(f"Failed to serialize validated plans: {e}", file=sys.stderr)
89+
dump({'result': 'error', 'error': str(e)}, sys.stdout)
90+
print()
91+
sys.exit(2)
92+
93+
sys.stdout.write(s)
94+
print()
95+
6896
if len(plans) == 1:
6997
print("Validation of plan successful!", file=sys.stderr)
7098
else:
7199
print(f"Validation of {len(plans)} plans successful!", file=sys.stderr)
72100

73-
if json:
74-
from json import dump
75-
dump({
76-
'result': 'success',
77-
'plans': [(plan.name, plan.into_data()) for plan in plans],
78-
}, sys.stdout)
79-
print()
80-
81101

82102
@cli.command('worker')
83103
@click.argument('url', type=str, required=True)

phaser/types.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,23 @@ def collect_errors(self, val: t.Any) -> t.Optional[ErrorNode]:
302302
return self.inner.collect_errors(val)
303303

304304

305+
class ComplexCartesian(pane.PaneBase, kw_only=True):
306+
re: float
307+
im: float = 0.0
308+
309+
def __complex__(self) -> complex:
310+
return complex(self.re, self.im)
311+
312+
313+
class ComplexPolar(pane.PaneBase, kw_only=True):
314+
mag: float
315+
angle: float = 0.0 # degrees
316+
317+
def __complex__(self) -> complex:
318+
theta = numpy.deg2rad(self.angle)
319+
return self.mag * complex(numpy.cos(theta), numpy.sin(theta))
320+
321+
305322
__all__ = [
306323
'BackendName', 'Dataclass', 'Slices', 'Flag',
307324
'process_flag', 'flag_any_true',

0 commit comments

Comments
 (0)