Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 109 additions & 2 deletions trx/tests/test_memmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,10 +124,11 @@ def test__dichotomic_search(arr, l_bound, r_bound, expected):
)
def test__create_memmap(basename, create, expected):
if create:
# Need to create array before evaluating
with get_trx_tmp_dir() as dirname:
filename = os.path.join(dirname, basename)
fp = np.memmap(filename, dtype=np.int16, mode="w+", shape=(3, 4))
fp = tmm._create_memmap(
filename=filename, mode="w+", shape=(3, 4), dtype=np.int16
)
fp[:] = expected[:]
mmarr = tmm._create_memmap(filename=filename, shape=(3, 4), dtype=np.int16)
assert np.array_equal(mmarr, expected)
Expand Down Expand Up @@ -361,3 +362,109 @@ def test_trxfile_to_memory():

def test_trxfile_close():
pass


# Endianness tests for cross-platform compatibility (Issue #83)
@pytest.mark.parametrize(
"dtype_input,expected_byteorder",
[
# Native dtypes should be converted to little-endian
(np.float32, "<"),
(np.float64, "<"),
(np.int32, "<"),
(np.int64, "<"),
(np.uint32, "<"),
(np.uint64, "<"),
("float32", "<"),
("float64", "<"),
# Big-endian dtypes should be converted to little-endian
(">f4", "<"),
(">f8", "<"),
(">i4", "<"),
(">u4", "<"),
# Little-endian dtypes should remain little-endian
("<f4", "<"),
("<i4", "<"),
# Single-byte types don't have endianness (byteorder is '|')
(np.uint8, "|"),
(np.int8, "|"),
(np.bool_, "|"),
],
)
def test__get_dtype_little_endian(dtype_input, expected_byteorder):
"""Test that _get_dtype_little_endian correctly converts dtypes."""
result = tmm._get_dtype_little_endian(dtype_input)
assert result.byteorder == expected_byteorder or (
result.byteorder == "=" and np.little_endian and expected_byteorder == "<"
)


@pytest.mark.parametrize(
"dtype,test_value",
[
(np.float32, 3.14159),
(np.float64, 2.71828),
(np.int32, 12345),
(np.int64, 9876543210),
(np.uint32, 0xDEADBEEF),
(np.uint64, 0xDEADBEEFCAFEBABE),
],
)
def test__ensure_little_endian(dtype, test_value):
"""Test that _ensure_little_endian correctly converts arrays."""
# Create array in native byte order
arr = np.array([test_value], dtype=dtype)

# Ensure little endian
result = tmm._ensure_little_endian(arr)

# Result should be little-endian (or native if system is little-endian)
assert result.dtype.byteorder in ("<", "=", "|")

# Values should be preserved
assert result[0] == test_value


def test__ensure_little_endian_big_endian_input():
"""Test _ensure_little_endian with explicitly big-endian input."""
# Create a big-endian array
big_endian_dtype = np.dtype(">u4")
arr = np.array([0x12345678], dtype=big_endian_dtype)

# Ensure little endian
result = tmm._ensure_little_endian(arr)

# Result should be little-endian
assert result.dtype.byteorder == "<"

# Value should be preserved
assert result[0] == 0x12345678


def test_endianness_roundtrip():
"""Test that data survives write/read cycle with correct endianness."""
with get_trx_tmp_dir() as dirname:
# Test values that would be corrupted if endianness is wrong
test_positions = np.array(
[[1.5, 2.5, 3.5], [4.5, 5.5, 6.5], [7.5, 8.5, 9.5]], dtype=np.float32
)
test_offsets = np.array([0, 3], dtype=np.uint32)

# Write as little-endian
pos_file = os.path.join(dirname, "test_positions.3.float32")
off_file = os.path.join(dirname, "test_offsets.uint32")

tmm._ensure_little_endian(test_positions).tofile(pos_file)
tmm._ensure_little_endian(test_offsets).tofile(off_file)

# Read back using _create_memmap (which enforces little-endian)
read_positions = tmm._create_memmap(
pos_file, mode="r", shape=(3, 3), dtype="float32"
)
read_offsets = tmm._create_memmap(
off_file, mode="r", shape=(2,), dtype="uint32"
)

# Values should match
np.testing.assert_array_almost_equal(read_positions, test_positions)
np.testing.assert_array_equal(read_offsets, test_offsets)
73 changes: 67 additions & 6 deletions trx/trx_file_memmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,63 @@
dipy_available = False


def _get_dtype_little_endian(dtype: Union[np.dtype, str, type]) -> np.dtype:
"""Convert a dtype to its little-endian equivalent.

The TRX file format uses little-endian byte order for cross-platform
compatibility. This function ensures that dtypes are always interpreted
as little-endian when reading/writing TRX files.

Parameters
----------
dtype : np.dtype, str, or type
Input dtype specification (e.g., np.float32, 'float32', '>f4')

Returns
-------
np.dtype
Little-endian dtype. For single-byte types (uint8, int8, bool),
returns the original dtype as endianness is not applicable.
"""
dt = np.dtype(dtype)
# Single-byte types don't have endianness
if dt.byteorder == "|" or dt.itemsize == 1:
return dt
# Already little-endian
if dt.byteorder == "<":
return dt
# Convert to little-endian
return dt.newbyteorder("<")


