diff --git a/src/nexgen/beamlines/I19_2_nxs.py b/src/nexgen/beamlines/I19_2_nxs.py index 64bcc612..f8e7f530 100644 --- a/src/nexgen/beamlines/I19_2_nxs.py +++ b/src/nexgen/beamlines/I19_2_nxs.py @@ -11,10 +11,10 @@ from typing import Any, NamedTuple, Optional import h5py -import numpy as np from numpy.typing import ArrayLike from pydantic import field_validator +from nexgen.tools.vds_tools import define_vds_dtype_from_bit_depth from nexgen.utils import get_iso_timestamp from .. import log @@ -230,6 +230,7 @@ def eiger_writer( vds_offset: int = 0, notes: dict[str, Any] | None = None, data_entry_key: str = "data", + bit_depth: int = 32, ): """ A function to call the NXmx nexus file writer for Eiger 2X 4M detector. @@ -251,6 +252,8 @@ def eiger_writer( dataset name and value its data. Defaults to None. data_entry_key (str, optional): Dataset entry key in datafiles. eg. for gating mode it's data1.\ Defaults to data. + bit_depth(int, optional): Default bit depth for eiger collections, used to define dtype of vds data. \ + Defaults to 32. Raises: ValueError: If use_meta is set to False but axes_pos and det_pos haven't been passed. @@ -338,7 +341,7 @@ def eiger_writer( logger.info( "Not using meta file to update metadata, only the external links will be set up." ) - vds_dtype = np.uint32 + vds_dtype = define_vds_dtype_from_bit_depth(bit_depth) # Update axes # Goniometer for gax in TR.axes_pos: @@ -457,6 +460,7 @@ def serial_nexus_writer( use_meta: bool = False, vds_offset: int = 0, n_frames: int | None = None, + bit_depth: int = 32, notes: dict[str, Any] | None = None, ): """Wrapper function to gather all parameters from the beamline and kick off the nexus writer for a \ @@ -473,6 +477,8 @@ def serial_nexus_writer( n_frames (int | None, optional): Number of images for the nexus file. Only needed if different \ from the tot_num_images in the collection params. If passed, the VDS will only contain the \ number of frames specified here. Defaults to None. + bit_depth(int, optional): Default bit depth for eiger collections, used to define dtype of vds data. \ + Defaults to 32. notes (dict[str, Any] | None, optional): Any additional information to be written as NXnote, \ passed as a dictionary of (key, value) pairs where key represents the dataset name and \ value its data. Defaults to None. @@ -511,7 +517,8 @@ def serial_nexus_writer( use_meta, n_frames, vds_offset, - notes, + bit_depth=bit_depth, + notes=notes, ) case DetectorName.TRISTAN: tristan_writer(master_file, collection_params, timestamps, notes) @@ -523,6 +530,7 @@ def nexus_writer( timestamps: tuple[datetime, datetime] = (None, None), use_meta: bool = False, data_entry_key: str = "data", + bit_depth: int = 32, ): """Wrapper function to gather all parameters from the beamline and kick off the nexus writer for a \ standard experiment on I19-2. @@ -536,6 +544,8 @@ def nexus_writer( all parameters will need to be passed manually. Defaults to False. data_entry_key (str, optional): Dataset entry key in datafiles. eg. for gating mode it's data1.\ Defaults to data. + bit_depth(int, optional): Default bit depth for eiger collections, used to define dtype of vds data. \ + Defaults to 32. """ collection_params = CollectionParams(**params) wdir = master_file.parent @@ -603,6 +613,7 @@ def nexus_writer( timestamps, use_meta, data_entry_key=data_entry_key, + bit_depth=bit_depth, ) case DetectorName.TRISTAN: tristan_writer( diff --git a/src/nexgen/beamlines/SSX_Eiger_nxs.py b/src/nexgen/beamlines/SSX_Eiger_nxs.py index b6c71f09..0fd3d477 100644 --- a/src/nexgen/beamlines/SSX_Eiger_nxs.py +++ b/src/nexgen/beamlines/SSX_Eiger_nxs.py @@ -8,8 +8,7 @@ from pathlib import Path from typing import Literal, get_args -import numpy as np -from numpy.typing import DTypeLike +from nexgen.tools.vds_tools import define_vds_dtype_from_bit_depth from .. import log from ..nxs_utils import ( @@ -64,16 +63,6 @@ class SerialParams(GeneralParams): experiment_type: str -def _define_vds_dtype_from_bit_depth(bit_depth: int) -> DTypeLike: - """Define dtype of VDS based on the passed bit depth.""" - if bit_depth == 32: - return np.uint32 - elif bit_depth == 8: - return np.uint8 - else: - return np.uint16 - - def _get_beamline_specific_params(beamline: str) -> tuple[BeamlineAxes, EigerDetector]: """Get beamline specific axes and eiger description. @@ -295,7 +284,7 @@ def ssx_eiger_writer( bit_depth = 32 else: bit_depth = ssx_params["bit_depth"] - vds_dtype = _define_vds_dtype_from_bit_depth(bit_depth) + vds_dtype = define_vds_dtype_from_bit_depth(bit_depth) logger.debug(f"VDS dtype will be {vds_dtype}") # Define Goniometer axes diff --git a/src/nexgen/command_line/I19_2_cli.py b/src/nexgen/command_line/I19_2_cli.py index 61d913c8..30f2e7be 100644 --- a/src/nexgen/command_line/I19_2_cli.py +++ b/src/nexgen/command_line/I19_2_cli.py @@ -135,6 +135,7 @@ def nexgen_writer(args): args.use_meta, args.vds_offset, args.n_frames, + bit_depth=args.bit_depth, ) else: nexus_writer( @@ -143,6 +144,7 @@ def nexgen_writer(args): (_start, _stop), args.use_meta, data_entry_key=args.data_key, + bit_depth=args.bit_depth, ) @@ -304,6 +306,13 @@ def nexgen_writer(args): default="data", help="Data entry key of dataset in raw .h5 file. Defaults to data.", ) +parser_nex.add_argument( + "-bits" "--bit-depth", + type=int, + choices=[8, 16, 32], + default=32, + help="Default bit depth for eiger collections, used to define dtype of vds data. Defaults to 32.", +) parser_nex.set_defaults(func=nexgen_writer) diff --git a/src/nexgen/command_line/nexus_generator.py b/src/nexgen/command_line/nexus_generator.py index b307043e..c2ce24c5 100644 --- a/src/nexgen/command_line/nexus_generator.py +++ b/src/nexgen/command_line/nexus_generator.py @@ -28,6 +28,7 @@ from nexgen.nxs_write.nxmx_writer import EventNXmxFileWriter, NXmxFileWriter from nexgen.nxs_write.write_utils import find_number_of_images from nexgen.tools.data_writer import generate_event_files, generate_image_files +from nexgen.tools.vds_tools import define_vds_dtype_from_bit_depth from nexgen.utils import ( get_filename_template, get_iso_timestamp, @@ -143,6 +144,7 @@ def write_nxmx_cli(args): try: entry_key = args.data_key if args.data_key else "data" + vds_dtype = define_vds_dtype_from_bit_depth(args.bit_depth) # Aaaaaaaaaaaand write if params.det.mode == "images": writer = NXmxFileWriter( @@ -156,7 +158,7 @@ def write_nxmx_cli(args): ) writer.write(image_datafiles=datafiles, data_entry_key=entry_key) if not args.no_vds: - writer.write_vds(args.vds_offset) + writer.write_vds(args.vds_offset, vds_dtype=vds_dtype) else: writer = EventNXmxFileWriter( master_file, @@ -390,6 +392,13 @@ def _parse_cli() -> argparse.ArgumentParser: default="data", help="Data entry key of dataset in raw .h5 file. Defaults to data.", ) + nxmx_parser.add_argument( + "-bits" "--bit-depth", + type=int, + choices=[8, 16, 32], + default=32, + help="Default bit depth for eiger collections, used to define dtype of vds data. Defaults to 32.", + ) nxmx_parser.set_defaults(func=write_nxmx_cli) demo_parser = subparsers.add_parser( "2", diff --git a/src/nexgen/tools/vds_tools.py b/src/nexgen/tools/vds_tools.py index a519bd65..1d4c82a6 100644 --- a/src/nexgen/tools/vds_tools.py +++ b/src/nexgen/tools/vds_tools.py @@ -21,6 +21,16 @@ vds_logger = logging.getLogger("nexgen.VDSWriter") +def define_vds_dtype_from_bit_depth(bit_depth: int) -> DTypeLike: + """Define dtype of VDS based on the passed bit depth.""" + if bit_depth == 32: + return np.uint32 + elif bit_depth == 8: + return np.uint8 + else: + return np.uint16 + + @dataclass class Dataset: name: str diff --git a/tests/beamlines/test_i19nxs.py b/tests/beamlines/test_i19nxs.py index 8d23520f..7ffd9e33 100644 --- a/tests/beamlines/test_i19nxs.py +++ b/tests/beamlines/test_i19nxs.py @@ -84,7 +84,8 @@ def test_serial_nexus_writer_calls_correct_writer_for_eiger( True, None, 0, - None, + bit_depth=32, + notes=None, ) diff --git a/tests/tools/test_VDS_tools.py b/tests/tools/test_VDS_tools.py index 92013f9e..5bf94bdf 100644 --- a/tests/tools/test_VDS_tools.py +++ b/tests/tools/test_VDS_tools.py @@ -8,6 +8,7 @@ from nexgen.tools.vds_tools import ( Dataset, create_virtual_layout, + define_vds_dtype_from_bit_depth, find_datasets_in_file, image_vds_writer, jungfrau_vds_writer, @@ -15,6 +16,15 @@ ) +@pytest.mark.parametrize( + "bit_depth, expected_dtype", [(8, np.uint8), (16, np.uint16), (32, np.uint32)] +) +def test_vds_dtype_from_input(bit_depth, expected_dtype): + d = define_vds_dtype_from_bit_depth(bit_depth) + + assert d == expected_dtype + + def test_when_get_frames_and_shape_less_than_1000_then_correct(): sshape = split_datasets(["test1"], (500, 10, 10)) assert sshape == [Dataset("test1", (500, 10, 10), 0, 500)]