diff --git a/packages/essreduce/pyproject.toml b/packages/essreduce/pyproject.toml index 374d65426..67985d8bd 100644 --- a/packages/essreduce/pyproject.toml +++ b/packages/essreduce/pyproject.toml @@ -31,7 +31,7 @@ dynamic = ["version"] dependencies = [ "sciline>=25.11.0", - "scipp>=26.3.0", + "scipp>=26.3.1", "scippneutron>=25.11.1", "scippnexus>=25.06.0", ] diff --git a/packages/essreduce/src/ess/reduce/nexus/__init__.py b/packages/essreduce/src/ess/reduce/nexus/__init__.py index c310ca4a0..e739a622f 100644 --- a/packages/essreduce/src/ess/reduce/nexus/__init__.py +++ b/packages/essreduce/src/ess/reduce/nexus/__init__.py @@ -15,6 +15,7 @@ from . import types from ._nexus_loader import ( compute_component_position, + compute_detector_position, extract_signal_data_array, group_event_data, load_all_components, @@ -29,6 +30,7 @@ __all__ = [ 'GenericNeXusWorkflow', 'compute_component_position', + 'compute_detector_position', 'extract_signal_data_array', 'group_event_data', 'load_all_components', diff --git a/packages/essreduce/src/ess/reduce/nexus/_nexus_loader.py b/packages/essreduce/src/ess/reduce/nexus/_nexus_loader.py index 100e9efc2..d14c7d298 100644 --- a/packages/essreduce/src/ess/reduce/nexus/_nexus_loader.py +++ b/packages/essreduce/src/ess/reduce/nexus/_nexus_loader.py @@ -23,6 +23,8 @@ NeXusFile, NeXusGroup, NeXusLocationSpec, + NeXusTransformation, + RunType, ) @@ -454,6 +456,55 @@ def _to_snx_selection(selection, *, for_events: bool) -> snx.typing.ScippIndex: return selection +def compute_detector_position( + da: sc.DataArray, + *, + transform: NeXusTransformation[snx.NXdetector, RunType], + # Strictly speaking we could apply an offset by modifying the transformation chain, + # using a more generic implementation. However, this may in general require + # extending the chain and it is currently not clear if that is desirable. As far as + # I am aware the offset is currently mainly used for handling files from other + # facilities and it is not clear if it is needed for ESS data and should be kept at + # all. + offset: sc.Variable, +) -> sc.Variable | sc.DataArray: + """Compute the positions of detector pixels. + + Parameters + ---------- + da: + Detector (event) data as returned by :func:`extract_signal_data_array`. + transform: + Transformation matrix for the detector. + offset: + Offset to add to the detector position. + + Returns + ------- + : + The detector position as a data array if ``transform`` is time-dependent + or as a variable otherwise. + """ + # Note: We apply offset as early as possible, i.e., right in this function + # the detector array from the raw loader NeXus group, to prevent a source of bugs. + # If the NXdetector in the file is not 1-D, we want to match the order of dims. + # zip_pixel_offsets otherwise yields a vector with dimensions in the order given + # by the x/y/z offsets. + offsets = snx.zip_pixel_offsets(da.coords) + # Get the dims in the order of the detector data array, but filter out dims that + # don't exist in the offsets (e.g. the detector data may have a 'time' dimension). + dims = [dim for dim in da.dims if dim in offsets.dims] + offsets = offsets.transpose(dims).copy() + # We use the unit of the offsets as this is likely what the user expects. + if transform.value.unit is not None and transform.value.unit != '': + transform_value = transform.value.to(unit=offsets.unit) + else: + transform_value = transform.value + position = transform_value * offsets + position += offset.to(unit=position.unit, copy=False) + return position + + def load_data( file_path: FilePath | NeXusFile | NeXusGroup, selection: snx.typing.ScippIndex | slice = (), diff --git a/packages/essreduce/src/ess/reduce/nexus/types.py b/packages/essreduce/src/ess/reduce/nexus/types.py index 05dcbb62e..b135bd473 100644 --- a/packages/essreduce/src/ess/reduce/nexus/types.py +++ b/packages/essreduce/src/ess/reduce/nexus/types.py @@ -186,7 +186,61 @@ class NeXusData(sciline.Scope[Component, RunType, sc.DataArray], sc.DataArray): class Position(sciline.Scope[Component, RunType, sc.Variable], sc.Variable): - """Position of a component such as source, sample, monitor, or detector.""" + """Position of a component that does not move, such as source or sample.""" + + +@dataclass(init=False, repr=False, slots=True) +class DynamicPosition(Generic[Component, RunType]): + """Position of a potentially moving component such as an analyzer or detector. + + The position can depend on time. In this case, a time coordinate is also stored. + Use ``position`` to get the position if it is scalar, or ``positions`` + to get the position as a (potentially time-dependent) DataArray. + """ + + _position: sc.Variable + _time: sc.Variable | None + + def __init__(self, pos: sc.DataArray | sc.Variable) -> None: + if pos.ndim == 0: + self._position = pos.data if isinstance(pos, sc.DataArray) else pos + self._time = None + else: + if not isinstance(pos, sc.DataArray): + raise sc.DimensionError( + "Position is not a scalar, so it must be a DataArray" + ) + self._position = pos.data + self._time = pos.coords['time'] + + @property + def is_dynamic(self) -> bool: + return self._time is not None + + @property + def position(self) -> sc.Variable: + if self.is_dynamic: + raise sc.DimensionError( + "Position is time-dependent, use `positions` instead." + ) + return self._position + + @property + def positions(self) -> sc.DataArray: + da = sc.DataArray(self._position) + if self._time is not None: + da.coords['time'] = self._time + return da + + def __str__(self) -> str: + if self.is_dynamic: + time_str = f", time={self._time}" + else: + time_str = "" + return f"Position(position={self._position}{time_str})" + + def __repr__(self) -> str: + return f"Position(position={self._position}, time={self._time})" class DetectorPositionOffset(sciline.Scope[RunType, sc.Variable], sc.Variable): @@ -225,6 +279,15 @@ class TimeInterval(Generic[RunType]): value: slice +class TransformationTimeFilter(Generic[Component, RunType]): + """Filter for time-dependent transformations.""" + + def __new__(cls, x: Any) -> Any: + return x + + def __call__(self, transform: sc.DataArray) -> sc.Variable | sc.DataArray: ... + + @dataclass class NeXusFileSpec(Generic[RunType]): value: FilePath | NeXusFile | NeXusGroup @@ -280,25 +343,21 @@ class NeXusTransformationChain( @dataclass class NeXusTransformation(Generic[Component, RunType]): - value: sc.Variable + """A NeXus transformation computed from a transformation chain. + + If the transformation is time-dependent, it is stored as a data array + with a 'time' coordinate. + Otherwise, the transformation is stored as a variable. + """ + + value: sc.Variable | sc.DataArray @staticmethod def from_chain( chain: NeXusTransformationChain[Component, RunType], ) -> 'NeXusTransformation[Component, RunType]': - """ - Convert a transformation chain to a single transformation. - - As transformation chains may be time-dependent, this method will need to select - a specific time point to convert to a single transformation. This may include - averaging as well as threshold checks. This is not implemented yet and we - therefore currently raise an error if the transformation chain does not compute - to a scalar. - """ - if chain.transformations.sizes != {}: - raise ValueError(f"Expected scalar transformation, got {chain}") - transform = chain.compute() - return NeXusTransformation(value=transform) + """Convert a transformation chain to a single transformation.""" + return NeXusTransformation(value=chain.compute()) class RawChoppers( diff --git a/packages/essreduce/src/ess/reduce/nexus/workflow.py b/packages/essreduce/src/ess/reduce/nexus/workflow.py index 8d1939cb9..f332d4137 100644 --- a/packages/essreduce/src/ess/reduce/nexus/workflow.py +++ b/packages/essreduce/src/ess/reduce/nexus/workflow.py @@ -6,7 +6,7 @@ import warnings from collections.abc import Iterable from copy import deepcopy -from typing import Any, TypeVar +from typing import Any, Never, TypeVar import sciline import sciline.typing @@ -25,6 +25,7 @@ Component, DetectorBankSizes, DetectorPositionOffset, + DynamicPosition, EmptyDetector, EmptyMonitor, Filename, @@ -50,6 +51,7 @@ RunType, Source, TimeInterval, + TransformationTimeFilter, UniqueComponent, ) @@ -275,38 +277,52 @@ def load_nexus_data( def get_transformation_chain( - detector: NeXusComponent[Component, RunType], + component: NeXusComponent[Component, RunType], ) -> NeXusTransformationChain[Component, RunType]: """ - Extract the transformation chain from a NeXus detector group. + Extract the transformation chain from a NeXus component group. Parameters ---------- - detector: - NeXus detector group. + component: + NeXus component group. """ - chain = detector['depends_on'] + chain = component['depends_on'] return NeXusTransformationChain[Component, RunType](chain) -def _time_filter(transform: sc.DataArray) -> sc.Variable: - if transform.ndim == 0 or transform.sizes == {'time': 1}: - return transform.data.squeeze() +def reject_time_dependent_transform( + transform: sc.DataArray, +) -> Never: + """Raise a value error to forbid time-dependent transformations by default.""" raise ValueError( f"Transform is time-dependent: {transform}, but no filter is provided." ) +def _apply_time_filter( + transform: sc.DataArray, + user_filter: TransformationTimeFilter[Component, RunType], +) -> sc.Variable | sc.DataArray: + if transform.ndim == 0 or transform.sizes == {'time': 1}: + return transform.data.squeeze() + return user_filter(transform) + + def to_transformation( - chain: NeXusTransformationChain[Component, RunType], interval: TimeInterval[RunType] + chain: NeXusTransformationChain[Component, RunType], + interval: TimeInterval[RunType], + time_filter: TransformationTimeFilter[ + Component, RunType + ] = reject_time_dependent_transform, ) -> NeXusTransformation[Component, RunType]: """ - Convert transformation chain into a single transformation matrix. + Convert a transformation chain into a single transformation matrix. If one or more transformations in the chain are time-dependent, the time interval is used to select a specific time point. If the interval is not a single time point, - an error is raised. This may be extended in the future to a more sophisticated - mechanism, e.g., averaging over the interval to remove noise. + ``time_filter`` is applied to the transformation. By default, this will raise an + exception. Provide a different filter to customize how time-dependence is handled. Parameters ---------- @@ -314,6 +330,9 @@ def to_transformation( Transformation chain. interval: Time interval to select from the transformation chain. + time_filter: + Callable to apply to time-dependent transformations. + Defaults to raising a :class:`ValueError`. """ chain = deepcopy(chain) @@ -336,9 +355,9 @@ def to_transformation( idx = label_based_index_to_positional_index( sizes=t.sizes, coord=time, index=interval.value ) - t.value = _time_filter(t.value[idx]) + t.value = _apply_time_filter(t.value[idx], time_filter) else: - t.value = _time_filter(t.value['time', interval.value]) + t.value = _apply_time_filter(t.value['time', interval.value], time_filter) return NeXusTransformation[Component, RunType].from_chain(chain) @@ -347,9 +366,22 @@ def compute_position( transformation: NeXusTransformation[Component, RunType], ) -> Position[Component, RunType]: """Compute the position of a component from a transformation matrix.""" + if isinstance(transformation.value, sc.DataArray): + raise ValueError( + "Attempted to compute a static position from a time-dependent " + "transformation. Either provide a time interval parameter or " + "time filter." + ) return Position[Component, RunType](transformation.value * origin) +def compute_dynamic_position( + transformation: NeXusTransformation[Component, RunType], +) -> DynamicPosition[Component, RunType]: + """Compute the position of a component from a transformation matrix.""" + return DynamicPosition[Component, RunType](transformation.value * origin) + + def get_calibrated_detector( detector: NeXusComponent[snx.NXdetector, RunType], *, @@ -369,6 +401,10 @@ def get_calibrated_detector( The data array is reshaped to the logical detector shape, by folding the data array along the detector_number dimension. + The output contains pixel positions computed from ``transform`` and ``offset``. + If ``transform`` is time-dependent, the output contains a 'time' dimension + and coordinate corresponding to the time coordinate of ``transform``. + Parameters ---------- detector: @@ -385,25 +421,16 @@ def get_calibrated_detector( sizes := (bank_sizes or {}).get(detector.get('nexus_component_name')) ) is not None: da = da.fold(dim="detector_number", sizes=sizes) - # Note: We apply offset as early as possible, i.e., right in this function - # the detector array from the raw loader NeXus group, to prevent a source of bugs. - # If the NXdetector in the file is not 1-D, we want to match the order of dims. - # zip_pixel_offsets otherwise yields a vector with dimensions in the order given - # by the x/y/z offsets. - offsets = snx.zip_pixel_offsets(da.coords) - # Get the dims in the order of the detector data array, but filter out dims that - # don't exist in the offsets (e.g. the detector data may have a 'time' dimension). - dims = [dim for dim in da.dims if dim in offsets.dims] - offsets = offsets.transpose(dims).copy() - # We use the unit of the offsets as this is likely what the user expects. - if transform.value.unit is not None and transform.value.unit != '': - transform_value = transform.value.to(unit=offsets.unit) - else: - transform_value = transform.value - position = transform_value * offsets - return EmptyDetector[RunType]( - da.assign_coords(position=position + offset.to(unit=position.unit)) - ) + + position = nexus.compute_detector_position(da, transform=transform, offset=offset) + + if isinstance(position, sc.DataArray): # time-dependent transform + raise ValueError( + "Time-dependent positions are not supported by default. Either select a " + "time interval or override `get_calibrated_detector`." + ) + + return EmptyDetector[RunType](da.assign_coords(position=position)) def assemble_detector_data( @@ -422,13 +449,29 @@ def assemble_detector_data( neutron_data: Neutron data array (events or histogram). """ - if neutron_data.bins is not None: + detector_coords = dict(detector.coords) + if neutron_data.is_binned: neutron_data = nexus.group_event_data( event_data=neutron_data, detector_number=detector.coords['detector_number'] ) + if 'time' in detector.dims: + # Give the neutron data a 'time' dimension matching the times in the + # detector data. Preserve the `event_time_zero` event coord. + # This is needed to add time-dependent detector coords and masks below. + neutron_data = neutron_data.bin( + event_time_zero=detector_coords['time'].rename(time='event_time_zero') + ).rename_dims(event_time_zero='time') + neutron_data.coords['time'] = neutron_data.coords.pop('event_time_zero') + else: + position = detector_coords.get('position') + if position is not None and 'time' in position.dims: + raise NotImplementedError( + "Time-dependent positions are not yet supported for histogram data." + ) + return RawDetector[RunType]( _add_variances(neutron_data) - .assign_coords(detector.coords) + .assign_coords(detector_coords) .assign_masks(detector.masks) ) @@ -659,7 +702,6 @@ def load_source_metadata_from_nexus( definitions["NXdetector"] = _StrippedDetector definitions["NXmonitor"] = _StrippedMonitor - _common_providers = ( gravity_vector_neg_y, file_path_to_file_spec, @@ -670,11 +712,11 @@ def load_source_metadata_from_nexus( get_transformation_chain, to_transformation, compute_position, + compute_dynamic_position, load_nexus_data, load_nexus_component, load_all_nexus_components, data_by_name, - nx_class_for_crystal, nx_class_for_detector, nx_class_for_monitor, nx_class_for_source, @@ -712,11 +754,14 @@ def LoadMonitorWorkflow( """Generic workflow for loading monitor data from a NeXus file.""" wf = sciline.Pipeline( (*_common_providers, *_monitor_providers), + params={ + PreopenNeXusFile: PreopenNeXusFile(False), + TransformationTimeFilter: reject_time_dependent_transform, + }, constraints=_gather_constraints( run_types=run_types, monitor_types=monitor_types ), ) - wf[PreopenNeXusFile] = PreopenNeXusFile(False) return wf @@ -726,10 +771,13 @@ def LoadDetectorWorkflow( """Generic workflow for loading detector data from a NeXus file.""" wf = sciline.Pipeline( (*_common_providers, *_detector_providers), + params={ + DetectorBankSizes: DetectorBankSizes({}), + PreopenNeXusFile: PreopenNeXusFile(False), + TransformationTimeFilter: reject_time_dependent_transform, + }, constraints=_gather_constraints(run_types=run_types, monitor_types=[]), ) - wf[DetectorBankSizes] = DetectorBankSizes({}) - wf[PreopenNeXusFile] = PreopenNeXusFile(False) return wf @@ -775,12 +823,15 @@ def GenericNeXusWorkflow( *_chopper_providers, *_metadata_providers, ), + params={ + DetectorBankSizes: DetectorBankSizes({}), + PreopenNeXusFile: PreopenNeXusFile(False), + TransformationTimeFilter: reject_time_dependent_transform, + }, constraints=_gather_constraints( run_types=run_types, monitor_types=monitor_types ), ) - wf[DetectorBankSizes] = DetectorBankSizes({}) - wf[PreopenNeXusFile] = PreopenNeXusFile(False) return wf diff --git a/packages/essreduce/tests/nexus/workflow_test.py b/packages/essreduce/tests/nexus/workflow_test.py index 9128c47e8..74e888265 100644 --- a/packages/essreduce/tests/nexus/workflow_test.py +++ b/packages/essreduce/tests/nexus/workflow_test.py @@ -182,6 +182,30 @@ def test_to_transform_raises_if_interval_does_not_yield_unique_value( ) +def test_to_transform_with_custom_time_filter( + time_dependent_depends_on: snx.TransformationChain, +) -> None: + def time_filter(transformation: sc.DataArray) -> sc.DataArray: + # -1* so we can see that the filter does something + return -1 * transformation + + transform = workflow.to_transformation( + time_dependent_depends_on, + TimeInterval(slice(sc.scalar(0.1, unit='s'), sc.scalar(1.9, unit='s'))), + time_filter=time_filter, + ).value + + expected = sc.DataArray( + sc.vectors( + dims=['time'], values=[[1.0, -1.0, 0.0], [1.0, -2.0, 0.0]], unit='m' + ), + coords={'time': sc.array(dims=['time'], values=[0.0, 1.0], unit='s')}, + ) + sc.testing.assert_identical( + transform * sc.vector([0.0, 0.0, 0.0], unit='m'), expected + ) + + def test_given_no_sample_load_nexus_sample_returns_group_with_origin_depends_on( loki_tutorial_sample_run_60250: Path, ) -> None: diff --git a/pixi.lock b/pixi.lock index 214a9f289..caea026e5 100644 --- a/pixi.lock +++ b/pixi.lock @@ -10998,7 +10998,7 @@ packages: timestamp: 1758743805063 - pypi: ./packages/essdiffraction name: essdiffraction - version: 0.1.dev2567+g6d663e796.d20260417 + version: 26.4.2.dev3494+g34223b73.d20260420 sha256: f7442fcb8892eb5baf1f2ab6b3206185a64df9d910462810d00d61961a1eb16f requires_dist: - dask>=2022.1.0 @@ -11036,7 +11036,7 @@ packages: requires_python: '>=3.11' - pypi: ./packages/essimaging name: essimaging - version: 26.4.1.dev332+g6d663e79.d20260417 + version: 26.4.1.dev2050+g34223b73.d20260420 sha256: f0070a5ae1f7957e8ed16a9a8e451dc47af25af14056a52bea18969ccdfa3aff requires_dist: - dask>=2022.1.0 @@ -11069,7 +11069,7 @@ packages: requires_python: '>=3.11' - pypi: ./packages/essnmx name: essnmx - version: 26.4.1.dev332+g6d663e79.d20260417 + version: 26.4.1.dev2050+g34223b73.d20260420 sha256: 55671f87213d0cad915b5def554cb8380e6bcc8f2a0af5acbcdc8a8ebcfd9531 requires_dist: - dask>=2022.1.0 @@ -11119,11 +11119,11 @@ packages: requires_python: '>=3.11' - pypi: ./packages/essreduce name: essreduce - version: 26.4.1.dev347+g6d663e79.d20260417 - sha256: 13cb0465a26df32340f26d3c9b42d1765542d30cdf0d09b87298285872d5fdee + version: 26.4.1.dev2065+g34223b73.d20260420 + sha256: 043a6aefd64757fb4d669459aa893cb3a9d7e81aa6591d9d6ce8fcc0f79a545b requires_dist: - sciline>=25.11.0 - - scipp>=26.3.0 + - scipp>=26.3.1 - scippneutron>=25.11.1 - scippnexus>=25.6.0 - graphviz>=0.20 ; extra == 'test' @@ -11152,7 +11152,7 @@ packages: requires_python: '>=3.11' - pypi: ./packages/essreflectometry name: essreflectometry - version: 0.1.dev2567+g6d663e796.d20260417 + version: 26.4.1.dev3398+g34223b73.d20260420 sha256: 1bdaf0414f6474e2d20b379284338a01705165a8ebabc9848312675ca2a89b0e requires_dist: - dask>=2022.1.0