def _ensure_little_endian(arr: np.ndarray) -> np.ndarray:
"""Ensure array data is in little-endian byte order for writing.

Parameters
----------
arr : np.ndarray
Input array

Returns
-------
np.ndarray
Array with little-endian byte order. Returns a copy if conversion
was needed, otherwise returns the original array.
"""
dt = arr.dtype
# Single-byte types don't have endianness
if dt.byteorder == "|" or dt.itemsize == 1:
return arr
# Already little-endian
if dt.byteorder == "<":
return arr
# Native byte order on little-endian system
if dt.byteorder == "=" and np.little_endian:
return arr
# Convert to little-endian
return arr.astype(dt.newbyteorder("<"))


def _append_last_offsets(nib_offsets: np.ndarray, nb_vertices: int) -> np.ndarray:
"""Appends the last element of offsets from header information

Expand Down Expand Up @@ -200,6 +257,9 @@ def _create_memmap(
if np.dtype(dtype) == bool:
filename = filename.replace(".bool", ".bit")

# TRX format uses little-endian byte order for cross-platform compatibility
dtype = _get_dtype_little_endian(dtype)

if shape[0]:
return np.memmap(
filename, mode=mode, offset=offset, shape=shape, dtype=dtype, order=order
Expand Down Expand Up @@ -794,6 +854,7 @@ def deepcopy(self) -> Type["TrxFile"]: # noqa: C901
tmp_header["DIMENSIONS"] = tmp_header["DIMENSIONS"].tolist()

# tofile() always write in C-order
# Ensure little-endian byte order for cross-platform compatibility
if not self._copy_safe:
to_dump = self.streamlines.copy()._data
tmp_header["NB_STREAMLINES"] = len(self.streamlines)
Expand All @@ -806,7 +867,7 @@ def deepcopy(self) -> Type["TrxFile"]: # noqa: C901
positions_filename = _generate_filename_from_data(
to_dump, os.path.join(tmp_dir.name, "positions")
)
to_dump.tofile(positions_filename)
_ensure_little_endian(to_dump).tofile(positions_filename)

if not self._copy_safe:
to_dump = _append_last_offsets(
Expand All @@ -819,7 +880,7 @@ def deepcopy(self) -> Type["TrxFile"]: # noqa: C901
offsets_filename = _generate_filename_from_data(
self.streamlines._offsets, os.path.join(tmp_dir.name, "offsets")
)
to_dump.tofile(offsets_filename)
_ensure_little_endian(to_dump).tofile(offsets_filename)

if len(self.data_per_vertex.keys()) > 0:
os.mkdir(os.path.join(tmp_dir.name, "dpv/"))
Expand All @@ -832,7 +893,7 @@ def deepcopy(self) -> Type["TrxFile"]: # noqa: C901
dpv_filename = _generate_filename_from_data(
to_dump, os.path.join(tmp_dir.name, "dpv/", dpv_key)
)
to_dump.tofile(dpv_filename)
_ensure_little_endian(to_dump).tofile(dpv_filename)

if len(self.data_per_streamline.keys()) > 0:
os.mkdir(os.path.join(tmp_dir.name, "dps/"))
Expand All @@ -841,7 +902,7 @@ def deepcopy(self) -> Type["TrxFile"]: # noqa: C901
dps_filename = _generate_filename_from_data(
to_dump, os.path.join(tmp_dir.name, "dps/", dps_key)
)
to_dump.tofile(dps_filename)
_ensure_little_endian(to_dump).tofile(dps_filename)

if len(self.groups.keys()) > 0:
os.mkdir(os.path.join(tmp_dir.name, "groups/"))
Expand All @@ -850,7 +911,7 @@ def deepcopy(self) -> Type["TrxFile"]: # noqa: C901
group_filename = _generate_filename_from_data(
to_dump, os.path.join(tmp_dir.name, "groups/", group_key)
)
to_dump.tofile(group_filename)
_ensure_little_endian(to_dump).tofile(group_filename)

if group_key not in self.data_per_group:
continue
Expand All @@ -864,7 +925,7 @@ def deepcopy(self) -> Type["TrxFile"]: # noqa: C901
dpg_filename = _generate_filename_from_data(
to_dump, os.path.join(tmp_dir.name, "dpg/", group_key, dpg_key)
)
to_dump.tofile(dpg_filename)
_ensure_little_endian(to_dump).tofile(dpg_filename)

copy_trx = load_from_directory(tmp_dir.name)
copy_trx._uncompressed_folder_handle = tmp_dir
Expand Down
8 changes: 5 additions & 3 deletions trx/workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,10 +417,12 @@ def _write_header(tmp_dir_name, reference, streamlines):
def _write_streamline_data(tmp_dir_name, streamlines, positions_dtype, offsets_dtype):
"""Write streamline position and offset data."""
curr_filename = os.path.join(tmp_dir_name, "positions.3.{}".format(positions_dtype))
streamlines._data.astype(positions_dtype).tofile(curr_filename)
positions = streamlines._data.astype(positions_dtype)
tmm._ensure_little_endian(positions).tofile(curr_filename)

curr_filename = os.path.join(tmp_dir_name, "offsets.{}".format(offsets_dtype))
streamlines._offsets.astype(offsets_dtype).tofile(curr_filename)
offsets = streamlines._offsets.astype(offsets_dtype)
tmm._ensure_little_endian(offsets).tofile(curr_filename)


def _normalize_dtype(dtype_str):
Expand Down Expand Up @@ -460,7 +462,7 @@ def _write_data_array(tmp_dir_name, subdir_name, args, is_dpg=False):
tmp_dir_name, subdir_name, "{}.{}{}".format(basename, dim, dtype)
)

curr_arr.tofile(curr_filename)
tmm._ensure_little_endian(curr_arr).tofile(curr_filename)


def generate_trx_from_scratch( # noqa: C901
Expand Down