Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions src/nexgen/beamlines/I19_2_nxs.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
from typing import Any, NamedTuple, Optional

import h5py
import numpy as np
from numpy.typing import ArrayLike
from pydantic import field_validator

from nexgen.tools.vds_tools import define_vds_dtype_from_bit_depth
from nexgen.utils import get_iso_timestamp

from .. import log
Expand Down Expand Up @@ -230,6 +230,7 @@ def eiger_writer(
vds_offset: int = 0,
notes: dict[str, Any] | None = None,
data_entry_key: str = "data",
bit_depth: int = 32,
):
"""
A function to call the NXmx nexus file writer for Eiger 2X 4M detector.
Expand All @@ -251,6 +252,8 @@ def eiger_writer(
dataset name and value its data. Defaults to None.
data_entry_key (str, optional): Dataset entry key in datafiles. eg. for gating mode it's data1.\
Defaults to data.
bit_depth(int, optional): Default bit depth for eiger collections, used to define dtype of vds data. \
Defaults to 32.

Raises:
ValueError: If use_meta is set to False but axes_pos and det_pos haven't been passed.
Expand Down Expand Up @@ -338,7 +341,7 @@ def eiger_writer(
logger.info(
"Not using meta file to update metadata, only the external links will be set up."
)
vds_dtype = np.uint32
vds_dtype = define_vds_dtype_from_bit_depth(bit_depth)
# Update axes
# Goniometer
for gax in TR.axes_pos:
Expand Down Expand Up @@ -457,6 +460,7 @@ def serial_nexus_writer(
use_meta: bool = False,
vds_offset: int = 0,
n_frames: int | None = None,
bit_depth: int = 32,
notes: dict[str, Any] | None = None,
):
"""Wrapper function to gather all parameters from the beamline and kick off the nexus writer for a \
Expand All @@ -473,6 +477,8 @@ def serial_nexus_writer(
n_frames (int | None, optional): Number of images for the nexus file. Only needed if different \
from the tot_num_images in the collection params. If passed, the VDS will only contain the \
number of frames specified here. Defaults to None.
bit_depth(int, optional): Default bit depth for eiger collections, used to define dtype of vds data. \
Defaults to 32.
notes (dict[str, Any] | None, optional): Any additional information to be written as NXnote, \
passed as a dictionary of (key, value) pairs where key represents the dataset name and \
value its data. Defaults to None.
Expand Down Expand Up @@ -511,7 +517,8 @@ def serial_nexus_writer(
use_meta,
n_frames,
vds_offset,
notes,
bit_depth=bit_depth,
notes=notes,
)
case DetectorName.TRISTAN:
tristan_writer(master_file, collection_params, timestamps, notes)
Expand All @@ -523,6 +530,7 @@ def nexus_writer(
timestamps: tuple[datetime, datetime] = (None, None),
use_meta: bool = False,
data_entry_key: str = "data",
bit_depth: int = 32,
):
"""Wrapper function to gather all parameters from the beamline and kick off the nexus writer for a \
standard experiment on I19-2.
Expand All @@ -536,6 +544,8 @@ def nexus_writer(
all parameters will need to be passed manually. Defaults to False.
data_entry_key (str, optional): Dataset entry key in datafiles. eg. for gating mode it's data1.\
Defaults to data.
bit_depth(int, optional): Default bit depth for eiger collections, used to define dtype of vds data. \
Defaults to 32.
"""
collection_params = CollectionParams(**params)
wdir = master_file.parent
Expand Down Expand Up @@ -603,6 +613,7 @@ def nexus_writer(
timestamps,
use_meta,
data_entry_key=data_entry_key,
bit_depth=bit_depth,
)
case DetectorName.TRISTAN:
tristan_writer(
Expand Down
15 changes: 2 additions & 13 deletions src/nexgen/beamlines/SSX_Eiger_nxs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
from pathlib import Path
from typing import Literal, get_args

import numpy as np
from numpy.typing import DTypeLike
from nexgen.tools.vds_tools import define_vds_dtype_from_bit_depth

from .. import log
from ..nxs_utils import (
Expand Down Expand Up @@ -64,16 +63,6 @@ class SerialParams(GeneralParams):
experiment_type: str


def _define_vds_dtype_from_bit_depth(bit_depth: int) -> DTypeLike:
"""Define dtype of VDS based on the passed bit depth."""
if bit_depth == 32:
return np.uint32
elif bit_depth == 8:
return np.uint8
else:
return np.uint16


def _get_beamline_specific_params(beamline: str) -> tuple[BeamlineAxes, EigerDetector]:
"""Get beamline specific axes and eiger description.

Expand Down Expand Up @@ -295,7 +284,7 @@ def ssx_eiger_writer(
bit_depth = 32
else:
bit_depth = ssx_params["bit_depth"]
vds_dtype = _define_vds_dtype_from_bit_depth(bit_depth)
vds_dtype = define_vds_dtype_from_bit_depth(bit_depth)
logger.debug(f"VDS dtype will be {vds_dtype}")

# Define Goniometer axes
Expand Down
9 changes: 9 additions & 0 deletions src/nexgen/command_line/I19_2_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ def nexgen_writer(args):
args.use_meta,
args.vds_offset,
args.n_frames,
bit_depth=args.bit_depth,
)
else:
nexus_writer(
Expand All @@ -143,6 +144,7 @@ def nexgen_writer(args):
(_start, _stop),
args.use_meta,
data_entry_key=args.data_key,
bit_depth=args.bit_depth,
)


Expand Down Expand Up @@ -304,6 +306,13 @@ def nexgen_writer(args):
default="data",
help="Data entry key of dataset in raw .h5 file. Defaults to data.",
)
parser_nex.add_argument(
"-bits" "--bit-depth",
type=int,
choices=[8, 16, 32],
default=32,
help="Default bit depth for eiger collections, used to define dtype of vds data. Defaults to 32.",
)
parser_nex.set_defaults(func=nexgen_writer)


Expand Down
11 changes: 10 additions & 1 deletion src/nexgen/command_line/nexus_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from nexgen.nxs_write.nxmx_writer import EventNXmxFileWriter, NXmxFileWriter
from nexgen.nxs_write.write_utils import find_number_of_images
from nexgen.tools.data_writer import generate_event_files, generate_image_files
from nexgen.tools.vds_tools import define_vds_dtype_from_bit_depth
from nexgen.utils import (
get_filename_template,
get_iso_timestamp,
Expand Down Expand Up @@ -143,6 +144,7 @@ def write_nxmx_cli(args):

try:
entry_key = args.data_key if args.data_key else "data"
vds_dtype = define_vds_dtype_from_bit_depth(args.bit_depth)
# Aaaaaaaaaaaand write
if params.det.mode == "images":
writer = NXmxFileWriter(
Expand All @@ -156,7 +158,7 @@ def write_nxmx_cli(args):
)
writer.write(image_datafiles=datafiles, data_entry_key=entry_key)
if not args.no_vds:
writer.write_vds(args.vds_offset)
writer.write_vds(args.vds_offset, vds_dtype=vds_dtype)
else:
writer = EventNXmxFileWriter(
master_file,
Expand Down Expand Up @@ -390,6 +392,13 @@ def _parse_cli() -> argparse.ArgumentParser:
default="data",
help="Data entry key of dataset in raw .h5 file. Defaults to data.",
)
nxmx_parser.add_argument(
"-bits" "--bit-depth",
type=int,
choices=[8, 16, 32],
default=32,
help="Default bit depth for eiger collections, used to define dtype of vds data. Defaults to 32.",
)
nxmx_parser.set_defaults(func=write_nxmx_cli)
demo_parser = subparsers.add_parser(
"2",
Expand Down
10 changes: 10 additions & 0 deletions src/nexgen/tools/vds_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,16 @@
vds_logger = logging.getLogger("nexgen.VDSWriter")


def define_vds_dtype_from_bit_depth(bit_depth: int) -> DTypeLike:
"""Define dtype of VDS based on the passed bit depth."""
if bit_depth == 32:
return np.uint32
elif bit_depth == 8:
return np.uint8
else:
return np.uint16


@dataclass
class Dataset:
name: str
Expand Down
3 changes: 2 additions & 1 deletion tests/beamlines/test_i19nxs.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,8 @@ def test_serial_nexus_writer_calls_correct_writer_for_eiger(
True,
None,
0,
None,
bit_depth=32,
notes=None,
)


Expand Down
10 changes: 10 additions & 0 deletions tests/tools/test_VDS_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,23 @@
from nexgen.tools.vds_tools import (
Dataset,
create_virtual_layout,
define_vds_dtype_from_bit_depth,
find_datasets_in_file,
image_vds_writer,
jungfrau_vds_writer,
split_datasets,
)


@pytest.mark.parametrize(
"bit_depth, expected_dtype", [(8, np.uint8), (16, np.uint16), (32, np.uint32)]
)
def test_vds_dtype_from_input(bit_depth, expected_dtype):
d = define_vds_dtype_from_bit_depth(bit_depth)

assert d == expected_dtype


def test_when_get_frames_and_shape_less_than_1000_then_correct():
sshape = split_datasets(["test1"], (500, 10, 10))
assert sshape == [Dataset("test1", (500, 10, 10), 0, 500)]
Expand Down
Loading