diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 47433e4..6b461d5 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -29,7 +29,7 @@ jobs: if: runner.os == 'macOS' run: | brew update - brew install libzip eigen googletest ninja lcov + brew install libzip eigen googletest ninja lcov zlib - name: Install dependencies (Windows) if: runner.os == 'Windows' @@ -45,6 +45,7 @@ jobs: libzip ` eigen3 ` gtest ` + zlib ` --triplet x64-windows - name: Setup MSVC environment (Windows) @@ -60,6 +61,7 @@ jobs: -DTRX_USE_CONAN=OFF \ -DTRX_BUILD_TESTS=ON \ -DTRX_BUILD_EXAMPLES=ON \ + -DTRX_ENABLE_NIFTI=ON \ -DCMAKE_PREFIX_PATH="${BREW_PREFIX}" \ -DCMAKE_BUILD_TYPE=Debug \ -DCMAKE_C_FLAGS="--coverage" \ @@ -72,6 +74,7 @@ jobs: -G Ninja -DTRX_USE_CONAN=OFF -DTRX_BUILD_TESTS=ON + -DTRX_ENABLE_NIFTI=ON -DCMAKE_BUILD_TYPE=Debug -DCMAKE_TOOLCHAIN_FILE=${{ github.workspace }}/vcpkg/scripts/buildsystems/vcpkg.cmake diff --git a/.github/workflows/trx-cpp-tests.yml b/.github/workflows/trx-cpp-tests.yml index a981bf1..3ad1b81 100644 --- a/.github/workflows/trx-cpp-tests.yml +++ b/.github/workflows/trx-cpp-tests.yml @@ -50,6 +50,7 @@ jobs: -G Ninja \ -DTRX_BUILD_TESTS=ON \ -DTRX_BUILD_EXAMPLES=ON \ + -DTRX_ENABLE_NIFTI=ON \ -DGTest_DIR=${GITHUB_WORKSPACE}/deps/googletest/install/lib/cmake/GTest \ -DCMAKE_BUILD_TYPE=Debug \ -DCMAKE_C_FLAGS="--coverage" \ diff --git a/.gitignore b/.gitignore index 4559da9..43bcab0 100644 --- a/.gitignore +++ b/.gitignore @@ -13,6 +13,8 @@ tests/data test_package/build/ test_package/CMakeUserPresets.json syntax.log +docs/_build/ +docs/api/ test_package/build test_package/CMakeUserPresets.json \ No newline at end of file diff --git a/.readthedocs.yaml b/.readthedocs.yaml new file mode 100644 index 0000000..3a6876b --- /dev/null +++ b/.readthedocs.yaml @@ -0,0 +1,29 @@ +version: 2 + +build: + os: ubuntu-22.04 + tools: + python: "3.11" + apt_packages: + - cmake + - g++ + - pkg-config + - zlib1g-dev + - libeigen3-dev + - libzip-dev + - ninja-build + - zipcmp + - zipmerge + - ziptool + + jobs: + pre_build: + - cmake -S . -B build -DTRX_BUILD_DOCS=ON + - cmake --build build --target docs + +python: + install: + - requirements: docs/requirements.txt + +sphinx: + configuration: docs/conf.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 5d5a4c9..2dc74c4 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -7,6 +7,7 @@ project(trx VERSION 0.1.0) include(GNUInstallDirs) include(CMakePackageConfigHelpers) include(FetchContent) +list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake") set(TRX_IS_TOP_LEVEL OFF) if(CMAKE_SOURCE_DIR STREQUAL PROJECT_SOURCE_DIR) @@ -25,6 +26,8 @@ option(TRX_BUILD_TESTS "Build trx tests" OFF) option(TRX_BUILD_EXAMPLES "Build trx example commandline programs" ON) option(TRX_ENABLE_CLANG_TIDY "Run clang-tidy during builds" OFF) option(TRX_ENABLE_INSTALL "Install trx-cpp targets" ${TRX_IS_TOP_LEVEL}) +option(TRX_BUILD_DOCS "Build API documentation with Doxygen/Sphinx" OFF) +option(TRX_ENABLE_NIFTI "Enable optional NIfTI header utilities" ${TRX_BUILD_EXAMPLES}) if(TRX_ENABLE_CLANG_TIDY) set(CMAKE_EXPORT_COMPILE_COMMANDS ON) @@ -84,6 +87,7 @@ endif() add_library(trx src/trx.cpp + src/detail/dtype_helpers.cpp include/trx/trx.h include/trx/trx.tpp third_party/json11/json11.cpp @@ -144,6 +148,25 @@ if(TRX_BUILD_TESTS) endif() endif() +if(TRX_ENABLE_NIFTI) + find_package(ZLIB REQUIRED) + add_library(trx-nifti + src/nifti_io.cpp + ) + add_library(trx-cpp::trx-nifti ALIAS trx-nifti) + target_compile_features(trx-nifti PUBLIC cxx_std_17) + target_include_directories(trx-nifti + PUBLIC + $ + $ + ) + target_link_libraries(trx-nifti + PUBLIC + Eigen3::Eigen + ZLIB::ZLIB + ) +endif() + if(TRX_BUILD_EXAMPLES) FetchContent_Declare( cxxopts @@ -154,6 +177,25 @@ if(TRX_BUILD_EXAMPLES) add_subdirectory(examples) endif() +if(TRX_BUILD_DOCS) + find_package(Doxygen QUIET) + find_program(TRX_SPHINX_BUILD_EXE NAMES sphinx-build) + if(DOXYGEN_FOUND AND TRX_SPHINX_BUILD_EXE) + add_custom_target(docs + COMMAND ${CMAKE_COMMAND} -E make_directory ${CMAKE_CURRENT_SOURCE_DIR}/docs/_build/doxygen + COMMAND ${DOXYGEN_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/docs/Doxyfile + COMMAND ${TRX_SPHINX_BUILD_EXE} -b html + ${CMAKE_CURRENT_SOURCE_DIR}/docs + ${CMAKE_CURRENT_SOURCE_DIR}/docs/_build/html + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/docs + COMMENT "Building Doxygen XML and Sphinx HTML docs" + VERBATIM + ) + else() + message(STATUS "Docs disabled: Doxygen or sphinx-build not found.") + endif() +endif() + # Installation and package config if(TRX_ENABLE_INSTALL) set(TRX_INSTALL_CONFIGDIR "${CMAKE_INSTALL_LIBDIR}/cmake/trx-cpp") diff --git a/README.md b/README.md index cca4451..ed87447 100644 --- a/README.md +++ b/README.md @@ -4,92 +4,22 @@ TRX-cpp is a C++11 library for reading, writing, and memory-mapping the TRX tractography file format (zip archives or on-disk directories of memmaps). -## Dependencies +## Documentation -Required: -- C++11 compiler and CMake (>= 3.10) -- libzip -- Eigen3 +Project documentation lives in `docs/` and includes build/usage instructions +plus the API reference. Build the site locally with the `docs` CMake target: -Optional: -- GTest (for building tests) - -## Build and Install - -### Quick build (no tests) - -``` -cmake -S . -B build -cmake --build build -``` - -### Build with tests - -``` -cmake -S . -B build \ - -DTRX_BUILD_TESTS=ON \ - -DGTest_DIR=/path/to/GTestConfig.cmake -cmake --build build -ctest --test-dir build --output-on-failure -``` - -## Style checks - -This repo includes `.clang-tidy` and `.clang-format` at the root, plus -MRtrix-inspired helper scripts for formatting and style checks. - -### Prerequisites - -- `clang-format` available on `PATH` - - macOS (Homebrew): `brew install llvm` (or `llvm@17`) and ensure `clang-format` is on `PATH` - - Ubuntu: `sudo apt-get install clang-format` -- For `check_syntax` on macOS, GNU grep is required (`brew install grep`, then it will use `ggrep`). - -### clang-format (bulk formatting) - -Run the formatter across repo sources: - -``` -./clang-format-all -``` - -You can target a specific clang-format binary: - -``` -./clang-format-all --executable /path/to/clang-format -``` - -### check_syntax (style rules) - -Run the MRtrix-style checks against the C++ sources: - -``` -./check_syntax -``` - -Results are written to `syntax.log` when issues are found. - -### clang-tidy - -Generate a build with compile commands, then run clang-tidy (matches CI): - -``` -cmake -S . -B build \ - -DTRX_USE_CONAN=OFF \ - -DTRX_BUILD_EXAMPLES=ON \ - -DCMAKE_EXPORT_COMPILE_COMMANDS=ON -run-clang-tidy -p build $(git ls-files '*.cpp' '*.h' '*.hpp' '*.tpp' ':!third_party/**') ``` +# Install prerequisites (Ubuntu example) +sudo apt-get install -y doxygen python3-pip +python3 -m pip install --user -r docs/requirements.txt -To run clang-tidy automatically during builds: - -``` -cmake -S . -B build -DTRX_ENABLE_CLANG_TIDY=ON -cmake --build build +# Configure once, then build the docs target +cmake -S . -B build +cmake --build build --target docs ``` -If you have `run-clang-tidy` installed (LLVM extras), you can lint everything -tracked by the repo (excluding `third_party`), which matches CI. +Open `docs/_build/html/index.html` in a browser. ## Third-party notices @@ -104,75 +34,5 @@ tracked by the repo (excluding `third_party`), which matches CI. ## Usage Examples -All examples assume: - -``` -#include - -using namespace trxmmap; -``` - -### Read a TRX zip and inspect data - -``` -TrxFile *trx = load_from_zip("tracks.trx"); - -// Access streamlines: vertices are stored as an Eigen matrix -const auto num_vertices = trx->streamlines->_data.size() / 3; -const auto num_streamlines = trx->streamlines->_offsets.size() - 1; - -std::cout << "Vertices: " << num_vertices << "\n"; -std::cout << "Streamlines: " << num_streamlines << "\n"; -std::cout << "First vertex (x,y,z): " - << trx->streamlines->_data(0, 0) << "," - << trx->streamlines->_data(0, 1) << "," - << trx->streamlines->_data(0, 2) << "\n"; - -// Data-per-streamline and data-per-vertex are stored in maps -for (const auto &kv : trx->data_per_streamline) { - std::cout << "DPS '" << kv.first << "' elements: " - << kv.second->_matrix.size() << "\n"; -} -for (const auto &kv : trx->data_per_vertex) { - std::cout << "DPV '" << kv.first << "' elements: " - << kv.second->_data.size() << "\n"; -} - -trx->close(); // cleans up temporary on-disk data -delete trx; -``` - -### Read from an on-disk TRX directory - -``` -TrxFile *trx = load_from_directory("/path/to/trx_dir"); -std::cout << "Header JSON:\n" << trx->header.dump() << "\n"; -trx->close(); -delete trx; -``` - -### Write a TRX file - -You can modify a loaded `TrxFile` and save it to a new archive: - -``` -TrxFile *trx = load_from_zip("tracks.trx"); - -// Example: update header metadata -auto header_obj = trx->header.object_items(); -header_obj["COMMENT"] = "saved by trx-cpp"; -trx->header = json(header_obj); - -// Save with no compression (ZIP_CM_STORE) or another libzip compression level -save(*trx, "tracks_copy.trx", ZIP_CM_STORE); - -trx->close(); -delete trx; -``` - -### Notes on memory mapping - -`TrxFile` uses memory-mapped arrays under the hood for large datasets. The -`close()` method cleans up any temporary folders created during zip extraction. -If you skip `close()`, temporary directories may be left behind. +See the documentation in `docs/` for examples and API details. diff --git a/docs/Doxyfile b/docs/Doxyfile new file mode 100644 index 0000000..9d48cc6 --- /dev/null +++ b/docs/Doxyfile @@ -0,0 +1,19 @@ +PROJECT_NAME = "trx-cpp" +PROJECT_BRIEF = "C++ library for reading and writing the TRX format." + +OUTPUT_DIRECTORY = _build/doxygen + +GENERATE_HTML = NO +GENERATE_LATEX = NO +GENERATE_XML = YES +XML_OUTPUT = xml + +INPUT = ../include ../src +RECURSIVE = YES +FILE_PATTERNS = *.h *.hpp *.tpp + +EXTRACT_ALL = YES +EXTRACT_ANON_NSPACES = NO +HIDE_UNDOC_NAMESPACES = YES +EXCLUDE_SYMBOLS = Eigen* trx::detail* +QUIET = YES diff --git a/docs/building.rst b/docs/building.rst new file mode 100644 index 0000000..202af97 --- /dev/null +++ b/docs/building.rst @@ -0,0 +1,121 @@ +Building +======== + +Dependencies +------------ + +Required: + +- C++17 compiler +- libzip +- Eigen3 + + +Installing dependencies: +------------------------ + +These examples include installing google test, +but this is only necessary if you want to build the tests. +Similarly, ninja is not strictly necessary, but it is recommended. +zlib is only required if you want to use the NIfTI I/O features. + +On Debian-based systems the zip tools have been split into separate packages +on recent ubuntu versions. + +.. code-block:: bash + + sudo apt-get install \ + zlib1g-dev \ + libeigen3-dev \ + libzip-dev \ + zipcmp \ + zipmerge \ + ziptool \ + ninja-build \ + libgtest-dev + +On Mac OS, you can install the dependencies with brew: + +.. code-block:: bash + + brew install libzip eigen googletest ninja zlib + + +On Windows, you can install the dependencies through vcpkg and chocolatey: + +.. code-block:: powershell + + choco install ninja -y + vcpkg install libzip eigen3 gtest zlib + + +Building to use in other projects +--------------------------------- + +.. code-block:: bash + + cmake -S . -B build \ + -G Ninja \ + -DCMAKE_BUILD_TYPE=Release \ + -DTRX_BUILD_EXAMPLES=OFF \ + -DTRX_ENABLE_INSTALL=ON \ + -DCMAKE_INSTALL_PREFIX=/path/to/installation/directory + cmake --build build --config Release + cmake --install build + +Key CMake options: + +- ``TRX_ENABLE_INSTALL``: Install package config and targets (default ON for top-level builds) +- ``TRX_BUILD_EXAMPLES``: Build example executables (default ON) +- ``TRX_BUILD_TESTS``: Build tests (default OFF) +- ``TRX_BUILD_DOCS``: Build docs with Doxygen/Sphinx (default OFF) +- ``TRX_ENABLE_CLANG_TIDY``: Run clang-tidy during builds (default OFF) +- ``TRX_USE_CONAN```: Use Conan setup in ```cmake/ConanSetup.cmake`` (default OFF) + +To use trx-cpp from another CMake project after installation: + +.. code-block:: cmake + + find_package(trx-cpp CONFIG REQUIRED) + target_link_libraries(your_target PRIVATE trx-cpp::trx) + +If you prefer vendoring trx-cpp, you can add it as a subdirectory and link the +target directly: + +.. code-block:: cmake + + add_subdirectory(path/to/trx-cpp) + target_link_libraries(your_target PRIVATE trx-cpp::trx) + + +Building for testing +-------------------- + +.. code-block:: bash + + cmake -S . -B build \ + -G Ninja \ + -DTRX_BUILD_TESTS=ON \ + -DTRX_ENABLE_NIFTI=ON \ + -DTRX_BUILD_EXAMPLES=OFF + + cmake --build build + ctest --test-dir build --output-on-failure + +Tests require GTest to be discoverable by CMake (e.g., via a system package or +``GTest_DIR``). If GTest is not found, tests will be skipped. +NIfTI I/O tests additionally require zlib to be discoverable (``ZLIB::ZLIB``). + + +Building documentation: +----------------------- + +Building the docs requires both Doxygen and ``sphinx-build`` on your PATH. + +.. code-block:: bash + + cmake -S . -B build \ + -G Ninja \ + -DTRX_BUILD_DOCS=ON + + cmake --build build --target docs diff --git a/docs/conf.py b/docs/conf.py new file mode 100644 index 0000000..6d8a612 --- /dev/null +++ b/docs/conf.py @@ -0,0 +1,39 @@ +import os + +project = "trx-cpp" +author = "trx-cpp contributors" + +extensions = [ + "breathe", + "exhale", + "sphinx.ext.autosectionlabel", +] + +root_doc = "index" +templates_path = ["_templates"] +exclude_patterns = ["_build"] + +html_theme = "sphinx_rtd_theme" +html_theme_options = { + "collapse_navigation": False, + "navigation_depth": 3, +} + +breathe_projects = { + "trx-cpp": os.path.join(os.path.dirname(__file__), "_build", "doxygen", "xml"), +} +breathe_default_project = "trx-cpp" + +primary_domain = "cpp" +highlight_language = "cpp" + +autosectionlabel_prefix_document = True + +exhale_args = { + "containmentFolder": "./api", + "rootFileName": "library_root.rst", + "rootFileTitle": "API Reference", + "doxygenStripFromPath": "..", + "createTreeView": True, + "exhaleExecutesDoxygen": False, +} diff --git a/docs/downstream_usage.rst b/docs/downstream_usage.rst new file mode 100644 index 0000000..6e1dc9b --- /dev/null +++ b/docs/downstream_usage.rst @@ -0,0 +1,449 @@ +Downstream Usage +================ + +Writing MRtrix-style streamlines to TRX +--------------------------------------- + +MRtrix3 streamlines are created by ``MR::DWI::Tractography::Tracking::Exec``, +which appends points to a per-streamline container as tracking progresses. +That container is ``MR::DWI::Tractography::Tracking::GeneratedTrack``, which is +a ``std::vector`` with some tracking metadata +fields (seed index and status). During tracking, each call that advances the +streamline pushes the current position into this vector, so the resulting +streamline is just an ordered list of 3D points. + +In practice: + +- Streamline points are stored as ``Eigen::Vector3f`` entries in + ``GeneratedTrack``. +- The tracking code appends ``method.pos`` into that vector as each step + completes (seed point first, then subsequent vertices). +- The final output is a list of accepted ``GeneratedTrack`` instances, each + representing one streamline. + +TRX stores streamlines as a single ``positions`` array of shape ``(NB_VERTICES, 3)`` +and an ``offsets`` array of length ``(NB_STREAMLINES + 1)`` that provides the +prefix-sum offsets of each streamline in ``positions``. The example below shows +one idea on how to convert a list of MRtrix ``GeneratedTrack`` streamlines into a TRX file. + +.. code-block:: cpp + + #include + #include "dwi/tractography/tracking/generated_track.h" + + using MR::DWI::Tractography::Tracking::GeneratedTrack; + + void write_trx_from_mrtrix(const std::vector &tracks, + const std::string &out_path) { + // Count accepted streamlines and total vertices. + std::vector accepted; + accepted.reserve(tracks.size()); + size_t total_vertices = 0; + for (const auto &tck : tracks) { + if (tck.get_status() != GeneratedTrack::status_t::ACCEPTED) { + continue; + } + accepted.push_back(&tck); + total_vertices += tck.size(); + } + + const size_t nb_streamlines = accepted.size(); + const size_t nb_vertices = total_vertices; + + // Allocate a TRX file (float positions) with the desired sizes. + trx::TrxFile trx(nb_vertices, nb_streamlines); + + auto &positions = trx.streamlines->_data; // (NB_VERTICES, 3) + auto &offsets = trx.streamlines->_offsets; // (NB_STREAMLINES + 1, 1) + auto &lengths = trx.streamlines->_lengths; // (NB_STREAMLINES, 1) + + size_t cursor = 0; + offsets(0) = 0; + for (size_t i = 0; i < nb_streamlines; ++i) { + const auto &tck = *accepted[i]; + lengths(i) = static_cast(tck.size()); + offsets(i + 1) = offsets(i) + tck.size(); + + for (size_t j = 0; j < tck.size(); ++j, ++cursor) { + positions(cursor, 0) = tck[j].x(); + positions(cursor, 1) = tck[j].y(); + positions(cursor, 2) = tck[j].z(); + } + } + + trx.save(out_path, ZIP_CM_STORE); + trx.close(); + } + + +Streaming TRX from MRtrix tckgen +-------------------------------- + +MRtrix ``tckgen`` writes streamlines as they are generated. To stream into TRX +without buffering all streamlines in memory, use ``trx::TrxStream`` to append +streamlines and finalize once tracking completes. This mirrors how the tck +writer works. + +.. code-block:: cpp + + #include + #include "dwi/tractography/tracking/generated_track.h" + + using MR::DWI::Tractography::Tracking::GeneratedTrack; + + trx::TrxStream trx_stream; + + // Called for each accepted streamline. + void on_streamline(const GeneratedTrack &tck) { + std::vector xyz; + xyz.reserve(tck.size() * 3); + for (const auto &pt : tck) { + xyz.push_back(pt[0]); + xyz.push_back(pt[1]); + xyz.push_back(pt[2]); + } + trx_stream.push_streamline(xyz); + } + + trx_stream.finalize("tracks.trx", ZIP_CM_STORE); + + +Using TRX in DSI Studio +----------------------- + +DSI Studio stores tractography in ``tract_model.cpp`` as a list of per-tract +point arrays and optional cluster assignments. The TRX format maps cleanly onto +this representation: + +- DSI Studio cluster assignments map to TRX ``groups/`` files. Each cluster is a + group containing the indices of streamlines that belong to that cluster. +- Per-streamline values (e.g., DSI's loaded scalar values) map to TRX DPS + (``data_per_streamline``) arrays. +- Per-vertex values (e.g., along-tract scalars) map to TRX DPV + (``data_per_vertex``) arrays. + +This means a TRX file can carry the tract geometry, cluster membership, and +both per-streamline and per-vertex metrics in a single archive, and DSI Studio +can round-trip these fields without custom sidecars. + +Usage sketch (DSI Studio) +~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: cpp + + // DSI Studio tract data is stored as std::vector> + // (see tract_model.cpp), with each streamline as an interleaved xyz list. + // Coordinates are in DSI Studio's internal voxel space; convert to RASMM + // as needed (e.g., multiply by voxel size and apply transforms). + std::vector> streamlines = /* DSI Studio tract_data */; + + // Optional per-streamline cluster labels or group membership. + std::vector cluster_ids = /* same length as streamlines */; + + // Convert to TRX positions/offsets. + size_t total_vertices = 0; + for (const auto &sl : streamlines) { + total_vertices += sl.size() / 3; + } + + trx::TrxFile trx(static_cast(total_vertices), + static_cast(streamlines.size())); + auto &positions = trx.streamlines->_data; + auto &offsets = trx.streamlines->_offsets; + auto &lengths = trx.streamlines->_lengths; + + size_t cursor = 0; + offsets(0) = 0; + for (size_t i = 0; i < streamlines.size(); ++i) { + const auto &sl = streamlines[i]; + const size_t points = sl.size() / 3; + lengths(i) = static_cast(points); + offsets(i + 1) = offsets(i) + points; + for (size_t p = 0; p < points; ++p, ++cursor) { + positions(cursor, 0) = sl[p * 3 + 0]; + positions(cursor, 1) = sl[p * 3 + 1]; + positions(cursor, 2) = sl[p * 3 + 2]; + } + } + + // Map cluster labels to TRX groups (one group per label). + std::map> clusters; + for (size_t i = 0; i < cluster_ids.size(); ++i) { + clusters[cluster_ids[i]].push_back(static_cast(i)); + } + for (const auto &kv : clusters) { + trx.add_group_from_indices("cluster_" + std::to_string(kv.first), kv.second); + } + + trx.save("out.trx", ZIP_CM_STORE); + trx.close(); + +Using TRX with nibrary (dmriTrekker) +------------------------------------ + +nibrary (used by dmriTrekker for tractogram handling) could provide TRX +reading and writing support in its tractography I/O layer. TRX fits this +pipeline well because it exposes the same primitives that nibrary uses +internally: a list of streamlines (each a list of 3D points) plus optional +per-streamline and per-vertex fields. + +Coordinate systems: + +- TRX ``positions`` are stored in world space (RASMM), which is RAS+ and matches + the coordinate system used by MRtrix3's ``.tck`` format. nibrary's internal + streamline points are the same coordinates written out by its TCK writer, so + those points map directly to TRX ``positions`` when using the same reference + space. +- TRX header fields ``VOXEL_TO_RASMM`` and ``DIMENSIONS`` should be populated from + the reference image used by dmriTrekker/nibrary so downstream tools interpret + coordinates consistently. + +Usage sketch (nibrary) +~~~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: cpp + + // nibrary uses Streamline = std::vector and Tractogram = std::vector. + using NIBR::Streamline; + using NIBR::Tractogram; + + Tractogram nibr_streamlines = /* nibrary tractogram data */; + + // Write nibrary streamlines to TRX. + size_t total_vertices = 0; + for (const auto &sl : nibr_streamlines) { + total_vertices += sl.size(); + } + + trx::TrxFile trx_out(static_cast(total_vertices), + static_cast(nibr_streamlines.size())); + auto &positions = trx_out.streamlines->_data; + auto &offsets = trx_out.streamlines->_offsets; + auto &lengths = trx_out.streamlines->_lengths; + + size_t cursor = 0; + offsets(0) = 0; + for (size_t i = 0; i < nibr_streamlines.size(); ++i) { + const auto &sl = nibr_streamlines[i]; + lengths(i) = static_cast(sl.size()); + offsets(i + 1) = offsets(i) + sl.size(); + for (size_t p = 0; p < sl.size(); ++p, ++cursor) { + positions(cursor, 0) = sl[p][0]; + positions(cursor, 1) = sl[p][1]; + positions(cursor, 2) = sl[p][2]; + } + } + + trx_out.save("tracks.trx", ZIP_CM_STORE); + trx_out.close(); + + // Read TRX into nibrary-style streamlines. + auto trx_in = trx::load_any("tracks.trx"); + const auto pos = trx_in.positions.as_matrix(); + const auto offs = trx_in.offsets.as_matrix(); + + Tractogram out_streamlines; + out_streamlines.reserve(trx_in.num_streamlines()); + for (size_t i = 0; i < trx_in.num_streamlines(); ++i) { + const size_t start = static_cast(offs(i, 0)); + const size_t end = static_cast(offs(i + 1, 0)); + Streamline sl; + sl.reserve(end - start); + for (size_t j = start; j < end; ++j) { + sl.push_back({pos(j, 0), pos(j, 1), pos(j, 2)}); + } + out_streamlines.push_back(std::move(sl)); + } + + trx_in.close(); + +Loading MITK Diffusion streamlines into TRX +------------------------------------------- + +MITK Diffusion stores its streamline output in ``StreamlineTrackingFilter`` as a +``BundleType``, which is a ``std::vector`` of ``FiberType`` objects. Each ``FiberType`` +is a ``std::deque>``, i.e., an ordered list of 3D points in +physical space. Converting this to TRX follows the same pattern as any other +list-of-points representation: flatten all points into the TRX ``positions`` +array and build a prefix-sum ``offsets`` array. + +Note on physical space, headers, and affines: + +- MITK streamlines are in physical space (millimeters) using ITK's LPS+ + convention by default. TRX ``positions`` are expected to be in RASMM, so you + should flip the x and y axes when writing TRX (and flip back when reading). + +The sketch below shows how to write a ``BundleType`` to TRX and how to reconstruct +it from a TRX file if needed. + +.. code-block:: cpp + + #include + #include + #include + #include + + using FiberType = std::deque>; + using BundleType = std::vector; + + void mitk_bundle_to_trx(const BundleType &bundle, const std::string &out_path) { + size_t total_vertices = 0; + for (const auto &fiber : bundle) { + total_vertices += fiber.size(); + } + + trx::TrxFile trx(total_vertices, bundle.size()); + auto &positions = trx.streamlines->_data; + auto &offsets = trx.streamlines->_offsets; + auto &lengths = trx.streamlines->_lengths; + + size_t cursor = 0; + offsets(0) = 0; + for (size_t i = 0; i < bundle.size(); ++i) { + const auto &fiber = bundle[i]; + lengths(i) = static_cast(fiber.size()); + offsets(i + 1) = offsets(i) + fiber.size(); + size_t j = 0; + for (const auto &pt : fiber) { + // LPS (MITK/ITK) -> RAS (TRX) + positions(cursor, 0) = -pt[0]; + positions(cursor, 1) = -pt[1]; + positions(cursor, 2) = pt[2]; + ++cursor; + ++j; + } + } + + trx.save(out_path, ZIP_CM_STORE); + trx.close(); + } + + BundleType trx_to_mitk_bundle(const std::string &trx_path) { + auto trx = trx::load_any(trx_path); + const auto positions = trx.positions.as_matrix(); // (NB_VERTICES, 3) + const auto offsets = trx.offsets.as_matrix(); // (NB_STREAMLINES + 1, 1) + + BundleType bundle; + bundle.reserve(trx.num_streamlines()); + + for (size_t i = 0; i < trx.num_streamlines(); ++i) { + const size_t start = static_cast(offsets(i, 0)); + const size_t end = static_cast(offsets(i + 1, 0)); + FiberType fiber; + fiber.resize(end - start); + for (size_t j = start; j < end; ++j) { + // RAS (TRX) -> LPS (MITK/ITK) + fiber[j - start][0] = -positions(j, 0); + fiber[j - start][1] = -positions(j, 1); + fiber[j - start][2] = positions(j, 2); + } + bundle.push_back(std::move(fiber)); + } + + trx.close(); + return bundle; + } + +TRX in ITK-SNAP slice views +---------------------------- + +ITK-SNAP is a very nice image viewer that has not had the ability to visualize +streamlines. It is a very useful tool to check image alignment, especially if +you are working with ITK/ANTs, as it interprets image headers using ITK. + +Streamlines could be added to ITK-SNAP slice views by adding a renderer delegate to +the slice rendering pipeline. DSI Studio and MRview both have the ability to plot +streamlines on slices, but neither use ITK to interpret nifti headers. Someday +when TRX is directly integrated into ITK, ITK-SNAP integration could be used +to check that ``antsApplyTransformsToTRX`` is working correctly. + +Where to maybe integrate: + +- ``GUI/Qt/Components/SliceViewPanel`` sets up the slice view and installs + renderer delegates. +- ``GUI/Renderer/GenericSliceRenderer`` and ``SliceRendererDelegate`` define + the overlay rendering API (lines, polylines, etc.). +- Existing overlays (e.g., ``CrosshairsRenderer`` and + ``PolygonDrawingRenderer``) show how to draw line-based primitives. + +Possible workflow: + +1. Load a TRX file via GUI. +2. Create a new renderer delegate (e.g., ``StreamlineTrajectoryRenderer``) that: + - Filters streamlines that intersect the current slice plane (optionally + using cached AABBs for speed). + - Projects 3D points into slice coordinates using + ``GenericSliceModel::MapImageToSlice`` or + ``GenericSliceModel::MapSliceToImagePhysical``. + - Draws each trajectory with ``DrawPolyLine`` in the render context. +3. Register the delegate in ``SliceViewPanel`` so it renders above the image. + +Coordinate systems: + +- ITK-SNAP uses LPS+ physical coordinates by default. +- TRX stores positions in RAS+ world coordinates, so x/y should be negated when + moving between TRX and ITK-SNAP physical space. + +This design keeps the TRX integration localized to the slice overlay system and +does not require changes to core ITK-SNAP data structures. + +TRX in SlicerDMRI +----------------- + +SlicerDMRI represents tractography as ``vtkPolyData`` inside a +``vtkMRMLFiberBundleNode``. TRX support is implemented by converting TRX +structures to that polydata representation (and back on save). + +High-level mapping: + +- TRX ``positions`` + ``offsets`` map to polydata points and polyline cells. + Each streamline becomes one line cell; point coordinates are stored in RAS+. +- TRX DPV (data-per-vertex) becomes ``PointData`` arrays on the polydata. +- TRX DPS (data-per-streamline) becomes ``CellData`` arrays on the polydata. +- TRX groups are represented as a single label array per streamline + (``TRX_GroupId``), with a name table stored in ``FieldData`` as + ``TRX_GroupNames``. + +Round-trip metadata convention: + +- DPV arrays are stored as ``TRX_DPV_`` in ``PointData``. +- DPS arrays are stored as ``TRX_DPS_`` in ``CellData``. +- The storage node only exports arrays with these prefixes back into TRX, so + metadata remains recognizable and round-trippable. + +How users can visualize and interact with TRX metadata in the Slicer GUI: + +- **DPV**: can be used for per-point coloring (e.g., FA along the fiber) by + selecting the corresponding ``TRX_DPV_*`` array as the scalar to display. +- **DPS**: can be used for per-streamline coloring or thresholding by selecting + a ``TRX_DPS_*`` array in the fiber bundle display controls. +- **Groups**: color by ``TRX_GroupId`` to show group membership, and use + thresholding or selection filters to isolate specific group ids. The group + id-to-name mapping is stored in ``TRX_GroupNames`` for reference. + +Users can add their own DPV/DPS arrays in Slicer (via Python, modules, or +filters). To ensure these arrays are written back into TRX, name them with the +``TRX_DPV_`` or ``TRX_DPS_`` prefixes and keep them single-component with the +correct tuple counts (points for DPV, streamlines for DPS). + +TrxReader vs TrxFile +-------------------- + +``TrxFile
`` is the core typed container. It owns the memory-mapped arrays, +exposes streamlines and metadata as Eigen matrices, and provides mutation and +save operations. The template parameter ``DT`` fixes the positions dtype +(``half``, ``float``, or ``double``), which allows zero-copy access and avoids +per-element conversions. + +``TrxReader
`` is a small RAII wrapper that loads a TRX file and manages the +backing resources. It ensures the temporary extraction directory (for zipped +TRX) is cleaned up when the reader goes out of scope, and provides safe access +to the underlying ``TrxFile``. This separation keeps ``TrxFile`` focused on the +data model, while ``TrxReader`` handles ownership and lifecycle concerns for +loaded files. + +In practice, most downstream users do not need to instantiate ``TrxReader`` +directly. The common entry points are convenience functions like +``trx::load_any`` or ``trx::with_trx_reader`` and higher-level wrappers that +return a ready-to-use ``TrxFile``. ``TrxReader`` remains available for advanced +use cases where explicit lifetime control of the backing resources is needed. \ No newline at end of file diff --git a/docs/index.rst b/docs/index.rst new file mode 100644 index 0000000..58f0d3f --- /dev/null +++ b/docs/index.rst @@ -0,0 +1,25 @@ +TRX-cpp Documentation +===================== + +.. image:: https://codecov.io/gh/tee-ar-ex/trx-cpp/branch/main/graph/badge.svg + :target: https://codecov.io/gh/tee-ar-ex/trx-cpp + :alt: codecov + +TRX-cpp is a C++ library for reading, writing, and memory-mapping the TRX +tractography file format. + +.. toctree:: + :maxdepth: 2 + :caption: Getting Started + + overview + building + usage + downstream_usage + linting + +.. toctree:: + :maxdepth: 2 + :caption: API Reference + + api/library_root diff --git a/docs/linting.rst b/docs/linting.rst new file mode 100644 index 0000000..da1153b --- /dev/null +++ b/docs/linting.rst @@ -0,0 +1,62 @@ +Linting and Style Checks +======================== + +This repo includes ``.clang-tidy`` and ``.clang-format`` at the root, plus +MRtrix-inspired helper scripts for formatting and style checks. + +Prerequisites +------------- + +- ``clang-format`` available on ``PATH`` + - macOS (Homebrew): ``brew install llvm`` (or ``llvm@17``) and ensure + ``clang-format`` is on ``PATH`` + - Ubuntu: ``sudo apt-get install clang-format`` +- For ``check_syntax`` on macOS, GNU grep is required (``brew install grep``, + then it will use ``ggrep``). + +clang-format (bulk formatting) +------------------------------ + +.. code-block:: bash + + ./clang-format-all + +You can target a specific clang-format binary: + +.. code-block:: bash + + ./clang-format-all --executable /path/to/clang-format + +check_syntax (style rules) +-------------------------- + +Run the MRtrix-style checks against the C++ sources: + +.. code-block:: bash + + ./check_syntax + +Results are written to ``syntax.log`` when issues are found. + +clang-tidy +---------- + +Generate a build with compile commands, then run clang-tidy (matches CI): + +.. code-block:: bash + + cmake -S . -B build \ + -DTRX_USE_CONAN=OFF \ + -DTRX_BUILD_EXAMPLES=ON \ + -DCMAKE_EXPORT_COMPILE_COMMANDS=ON + run-clang-tidy -p build $(git ls-files '*.cpp' '*.h' '*.hpp' '*.tpp' ':!third_party/**') + +To run clang-tidy automatically during builds: + +.. code-block:: bash + + cmake -S . -B build -DTRX_ENABLE_CLANG_TIDY=ON + cmake --build build + +If you have ``run-clang-tidy`` installed (LLVM extras), you can lint everything +tracked by the repo (excluding ``third_party``), which matches CI. diff --git a/docs/overview.rst b/docs/overview.rst new file mode 100644 index 0000000..37e783e --- /dev/null +++ b/docs/overview.rst @@ -0,0 +1,108 @@ +Overview +======== + +TRX-cpp provides: + +- Read and write TRX archives and directories +- Memory-mapped access for large datasets +- Simple access to streamlines and metadata + +The API is header-focused under ``include/trx``, with implementation in ``src/``. + +TRX in C++ +========== + +TRX file format overview +------------------------ + +TRX is a tractography container for streamlines and auxiliary data. The core +geometry lives in two arrays: a single ``positions`` array containing all +points and an ``offsets`` array that delineates each streamline. Coordinates +are stored in **RAS+ world space (millimeters)**. This matches the coordinate +convention used by MRtrix ``.tck`` and many other tractography tools. +It also avoids the pitfalls of image-based coordinate systems. + +TRX can be stored either as: + +- A **directory** containing ``header.json`` and data files (positions, + offsets, dpv/dps/groups/dpg). +- A **zip archive** (``.trx``) with the same internal structure. + +Auxiliary data stored alongside streamlines: + +- **DPV** (data per vertex): values aligned with each point. +- **DPS** (data per streamline): values aligned with each streamline. +- **Groups**: named sets of streamline indices for labeling or clustering. +- **DPG** (data per group): values aligned with groups (one set per group). + +The ``header.json`` includes spatial metadata such as ``VOXEL_TO_RASMM`` and +``DIMENSIONS`` to preserve interpretation of coordinates. See the +`TRX specification `_ for details. + + +Positions array +-------------- + +The ``positions`` array is a single contiguous matrix of shape +``(NB_VERTICES, 3)``. Storing all vertices in one array is cache-friendly and +enables efficient memory mapping. In trx-cpp, ``positions`` is backed by a +``mio::shared_mmap_sink`` and exposed as an ``Eigen::Matrix`` for zero-copy access +when possible. + +Offsets and the sentinel value +------------------------------ + +The ``offsets`` array is a prefix-sum index into ``positions``. Its length is +``NB_STREAMLINES + 1``. The final element is a **sentinel** that equals +``NB_VERTICES`` and makes length computation trivial: + +``length_i = offsets[i + 1] - offsets[i]``. + +This design avoids per-streamline allocations and supports fast slicing of the +global ``positions`` array. In trx-cpp, offsets are stored as ``uint64`` and +mapped directly into Eigen. + +Data per vertex (DPV) +--------------------- + +DPV stores a value for each vertex in ``positions``. Examples include FA values +along the tract, per-point colors, or confidence measures. DPV arrays have +shape ``(NB_VERTICES, 1)`` or ``(NB_VERTICES, N)`` for multi-component values. + +In trx-cpp, DPV fields are stored under ``dpv/`` and are memory-mapped similarly +to ``positions``. This keeps per-point metadata aligned and contiguous, which +is important for large tractograms. + +Data per streamline (DPS) +------------------------- + +DPS stores a value per streamline. Examples include streamline length, average +FA, or per-tract weights. DPS arrays have shape ``(NB_STREAMLINES, 1)`` or +``(NB_STREAMLINES, N)``. + +In trx-cpp, DPS fields live under ``dps/`` and are mapped into ``MMappedMatrix`` +objects, enabling efficient access without loading entire arrays into RAM. + +Groups and data per group (DPG) +------------------------------- + +Groups provide a sparse, overlapping labeling of streamlines. Each group is a +named list of streamline indices, and a streamline can belong to multiple +groups. Examples: + +- **Bundle labels** (e.g., ``CST_L``, ``CST_R``, ``CC``) +- **Clusters** from quickbundles or similar algorithms +- **Connectivity subsets** (e.g., streamlines that connect two ROIs) + +DPG (data per group) attaches metadata to each group. Examples: + +- Mean FA for each bundle +- A per-group color or display weight +- Scalar summaries computed over the group + +In the TRX on-disk layout, groups are stored under ``groups/`` as index arrays, +and DPG data is stored under ``dpg//`` as one or more arrays. + +In trx-cpp, groups are represented as ``MMappedMatrix`` objects and +DPG fields are stored as ``MMappedMatrix
`` entries. This keeps group data +as memory-mapped arrays so large group sets can be accessed without copying. diff --git a/docs/requirements.txt b/docs/requirements.txt new file mode 100644 index 0000000..ef182b1 --- /dev/null +++ b/docs/requirements.txt @@ -0,0 +1,4 @@ +sphinx +breathe +exhale +sphinx_rtd_theme diff --git a/docs/spec.md b/docs/spec.md new file mode 100644 index 0000000..56c122f --- /dev/null +++ b/docs/spec.md @@ -0,0 +1,152 @@ + +# Generals +- (Un)-Compressed Zip File or simple folder architecture + - File architecture describe the data + - Each file basename is the metadata’s name + - Each file extension is the metadata’s dtype + - Each file dimension is in the value between basename and metdata, 1-dimension array do not have to follow this convention for readability +- All arrays have a C-style memory layout(row-major) +- All arrays have a little-endian byte order +- Compression is optional + - use ZIP_STORE, if compression is desired use ZIP_DEFLATE + - Compressed TRX files will have to be decompressed before being loaded + +# Header +Only (or mostly) for use-readability, read-time checks and broader compatibility. + +- Dictionary in JSON + - VOXEL_TO_RASMM (4 lists of 4 float, 4x4 transformation matrix) + - DIMENSIONS (list of 3 uint16) + - NB_STREAMLINES (uint32) + - NB_VERTICES (uint64) + +# Arrays +# positions.float16 +- Written in world space(RASMM) + - Like TCK file +- Should always be a float16/32/64 + - Default could be float16 +- As contiguous 3D array(NB_VERTICES, 3) + +# offsets.uint64 +- Should always be a uint32/64 +- Where is the first vertex of each streamline, start at 0 +- Two ways of knowing how many vertices there are: + - Check the header + - Positions array size / dtypes / 3 + +- To get streamlines lengths: append the total number of vertices to the end of offsets and to the differences between consecutive elements of the array(ediff1d in numpy). + +# dpv (data_per_vertex) +- Always of size (NB_VERTICES, 1) or (NB_VERTICES, N) + +# dps (data_per_streamline) +- Always of size (NB_STREAMLINES, 1) or (NB_STREAMLINES, N) + +# Groups +Groups are tables of indices that allow sparse & overlapping representation(clusters, connectomics, bundles). +- All indices must be 0 < id < NB_STREAMLINES +- Datatype should be uint32 +- Allow to get a predefined streamlines subset from the memmaps efficiently +- Variables in sizes + +# dpg (data_per_group) +- Each folder is the name of a group +- Not all metadata have to be present in all groups +- Always of size (1,) or (N,) + +# Accepted extensions (datatype) +- int8/16/32/64 +- uint8/16/32/64 +- float16/32/64 + +# Example structure +```bash +OHBM_demo.trx +├── dpg +│ ├── AF_L +│ │ ├── mean_fa.float16 +│ │ ├── shuffle_colors.3.uint8 +│ │ └── volume.uint32 +│ ├── AF_R +│ │ ├── mean_fa.float16 +│ │ ├── shuffle_colors.3.uint8 +│ │ └── volume.uint32 +│ ├── CC +│ │ ├── mean_fa.float16 +│ │ ├── shuffle_colors.3.uint8 +│ │ └── volume.uint32 +│ ├── CST_L +│ │ └── shuffle_colors.3.uint8 +│ ├── CST_R +│ │ └── shuffle_colors.3.uint8 +│ ├── SLF_L +│ │ ├── mean_fa.float16 +│ │ ├── shuffle_colors.3.uint8 +│ │ └── volume.uint32 +│ └── SLF_R +│ ├── mean_fa.float16 +│ ├── shuffle_colors.3.uint8 +│ └── volume.uint32 +├── dpv +│ ├── color_x.uint8 +│ ├── color_y.uint8 +│ ├── color_z.uint8 +│ └── fa.float16 +├── dps +│ ├── algo.uint8 +│ ├── algo.json +│ ├── clusters_QB.uint16 +│ ├── commit_colors.3.uint8 +│ └── commit_weights.float32 +├── groups +│ ├── AF_L.uint32 +│ ├── AF_R.uint32 +│ ├── CC.uint32 +│ ├── CST_L.uint32 +│ ├── CST_R.uint32 +│ ├── SLF_L.uint32 +│ └── SLF_R.uint32 +├── header.json +├── offsets.uint64 +└── positions.3.float16 +``` + +# Example code +```python +from trx_file_memmap import TrxFile, load, save +import numpy as np + +trx = load('complete_big_v5.trx') + +# Access the header (dict) / streamlines (ArraySequences) +trx.header +trx.streamlines + +# Access the dpv (dict) / dps (dict) +trx.data_per_vertex +trx.data_per_streamline + +# Access the groups (dict) / dpg (dict) +trx.groups +trx.data_per_group + +# Get a random subset of 10000 streamlines +indices = np.arange(len(trx.streamlines._lengths)) +np.random.shuffle(indices) +sub_trx = trx.select(indices[0:10000]) +save(sub_trx, 'random_1000.trx') + +# Get sub-groups only, from the random subset +for key in sub_trx.groups.keys(): + group_trx = sub_trx.get_group(key) + save(group_trx, '{}.trx'.format(key)) + +# Pre-allocate memmaps and append 100x the random subset +alloc_trx = TrxFile(nb_streamlines=1500000, nb_vertices=500000000, init_as=trx) +for i in range(100): + alloc_trx.append(sub_trx) + +# Resize to remove the unused portion of the memmap +alloc_trx.resize() +``` diff --git a/docs/usage.rst b/docs/usage.rst new file mode 100644 index 0000000..16860ee --- /dev/null +++ b/docs/usage.rst @@ -0,0 +1,122 @@ +Usage +===== + +AnyTrxFile vs TrxFile +--------------------- + +``AnyTrxFile`` is the runtime-typed API. It reads the dtype from the file and +exposes arrays as ``TypedArray`` with a ``dtype`` string. This is the simplest +entry point when you only have a TRX path. + +``TrxFile
`` is the typed API. It is templated on the positions dtype +(``half``, ``float``, or ``double``) and maps data directly into Eigen matrices of +that type. It provides stronger compile-time guarantees but requires knowing +the dtype at compile time or doing manual dispatch. The recommended typed entry +point is :func:`trx::with_trx_reader`, which performs dtype detection and +dispatches to the matching ``TrxReader
``. + +See the API reference for details: :class:`trx::AnyTrxFile` and +:class:`trx::TrxFile`. + +Read a TRX zip and inspect data +------------------------------- + +.. code-block:: cpp + + #include + + using namespace trx; + + const std::string path = "/path/to/tracks.trx"; + + auto trx = load_any(path); + + std::cout << "dtype: " << trx.positions.dtype << "\n"; + std::cout << "Vertices: " << trx.num_vertices() << "\n"; + std::cout << "Streamlines: " << trx.num_streamlines() << "\n"; + + trx.close(); + + +Write a TRX file +---------------- + +.. code-block:: cpp + + auto trx = load_any("tracks.trx"); + auto header_obj = trx.header.object_items(); + header_obj["COMMENT"] = "saved by trx-cpp"; + trx.header = json(header_obj); + + trx.save("tracks_copy.trx", ZIP_CM_STORE); + trx.close(); + +Optional NIfTI header support +----------------------------- + +If downstream software will have to interact with ``trk`` format data, the NIfTI header +is going to be essential to go back and forth between ``trk`` and ``trx``. + +When built with ``TRX_ENABLE_NIFTI=ON``, or by default if you're building the examples, +you can read a NIfTI header (``.nii``, ``.hdr``, optionally ``.nii.gz``) and populate +``VOXEL_TO_RASMM`` in a TRX header. The qform is preferred; if missing, the +sform is orthogonalized to a qform-equivalent matrix (consistent with ITK's handling of nifti). +If using this feature, you will also need zlib available. The qform/sform logic here is +translated from nibabel's MIT-licensed implementation (see ``third_party/nibabel/LICENSE``). + +.. code-block:: cpp + + #include + #include + + Eigen::Matrix4f affine = trx::read_nifti_voxel_to_rasmm("reference.nii.gz"); + auto trx = trx::load("tracks.trx"); + trx->set_voxel_to_rasmm(affine); + trx->save("tracks_with_ref.trx"); + trx->close(); + +Build and query AABBs +--------------------- + +TRX can build per-streamline axis-aligned bounding boxes (AABB) and use them to +extract a subset of streamlines intersecting a rectangular region. AABBs are +stored in ``float16`` for memory efficiency, while comparisons are done in +``float32``. + +.. code-block:: cpp + + #include + + auto trx = trx::load("/path/to/tracks.trx"); + + // Query an axis-aligned box (min/max corners in RAS+ world coordinates). + std::array min_corner{-10.0f, -10.0f, -10.0f}; + std::array max_corner{10.0f, 10.0f, 10.0f}; + + auto subset = trx->query_aabb(min_corner, max_corner); + // Or precompute and pass the AABB cache explicitly: + // auto aabbs = trx->build_streamline_aabbs(); + // auto subset = trx->query_aabb(min_corner, max_corner, &aabbs); + // Optionally build cache for the result: + // auto subset = trx->query_aabb(min_corner, max_corner, &aabbs, true); + subset->save("subset.trx", ZIP_CM_STORE); + subset->close(); + +Subset by streamline IDs +------------------------ + +If you already have a list of streamline indices (for example, from a clustering +step or a spatial query), you can create a new TrxFile directly from those +indices. + +.. code-block:: cpp + + #include + + auto trx = trx::load("/path/to/tracks.trx"); + + std::vector ids{0, 4, 42, 99}; + auto subset = trx->subset_streamlines(ids); + + subset->save("subset_by_id.trx", ZIP_CM_STORE); + subset->close(); diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index cae2ed3..c2aea9a 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -2,6 +2,12 @@ add_executable(trxinfo trxinfo.cpp) target_link_libraries(trxinfo PRIVATE trx cxxopts::cxxopts) target_compile_features(trxinfo PRIVATE cxx_std_17) +if(TRX_ENABLE_NIFTI) + add_executable(trx-update-affine trx_nifti_affine.cpp) + target_link_libraries(trx-update-affine PRIVATE trx trx-nifti) + target_compile_features(trx-update-affine PRIVATE cxx_std_17) +endif() + install(TARGETS trxinfo RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} ) diff --git a/examples/load_trx.cpp b/examples/load_trx.cpp index c708394..1b7fec2 100644 --- a/examples/load_trx.cpp +++ b/examples/load_trx.cpp @@ -1,8 +1,12 @@ -#include "../src/trx.h" +#include -using namespace trxmmap; int main(int argc, char **argv) { // check_syntax off - trxmmap::TrxFile *trx = trxmmap::load_from_zip(argv[1]); + if (argc < 2) { + std::cerr << "Usage: load_trx \n"; + return 1; + } + + auto trx = trx::TrxFile::load_from_zip(argv[1]); std::cout << "Vertices: " << trx->streamlines->_data.size() / 3 << "\n"; std::cout << "First vertex (x,y,z): " << trx->streamlines->_data(0, 0) << "," << trx->streamlines->_data(0, 1) << "," diff --git a/examples/trx_nifti_affine.cpp b/examples/trx_nifti_affine.cpp new file mode 100644 index 0000000..da9aaa2 --- /dev/null +++ b/examples/trx_nifti_affine.cpp @@ -0,0 +1,25 @@ +#include +#include + +#include + +int main(int argc, char **argv) { + if (argc < 4) { + std::cerr << "Usage: trx-update-affine \n"; + return 1; + } + + const std::string input_trx = argv[1]; + const std::string nifti_path = argv[2]; + const std::string output_trx = argv[3]; + + Eigen::Matrix4f affine = trx::read_nifti_voxel_to_rasmm(nifti_path); + + return trx::with_trx_reader(input_trx, [&](auto &reader, trx::TrxScalarType) { + auto &trx_file = *reader; + trx_file.set_voxel_to_rasmm(affine); + trx_file.save(output_trx); + trx_file.close(); + return 0; + }); +} diff --git a/examples/trxinfo.cpp b/examples/trxinfo.cpp index a1f4f22..abec187 100644 --- a/examples/trxinfo.cpp +++ b/examples/trxinfo.cpp @@ -55,9 +55,7 @@ std::string format_matrix_row(const json &row) { return out.str(); } -template -void print_trx_info( - trxmmap::TrxFile
*trx, const std::string &path, bool is_dir, trxmmap::TrxScalarType dtype, bool show_stats) { +void print_trx_info(const trx::AnyTrxFile &trx, const std::string &path, bool is_dir, bool show_stats) { trx_cli::Colors colors; colors.enabled = trx_cli::stdout_supports_color(); @@ -65,10 +63,9 @@ void print_trx_info( std::cout << trx_cli::colorize(colors, colors.cyan, "Path") << ": " << path << "\n"; std::cout << trx_cli::colorize(colors, colors.cyan, "Storage") << ": " << (is_dir ? "directory" : "zip archive") << "\n"; - std::cout << trx_cli::colorize(colors, colors.cyan, "Positions dtype") << ": " << trxmmap::scalar_type_name(dtype) - << "\n"; + std::cout << trx_cli::colorize(colors, colors.cyan, "Positions dtype") << ": " << trx.positions.dtype << "\n"; - const json &header = trx->header; + const json &header = trx.header; if (!header.is_null()) { std::cout << trx_cli::colorize(colors, colors.green, "Header") << ":\n"; const json &nb_streamlines = header["NB_STREAMLINES"]; @@ -90,15 +87,15 @@ void print_trx_info( } } - if (show_stats && trx->streamlines != nullptr) { - const auto &lengths = trx->streamlines->_lengths; - const Eigen::Index count = lengths.size(); // NOLINT(misc-include-cleaner) + if (show_stats && !trx.lengths.empty()) { + const auto &lengths = trx.lengths; + const size_t count = lengths.size(); if (count > 0) { - uint64_t min_len = lengths(0); - uint64_t max_len = lengths(0); + uint64_t min_len = lengths[0]; + uint64_t max_len = lengths[0]; uint64_t total_len = 0; - for (Eigen::Index i = 0; i < count; ++i) { - const uint64_t value = lengths(i); + for (size_t i = 0; i < count; ++i) { + const uint64_t value = lengths[i]; min_len = std::min(min_len, value); max_len = std::max(max_len, value); total_len += value; @@ -112,46 +109,46 @@ void print_trx_info( } std::cout << trx_cli::colorize(colors, colors.green, "Data arrays") << ":\n"; - if (!trx->data_per_vertex.empty()) { - std::cout << " " << trx_cli::colorize(colors, colors.cyan, "Data per vertex") << ": " - << trx->data_per_vertex.size() << "\n"; - for (const auto &kv : trx->data_per_vertex) { - const auto &arr = kv.second->_data; - std::cout << " - " << kv.first << " (" << arr.rows() << "x" << arr.cols() << ")\n"; + if (!trx.data_per_vertex.empty()) { + std::cout << " " << trx_cli::colorize(colors, colors.cyan, "Data per vertex") << ": " << trx.data_per_vertex.size() + << "\n"; + for (const auto &kv : trx.data_per_vertex) { + const auto &arr = kv.second; + std::cout << " - " << kv.first << " (" << arr.rows << "x" << arr.cols << ")\n"; } } else { std::cout << " " << trx_cli::colorize(colors, colors.cyan, "Data per vertex") << ": none\n"; } - if (!trx->data_per_streamline.empty()) { + if (!trx.data_per_streamline.empty()) { std::cout << " " << trx_cli::colorize(colors, colors.cyan, "Data per streamline") << ": " - << trx->data_per_streamline.size() << "\n"; - for (const auto &kv : trx->data_per_streamline) { - const auto &arr = kv.second->_matrix; - std::cout << " - " << kv.first << " (" << arr.rows() << "x" << arr.cols() << ")\n"; + << trx.data_per_streamline.size() << "\n"; + for (const auto &kv : trx.data_per_streamline) { + const auto &arr = kv.second; + std::cout << " - " << kv.first << " (" << arr.rows << "x" << arr.cols << ")\n"; } } else { std::cout << " " << trx_cli::colorize(colors, colors.cyan, "Data per streamline") << ": none\n"; } - if (!trx->groups.empty()) { - std::cout << " " << trx_cli::colorize(colors, colors.cyan, "Groups") << ": " << trx->groups.size() << "\n"; - for (const auto &kv : trx->groups) { - const auto &arr = kv.second->_matrix; - std::cout << " - " << kv.first << " (" << arr.rows() << "x" << arr.cols() << ")\n"; + if (!trx.groups.empty()) { + std::cout << " " << trx_cli::colorize(colors, colors.cyan, "Groups") << ": " << trx.groups.size() << "\n"; + for (const auto &kv : trx.groups) { + const auto &arr = kv.second; + std::cout << " - " << kv.first << " (" << arr.rows << "x" << arr.cols << ")\n"; } } else { std::cout << " " << trx_cli::colorize(colors, colors.cyan, "Groups") << ": none\n"; } - if (!trx->data_per_group.empty()) { + if (!trx.data_per_group.empty()) { std::cout << " " << trx_cli::colorize(colors, colors.cyan, "Data per group") << ":\n"; - for (const auto &grp : trx->data_per_group) { + for (const auto &grp : trx.data_per_group) { std::cout << " - " << trx_cli::colorize(colors, colors.magenta, grp.first) << ": " << grp.second.size() << "\n"; for (const auto &kv : grp.second) { - const auto &arr = kv.second->_matrix; - std::cout << " * " << kv.first << " (" << arr.rows() << "x" << arr.cols() << ")\n"; + const auto &arr = kv.second; + std::cout << " * " << kv.first << " (" << arr.rows << "x" << arr.cols << ")\n"; } } } else { @@ -159,17 +156,6 @@ void print_trx_info( } } -struct ReaderPrinter { - std::string path; - bool is_dir; - bool show_stats; - - template int operator()(ReaderT &reader, trxmmap::TrxScalarType dtype) const { - print_trx_info(reader.get(), path, is_dir, dtype, show_stats); - return 0; - } -}; - } // namespace int main(int argc, char **argv) { // check_syntax off @@ -192,9 +178,11 @@ int main(int argc, char **argv) { // check_syntax off path = result["path"].as(); show_stats = result["stats"].as(); - const bool is_dir = trxmmap::is_trx_directory(path); - const ReaderPrinter printer{path, is_dir, show_stats}; - return trxmmap::with_trx_reader(path, printer); + const bool is_dir = trx::is_trx_directory(path); + auto trx_file = trx::load_any(path); + print_trx_info(trx_file, path, is_dir, show_stats); + trx_file.close(); + return 0; } catch (const cxxopts::exceptions::exception &e) { std::cerr << "trxinfo: " << e.what() << "\n"; if (!help_text.empty()) { diff --git a/include/trx/detail/dtype_helpers.h b/include/trx/detail/dtype_helpers.h new file mode 100644 index 0000000..703edb2 --- /dev/null +++ b/include/trx/detail/dtype_helpers.h @@ -0,0 +1,60 @@ +#ifndef TRX_DETAIL_DTYPE_HELPERS_H +#define TRX_DETAIL_DTYPE_HELPERS_H + +#include + +#include +#include + +namespace trx { +namespace detail { + +int _sizeof_dtype(const std::string &dtype); +std::string _get_dtype(const std::string &dtype); +bool _is_dtype_valid(const std::string &ext); +std::tuple _split_ext_with_dimensionality(const std::string &filename); + +template +inline Eigen::Matrix _compute_lengths(const Eigen::MatrixBase
&offsets, + int nb_vertices) { + static_cast(nb_vertices); + if (offsets.size() > 1) { + const auto casted = offsets.template cast(); + const Eigen::Index len = offsets.size() - 1; + Eigen::Matrix lengths(len); + for (Eigen::Index i = 0; i < len; ++i) { + lengths(i) = static_cast(casted(i + 1) - casted(i)); + } + return lengths; + } + // If offsets are empty or only contain the sentinel, there are zero streamlines. + return Eigen::Matrix(0); +} + +template +inline int _dichotomic_search(const Eigen::MatrixBase
&x, int l_bound = -1, int r_bound = -1) { + if (l_bound == -1 && r_bound == -1) { + l_bound = 0; + r_bound = static_cast(x.size()) - 1; + } + + if (l_bound == r_bound) { + int val = -1; + if (x(l_bound) != 0) { + val = l_bound; + } + return val; + } + + int mid_bound = (l_bound + r_bound + 1) / 2; + + if (x(mid_bound) == 0) { + return _dichotomic_search(x, l_bound, mid_bound - 1); + } + return _dichotomic_search(x, mid_bound, r_bound); +} + +} // namespace detail +} // namespace trx + +#endif // TRX_DETAIL_DTYPE_HELPERS_H diff --git a/include/trx/nifti_io.h b/include/trx/nifti_io.h new file mode 100644 index 0000000..d0e5b5d --- /dev/null +++ b/include/trx/nifti_io.h @@ -0,0 +1,26 @@ +#ifndef TRX_NIFTI_IO_H +#define TRX_NIFTI_IO_H + +#include +#include + +namespace trx { + +/** + * @brief Read VOXEL_TO_RASMM from a NIfTI header on disk. + * + * Implementation notes: + * - The qform/sform handling follows a direct translation of nibabel's + * NIfTI header logic (see nibabel/nifti1.py). The translation is adapted + * to C++ and Eigen, and avoids depending on nibabel at runtime. + * - Prefers qform when present. If qform is missing but sform is present, + * the sform is orthogonalized to a qform-equivalent matrix. + * + * Licensing: + * - nibabel is MIT-licensed; see third_party/nibabel/LICENSE for details. + */ +Eigen::Matrix4f read_nifti_voxel_to_rasmm(const std::string &path); + +} // namespace trx + +#endif diff --git a/include/trx/streamlines_ops.h b/include/trx/streamlines_ops.h index 1201059..8fd2749 100644 --- a/include/trx/streamlines_ops.h +++ b/include/trx/streamlines_ops.h @@ -13,7 +13,7 @@ #include #include -namespace trxmmap { +namespace trx { constexpr int kMinNbPoints = 5; inline float round_to_precision(float value, int precision) { @@ -150,6 +150,6 @@ perform_streamlines_operation( return {output, indices}; } -} // namespace trxmmap +} // namespace trx #endif diff --git a/include/trx/trx.h b/include/trx/trx.h index cfdc3d2..03ab146 100644 --- a/include/trx/trx.h +++ b/include/trx/trx.h @@ -13,13 +13,10 @@ #include #include #include -#include #include -#include +#include #include #include -#include -#include #include #include #include @@ -33,10 +30,9 @@ namespace trx { namespace fs = std::filesystem; } -using namespace Eigen; using json = json11::Json; -namespace trxmmap { +namespace trx { inline json::object _json_object(const json &value) { if (value.is_object()) { return value.object_items(); @@ -161,34 +157,29 @@ template <> struct DTypeName { static constexpr std::string_view value() { return "uint64"; } }; -template <> struct DTypeName { - static constexpr std::string_view value() { return "bit"; } -}; - template inline std::string dtype_from_scalar() { typedef typename std::remove_cv::type>::type CleanT; return std::string(DTypeName::value()); } -const std::string SEPARATOR = "/"; -const std::vector dtypes({"float16", - "bit", - "uint8", - "uint16", - "ushort", - "uint32", - "uint64", - "int8", - "int16", - "int32", - "int64", - "float32", - "float64"}); +inline constexpr const char *SEPARATOR = "/"; +inline const std::array dtypes = {"float16", + "uint8", + "uint16", + "ushort", + "uint32", + "uint64", + "int8", + "int16", + "int32", + "int64", + "float32", + "float64"}; template struct ArraySequence { - Map> _data; - Map> _offsets; - Matrix _lengths; + Eigen::Map> _data; + Eigen::Map> _offsets; + Eigen::Matrix _lengths; std::vector _offsets_owned; mio::shared_mmap_sink mmap_pos; mio::shared_mmap_sink mmap_off; @@ -197,7 +188,7 @@ template struct ArraySequence { }; template struct MMappedMatrix { - Map> _matrix; + Eigen::Map> _matrix; mio::shared_mmap_sink mmap; MMappedMatrix() : _matrix(nullptr, 1, 1) {} @@ -208,14 +199,14 @@ template class TrxFile { public: // Data Members json header; - ArraySequence
*streamlines; + std::unique_ptr> streamlines; - std::map *> groups; // vector of indices + std::map>> groups; // vector of indices - // int or float --check python floa *> data_per_streamline; - std::map *> data_per_vertex; - std::map *>> data_per_group; + // int or float --check python float precision (singletons) + std::map>> data_per_streamline; + std::map>> data_per_vertex; + std::map>>> data_per_group; std::string _uncompressed_folder_handle; bool _copy_safe; bool _owns_uncompressed_folder = false; @@ -238,18 +229,20 @@ template class TrxFile { * @param root The dirname of the ZipFile pointer * @return TrxFile* */ - static TrxFile
* + static std::unique_ptr> _create_trx_from_pointer(json header, std::map> dict_pointer_size, std::string root_zip = "", std::string root = ""); + template friend class TrxReader; + /** * @brief Create a deepcopy of the TrxFile * * @return TrxFile
* A deepcopied TrxFile of the current object */ - TrxFile
*deepcopy(); + std::unique_ptr> deepcopy(); /** * @brief Remove the ununsed portion of preallocated memmaps @@ -260,6 +253,45 @@ template class TrxFile { */ void resize(int nb_streamlines = -1, int nb_vertices = -1, bool delete_dpg = false); + /** + * @brief Save a TrxFile + * + * @param filename The path to save the TrxFile to + * @param compression_standard The compression standard to use, as defined by libzip (default: no compression) + */ + void save(const std::string &filename, zip_uint32_t compression_standard = ZIP_CM_STORE); + + void add_dps_from_text(const std::string &name, const std::string &dtype, const std::string &path); + template + void add_dps_from_vector(const std::string &name, const std::string &dtype, const std::vector &values); + /** + * @brief Add per-vertex values as DPV from an in-memory vector. + * + * @param name DPV name (used as filename in dpv/). + * @param dtype Output dtype (float16/float32/float64). + * @param values Per-vertex values; size must match NB_VERTICES. + */ + template + void add_dpv_from_vector(const std::string &name, const std::string &dtype, const std::vector &values); + /** + * @brief Add a group from a list of streamline indices. + * + * @param name Group name (used as filename in groups/). + * @param indices Streamline indices (uint32) belonging to the group. + */ + void add_group_from_indices(const std::string &name, const std::vector &indices); + /** + * @brief Set the VOXEL_TO_RASMM affine matrix in the TRX header. + * + * This updates header["VOXEL_TO_RASMM"] with the provided 4x4 matrix. + */ + void set_voxel_to_rasmm(const Eigen::Matrix4f &affine); + void add_dpv_from_tsf(const std::string &name, const std::string &dtype, const std::string &path); + void export_dpv_to_tsf(const std::string &name, + const std::string &path, + const std::string ×tamp, + const std::string &dtype = "float32") const; + /** * @brief Cleanup on-disk temporary folder and initialize an empty TrxFile * @@ -267,7 +299,141 @@ template class TrxFile { void close(); void _cleanup_temporary_directory(); + size_t num_vertices() const { + if (streamlines && streamlines->_offsets.size() > 0) { + const auto last = streamlines->_offsets(streamlines->_offsets.size() - 1); + return static_cast(last); + } + if (streamlines && streamlines->_data.size() > 0) { + return static_cast(streamlines->_data.rows()); + } + if (header["NB_VERTICES"].is_number()) { + return static_cast(header["NB_VERTICES"].int_value()); + } + return 0; + } + + size_t num_streamlines() const { + if (streamlines && streamlines->_offsets.size() > 0) { + return static_cast(streamlines->_offsets.size() - 1); + } + if (streamlines && streamlines->_lengths.size() > 0) { + return static_cast(streamlines->_lengths.size()); + } + if (header["NB_STREAMLINES"].is_number()) { + return static_cast(header["NB_STREAMLINES"].int_value()); + } + return 0; + } + + /** + * @brief Build per-streamline axis-aligned bounding boxes (AABB). + * + * Each entry is {min_x, min_y, min_z, max_x, max_y, max_z} in TRX coordinates. + */ + std::vector> build_streamline_aabbs() const; + + /** + * @brief Extract a subset of streamlines intersecting an axis-aligned box. + * + * The box is defined by min/max corners in TRX coordinates. + * Returns a new TrxFile with positions, DPV/DPS, and groups remapped. + * Optionally builds the AABB cache for the returned TrxFile. + */ + /** + * @brief Extract a subset of streamlines intersecting an axis-aligned box. + * + * The box is defined by min/max corners in TRX coordinates. + * Returns a new TrxFile with positions, DPV/DPS, and groups remapped. + * Optionally builds the AABB cache for the returned TrxFile. + */ + std::unique_ptr> + query_aabb(const std::array &min_corner, + const std::array &max_corner, + const std::vector> *precomputed_aabbs = nullptr, + bool build_cache_for_result = false) const; + + /** + * @brief Extract a subset of streamlines by index. + * + * The returned TrxFile remaps positions, DPV/DPS, and groups. + * Optionally builds the AABB cache for the returned TrxFile. + */ + std::unique_ptr> + subset_streamlines(const std::vector &streamline_ids, + bool build_cache_for_result = false) const; + + /** + * @brief Add a data-per-group (DPG) field from a flat vector. + * + * Values are stored as a matrix of shape (rows, cols). If cols is -1, + * it is inferred from values.size() / rows. + */ + template + void add_dpg_from_vector(const std::string &group, + const std::string &name, + const std::string &dtype, + const std::vector &values, + int rows = 1, + int cols = -1); + + /** + * @brief Add a data-per-group (DPG) field from an Eigen matrix. + */ + template + void add_dpg_from_matrix(const std::string &group, + const std::string &name, + const std::string &dtype, + const Eigen::MatrixBase &matrix); + + /** + * @brief Get a DPG field or nullptr if missing. + */ + const MMappedMatrix
*get_dpg(const std::string &group, const std::string &name) const; + + /** + * @brief List DPG groups present in this TrxFile. + */ + std::vector list_dpg_groups() const; + + /** + * @brief List DPG fields for a given group. + */ + std::vector list_dpg_fields(const std::string &group) const; + + /** + * @brief Remove a DPG field from a group. + */ + void remove_dpg(const std::string &group, const std::string &name); + + /** + * @brief Remove all DPG fields for a group. + */ + void remove_dpg_group(const std::string &group); + private: + void invalidate_aabb_cache() const; + mutable std::vector> aabb_cache_; + /** + * @brief Load a TrxFile from a zip archive. + * + * Internal: prefer TrxReader / with_trx_reader in public API. + */ + static std::unique_ptr> load_from_zip(const std::string &path); + + /** + * @brief Load a TrxFile from an on-disk directory. + * + * Internal: prefer TrxReader / with_trx_reader in public API. + */ + static std::unique_ptr> load_from_directory(const std::string &path); + + /** + * @brief Load a TrxFile from either a zip archive or directory. + * + * Internal: prefer TrxReader / with_trx_reader in public API. + */ + static std::unique_ptr> load(const std::string &path); /** * @brief Get the real size of data (ignoring zeros of preallocation) * @@ -291,40 +457,231 @@ template class TrxFile { int len(); }; -/** - * TODO: This function might be completely unecessary - * - * @param[in] root a Json::Value root obtained from reading a header file with JsonCPP - * @param[out] header a header containing the same elements as the original root - * */ -json assignHeader(const json &root); +namespace detail { +int _sizeof_dtype(const std::string &dtype); +} // namespace detail -/** - * Returns the properly formatted datatype name - * - * @param[in] dtype the returned Eigen datatype - * @param[out] fmt_dtype the formatted datatype - * - * */ -std::string _get_dtype(const std::string &dtype); +struct TypedArray { + std::string dtype; + int rows = 0; + int cols = 0; + mio::shared_mmap_sink mmap; + + bool empty() const { return rows == 0 || cols == 0 || mmap.data() == nullptr; } + size_t size() const { return static_cast(rows) * static_cast(cols); } + + /** + * @brief View the buffer as a row-major Eigen matrix of type T. + * + * This is a zero-copy view over the underlying memory map. The dtype must + * match the requested T or an exception is thrown. + */ + template Eigen::Map> as_matrix() { + return Eigen::Map>(data_as(), rows, cols); + } + + /** + * @brief View the buffer as a const row-major Eigen matrix of type T. + * + * This is a zero-copy view over the underlying memory map. The dtype must + * match the requested T or an exception is thrown. + */ + template + Eigen::Map> as_matrix() const { + return Eigen::Map>( + data_as(), rows, cols); + } + + struct ByteView { + const std::uint8_t *data = nullptr; + size_t size = 0; + }; + + struct MutableByteView { + std::uint8_t *data = nullptr; + size_t size = 0; + }; + + /** + * @brief Return a read-only byte view of the underlying buffer. + * + * The size is computed from dtype * rows * cols. This is useful for + * interop, hashing, or serialization without exposing raw pointers. + */ + ByteView to_bytes() const { + if (empty()) { + return {}; + } + return {reinterpret_cast(mmap.data()), + static_cast(detail::_sizeof_dtype(dtype)) * size()}; + } + + /** + * @brief Return a mutable byte view of the underlying buffer. + * + * The size is computed from dtype * rows * cols. Use with care: mutating + * the bytes will mutate the mapped file contents. + */ + MutableByteView to_bytes_mutable() { + if (empty()) { + return {}; + } + return {reinterpret_cast(mmap.data()), static_cast(detail::_sizeof_dtype(dtype)) * size()}; + } + +private: + const void *data() const { return mmap.data(); } + void *data() { return mmap.data(); } + + template T *data_as() { + const std::string expected = dtype_from_scalar(); + if (dtype != expected) { + throw std::invalid_argument("TypedArray dtype mismatch: expected " + expected + " got " + dtype); + } + return reinterpret_cast(mmap.data()); + } + + template const T *data_as() const { + const std::string expected = dtype_from_scalar(); + if (dtype != expected) { + throw std::invalid_argument("TypedArray dtype mismatch: expected " + expected + " got " + dtype); + } + return reinterpret_cast(mmap.data()); + } +}; + +class AnyTrxFile { +public: + AnyTrxFile() = default; + ~AnyTrxFile(); + + AnyTrxFile(const AnyTrxFile &) = delete; + AnyTrxFile &operator=(const AnyTrxFile &) = delete; + AnyTrxFile(AnyTrxFile &&) noexcept = default; + AnyTrxFile &operator=(AnyTrxFile &&) noexcept = default; + + json header; + TypedArray positions; + TypedArray offsets; + std::vector offsets_u64; + std::vector lengths; + + std::map groups; + std::map data_per_streamline; + std::map data_per_vertex; + std::map> data_per_group; + + size_t num_vertices() const; + size_t num_streamlines() const; + void close(); + void save(const std::string &filename, zip_uint32_t compression_standard = ZIP_CM_STORE); + + static AnyTrxFile load(const std::string &path); + static AnyTrxFile load_from_zip(const std::string &path); + static AnyTrxFile load_from_directory(const std::string &path); + +private: + std::string _uncompressed_folder_handle; + bool _owns_uncompressed_folder = false; + std::string _backing_directory; + + static std::string _normalize_dtype(const std::string &dtype); + static AnyTrxFile + _create_from_pointer(json header, + const std::map> &dict_pointer_size, + const std::string &root); + void _cleanup_temporary_directory(); +}; + +inline AnyTrxFile load_any(const std::string &path) { return AnyTrxFile::load(path); } /** - * @brief Get the size of the datatype + * @brief Streaming-friendly TRX builder that spools positions and finalizes to TrxFile. * - * @param dtype the string name of the datatype - * @return int corresponding to the size of the datatype + * TrxStream allows appending streamlines without knowing totals up front. + * It writes positions to a temporary binary file, tracks lengths, and can + * add DPS/DPV/group metadata. Call finalize() to produce a standard TRX. */ -int _sizeof_dtype(const std::string &dtype); +class TrxStream { +public: + explicit TrxStream(std::string positions_dtype = "float32"); + ~TrxStream(); + + TrxStream(const TrxStream &) = delete; + TrxStream &operator=(const TrxStream &) = delete; + TrxStream(TrxStream &&) noexcept = default; + TrxStream &operator=(TrxStream &&) noexcept = default; + + /** + * @brief Append a streamline from a flat xyz buffer. + * + * @param xyz Pointer to interleaved x,y,z data. + * @param point_count Number of 3D points in the buffer. + */ + void push_streamline(const float *xyz, size_t point_count); + /** + * @brief Append a streamline from a flat xyz vector. + */ + void push_streamline(const std::vector &xyz_flat); + /** + * @brief Append a streamline from a vector of 3D points. + */ + void push_streamline(const std::vector> &points); + + /** + * @brief Add per-streamline values (DPS) from an in-memory vector. + */ + template + void push_dps_from_vector(const std::string &name, const std::string &dtype, const std::vector &values); + /** + * @brief Add per-vertex values (DPV) from an in-memory vector. + */ + template + void push_dpv_from_vector(const std::string &name, const std::string &dtype, const std::vector &values); + /** + * @brief Add a group from a list of streamline indices. + */ + void push_group_from_indices(const std::string &name, const std::vector &indices); + + /** + * @brief Finalize and write a TRX file. + */ + template void finalize(const std::string &filename, zip_uint32_t compression_standard = ZIP_CM_STORE); + + size_t num_streamlines() const { return lengths_.size(); } + size_t num_vertices() const { return total_vertices_; } + + json header; + +private: + struct FieldValues { + std::string dtype; + std::vector values; + }; + + void ensure_positions_stream(); + void cleanup_tmp(); + + std::string positions_dtype_; + std::string tmp_dir_; + std::string positions_path_; + std::ofstream positions_out_; + std::vector lengths_; + size_t total_vertices_ = 0; + bool finalized_ = false; + + std::map> groups_; + std::map dps_; + std::map dpv_; +}; /** - * Determine whether the extension is a valid extension - * - * - * @param[in] ext a string consisting of the extension starting by a . - * @param[out] is_valid a boolean denoting whether the extension is valid. + * TODO: This function might be completely unecessary * + * @param[in] root a Json::Value root obtained from reading a header file with JsonCPP + * @param[out] header a header containing the same elements as the original root * */ -bool _is_dtype_valid(const std::string &ext); +json assignHeader(const json &root); /** * This function loads the header json file @@ -343,7 +700,7 @@ json load_header(zip_t *zfolder); * @param[out] status return 0 if success else 1 * * */ -template TrxFile
*load_from_zip(std::string path); +template std::unique_ptr> load_from_zip(const std::string &path); /** * @brief Load a TrxFile from a folder containing memmaps @@ -352,7 +709,7 @@ template TrxFile
*load_from_zip(std::string path); * @param path path of the zipped TrxFile * @return TrxFile
* TrxFile representing the read data */ -template TrxFile
*load_from_directory(std::string path); +template std::unique_ptr> load_from_directory(const std::string &path); /** * @brief Detect the dtype of the positions array for a TRX path. @@ -407,7 +764,7 @@ bool is_trx_directory(const std::string &path); * @param path Path to TRX archive or directory * @return TrxFile
* TrxFile representing the read data */ -template TrxFile
*load(std::string path); +template std::unique_ptr> load(const std::string &path); /** * @brief RAII wrapper for loading TRX files from a path. @@ -417,19 +774,19 @@ template TrxFile
*load(std::string path); template class TrxReader { public: explicit TrxReader(const std::string &path); - ~TrxReader(); + ~TrxReader() = default; TrxReader(const TrxReader &) = delete; TrxReader &operator=(const TrxReader &) = delete; TrxReader(TrxReader &&other) noexcept; TrxReader &operator=(TrxReader &&other) noexcept; - TrxFile
*get() const { return trx_; } + TrxFile
*get() const { return trx_.get(); } TrxFile
&operator*() const { return *trx_; } - TrxFile
*operator->() const { return trx_; } + TrxFile
*operator->() const { return trx_.get(); } private: - TrxFile
*trx_ = nullptr; + std::unique_ptr> trx_; }; /** @@ -452,7 +809,9 @@ auto with_trx_reader(const std::string &path, Fn &&fn) * @param[in] dimensions vector of size 3 * * */ -void get_reference_info(const std::string &reference, const MatrixXf &affine, const RowVectorXi &dimensions); +void get_reference_info(const std::string &reference, + const Eigen::MatrixXf &affine, + const Eigen::RowVectorXi &dimensions); template std::ostream &operator<<(std::ostream &out, const TrxFile
&trx); // private: @@ -473,34 +832,13 @@ void allocate_file(const std::string &path, std::size_t size); // TODO: change tuple to vector to support ND arrays? // TODO: remove data type as that's done outside of this function mio::shared_mmap_sink _create_memmap(std::string filename, - std::tuple &shape, + const std::tuple &shape, const std::string &mode = "r", const std::string &dtype = "float32", long long offset = 0); -template std::string _generate_filename_from_data(const MatrixBase
&arr, const std::string filename); -std::tuple _split_ext_with_dimensionality(const std::string &filename); - -/** - * @brief Compute the lengths from offsets and header information - * - * @tparam DT The datatype (used for the input matrix) - * @param[in] offsets An array of offsets - * @param[in] nb_vertices the number of vertices - * @return Matrix of lengths - */ -template Matrix _compute_lengths(const MatrixBase
&offsets, int nb_vertices); - -/** - * @brief Find where data of a contiguous array is actually ending - * - * @tparam DT (the datatype) - * @param x Matrix of values - * @param l_bound lower bound index for search - * @param r_bound upper bound index for search - * @return int index at which array value is 0 (if possible), otherwise returns -1 - */ -template int _dichotomic_search(const MatrixBase
&x, int l_bound = -1, int r_bound = -1); +template +std::string _generate_filename_from_data(const Eigen::MatrixBase
&arr, const std::string filename); /** * @brief Create on-disk memmaps of a certain size (preallocation) @@ -511,10 +849,13 @@ template int _dichotomic_search(const MatrixBase
&x, int l_bou * @return TrxFile
An empty TrxFile preallocated with a certain size */ template -TrxFile
*_initialize_empty_trx(int nb_streamlines, int nb_vertices, const TrxFile
*init_as = nullptr); +std::unique_ptr> +_initialize_empty_trx(int nb_streamlines, int nb_vertices, const TrxFile
*init_as = nullptr); template -void ediff1d(Matrix &lengths, const Matrix &tmp, uint32_t to_end); +void ediff1d(Eigen::Matrix &lengths, + const Eigen::Matrix &tmp, + uint32_t to_end); /** * @brief Save a TrxFile @@ -525,21 +866,6 @@ void ediff1d(Matrix &lengths, const Matrix * @param compression_standard The compression standard to use, as defined by libzip (default: no * compression) */ -template -void save(TrxFile
&trx, const std::string filename, zip_uint32_t compression_standard = ZIP_CM_STORE); - -template -void add_dps_from_text(TrxFile
&trx, const std::string &name, const std::string &dtype, const std::string &path); - -template -void add_dpv_from_tsf(TrxFile
&trx, const std::string &name, const std::string &dtype, const std::string &path); - -template -void export_dpv_to_tsf(const TrxFile
&trx, - const std::string &name, - const std::string &path, - const std::string ×tamp, - const std::string &dtype = "float32"); /** * @brief Utils function to zip on-disk memmaps @@ -564,8 +890,17 @@ std::string make_temp_dir(const std::string &prefix); std::string extract_zip_to_directory(zip_t *zfolder); std::string rm_root(const std::string &root, const std::string &path); -#include +#ifndef TRX_TPP_STANDALONE +#endif + +} // namespace trx + +#include -} // namespace trxmmap +namespace trx { +#ifndef TRX_TPP_STANDALONE +#include +#endif +} // namespace trx #endif /* TRX_H */ \ No newline at end of file diff --git a/include/trx/trx.tpp b/include/trx/trx.tpp index 9e2ebfa..f0c8079 100644 --- a/include/trx/trx.tpp +++ b/include/trx/trx.tpp @@ -1,4 +1,17 @@ // Taken from: https://stackoverflow.com/a/25389481 +#ifndef TRX_H +#define TRX_TPP_STANDALONE +#define TRX_TPP_OPEN_NAMESPACE +#include +#undef TRX_TPP_STANDALONE +namespace trx { +#endif +using Eigen::Dynamic; +using Eigen::half; +using Eigen::Index; +using Eigen::Map; +using Eigen::Matrix; +using Eigen::RowMajor; template void write_binary(const std::string &filename, const Matrix &matrix) { std::ofstream out(filename, std::ios::out | std::ios::binary | std::ios::trunc); typename Matrix::Index rows = matrix.rows(), cols = matrix.cols(); @@ -23,7 +36,7 @@ template void read_binary(const std::string &filename, Matrix &ma template void ediff1d(Matrix &lengths, Matrix &tmp, uint32_t to_end) { - Map> v(tmp.data(), tmp.size()); + Map> v(tmp.data(), tmp.size()); lengths.resize(v.size(), 1); // TODO: figure out if there's a built in way to manage this @@ -35,7 +48,7 @@ void ediff1d(Matrix &lengths, Matrix &tmp, template // Caveat: if filename has an extension, it will be replaced by the generated dtype extension. -std::string _generate_filename_from_data(const MatrixBase
&arr, std::string filename) { +std::string _generate_filename_from_data(const Eigen::MatrixBase
&arr, std::string filename) { std::string base, ext; @@ -60,43 +73,6 @@ std::string _generate_filename_from_data(const MatrixBase
&arr, std::string return new_filename; } -template Matrix _compute_lengths(const MatrixBase
&offsets, int nb_vertices) { - if (offsets.size() > 1) { - const auto casted = offsets.template cast(); - const Eigen::Index len = offsets.size() - 1; - Matrix lengths(len); - for (Eigen::Index i = 0; i < len; ++i) { - lengths(i) = static_cast(casted(i + 1) - casted(i)); - } - return lengths; - } - // If offsets are empty or only contain the sentinel, there are zero streamlines. - return Matrix(0); -} - -template int _dichotomic_search(const MatrixBase
&x, int l_bound, int r_bound) { - if (l_bound == -1 && r_bound == -1) { - l_bound = 0; - r_bound = static_cast(x.size()) - 1; - } - - if (l_bound == r_bound) { - int val; - if (x(l_bound) != 0) - val = l_bound; - else - val = -1; - return val; - } - - int mid_bound = (l_bound + r_bound + 1) / 2; - - if (x(mid_bound) == 0) - return _dichotomic_search(x, l_bound, mid_bound - 1); - else - return _dichotomic_search(x, mid_bound, r_bound); -} - template TrxFile
::TrxFile(int nb_vertices, int nb_streamlines, const TrxFile
*init_as, std::string reference) { std::vector> affine(4); @@ -133,7 +109,7 @@ TrxFile
::TrxFile(int nb_vertices, int nb_streamlines, const TrxFile
*ini // will remove as completely unecessary. using as placeholders this->header = {}; - this->streamlines = nullptr; + this->streamlines.reset(); // TODO: maybe create a matrix to map to of specified DT. Do we need this?? // set default datatype to half @@ -144,18 +120,17 @@ TrxFile
::TrxFile(int nb_vertices, int nb_streamlines, const TrxFile
*ini nb_vertices = 0; nb_streamlines = 0; } else if (nb_vertices > 0 && nb_streamlines > 0) { - TrxFile
*trx = _initialize_empty_trx
(nb_streamlines, nb_vertices, init_as); - this->streamlines = trx->streamlines; - this->groups = trx->groups; - this->data_per_streamline = trx->data_per_streamline; - this->data_per_vertex = trx->data_per_vertex; - this->data_per_group = trx->data_per_group; - this->_uncompressed_folder_handle = trx->_uncompressed_folder_handle; + auto trx = _initialize_empty_trx
(nb_streamlines, nb_vertices, init_as); + this->streamlines = std::move(trx->streamlines); + this->groups = std::move(trx->groups); + this->data_per_streamline = std::move(trx->data_per_streamline); + this->data_per_vertex = std::move(trx->data_per_vertex); + this->data_per_group = std::move(trx->data_per_group); + this->_uncompressed_folder_handle = std::move(trx->_uncompressed_folder_handle); this->_owns_uncompressed_folder = trx->_owns_uncompressed_folder; this->_copy_safe = trx->_copy_safe; trx->_owns_uncompressed_folder = false; trx->_uncompressed_folder_handle.clear(); - delete trx; } else { throw std::invalid_argument("You must declare both NB_VERTICES AND NB_STREAMLINES"); } @@ -171,8 +146,8 @@ TrxFile
::TrxFile(int nb_vertices, int nb_streamlines, const TrxFile
*ini } template -TrxFile
*_initialize_empty_trx(int nb_streamlines, int nb_vertices, const TrxFile
*init_as) { - TrxFile
*trx = new TrxFile
(); +std::unique_ptr> _initialize_empty_trx(int nb_streamlines, int nb_vertices, const TrxFile
*init_as) { + auto trx = std::make_unique>(); std::string tmp_dir = make_temp_dir("trx"); @@ -203,8 +178,8 @@ TrxFile
*_initialize_empty_trx(int nb_streamlines, int nb_vertices, const Tr std::tuple shape = std::make_tuple(nb_vertices, 3); - trx->streamlines = new ArraySequence
(); - trx->streamlines->mmap_pos = trxmmap::_create_memmap(positions_filename, shape, "w+", positions_dtype); + trx->streamlines = std::make_unique>(); + trx->streamlines->mmap_pos = trx::_create_memmap(positions_filename, shape, "w+", positions_dtype); // TODO: find a better way to get the dtype than using all these switch cases. Also refactor // into function as per specifications, positions can only be floats @@ -224,7 +199,7 @@ TrxFile
*_initialize_empty_trx(int nb_streamlines, int nb_vertices, const Tr std::tuple shape_off = std::make_tuple(nb_streamlines + 1, 1); - trx->streamlines->mmap_off = trxmmap::_create_memmap(offsets_filename, shape_off, "w+", offsets_dtype); + trx->streamlines->mmap_off = trx::_create_memmap(offsets_filename, shape_off, "w+", offsets_dtype); new (&(trx->streamlines->_offsets)) Map>( reinterpret_cast(trx->streamlines->mmap_off.data()), std::get<0>(shape_off), std::get<1>(shape_off)); @@ -269,8 +244,8 @@ TrxFile
*_initialize_empty_trx(int nb_streamlines, int nb_vertices, const Tr } std::tuple dpv_shape = std::make_tuple(rows, cols); - trx->data_per_vertex[x.first] = new ArraySequence
(); - trx->data_per_vertex[x.first]->mmap_pos = trxmmap::_create_memmap(dpv_filename, dpv_shape, "w+", dpv_dtype); + trx->data_per_vertex[x.first] = std::make_unique>(); + trx->data_per_vertex[x.first]->mmap_pos = trx::_create_memmap(dpv_filename, dpv_shape, "w+", dpv_dtype); if (dpv_dtype.compare("float16") == 0) { new (&(trx->data_per_vertex[x.first]->_data)) Map>( reinterpret_cast(trx->data_per_vertex[x.first]->mmap_pos.data()), rows, cols); @@ -307,9 +282,9 @@ TrxFile
*_initialize_empty_trx(int nb_streamlines, int nb_vertices, const Tr } std::tuple dps_shape = std::make_tuple(rows, cols); - trx->data_per_streamline[x.first] = new trxmmap::MMappedMatrix
(); + trx->data_per_streamline[x.first] = std::make_unique>(); trx->data_per_streamline[x.first]->mmap = - trxmmap::_create_memmap(dps_filename, dps_shape, std::string("w+"), dps_dtype); + trx::_create_memmap(dps_filename, dps_shape, std::string("w+"), dps_dtype); if (dps_dtype.compare("float16") == 0) { new (&(trx->data_per_streamline[x.first]->_matrix)) Map>( @@ -332,14 +307,14 @@ TrxFile
*_initialize_empty_trx(int nb_streamlines, int nb_vertices, const Tr } template -TrxFile
* +std::unique_ptr> TrxFile
::_create_trx_from_pointer(json header, std::map> dict_pointer_size, std::string root_zip, std::string root) { - trxmmap::TrxFile
*trx = new trxmmap::TrxFile
(); + auto trx = std::make_unique>(); trx->header = header; - trx->streamlines = new ArraySequence
(); + trx->streamlines = std::make_unique>(); std::string filename; @@ -372,15 +347,11 @@ TrxFile
::_create_trx_from_pointer(json header, } // _split_ext_with_dimensionality - std::tuple base_tuple = _split_ext_with_dimensionality(elem_filename); + std::tuple base_tuple = trx::detail::_split_ext_with_dimensionality(elem_filename); std::string base(std::get<0>(base_tuple)); int dim = std::get<1>(base_tuple); std::string ext(std::get<2>(base_tuple)); - if (ext.compare("bit") == 0) { - ext = "bool"; - } - long long mem_adress = std::get<0>(x->second); long long size = std::get<1>(x->second); @@ -392,7 +363,7 @@ TrxFile
::_create_trx_from_pointer(json header, std::tuple shape = std::make_tuple(static_cast(trx->header["NB_VERTICES"].int_value()), 3); trx->streamlines->mmap_pos = - trxmmap::_create_memmap(filename, shape, "r+", ext.substr(1, ext.size() - 1), mem_adress); + trx::_create_memmap(filename, shape, "r+", ext.substr(1, ext.size() - 1), mem_adress); // TODO: find a better way to get the dtype than using all these switch cases. Also // refactor into function as per specifications, positions can only be floats @@ -418,7 +389,7 @@ TrxFile
::_create_trx_from_pointer(json header, const int nb_str = static_cast(trx->header["NB_STREAMLINES"].int_value()); std::tuple shape = std::make_tuple(nb_str + 1, 1); - trx->streamlines->mmap_off = trxmmap::_create_memmap(filename, shape, "r+", ext, mem_adress); + trx->streamlines->mmap_off = trx::_create_memmap(filename, shape, "r+", ext, mem_adress); if (ext.compare("uint64") == 0) { new (&(trx->streamlines->_offsets)) Map>( @@ -435,12 +406,13 @@ TrxFile
::_create_trx_from_pointer(json header, } Matrix offsets = trx->streamlines->_offsets; - trx->streamlines->_lengths = _compute_lengths(offsets, static_cast(trx->header["NB_VERTICES"].int_value())); + trx->streamlines->_lengths = + trx::detail::_compute_lengths(offsets, static_cast(trx->header["NB_VERTICES"].int_value())); } else if (folder.compare("dps") == 0) { std::tuple shape; - trx->data_per_streamline[base] = new MMappedMatrix
(); + trx->data_per_streamline[base] = std::make_unique>(); int nb_scalar = size / static_cast(trx->header["NB_STREAMLINES"].int_value()); if (size % static_cast(trx->header["NB_STREAMLINES"].int_value()) != 0 || nb_scalar != dim) { @@ -449,7 +421,7 @@ TrxFile
::_create_trx_from_pointer(json header, } else { shape = std::make_tuple(static_cast(trx->header["NB_STREAMLINES"].int_value()), nb_scalar); } - trx->data_per_streamline[base]->mmap = trxmmap::_create_memmap(filename, shape, "r+", ext, mem_adress); + trx->data_per_streamline[base]->mmap = trx::_create_memmap(filename, shape, "r+", ext, mem_adress); if (ext.compare("float16") == 0) { new (&(trx->data_per_streamline[base]->_matrix)) @@ -471,7 +443,7 @@ TrxFile
::_create_trx_from_pointer(json header, else if (folder.compare("dpv") == 0) { std::tuple shape; - trx->data_per_vertex[base] = new ArraySequence
(); + trx->data_per_vertex[base] = std::make_unique>(); int nb_scalar = size / static_cast(trx->header["NB_VERTICES"].int_value()); if (size % static_cast(trx->header["NB_VERTICES"].int_value()) != 0 || nb_scalar != dim) { @@ -480,7 +452,7 @@ TrxFile
::_create_trx_from_pointer(json header, } else { shape = std::make_tuple(static_cast(trx->header["NB_VERTICES"].int_value()), nb_scalar); } - trx->data_per_vertex[base]->mmap_pos = trxmmap::_create_memmap(filename, shape, "r+", ext, mem_adress); + trx->data_per_vertex[base]->mmap_pos = trx::_create_memmap(filename, shape, "r+", ext, mem_adress); if (ext.compare("float16") == 0) { new (&(trx->data_per_vertex[base]->_data)) @@ -517,9 +489,8 @@ TrxFile
::_create_trx_from_pointer(json header, std::string data_name = path_basename(base); std::string sub_folder = path_basename(folder); - trx->data_per_group[sub_folder][data_name] = new MMappedMatrix
(); - trx->data_per_group[sub_folder][data_name]->mmap = - trxmmap::_create_memmap(filename, shape, "r+", ext, mem_adress); + trx->data_per_group[sub_folder][data_name] = std::make_unique>(); + trx->data_per_group[sub_folder][data_name]->mmap = trx::_create_memmap(filename, shape, "r+", ext, mem_adress); if (ext.compare("float16") == 0) { new (&(trx->data_per_group[sub_folder][data_name]->_matrix)) Map>( @@ -546,8 +517,8 @@ TrxFile
::_create_trx_from_pointer(json header, } else { shape = std::make_tuple(static_cast(size), 1); } - trx->groups[base] = new MMappedMatrix(); - trx->groups[base]->mmap = trxmmap::_create_memmap(filename, shape, "r+", ext, mem_adress); + trx->groups[base] = std::make_unique>(); + trx->groups[base]->mmap = trx::_create_memmap(filename, shape, "r+", ext, mem_adress); new (&(trx->groups[base]->_matrix)) Map>( reinterpret_cast(trx->groups[base]->mmap.data()), std::get<0>(shape), std::get<1>(shape)); } else { @@ -563,9 +534,9 @@ TrxFile
::_create_trx_from_pointer(json header, } // TODO: Major refactoring -template TrxFile
*TrxFile
::deepcopy() { - if (this->streamlines == nullptr || this->streamlines->_data.size() == 0 || this->streamlines->_offsets.size() == 0) { - trxmmap::TrxFile
*empty_copy = new trxmmap::TrxFile
(); +template std::unique_ptr> TrxFile
::deepcopy() { + if (!this->streamlines || this->streamlines->_data.size() == 0 || this->streamlines->_offsets.size() == 0) { + auto empty_copy = std::make_unique>(); empty_copy->header = this->header; return empty_copy; } @@ -577,7 +548,7 @@ template TrxFile
*TrxFile
::deepcopy() { // TODO: Definitely a better way to deepcopy json tmp_header = this->header; - ArraySequence
*to_dump = new ArraySequence
(); + auto to_dump = std::make_unique>(); // TODO: Verify that this is indeed a deep copy new (&(to_dump->_data)) Matrix(this->streamlines->_data); new (&(to_dump->_offsets)) Matrix(this->streamlines->_offsets); @@ -694,7 +665,7 @@ template TrxFile
*TrxFile
::deepcopy() { } } - TrxFile
*copy_trx = load_from_directory
(tmp_dir); + auto copy_trx = TrxFile
::load_from_directory(tmp_dir); copy_trx->_uncompressed_folder_handle = tmp_dir; copy_trx->_owns_uncompressed_folder = true; @@ -707,11 +678,11 @@ template std::tuple TrxFile
::_get_real_len() { if (this->streamlines->_lengths.size() == 0) return std::make_tuple(0, 0); - int last_elem_pos = _dichotomic_search(this->streamlines->_lengths); + int last_elem_pos = trx::detail::_dichotomic_search(this->streamlines->_lengths); if (last_elem_pos != -1) { int strs_end = last_elem_pos + 1; - int pts_end = this->streamlines->_lengths(seq(0, last_elem_pos), 0).sum(); + int pts_end = this->streamlines->_lengths(Eigen::seq(0, last_elem_pos), 0).sum(); return std::make_tuple(strs_end, pts_end); } @@ -730,7 +701,7 @@ TrxFile
::_copy_fixed_arrays_from(TrxFile
*trx, int strs_start, int pts_s curr_pts_len = std::get<1>(curr); } else { curr_strs_len = nb_strs_to_copy; - curr_pts_len = trx->streamlines->_lengths(seq(0, curr_strs_len - 1)).sum(); + curr_pts_len = trx->streamlines->_lengths(Eigen::seq(0, curr_strs_len - 1)).sum(); } if (pts_start == -1) { @@ -778,7 +749,26 @@ TrxFile
::_copy_fixed_arrays_from(TrxFile
*trx, int strs_start, int pts_s template void TrxFile
::close() { this->_cleanup_temporary_directory(); - *this = TrxFile
(); // probably dangerous to do + this->streamlines.reset(); + this->groups.clear(); + this->data_per_streamline.clear(); + this->data_per_vertex.clear(); + this->data_per_group.clear(); + this->_uncompressed_folder_handle.clear(); + this->_owns_uncompressed_folder = false; + this->_copy_safe = true; + + std::vector> affine(4, std::vector(4, 0.0f)); + for (int i = 0; i < 4; i++) { + affine[i][i] = 1.0f; + } + std::vector dimensions{1, 1, 1}; + json::object header_obj; + header_obj["VOXEL_TO_RASMM"] = affine; + header_obj["DIMENSIONS"] = dimensions; + header_obj["NB_VERTICES"] = 0; + header_obj["NB_STREAMLINES"] = 0; + this->header = json(header_obj); } template TrxFile
::~TrxFile() { this->_cleanup_temporary_directory(); } @@ -826,7 +816,7 @@ void TrxFile
::resize(int nb_streamlines, int nb_vertices, bool delete_dpg) { return; } - TrxFile
*trx = _initialize_empty_trx(nb_streamlines, nb_vertices, this); + auto trx = _initialize_empty_trx(nb_streamlines, nb_vertices, this); if (nb_streamlines < this->header["NB_STREAMLINES"].int_value()) trx->_copy_fixed_arrays_from(this, -1, -1, nb_streamlines); @@ -868,8 +858,8 @@ void TrxFile
::resize(int nb_streamlines, int nb_vertices, bool delete_dpg) { Matrix tmp = this->groups[x.first]->_matrix(keep_rows, keep_cols); std::tuple group_shape = std::make_tuple(tmp.size(), 1); - trx->groups[x.first] = new MMappedMatrix(); - trx->groups[x.first]->mmap = trxmmap::_create_memmap(group_name, group_shape, "w+", group_dtype); + trx->groups[x.first] = std::make_unique>(); + trx->groups[x.first]->mmap = trx::_create_memmap(group_name, group_shape, "w+", group_dtype); new (&(trx->groups[x.first]->_matrix)) Map>(reinterpret_cast(trx->groups[x.first]->mmap.data()), std::get<0>(group_shape), @@ -911,7 +901,9 @@ void TrxFile
::resize(int nb_streamlines, int nb_vertices, bool delete_dpg) { } if (trx->data_per_group.find(x.first) == trx->data_per_group.end()) { - trx->data_per_group[x.first] = {}; + trx->data_per_group.emplace(x.first, std::map>>{}); + } else { + trx->data_per_group[x.first].clear(); } for (auto const &y : this->data_per_group[x.first]) { @@ -923,7 +915,7 @@ void TrxFile
::resize(int nb_streamlines, int nb_vertices, bool delete_dpg) { this->data_per_group[x.first][y.first]->_matrix.cols()); if (trx->data_per_group[x.first].find(y.first) == trx->data_per_group[x.first].end()) { - trx->data_per_group[x.first][y.first] = new MMappedMatrix
(); + trx->data_per_group[x.first][y.first] = std::make_unique>(); } trx->data_per_group[x.first][y.first]->mmap = _create_memmap(dpg_filename, dpg_shape, "w+", dpg_dtype); @@ -945,7 +937,7 @@ void TrxFile
::resize(int nb_streamlines, int nb_vertices, bool delete_dpg) { } } -template TrxFile
*load_from_zip(std::string filename) { +template std::unique_ptr> TrxFile
::load_from_zip(const std::string &filename) { int errorp = 0; zip_t *zf = open_zip_for_read(filename, errorp); if (zf == nullptr) { @@ -955,13 +947,13 @@ template TrxFile
*load_from_zip(std::string filename) { std::string temp_dir = extract_zip_to_directory(zf); zip_close(zf); - TrxFile
*trx = load_from_directory
(temp_dir); + auto trx = TrxFile
::load_from_directory(temp_dir); trx->_uncompressed_folder_handle = temp_dir; trx->_owns_uncompressed_folder = true; return trx; } -template TrxFile
*load_from_directory(std::string path) { +template std::unique_ptr> TrxFile
::load_from_directory(const std::string &path) { std::string directory = path; { std::error_code ec; @@ -991,38 +983,25 @@ template TrxFile
*load_from_directory(std::string path) { return TrxFile
::_create_trx_from_pointer(header, files_pointer_size, "", directory); } -template TrxFile
*load(std::string path) { +template std::unique_ptr> TrxFile
::load(const std::string &path) { trx::fs::path input(path); if (!trx::fs::exists(input)) { throw std::runtime_error("Input path does not exist: " + path); } std::error_code ec; if (trx::fs::is_directory(input, ec) && !ec) { - return load_from_directory
(path); + return TrxFile
::load_from_directory(path); } - return load_from_zip
(path); + return TrxFile
::load_from_zip(path); } -template TrxReader
::TrxReader(const std::string &path) { trx_ = load
(path); } +template TrxReader
::TrxReader(const std::string &path) { trx_ = TrxFile
::load(path); } -template TrxReader
::~TrxReader() { - if (trx_ != nullptr) { - trx_->close(); - delete trx_; - trx_ = nullptr; - } -} - -template TrxReader
::TrxReader(TrxReader &&other) noexcept : trx_(other.trx_) { other.trx_ = nullptr; } +template TrxReader
::TrxReader(TrxReader &&other) noexcept : trx_(std::move(other.trx_)) {} template TrxReader
&TrxReader
::operator=(TrxReader &&other) noexcept { if (this != &other) { - if (trx_ != nullptr) { - trx_->close(); - delete trx_; - } - trx_ = other.trx_; - other.trx_ = nullptr; + trx_ = std::move(other.trx_); } return *this; } @@ -1048,15 +1027,44 @@ auto with_trx_reader(const std::string &path, Fn &&fn) } } -template void save(TrxFile
&trx, const std::string filename, zip_uint32_t compression_standard) { +template void TrxFile
::save(const std::string &filename, zip_uint32_t compression_standard) { std::string ext = get_ext(filename); if (ext.size() > 0 && (ext != "zip" && ext != "trx")) { throw std::invalid_argument("Unsupported extension." + ext); } - TrxFile
*copy_trx = trx.deepcopy(); + auto copy_trx = this->deepcopy(); copy_trx->resize(); + if (!copy_trx->streamlines || copy_trx->streamlines->_offsets.size() == 0) { + throw std::runtime_error("Cannot save TRX without offsets data"); + } + if (copy_trx->header["NB_STREAMLINES"].is_number()) { + const auto nb_streamlines = static_cast(copy_trx->header["NB_STREAMLINES"].int_value()); + if (copy_trx->streamlines->_offsets.size() != static_cast(nb_streamlines + 1)) { + throw std::runtime_error("TRX offsets size does not match NB_STREAMLINES"); + } + } + if (copy_trx->header["NB_VERTICES"].is_number()) { + const auto nb_vertices = static_cast(copy_trx->header["NB_VERTICES"].int_value()); + const auto last = + static_cast(copy_trx->streamlines->_offsets(copy_trx->streamlines->_offsets.size() - 1)); + if (last != nb_vertices) { + throw std::runtime_error("TRX offsets sentinel does not match NB_VERTICES"); + } + } + for (Eigen::Index i = 1; i < copy_trx->streamlines->_offsets.size(); ++i) { + if (copy_trx->streamlines->_offsets(i) < copy_trx->streamlines->_offsets(i - 1)) { + throw std::runtime_error("TRX offsets must be monotonically increasing"); + } + } + if (copy_trx->streamlines->_data.size() > 0) { + const auto last = + static_cast(copy_trx->streamlines->_offsets(copy_trx->streamlines->_offsets.size() - 1)); + if (last != static_cast(copy_trx->streamlines->_data.rows())) { + throw std::runtime_error("TRX positions row count does not match offsets sentinel"); + } + } std::string tmp_dir_name = copy_trx->_uncompressed_folder_handle; if (ext.size() > 0 && (ext == "zip" || ext == "trx")) { @@ -1102,7 +1110,27 @@ template void save(TrxFile
&trx, const std::string filename, z } template -void add_dps_from_text(TrxFile
&trx, const std::string &name, const std::string &dtype, const std::string &path) { +void TrxFile
::add_dps_from_text(const std::string &name, const std::string &dtype, const std::string &path) { + std::ifstream input(path); + if (!input.is_open()) { + throw std::runtime_error("Failed to open DPS text file: " + path); + } + + std::vector values; + double value = 0.0; + while (input >> value) { + values.push_back(value); + } + if (!input.eof() && input.fail()) { + throw std::runtime_error("Failed to parse DPS text file: " + path); + } + + add_dps_from_vector(name, dtype, values); +} + +template +template +void TrxFile
::add_dps_from_vector(const std::string &name, const std::string &dtype, const std::vector &values) { if (name.empty()) { throw std::invalid_argument("DPS name cannot be empty"); } @@ -1112,37 +1140,22 @@ void add_dps_from_text(TrxFile
&trx, const std::string &name, const std::str return static_cast(std::tolower(c)); }); - if (!_is_dtype_valid(dtype_norm)) { + if (!trx::detail::_is_dtype_valid(dtype_norm)) { throw std::invalid_argument("Unsupported DPS dtype: " + dtype); } if (dtype_norm != "float16" && dtype_norm != "float32" && dtype_norm != "float64") { - throw std::invalid_argument("Unsupported DPS dtype for text input: " + dtype); + throw std::invalid_argument("Unsupported DPS dtype: " + dtype); } - if (trx._uncompressed_folder_handle.empty()) { + if (this->_uncompressed_folder_handle.empty()) { throw std::runtime_error("TRX file has no backing directory to store DPS data"); } size_t nb_streamlines = 0; - if (trx.streamlines) { - nb_streamlines = static_cast(trx.streamlines->_lengths.size()); - } else if (trx.header["NB_STREAMLINES"].is_number()) { - nb_streamlines = static_cast(trx.header["NB_STREAMLINES"].int_value()); - } - - std::ifstream input(path); - if (!input.is_open()) { - throw std::runtime_error("Failed to open DPS text file: " + path); - } - - std::vector values; - values.reserve(nb_streamlines); - double value = 0.0; - while (input >> value) { - values.push_back(value); - } - if (!input.eof() && input.fail()) { - throw std::runtime_error("Failed to parse DPS text file: " + path); + if (this->streamlines) { + nb_streamlines = static_cast(this->streamlines->_lengths.size()); + } else if (this->header["NB_STREAMLINES"].is_number()) { + nb_streamlines = static_cast(this->header["NB_STREAMLINES"].int_value()); } if (values.size() != nb_streamlines) { @@ -1150,7 +1163,7 @@ void add_dps_from_text(TrxFile
&trx, const std::string &name, const std::str std::to_string(nb_streamlines) + ")"); } - std::string dps_dirname = trx._uncompressed_folder_handle + SEPARATOR + "dps" + SEPARATOR; + std::string dps_dirname = this->_uncompressed_folder_handle + SEPARATOR + "dps" + SEPARATOR; { std::error_code ec; trx::fs::create_directories(dps_dirname, ec); @@ -1167,18 +1180,17 @@ void add_dps_from_text(TrxFile
&trx, const std::string &name, const std::str } } - auto existing = trx.data_per_streamline.find(name); - if (existing != trx.data_per_streamline.end()) { - delete existing->second; - trx.data_per_streamline.erase(existing); + auto existing = this->data_per_streamline.find(name); + if (existing != this->data_per_streamline.end()) { + this->data_per_streamline.erase(existing); } const int rows = static_cast(nb_streamlines); const int cols = 1; std::tuple shape = std::make_tuple(rows, cols); - auto *matrix = new trxmmap::MMappedMatrix
(); - matrix->mmap = trxmmap::_create_memmap(dps_filename, shape, "w+", dtype_norm); + auto matrix = std::make_unique>(); + matrix->mmap = trx::_create_memmap(dps_filename, shape, "w+", dtype_norm); if (dtype_norm == "float16") { auto *data = reinterpret_cast(matrix->mmap.data()); @@ -1203,11 +1215,358 @@ void add_dps_from_text(TrxFile
&trx, const std::string &name, const std::str } } - trx.data_per_streamline[name] = matrix; + this->data_per_streamline[name] = std::move(matrix); +} + +template +template +void TrxFile
::add_dpv_from_vector(const std::string &name, const std::string &dtype, const std::vector &values) { + if (name.empty()) { + throw std::invalid_argument("DPV name cannot be empty"); + } + + std::string dtype_norm = dtype; + std::transform(dtype_norm.begin(), dtype_norm.end(), dtype_norm.begin(), [](unsigned char c) { + return static_cast(std::tolower(c)); + }); + + if (!trx::detail::_is_dtype_valid(dtype_norm)) { + throw std::invalid_argument("Unsupported DPV dtype: " + dtype); + } + if (dtype_norm != "float16" && dtype_norm != "float32" && dtype_norm != "float64") { + throw std::invalid_argument("Unsupported DPV dtype: " + dtype); + } + + if (this->_uncompressed_folder_handle.empty()) { + throw std::runtime_error("TRX file has no backing directory to store DPV data"); + } + + size_t nb_vertices = 0; + if (this->streamlines) { + nb_vertices = static_cast(this->streamlines->_data.rows()); + } else if (this->header["NB_VERTICES"].is_number()) { + nb_vertices = static_cast(this->header["NB_VERTICES"].int_value()); + } + + if (values.size() != nb_vertices) { + throw std::runtime_error("DPV values (" + std::to_string(values.size()) + ") do not match number of vertices (" + + std::to_string(nb_vertices) + ")"); + } + + std::string dpv_dirname = this->_uncompressed_folder_handle + SEPARATOR + "dpv" + SEPARATOR; + { + std::error_code ec; + trx::fs::create_directories(dpv_dirname, ec); + if (ec) { + throw std::runtime_error("Could not create directory " + dpv_dirname); + } + } + + std::string dpv_filename = dpv_dirname + name + "." + dtype_norm; + { + std::error_code ec; + if (trx::fs::exists(dpv_filename, ec)) { + trx::fs::remove(dpv_filename, ec); + } + } + + auto existing = this->data_per_vertex.find(name); + if (existing != this->data_per_vertex.end()) { + this->data_per_vertex.erase(existing); + } + + const int rows = static_cast(nb_vertices); + const int cols = 1; + std::tuple shape = std::make_tuple(rows, cols); + + auto seq = std::make_unique>(); + seq->mmap_pos = trx::_create_memmap(dpv_filename, shape, "w+", dtype_norm); + + if (dtype_norm == "float16") { + auto *data = reinterpret_cast(seq->mmap_pos.data()); + Map> mapped(data, rows, cols); + new (&(seq->_data)) Map>(data, rows, cols); + for (int i = 0; i < rows; ++i) { + mapped(i, 0) = static_cast(values[static_cast(i)]); + } + } else if (dtype_norm == "float32") { + auto *data = reinterpret_cast(seq->mmap_pos.data()); + Map> mapped(data, rows, cols); + new (&(seq->_data)) Map>(data, rows, cols); + for (int i = 0; i < rows; ++i) { + mapped(i, 0) = static_cast(values[static_cast(i)]); + } + } else { + auto *data = reinterpret_cast(seq->mmap_pos.data()); + Map> mapped(data, rows, cols); + new (&(seq->_data)) Map>(data, rows, cols); + for (int i = 0; i < rows; ++i) { + mapped(i, 0) = static_cast(values[static_cast(i)]); + } + } + + if (this->streamlines && this->streamlines->_offsets.size() > 0) { + new (&(seq->_offsets)) Map>(this->streamlines->_offsets.data(), + int(this->streamlines->_offsets.rows()), + int(this->streamlines->_offsets.cols())); + seq->_lengths = this->streamlines->_lengths; + } + + this->data_per_vertex[name] = std::move(seq); +} + +template +void TrxFile
::add_group_from_indices(const std::string &name, const std::vector &indices) { + if (name.empty()) { + throw std::invalid_argument("Group name cannot be empty"); + } + if (this->_uncompressed_folder_handle.empty()) { + throw std::runtime_error("TRX file has no backing directory to store groups"); + } + + size_t nb_streamlines = 0; + if (this->streamlines) { + nb_streamlines = static_cast(this->streamlines->_lengths.size()); + } else if (this->header["NB_STREAMLINES"].is_number()) { + nb_streamlines = static_cast(this->header["NB_STREAMLINES"].int_value()); + } + + for (const auto idx : indices) { + if (idx >= nb_streamlines) { + throw std::runtime_error("Group index out of range: " + std::to_string(idx)); + } + } + + std::string groups_dirname = this->_uncompressed_folder_handle + SEPARATOR + "groups" + SEPARATOR; + { + std::error_code ec; + trx::fs::create_directories(groups_dirname, ec); + if (ec) { + throw std::runtime_error("Could not create directory " + groups_dirname); + } + } + + std::string group_filename = groups_dirname + name + ".uint32"; + { + std::error_code ec; + if (trx::fs::exists(group_filename, ec)) { + trx::fs::remove(group_filename, ec); + } + } + + auto existing = this->groups.find(name); + if (existing != this->groups.end()) { + this->groups.erase(existing); + } + + const int rows = static_cast(indices.size()); + const int cols = 1; + std::tuple shape = std::make_tuple(rows, cols); + + auto group = std::make_unique>(); + group->mmap = trx::_create_memmap(group_filename, shape, "w+", "uint32"); + new (&(group->_matrix)) Map>( + reinterpret_cast(group->mmap.data()), std::get<0>(shape), std::get<1>(shape)); + for (int i = 0; i < rows; ++i) { + group->_matrix(i, 0) = indices[static_cast(i)]; + } + this->groups[name] = std::move(group); +} + +template +void TrxFile
::set_voxel_to_rasmm(const Eigen::Matrix4f &affine) { + std::vector> matrix(4, std::vector(4, 0.0f)); + for (int i = 0; i < 4; ++i) { + for (int j = 0; j < 4; ++j) { + matrix[static_cast(i)][static_cast(j)] = affine(i, j); + } + } + this->header = _json_set(this->header, "VOXEL_TO_RASMM", matrix); +} + +inline TrxStream::TrxStream(std::string positions_dtype) : positions_dtype_(std::move(positions_dtype)) { + std::transform(positions_dtype_.begin(), positions_dtype_.end(), positions_dtype_.begin(), [](unsigned char c) { + return static_cast(std::tolower(c)); + }); + if (positions_dtype_ != "float32") { + throw std::invalid_argument("TrxStream only supports float32 positions for now"); + } + tmp_dir_ = make_temp_dir("trx_proto"); + positions_path_ = tmp_dir_ + SEPARATOR + "positions.tmp"; + ensure_positions_stream(); +} + +inline TrxStream::~TrxStream() { cleanup_tmp(); } + +inline void TrxStream::ensure_positions_stream() { + if (!positions_out_.is_open()) { + positions_out_.open(positions_path_, std::ios::binary | std::ios::out | std::ios::trunc); + if (!positions_out_.is_open()) { + throw std::runtime_error("Failed to open TrxStream temp positions file: " + positions_path_); + } + } +} + +inline void TrxStream::cleanup_tmp() { + if (positions_out_.is_open()) { + positions_out_.close(); + } + if (!tmp_dir_.empty()) { + rm_dir(tmp_dir_); + tmp_dir_.clear(); + } +} + +inline void TrxStream::push_streamline(const float *xyz, size_t point_count) { + if (finalized_) { + throw std::runtime_error("TrxStream already finalized"); + } + if (point_count == 0) { + lengths_.push_back(0); + return; + } + ensure_positions_stream(); + const size_t byte_count = point_count * 3 * sizeof(float); + positions_out_.write(reinterpret_cast(xyz), static_cast(byte_count)); + if (!positions_out_) { + throw std::runtime_error("Failed to write TrxStream positions"); + } + total_vertices_ += point_count; + lengths_.push_back(static_cast(point_count)); +} + +inline void TrxStream::push_streamline(const std::vector &xyz_flat) { + if (xyz_flat.size() % 3 != 0) { + throw std::invalid_argument("TrxStream streamline buffer must be a multiple of 3"); + } + push_streamline(xyz_flat.data(), xyz_flat.size() / 3); +} + +inline void TrxStream::push_streamline(const std::vector> &points) { + push_streamline(reinterpret_cast(points.data()), points.size()); +} + +template +inline void +TrxStream::push_dps_from_vector(const std::string &name, const std::string &dtype, const std::vector &values) { + if (name.empty()) { + throw std::invalid_argument("DPS name cannot be empty"); + } + std::string dtype_norm = dtype; + std::transform(dtype_norm.begin(), dtype_norm.end(), dtype_norm.begin(), [](unsigned char c) { + return static_cast(std::tolower(c)); + }); + if (!trx::detail::_is_dtype_valid(dtype_norm)) { + throw std::invalid_argument("Unsupported DPS dtype: " + dtype); + } + if (dtype_norm != "float16" && dtype_norm != "float32" && dtype_norm != "float64") { + throw std::invalid_argument("Unsupported DPS dtype: " + dtype); + } + FieldValues field; + field.dtype = dtype_norm; + field.values.reserve(values.size()); + for (const auto &v : values) { + field.values.push_back(static_cast(v)); + } + dps_[name] = std::move(field); +} + +template +inline void +TrxStream::push_dpv_from_vector(const std::string &name, const std::string &dtype, const std::vector &values) { + if (name.empty()) { + throw std::invalid_argument("DPV name cannot be empty"); + } + std::string dtype_norm = dtype; + std::transform(dtype_norm.begin(), dtype_norm.end(), dtype_norm.begin(), [](unsigned char c) { + return static_cast(std::tolower(c)); + }); + if (!trx::detail::_is_dtype_valid(dtype_norm)) { + throw std::invalid_argument("Unsupported DPV dtype: " + dtype); + } + if (dtype_norm != "float16" && dtype_norm != "float32" && dtype_norm != "float64") { + throw std::invalid_argument("Unsupported DPV dtype: " + dtype); + } + FieldValues field; + field.dtype = dtype_norm; + field.values.reserve(values.size()); + for (const auto &v : values) { + field.values.push_back(static_cast(v)); + } + dpv_[name] = std::move(field); +} + +inline void TrxStream::push_group_from_indices(const std::string &name, const std::vector &indices) { + if (name.empty()) { + throw std::invalid_argument("Group name cannot be empty"); + } + groups_[name] = indices; +} + +template void TrxStream::finalize(const std::string &filename, zip_uint32_t compression_standard) { + if (finalized_) { + throw std::runtime_error("TrxStream already finalized"); + } + finalized_ = true; + + if (positions_out_.is_open()) { + positions_out_.flush(); + positions_out_.close(); + } + + const size_t nb_streamlines = lengths_.size(); + const size_t nb_vertices = total_vertices_; + + TrxFile
trx(static_cast(nb_vertices), static_cast(nb_streamlines)); + + json header_out = header; + header_out = _json_set(header_out, "NB_VERTICES", static_cast(nb_vertices)); + header_out = _json_set(header_out, "NB_STREAMLINES", static_cast(nb_streamlines)); + trx.header = header_out; + + auto &positions = trx.streamlines->_data; + auto &offsets = trx.streamlines->_offsets; + auto &lengths = trx.streamlines->_lengths; + + offsets(0, 0) = 0; + for (size_t i = 0; i < nb_streamlines; ++i) { + lengths(static_cast(i)) = static_cast(lengths_[i]); + offsets(static_cast(i + 1), 0) = offsets(static_cast(i), 0) + lengths_[i]; + } + + std::ifstream in(positions_path_, std::ios::binary); + if (!in.is_open()) { + throw std::runtime_error("Failed to open TrxStream temp positions file for read: " + positions_path_); + } + for (size_t i = 0; i < nb_vertices; ++i) { + float xyz[3]; + in.read(reinterpret_cast(xyz), sizeof(xyz)); + if (!in) { + throw std::runtime_error("Failed to read TrxStream positions"); + } + positions(static_cast(i), 0) = static_cast
(xyz[0]); + positions(static_cast(i), 1) = static_cast
(xyz[1]); + positions(static_cast(i), 2) = static_cast
(xyz[2]); + } + + for (const auto &kv : dps_) { + trx.add_dps_from_vector(kv.first, kv.second.dtype, kv.second.values); + } + for (const auto &kv : dpv_) { + trx.add_dpv_from_vector(kv.first, kv.second.dtype, kv.second.values); + } + for (const auto &kv : groups_) { + trx.add_group_from_indices(kv.first, kv.second); + } + + trx.save(filename, compression_standard); + trx.close(); + + cleanup_tmp(); } template -void add_dpv_from_tsf(TrxFile
&trx, const std::string &name, const std::string &dtype, const std::string &path) { +void TrxFile
::add_dpv_from_tsf(const std::string &name, const std::string &dtype, const std::string &path) { if (name.empty()) { throw std::invalid_argument("DPV name cannot be empty"); } @@ -1217,23 +1576,23 @@ void add_dpv_from_tsf(TrxFile
&trx, const std::string &name, const std::stri return static_cast(std::tolower(c)); }); - if (!_is_dtype_valid(dtype_norm)) { + if (!trx::detail::_is_dtype_valid(dtype_norm)) { throw std::invalid_argument("Unsupported DPV dtype: " + dtype); } if (dtype_norm != "float16" && dtype_norm != "float32" && dtype_norm != "float64") { throw std::invalid_argument("Unsupported DPV dtype for TSF input: " + dtype); } - if (!trx.streamlines) { + if (!this->streamlines) { throw std::runtime_error("TRX file has no streamlines to attach DPV data"); } - if (trx._uncompressed_folder_handle.empty()) { + if (this->_uncompressed_folder_handle.empty()) { throw std::runtime_error("TRX file has no backing directory to store DPV data"); } - const auto &lengths = trx.streamlines->_lengths; + const auto &lengths = this->streamlines->_lengths; const size_t nb_streamlines = static_cast(lengths.size()); - const size_t nb_vertices = static_cast(trx.streamlines->_data.rows()); + const size_t nb_vertices = static_cast(this->streamlines->_data.rows()); std::ifstream input(path); if (!input.is_open()) { @@ -1422,7 +1781,7 @@ void add_dpv_from_tsf(TrxFile
&trx, const std::string &name, const std::stri std::to_string(nb_vertices) + ")"); } - std::string dpv_dirname = trx._uncompressed_folder_handle + SEPARATOR + "dpv" + SEPARATOR; + std::string dpv_dirname = this->_uncompressed_folder_handle + SEPARATOR + "dpv" + SEPARATOR; { std::error_code ec; trx::fs::create_directories(dpv_dirname, ec); @@ -1439,18 +1798,17 @@ void add_dpv_from_tsf(TrxFile
&trx, const std::string &name, const std::stri } } - auto existing = trx.data_per_vertex.find(name); - if (existing != trx.data_per_vertex.end()) { - delete existing->second; - trx.data_per_vertex.erase(existing); + auto existing = this->data_per_vertex.find(name); + if (existing != this->data_per_vertex.end()) { + this->data_per_vertex.erase(existing); } const int rows = static_cast(nb_vertices); const int cols = 1; std::tuple shape = std::make_tuple(rows, cols); - auto *seq = new trxmmap::ArraySequence
(); - seq->mmap_pos = trxmmap::_create_memmap(dpv_filename, shape, "w+", dtype_norm); + auto seq = std::make_unique>(); + seq->mmap_pos = trx::_create_memmap(dpv_filename, shape, "w+", dtype_norm); if (dtype_norm == "float16") { auto *data = reinterpret_cast(seq->mmap_pos.data()); @@ -1475,20 +1833,19 @@ void add_dpv_from_tsf(TrxFile
&trx, const std::string &name, const std::stri } } - new (&(seq->_offsets)) Map>(trx.streamlines->_offsets.data(), - static_cast(trx.streamlines->_offsets.rows()), - static_cast(trx.streamlines->_offsets.cols())); - seq->_lengths = trx.streamlines->_lengths; + new (&(seq->_offsets)) Map>(this->streamlines->_offsets.data(), + static_cast(this->streamlines->_offsets.rows()), + static_cast(this->streamlines->_offsets.cols())); + seq->_lengths = this->streamlines->_lengths; - trx.data_per_vertex[name] = seq; + this->data_per_vertex[name] = std::move(seq); } template -void export_dpv_to_tsf(const TrxFile
&trx, - const std::string &name, - const std::string &path, - const std::string ×tamp, - const std::string &dtype) { +void TrxFile
::export_dpv_to_tsf(const std::string &name, + const std::string &path, + const std::string ×tamp, + const std::string &dtype) const { if (name.empty()) { throw std::invalid_argument("DPV name cannot be empty"); } @@ -1501,31 +1858,34 @@ void export_dpv_to_tsf(const TrxFile
&trx, return static_cast(std::tolower(c)); }); - if (!_is_dtype_valid(dtype_norm)) { + if (!trx::detail::_is_dtype_valid(dtype_norm)) { throw std::invalid_argument("Unsupported TSF dtype: " + dtype); } if (dtype_norm != "float32" && dtype_norm != "float64") { throw std::invalid_argument("Unsupported TSF dtype for output: " + dtype); } - if (!trx.streamlines) { + if (!this->streamlines) { throw std::runtime_error("TRX file has no streamlines to export DPV data"); } - const auto dpv_it = trx.data_per_vertex.find(name); - if (dpv_it == trx.data_per_vertex.end()) { + const auto dpv_it = this->data_per_vertex.find(name); + if (dpv_it == this->data_per_vertex.end()) { throw std::runtime_error("DPV entry not found: " + name); } - const auto *seq = dpv_it->second; + const auto *seq = dpv_it->second.get(); + if (!seq) { + throw std::runtime_error("DPV entry is null: " + name); + } if (seq->_data.cols() != 1) { throw std::runtime_error("DPV must be 1D to export as TSF: " + name); } - const auto &lengths = trx.streamlines->_lengths; + const auto &lengths = this->streamlines->_lengths; const size_t nb_streamlines = static_cast(lengths.size()); const size_t nb_vertices = static_cast(seq->_data.rows()); - if (nb_vertices != static_cast(trx.streamlines->_data.rows())) { + if (nb_vertices != static_cast(this->streamlines->_data.rows())) { throw std::runtime_error("DPV vertex count does not match streamlines data"); } @@ -1626,3 +1986,533 @@ template std::ostream &operator<<(std::ostream &out, const TrxFile out << trx.header.dump(); return out; } + +template +std::vector> TrxFile
::build_streamline_aabbs() const { + std::vector> aabbs; + if (!this->streamlines) { + return aabbs; + } + + std::vector offsets; + if (this->streamlines->_offsets.size() > 0) { + offsets.resize(static_cast(this->streamlines->_offsets.size())); + for (Eigen::Index i = 0; i < this->streamlines->_offsets.size(); ++i) { + offsets[static_cast(i)] = this->streamlines->_offsets(i, 0); + } + } else if (this->streamlines->_lengths.size() > 0) { + const size_t nb_streamlines = static_cast(this->streamlines->_lengths.size()); + offsets.resize(nb_streamlines + 1); + offsets[0] = 0; + for (size_t i = 0; i < nb_streamlines; ++i) { + offsets[i + 1] = offsets[i] + static_cast(this->streamlines->_lengths(static_cast(i))); + } + } else { + return aabbs; + } + + const size_t nb_streamlines = offsets.size() > 0 ? offsets.size() - 1 : 0; + aabbs.resize(nb_streamlines); + + for (size_t i = 0; i < nb_streamlines; ++i) { + const uint64_t start = offsets[i]; + const uint64_t end = offsets[i + 1]; + if (end <= start) { + aabbs[i] = {Eigen::half(0), Eigen::half(0), Eigen::half(0), + Eigen::half(0), Eigen::half(0), Eigen::half(0)}; + continue; + } + + float min_x = std::numeric_limits::infinity(); + float min_y = std::numeric_limits::infinity(); + float min_z = std::numeric_limits::infinity(); + float max_x = -std::numeric_limits::infinity(); + float max_y = -std::numeric_limits::infinity(); + float max_z = -std::numeric_limits::infinity(); + + for (uint64_t p = start; p < end; ++p) { + const float x = static_cast(this->streamlines->_data(static_cast(p), 0)); + const float y = static_cast(this->streamlines->_data(static_cast(p), 1)); + const float z = static_cast(this->streamlines->_data(static_cast(p), 2)); + min_x = std::min(min_x, x); + min_y = std::min(min_y, y); + min_z = std::min(min_z, z); + max_x = std::max(max_x, x); + max_y = std::max(max_y, y); + max_z = std::max(max_z, z); + } + + aabbs[i] = {static_cast(min_x), static_cast(min_y), static_cast(min_z), + static_cast(max_x), static_cast(max_y), static_cast(max_z)}; + } + + this->aabb_cache_ = aabbs; + return aabbs; +} + +template +std::unique_ptr> TrxFile
::query_aabb( + const std::array &min_corner, + const std::array &max_corner, + const std::vector> *precomputed_aabbs, + bool build_cache_for_result) const { + if (!this->streamlines) { + auto empty = std::make_unique>(); + empty->header = _json_set(this->header, "NB_VERTICES", 0); + empty->header = _json_set(empty->header, "NB_STREAMLINES", 0); + return empty; + } + + size_t nb_streamlines = 0; + if (this->streamlines->_offsets.size() > 0) { + nb_streamlines = static_cast(this->streamlines->_offsets.size() - 1); + } else if (this->streamlines->_lengths.size() > 0) { + nb_streamlines = static_cast(this->streamlines->_lengths.size()); + } else { + auto empty = std::make_unique>(); + empty->header = _json_set(this->header, "NB_VERTICES", 0); + empty->header = _json_set(empty->header, "NB_STREAMLINES", 0); + return empty; + } + + std::vector> aabbs_local; + const std::vector> &aabbs = precomputed_aabbs + ? *precomputed_aabbs + : (!this->aabb_cache_.empty() ? this->aabb_cache_ : (aabbs_local = this->build_streamline_aabbs())); + if (aabbs.size() != nb_streamlines) { + throw std::invalid_argument("AABB size does not match streamlines count"); + } + + const float min_x = min_corner[0]; + const float min_y = min_corner[1]; + const float min_z = min_corner[2]; + const float max_x = max_corner[0]; + const float max_y = max_corner[1]; + const float max_z = max_corner[2]; + + std::vector selected; + selected.reserve(nb_streamlines); + + for (size_t i = 0; i < nb_streamlines; ++i) { + const auto &box = aabbs[i]; + const float box_min_x = static_cast(box[0]); + const float box_min_y = static_cast(box[1]); + const float box_min_z = static_cast(box[2]); + const float box_max_x = static_cast(box[3]); + const float box_max_y = static_cast(box[4]); + const float box_max_z = static_cast(box[5]); + + if (box_min_x <= max_x && box_max_x >= min_x && + box_min_y <= max_y && box_max_y >= min_y && + box_min_z <= max_z && box_max_z >= min_z) { + selected.push_back(static_cast(i)); + } + } + + return this->subset_streamlines(selected, build_cache_for_result); +} + +template +void TrxFile
::invalidate_aabb_cache() const { + this->aabb_cache_.clear(); +} + +template +template +void TrxFile
::add_dpg_from_vector(const std::string &group, + const std::string &name, + const std::string &dtype, + const std::vector &values, + int rows, + int cols) { + if (group.empty()) { + throw std::invalid_argument("DPG group cannot be empty"); + } + if (name.empty()) { + throw std::invalid_argument("DPG name cannot be empty"); + } + std::string dtype_norm = dtype; + std::transform(dtype_norm.begin(), dtype_norm.end(), dtype_norm.begin(), [](unsigned char c) { + return static_cast(std::tolower(c)); + }); + if (!trx::detail::_is_dtype_valid(dtype_norm)) { + throw std::invalid_argument("Unsupported DPG dtype: " + dtype); + } + if (dtype_norm != "float16" && dtype_norm != "float32" && dtype_norm != "float64") { + throw std::invalid_argument("Unsupported DPG dtype: " + dtype); + } + if (this->_uncompressed_folder_handle.empty()) { + throw std::runtime_error("TRX file has no backing directory to store DPG data"); + } + if (rows <= 0) { + throw std::invalid_argument("DPG rows must be positive"); + } + if (cols < 0) { + if (values.size() % static_cast(rows) != 0) { + throw std::invalid_argument("DPG values size does not match rows"); + } + cols = static_cast(values.size() / static_cast(rows)); + } + if (cols <= 0) { + throw std::invalid_argument("DPG cols must be positive"); + } + if (static_cast(rows) * static_cast(cols) != values.size()) { + throw std::invalid_argument("DPG values size does not match rows*cols"); + } + + std::string dpg_dir = this->_uncompressed_folder_handle + SEPARATOR + "dpg" + SEPARATOR; + { + std::error_code ec; + trx::fs::create_directories(dpg_dir, ec); + if (ec) { + throw std::runtime_error("Could not create directory " + dpg_dir); + } + } + std::string dpg_subdir = dpg_dir + group; + { + std::error_code ec; + trx::fs::create_directories(dpg_subdir, ec); + if (ec) { + throw std::runtime_error("Could not create directory " + dpg_subdir); + } + } + + std::string dpg_filename = dpg_subdir + SEPARATOR + name + "." + dtype_norm; + { + std::error_code ec; + if (trx::fs::exists(dpg_filename, ec)) { + trx::fs::remove(dpg_filename, ec); + } + } + + auto &group_map = this->data_per_group[group]; + group_map.erase(name); + + std::tuple shape = std::make_tuple(rows, cols); + group_map[name] = std::make_unique>(); + group_map[name]->mmap = _create_memmap(dpg_filename, shape, "w+", dtype_norm); + + if (dtype_norm == "float16") { + auto *data = reinterpret_cast(group_map[name]->mmap.data()); + Map> mapped(data, rows, cols); + new (&(group_map[name]->_matrix)) Map>(data, rows, cols); + for (int i = 0; i < rows * cols; ++i) { + data[i] = static_cast(values[static_cast(i)]); + } + } else if (dtype_norm == "float32") { + auto *data = reinterpret_cast(group_map[name]->mmap.data()); + Map> mapped(data, rows, cols); + new (&(group_map[name]->_matrix)) Map>(data, rows, cols); + for (int i = 0; i < rows * cols; ++i) { + data[i] = static_cast(values[static_cast(i)]); + } + } else { + auto *data = reinterpret_cast(group_map[name]->mmap.data()); + Map> mapped(data, rows, cols); + new (&(group_map[name]->_matrix)) Map>(data, rows, cols); + for (int i = 0; i < rows * cols; ++i) { + data[i] = static_cast(values[static_cast(i)]); + } + } +} + +template +template +void TrxFile
::add_dpg_from_matrix(const std::string &group, + const std::string &name, + const std::string &dtype, + const Eigen::MatrixBase &matrix) { + if (matrix.size() == 0) { + throw std::invalid_argument("DPG matrix cannot be empty"); + } + std::vector values; + values.reserve(static_cast(matrix.size())); + for (Eigen::Index i = 0; i < matrix.rows(); ++i) { + for (Eigen::Index j = 0; j < matrix.cols(); ++j) { + values.push_back(matrix(i, j)); + } + } + add_dpg_from_vector(group, name, dtype, values, static_cast(matrix.rows()), + static_cast(matrix.cols())); +} + +template +const MMappedMatrix
*TrxFile
::get_dpg(const std::string &group, const std::string &name) const { + auto group_it = this->data_per_group.find(group); + if (group_it == this->data_per_group.end()) { + return nullptr; + } + auto field_it = group_it->second.find(name); + if (field_it == group_it->second.end()) { + return nullptr; + } + return field_it->second.get(); +} + +template +std::vector TrxFile
::list_dpg_groups() const { + std::vector groups; + groups.reserve(this->data_per_group.size()); + for (const auto &kv : this->data_per_group) { + groups.push_back(kv.first); + } + return groups; +} + +template +std::vector TrxFile
::list_dpg_fields(const std::string &group) const { + std::vector fields; + auto it = this->data_per_group.find(group); + if (it == this->data_per_group.end()) { + return fields; + } + fields.reserve(it->second.size()); + for (const auto &kv : it->second) { + fields.push_back(kv.first); + } + return fields; +} + +template +void TrxFile
::remove_dpg(const std::string &group, const std::string &name) { + auto group_it = this->data_per_group.find(group); + if (group_it == this->data_per_group.end()) { + return; + } + group_it->second.erase(name); + if (group_it->second.empty()) { + this->data_per_group.erase(group_it); + } +} + +template +void TrxFile
::remove_dpg_group(const std::string &group) { + this->data_per_group.erase(group); +} + +template +std::unique_ptr> TrxFile
::subset_streamlines(const std::vector &streamline_ids, + bool build_cache_for_result) const { + if (!this->streamlines) { + auto empty = std::make_unique>(); + empty->header = _json_set(this->header, "NB_VERTICES", 0); + empty->header = _json_set(empty->header, "NB_STREAMLINES", 0); + return empty; + } + + std::vector offsets; + if (this->streamlines->_offsets.size() > 0) { + offsets.resize(static_cast(this->streamlines->_offsets.size())); + for (Eigen::Index i = 0; i < this->streamlines->_offsets.size(); ++i) { + offsets[static_cast(i)] = this->streamlines->_offsets(i, 0); + } + } else if (this->streamlines->_lengths.size() > 0) { + const size_t nb_streamlines = static_cast(this->streamlines->_lengths.size()); + offsets.resize(nb_streamlines + 1); + offsets[0] = 0; + for (size_t i = 0; i < nb_streamlines; ++i) { + offsets[i + 1] = offsets[i] + static_cast(this->streamlines->_lengths(static_cast(i))); + } + } else { + auto empty = std::make_unique>(); + empty->header = _json_set(this->header, "NB_VERTICES", 0); + empty->header = _json_set(empty->header, "NB_STREAMLINES", 0); + return empty; + } + + const size_t nb_streamlines = offsets.size() > 0 ? offsets.size() - 1 : 0; + if (nb_streamlines == 0) { + auto empty = std::make_unique>(); + empty->header = _json_set(this->header, "NB_VERTICES", 0); + empty->header = _json_set(empty->header, "NB_STREAMLINES", 0); + return empty; + } + + std::vector selected; + selected.reserve(streamline_ids.size()); + std::vector seen(nb_streamlines, 0); + for (uint32_t id : streamline_ids) { + if (id >= nb_streamlines) { + throw std::invalid_argument("Streamline id out of range"); + } + if (!seen[id]) { + selected.push_back(id); + seen[id] = 1; + } + } + + if (selected.empty()) { + auto empty = std::make_unique>(); + empty->header = _json_set(this->header, "NB_VERTICES", 0); + empty->header = _json_set(empty->header, "NB_STREAMLINES", 0); + return empty; + } + + std::vector old_to_new(nb_streamlines, -1); + size_t total_vertices = 0; + for (size_t i = 0; i < selected.size(); ++i) { + const uint32_t idx = selected[i]; + old_to_new[idx] = static_cast(i); + const uint64_t start = offsets[idx]; + const uint64_t end = offsets[idx + 1]; + total_vertices += static_cast(end - start); + } + + auto out = std::make_unique>(static_cast(total_vertices), + static_cast(selected.size()), + this); + out->header = _json_set(this->header, "NB_VERTICES", static_cast(total_vertices)); + out->header = _json_set(out->header, "NB_STREAMLINES", static_cast(selected.size())); + + auto &out_positions = out->streamlines->_data; + auto &out_offsets = out->streamlines->_offsets; + auto &out_lengths = out->streamlines->_lengths; + + size_t cursor = 0; + out_offsets(0, 0) = 0; + for (size_t new_idx = 0; new_idx < selected.size(); ++new_idx) { + const uint32_t old_idx = selected[new_idx]; + const uint64_t start = offsets[old_idx]; + const uint64_t end = offsets[old_idx + 1]; + const uint64_t len = end - start; + + out_lengths(static_cast(new_idx)) = static_cast(len); + out_offsets(static_cast(new_idx + 1), 0) = + out_offsets(static_cast(new_idx), 0) + len; + + if (len > 0) { + out_positions.block(static_cast(cursor), 0, + static_cast(len), 3) = + this->streamlines->_data.block(static_cast(start), 0, + static_cast(len), 3); + + for (const auto &kv : this->data_per_vertex) { + const std::string &name = kv.first; + auto out_it = out->data_per_vertex.find(name); + if (out_it == out->data_per_vertex.end()) { + continue; + } + auto &out_dpv = out_it->second->_data; + auto &src_dpv = kv.second->_data; + const Eigen::Index cols = src_dpv.cols(); + out_dpv.block(static_cast(cursor), 0, + static_cast(len), cols) = + src_dpv.block(static_cast(start), 0, + static_cast(len), cols); + } + } + + for (const auto &kv : this->data_per_streamline) { + const std::string &name = kv.first; + auto out_it = out->data_per_streamline.find(name); + if (out_it == out->data_per_streamline.end()) { + continue; + } + out_it->second->_matrix.row(static_cast(new_idx)) = + kv.second->_matrix.row(static_cast(old_idx)); + } + + cursor += static_cast(len); + } + + for (const auto &kv : this->groups) { + const std::string &group_name = kv.first; + std::vector indices; + auto &matrix = kv.second->_matrix; + indices.reserve(static_cast(matrix.size())); + for (Eigen::Index r = 0; r < matrix.rows(); ++r) { + for (Eigen::Index c = 0; c < matrix.cols(); ++c) { + const uint32_t old_idx = matrix(r, c); + if (old_idx >= old_to_new.size()) { + continue; + } + const int new_idx = old_to_new[old_idx]; + if (new_idx >= 0) { + indices.push_back(static_cast(new_idx)); + } + } + } + if (!indices.empty()) { + out->add_group_from_indices(group_name, indices); + } + } + + if (!this->data_per_group.empty() && !out->groups.empty()) { + std::string dpg_dir = out->_uncompressed_folder_handle + SEPARATOR + "dpg" + SEPARATOR; + { + std::error_code ec; + trx::fs::create_directories(dpg_dir, ec); + if (ec) { + throw std::runtime_error("Could not create directory " + dpg_dir); + } + } + + for (const auto &group_kv : out->groups) { + const std::string &group_name = group_kv.first; + auto src_group_it = this->data_per_group.find(group_name); + if (src_group_it == this->data_per_group.end()) { + continue; + } + + std::string dpg_subdir = dpg_dir + group_name; + { + std::error_code ec; + trx::fs::create_directories(dpg_subdir, ec); + if (ec) { + throw std::runtime_error("Could not create directory " + dpg_subdir); + } + } + + if (out->data_per_group.find(group_name) == out->data_per_group.end()) { + out->data_per_group.emplace(group_name, std::map>>{}); + } else { + out->data_per_group[group_name].clear(); + } + + for (const auto &field_kv : src_group_it->second) { + const std::string &field_name = field_kv.first; + std::string dpg_dtype = dtype_from_scalar
(); + std::string dpg_filename = dpg_subdir + SEPARATOR + field_name; + dpg_filename = _generate_filename_from_data(field_kv.second->_matrix, dpg_filename); + + std::tuple dpg_shape = std::make_tuple(field_kv.second->_matrix.rows(), + field_kv.second->_matrix.cols()); + + out->data_per_group[group_name][field_name] = std::make_unique>(); + out->data_per_group[group_name][field_name]->mmap = + _create_memmap(dpg_filename, dpg_shape, "w+", dpg_dtype); + + if (dpg_dtype.compare("float16") == 0) { + new (&(out->data_per_group[group_name][field_name]->_matrix)) Map>( + reinterpret_cast(out->data_per_group[group_name][field_name]->mmap.data()), + std::get<0>(dpg_shape), std::get<1>(dpg_shape)); + } else if (dpg_dtype.compare("float32") == 0) { + new (&(out->data_per_group[group_name][field_name]->_matrix)) Map>( + reinterpret_cast(out->data_per_group[group_name][field_name]->mmap.data()), + std::get<0>(dpg_shape), std::get<1>(dpg_shape)); + } else { + new (&(out->data_per_group[group_name][field_name]->_matrix)) Map>( + reinterpret_cast(out->data_per_group[group_name][field_name]->mmap.data()), + std::get<0>(dpg_shape), std::get<1>(dpg_shape)); + } + + for (int i = 0; i < out->data_per_group[group_name][field_name]->_matrix.rows(); ++i) { + for (int j = 0; j < out->data_per_group[group_name][field_name]->_matrix.cols(); ++j) { + out->data_per_group[group_name][field_name]->_matrix(i, j) = + field_kv.second->_matrix(i, j); + } + } + } + } + } + + if (build_cache_for_result) { + out->build_streamline_aabbs(); + } + return out; +} + +#ifdef TRX_TPP_OPEN_NAMESPACE +} // namespace trx +#undef TRX_TPP_OPEN_NAMESPACE +#endif diff --git a/src/detail/dtype_helpers.cpp b/src/detail/dtype_helpers.cpp new file mode 100644 index 0000000..197e5b7 --- /dev/null +++ b/src/detail/dtype_helpers.cpp @@ -0,0 +1,104 @@ +#include + +namespace trx { +namespace detail { + +int _sizeof_dtype(const std::string &dtype) { + if (dtype == "bit") + return 1; + if (dtype == "uint8") + return sizeof(uint8_t); + // Treat "ushort" as an alias of uint16 for cross-platform consistency. + if (dtype == "uint16" || dtype == "ushort") + return sizeof(uint16_t); + if (dtype == "uint32") + return sizeof(uint32_t); + if (dtype == "uint64") + return sizeof(uint64_t); + if (dtype == "int8") + return sizeof(int8_t); + if (dtype == "int16") + return sizeof(int16_t); + if (dtype == "int32") + return sizeof(int32_t); + if (dtype == "int64") + return sizeof(int64_t); + if (dtype == "float32") + return sizeof(float); + if (dtype == "float64") + return sizeof(double); + return sizeof(std::uint16_t); // default to 16-bit float size +} + +std::string _get_dtype(const std::string &dtype) { + const char dt = dtype.back(); + switch (dt) { + case 'h': + return "uint8"; + case 't': + return "uint16"; + case 'j': + return "uint32"; + case 'm': + case 'y': // unsigned long long (Itanium ABI) + return "uint64"; + case 'a': + return "int8"; + case 's': + return "int16"; + case 'i': + return "int32"; + case 'l': + case 'x': // long long (Itanium ABI) + return "int64"; + case 'f': + return "float32"; + case 'd': + return "float64"; + default: + return "float16"; // setting this as default for now but a better solution is needed + } +} + +bool _is_dtype_valid(const std::string &ext) { + if (std::find(::trx::dtypes.begin(), ::trx::dtypes.end(), ext) != ::trx::dtypes.end()) + return true; + return false; +} + +std::tuple _split_ext_with_dimensionality(const std::string &filename) { + std::string base = ::trx::path_basename(filename); + + const size_t num_splits = std::count(base.begin(), base.end(), '.'); + int dim = 0; + + if (num_splits != 1 && num_splits != 2) { + throw std::invalid_argument("Invalid filename"); + } + + const std::string ext = ::trx::get_ext(base); + + base = base.substr(0, base.length() - ext.length() - 1); + + if (num_splits == 1) { + dim = 1; + } else { + const size_t pos = base.find_last_of('.'); + dim = std::stoi(base.substr(pos + 1, base.size())); + base = base.substr(0, pos); + } + + const bool is_valid = _is_dtype_valid(ext); + + if (!is_valid) { + // TODO: make formatted string and include provided extension name + throw std::invalid_argument("Unsupported file extension"); + } + + std::tuple output{base, dim, ext}; + + return output; +} + +} // namespace detail +} // namespace trx diff --git a/src/nifti_io.cpp b/src/nifti_io.cpp new file mode 100644 index 0000000..68cd5e3 --- /dev/null +++ b/src/nifti_io.cpp @@ -0,0 +1,322 @@ +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace trx { +namespace detail { + +#pragma pack(push, 1) +struct NiftiHeader { + std::int32_t sizeof_hdr; + char data_type[10]; + char db_name[18]; + std::int32_t extents; + std::int16_t session_error; + char regular; + char dim_info; + std::int16_t dim[8]; + float intent_p1; + float intent_p2; + float intent_p3; + std::int16_t intent_code; + std::int16_t datatype; + std::int16_t bitpix; + std::int16_t slice_start; + float pixdim[8]; + float vox_offset; + float scl_slope; + float scl_inter; + std::int16_t slice_end; + char slice_code; + char xyzt_units; + float cal_max; + float cal_min; + float slice_duration; + float toffset; + std::int32_t glmax; + std::int32_t glmin; + char descrip[80]; + char aux_file[24]; + std::int16_t qform_code; + std::int16_t sform_code; + float quatern_b; + float quatern_c; + float quatern_d; + float qoffset_x; + float qoffset_y; + float qoffset_z; + float srow_x[4]; + float srow_y[4]; + float srow_z[4]; + char intent_name[16]; + char magic[4]; +}; +#pragma pack(pop) + +static_assert(sizeof(NiftiHeader) == 348, "NIfTI-1 header must be 348 bytes"); + +bool has_gz_extension(const std::string &path) { + const std::string suffix = ".gz"; + if (path.size() < suffix.size()) { + return false; + } + return path.compare(path.size() - suffix.size(), suffix.size(), suffix) == 0; +} + +void swap_2(std::int16_t &value) { + const uint16_t u = static_cast(value); + const uint16_t swapped = static_cast((u >> 8) | (u << 8)); + value = static_cast(swapped); +} + +void swap_4(std::int32_t &value) { + const uint32_t u = static_cast(value); + const uint32_t swapped = + ((u >> 24) & 0x000000FFu) | + ((u >> 8) & 0x0000FF00u) | + ((u << 8) & 0x00FF0000u) | + ((u << 24) & 0xFF000000u); + value = static_cast(swapped); +} + +void swap_4(float &value) { + uint32_t u = 0; + std::memcpy(&u, &value, sizeof(u)); + const uint32_t swapped = + ((u >> 24) & 0x000000FFu) | + ((u >> 8) & 0x0000FF00u) | + ((u << 8) & 0x00FF0000u) | + ((u << 24) & 0xFF000000u); + std::memcpy(&value, &swapped, sizeof(value)); +} + +void swap_nifti_header(NiftiHeader &hdr) { + swap_4(hdr.sizeof_hdr); + swap_4(hdr.extents); + swap_2(hdr.session_error); + for (std::int16_t &dim : hdr.dim) { + swap_2(dim); + } + swap_4(hdr.intent_p1); + swap_4(hdr.intent_p2); + swap_4(hdr.intent_p3); + swap_2(hdr.intent_code); + swap_2(hdr.datatype); + swap_2(hdr.bitpix); + swap_2(hdr.slice_start); + for (float &pixdim : hdr.pixdim) { + swap_4(pixdim); + } + swap_4(hdr.vox_offset); + swap_4(hdr.scl_slope); + swap_4(hdr.scl_inter); + swap_2(hdr.slice_end); + swap_4(hdr.cal_max); + swap_4(hdr.cal_min); + swap_4(hdr.slice_duration); + swap_4(hdr.toffset); + swap_4(hdr.glmax); + swap_4(hdr.glmin); + swap_2(hdr.qform_code); + swap_2(hdr.sform_code); + swap_4(hdr.quatern_b); + swap_4(hdr.quatern_c); + swap_4(hdr.quatern_d); + swap_4(hdr.qoffset_x); + swap_4(hdr.qoffset_y); + swap_4(hdr.qoffset_z); + for (float &value : hdr.srow_x) { + swap_4(value); + } + for (float &value : hdr.srow_y) { + swap_4(value); + } + for (float &value : hdr.srow_z) { + swap_4(value); + } +} + +void read_bytes_gz(const std::string &path, void *buffer, size_t length) { + gzFile file = gzopen(path.c_str(), "rb"); + if (!file) { + throw std::runtime_error("Failed to open gzip NIfTI header: " + path); + } + size_t total = 0; + while (total < length) { + const int read_now = gzread(file, static_cast(buffer) + total, + static_cast(length - total)); + if (read_now <= 0) { + gzclose(file); + throw std::runtime_error("Failed to read gzip NIfTI header: " + path); + } + total += static_cast(read_now); + } + gzclose(file); +} + +void read_bytes_raw(const std::string &path, void *buffer, size_t length) { + std::ifstream in(path, std::ios::binary); + if (!in) { + throw std::runtime_error("Failed to open NIfTI header: " + path); + } + in.read(static_cast(buffer), static_cast(length)); + if (in.gcount() != static_cast(length)) { + throw std::runtime_error("Failed to read NIfTI header: " + path); + } +} + +NiftiHeader read_header(const std::string &path) { + NiftiHeader hdr{}; + if (has_gz_extension(path)) { + read_bytes_gz(path, &hdr, sizeof(NiftiHeader)); + } else { + read_bytes_raw(path, &hdr, sizeof(NiftiHeader)); + } + + if (hdr.sizeof_hdr != 348) { + swap_nifti_header(hdr); + if (hdr.sizeof_hdr != 348) { + throw std::runtime_error("Invalid NIfTI header size for: " + path); + } + } + + return hdr; +} + +Eigen::Matrix4f quatern_to_mat44(float qb, + float qc, + float qd, + float qx, + float qy, + float qz, + float dx, + float dy, + float dz, + float qfac) { + Eigen::Matrix4f R = Eigen::Matrix4f::Identity(); + + double a = 1.0 - (qb * qb + qc * qc + qd * qd); + double b = qb; + double c = qc; + double d = qd; + + if (a < 1.e-7) { + a = 1.0 / std::sqrt(b * b + c * c + d * d); + b *= a; + c *= a; + d *= a; + a = 0.0; + } else { + a = std::sqrt(a); + } + + const double xd = (dx > 0.0f) ? dx : 1.0; + const double yd = (dy > 0.0f) ? dy : 1.0; + double zd = (dz > 0.0f) ? dz : 1.0; + if (qfac < 0.0f) { + zd = -zd; + } + + R(0, 0) = static_cast((a * a + b * b - c * c - d * d) * xd); + R(0, 1) = static_cast(2.0 * (b * c - a * d) * yd); + R(0, 2) = static_cast(2.0 * (b * d + a * c) * zd); + R(1, 0) = static_cast(2.0 * (b * c + a * d) * xd); + R(1, 1) = static_cast((a * a + c * c - b * b - d * d) * yd); + R(1, 2) = static_cast(2.0 * (c * d - a * b) * zd); + R(2, 0) = static_cast(2.0 * (b * d - a * c) * xd); + R(2, 1) = static_cast(2.0 * (c * d + a * b) * yd); + R(2, 2) = static_cast((a * a + d * d - c * c - b * b) * zd); + + R(0, 3) = qx; + R(1, 3) = qy; + R(2, 3) = qz; + return R; +} + +Eigen::Matrix4f sform_to_qform_matrix(const NiftiHeader &hdr) { + Eigen::Matrix3f M; + M << hdr.srow_x[0], hdr.srow_x[1], hdr.srow_x[2], + hdr.srow_y[0], hdr.srow_y[1], hdr.srow_y[2], + hdr.srow_z[0], hdr.srow_z[1], hdr.srow_z[2]; + + float dx = M.col(0).norm(); + float dy = M.col(1).norm(); + float dz = M.col(2).norm(); + + if (dx == 0.0f) { + dx = 1.0f; + M.col(0) = Eigen::Vector3f::UnitX(); + } + if (dy == 0.0f) { + dy = 1.0f; + M.col(1) = Eigen::Vector3f::UnitY(); + } + if (dz == 0.0f) { + dz = 1.0f; + M.col(2) = Eigen::Vector3f::UnitZ(); + } + + Eigen::Matrix3f R = M; + R.col(0) /= dx; + R.col(1) /= dy; + R.col(2) /= dz; + + Eigen::JacobiSVD svd(R, Eigen::ComputeFullU | Eigen::ComputeFullV); + Eigen::Matrix3f R_orth = svd.matrixU() * svd.matrixV().transpose(); + + if (R_orth.determinant() < 0.0f) { + R_orth.col(2) *= -1.0f; + dz = -dz; + } + + Eigen::Matrix4f out = Eigen::Matrix4f::Identity(); + out.block<3, 3>(0, 0) = R_orth * Eigen::DiagonalMatrix(dx, dy, dz); + out(0, 3) = hdr.srow_x[3]; + out(1, 3) = hdr.srow_y[3]; + out(2, 3) = hdr.srow_z[3]; + return out; +} + +} // namespace detail + +Eigen::Matrix4f read_nifti_voxel_to_rasmm(const std::string &path) { + const detail::NiftiHeader hdr = detail::read_header(path); + + if (hdr.qform_code > 0) { + float qfac = hdr.pixdim[0]; + if (qfac == 0.0f) { + qfac = 1.0f; + } else if (qfac > 0.0f) { + qfac = 1.0f; + } else { + qfac = -1.0f; + } + return detail::quatern_to_mat44(hdr.quatern_b, + hdr.quatern_c, + hdr.quatern_d, + hdr.qoffset_x, + hdr.qoffset_y, + hdr.qoffset_z, + hdr.pixdim[1], + hdr.pixdim[2], + hdr.pixdim[3], + qfac); + } + + if (hdr.sform_code > 0) { + return detail::sform_to_qform_matrix(hdr); + } + + throw std::runtime_error("NIfTI header has no qform or sform: " + path); +} + +} // namespace trx diff --git a/src/trx.cpp b/src/trx.cpp index ff937bf..5762b34 100644 --- a/src/trx.cpp +++ b/src/trx.cpp @@ -11,6 +11,7 @@ #include #include #include +#include #include #include #include @@ -38,7 +39,7 @@ using std::tuple; using std::uniform_int_distribution; using std::vector; -namespace trxmmap { +namespace trx { namespace { inline int sys_error() { return errno; } @@ -55,7 +56,7 @@ std::string normalize_slashes(std::string path) { bool parse_positions_dtype(const std::string &filename, std::string &out_dtype) { const std::string normalized = normalize_slashes(filename); try { - const auto tuple = trxmmap::_split_ext_with_dimensionality(normalized); + const auto tuple = trx::detail::_split_ext_with_dimensionality(normalized); const std::string &base = std::get<0>(tuple); if (base == "positions") { out_dtype = std::get<2>(tuple); @@ -86,6 +87,17 @@ bool is_path_within(const trx::fs::path &child, const trx::fs::path &parent) { const char next = child_str[parent_str.size()]; return next == '/' || next == '\\'; } + +TypedArray make_typed_array(const std::string &filename, int rows, int cols, const std::string &dtype) { + TypedArray array; + array.dtype = dtype; + array.rows = rows; + array.cols = cols; + if (rows > 0 && cols > 0) { + array.mmap = _create_memmap(filename, std::make_tuple(rows, cols), "r+", dtype, 0); + } + return array; +} } // namespace std::string detect_positions_dtype(const std::string &path) { @@ -97,7 +109,7 @@ std::string detect_positions_dtype(const std::string &path) { std::error_code ec; if (trx::fs::is_directory(input, ec) && !ec) { std::map> files; - trxmmap::populate_fps(path, files); + trx::populate_fps(path, files); for (const auto &kv : files) { std::string dtype; if (parse_positions_dtype(kv.first, dtype)) { @@ -150,6 +162,362 @@ bool is_trx_directory(const std::string &path) { return trx::fs::is_directory(input, ec) && !ec; } +AnyTrxFile::~AnyTrxFile() { _cleanup_temporary_directory(); } + +std::string AnyTrxFile::_normalize_dtype(const std::string &dtype) { + if (dtype == "bool") { + return "bit"; + } + if (dtype == "ushort") { + return "uint16"; + } + return dtype; +} + +size_t AnyTrxFile::num_vertices() const { + if (!positions.empty()) { + return static_cast(positions.rows); + } + if (header["NB_VERTICES"].is_number()) { + return static_cast(header["NB_VERTICES"].int_value()); + } + return 0; +} + +size_t AnyTrxFile::num_streamlines() const { + if (!lengths.empty()) { + return lengths.size(); + } + if (header["NB_STREAMLINES"].is_number()) { + return static_cast(header["NB_STREAMLINES"].int_value()); + } + return 0; +} + +void AnyTrxFile::close() { + _cleanup_temporary_directory(); + positions = TypedArray(); + offsets = TypedArray(); + offsets_u64.clear(); + lengths.clear(); + groups.clear(); + data_per_streamline.clear(); + data_per_vertex.clear(); + data_per_group.clear(); + _uncompressed_folder_handle.clear(); + _owns_uncompressed_folder = false; + + std::vector> affine(4, std::vector(4, 0.0f)); + for (int i = 0; i < 4; i++) { + affine[i][i] = 1.0f; + } + std::vector dimensions{1, 1, 1}; + json::object header_obj; + header_obj["VOXEL_TO_RASMM"] = affine; + header_obj["DIMENSIONS"] = dimensions; + header_obj["NB_VERTICES"] = 0; + header_obj["NB_STREAMLINES"] = 0; + header = json(header_obj); +} + +void AnyTrxFile::_cleanup_temporary_directory() { + if (_owns_uncompressed_folder && !_uncompressed_folder_handle.empty()) { + if (rm_dir(_uncompressed_folder_handle) != 0) { + } + _uncompressed_folder_handle.clear(); + _owns_uncompressed_folder = false; + } +} + +AnyTrxFile AnyTrxFile::load(const std::string &path) { + trx::fs::path input(path); + if (!trx::fs::exists(input)) { + throw std::runtime_error("Input path does not exist: " + path); + } + std::error_code ec; + if (trx::fs::is_directory(input, ec) && !ec) { + return AnyTrxFile::load_from_directory(path); + } + return AnyTrxFile::load_from_zip(path); +} + +AnyTrxFile AnyTrxFile::load_from_zip(const std::string &filename) { + int errorp = 0; + zip_t *zf = open_zip_for_read(filename, errorp); + if (zf == nullptr) { + throw std::runtime_error("Could not open zip file: " + filename); + } + + std::string temp_dir = extract_zip_to_directory(zf); + zip_close(zf); + + auto trx = AnyTrxFile::load_from_directory(temp_dir); + trx._uncompressed_folder_handle = temp_dir; + trx._owns_uncompressed_folder = true; + return trx; +} + +AnyTrxFile AnyTrxFile::load_from_directory(const std::string &path) { + std::string directory = path; + { + std::error_code ec; + trx::fs::path resolved = trx::fs::weakly_canonical(trx::fs::path(path), ec); + if (!ec) { + directory = resolved.string(); + } + } + + std::string header_name = directory + SEPARATOR + "header.json"; + std::ifstream header_file(header_name); + if (!header_file.is_open()) { + throw std::runtime_error("Failed to open header.json at: " + header_name); + } + std::string jstream((std::istreambuf_iterator(header_file)), std::istreambuf_iterator()); + header_file.close(); + std::string err; + json header = json::parse(jstream, err); + if (!err.empty()) { + throw std::runtime_error("Failed to parse header.json: " + err); + } + + std::map> files_pointer_size; + populate_fps(directory, files_pointer_size); + + auto trx = AnyTrxFile::_create_from_pointer(header, files_pointer_size, directory); + trx._backing_directory = directory; + return trx; +} + +AnyTrxFile +AnyTrxFile::_create_from_pointer(json header, + const std::map> &dict_pointer_size, + const std::string &root) { + AnyTrxFile trx; + trx.header = header; + + if (!header["NB_VERTICES"].is_number() || !header["NB_STREAMLINES"].is_number()) { + throw std::invalid_argument("Missing NB_VERTICES or NB_STREAMLINES in header.json"); + } + + const int nb_vertices = header["NB_VERTICES"].int_value(); + const int nb_streamlines = header["NB_STREAMLINES"].int_value(); + + for (auto x = dict_pointer_size.rbegin(); x != dict_pointer_size.rend(); ++x) { + const std::string elem_filename = x->first; + + trx::fs::path elem_path(elem_filename); + trx::fs::path folder_path = elem_path.parent_path(); + std::string folder; + if (!root.empty()) { + trx::fs::path rel_path = elem_path.lexically_relative(trx::fs::path(root)); + std::string rel_str = rel_path.string(); + if (!rel_str.empty() && rel_str.rfind("..", 0) != 0) { + folder = rel_path.parent_path().string(); + } else { + folder = folder_path.string(); + } + } else { + folder = folder_path.string(); + } + if (folder == ".") { + folder.clear(); + } + + std::tuple base_tuple = trx::detail::_split_ext_with_dimensionality(elem_filename); + std::string base(std::get<0>(base_tuple)); + int dim = std::get<1>(base_tuple); + std::string ext(std::get<2>(base_tuple)); + + ext = _normalize_dtype(ext); + + long long size = std::get<1>(x->second); + + if (base == "positions" && (folder.empty() || folder == ".")) { + if (size != static_cast(nb_vertices) * 3 || dim != 3) { + throw std::invalid_argument("Wrong positions size/dimensionality"); + } + if (ext != "float16" && ext != "float32" && ext != "float64") { + throw std::invalid_argument("Unsupported positions dtype: " + ext); + } + trx.positions = make_typed_array(elem_filename, nb_vertices, 3, ext); + } else if (base == "offsets" && (folder.empty() || folder == ".")) { + if (size != static_cast(nb_streamlines) + 1 || dim != 1) { + throw std::invalid_argument("Wrong offsets size/dimensionality"); + } + if (ext != "uint32" && ext != "uint64") { + throw std::invalid_argument("Unsupported offsets dtype: " + ext); + } + trx.offsets = make_typed_array(elem_filename, nb_streamlines + 1, 1, ext); + } else if (folder == "dps") { + const int nb_scalar = nb_streamlines > 0 ? static_cast(size / nb_streamlines) : 0; + if (nb_streamlines == 0 || size % nb_streamlines != 0 || nb_scalar != dim) { + throw std::invalid_argument("Wrong dps size/dimensionality"); + } + trx.data_per_streamline.emplace(base, make_typed_array(elem_filename, nb_streamlines, nb_scalar, ext)); + } else if (folder == "dpv") { + const int nb_scalar = nb_vertices > 0 ? static_cast(size / nb_vertices) : 0; + if (nb_vertices == 0 || size % nb_vertices != 0 || nb_scalar != dim) { + throw std::invalid_argument("Wrong dpv size/dimensionality"); + } + trx.data_per_vertex.emplace(base, make_typed_array(elem_filename, nb_vertices, nb_scalar, ext)); + } else if (folder.rfind("dpg", 0) == 0) { + if (size != dim) { + throw std::invalid_argument("Wrong dpg size/dimensionality"); + } + std::string data_name = path_basename(base); + std::string sub_folder = path_basename(folder); + trx.data_per_group[sub_folder].emplace(data_name, + make_typed_array(elem_filename, 1, static_cast(size), ext)); + } else if (folder == "groups") { + if (dim != 1) { + throw std::invalid_argument("Wrong group dimensionality"); + } + if (ext != "uint32") { + throw std::invalid_argument("Unsupported group dtype: " + ext); + } + trx.groups.emplace(base, make_typed_array(elem_filename, static_cast(size), 1, ext)); + } else { + throw std::invalid_argument("Entry is not part of a valid TRX structure: " + elem_filename); + } + } + + if (trx.positions.empty() || trx.offsets.empty()) { + throw std::invalid_argument("Missing essential data."); + } + + const size_t offsets_count = trx.offsets.size(); + if (offsets_count > 0) { + trx.offsets_u64.resize(offsets_count); + const auto bytes = trx.offsets.to_bytes(); + if (trx.offsets.dtype == "uint64") { + const auto *src = reinterpret_cast(bytes.data); + for (size_t i = 0; i < offsets_count; ++i) { + trx.offsets_u64[i] = src[i]; + } + } else if (trx.offsets.dtype == "uint32") { + const auto *src = reinterpret_cast(bytes.data); + for (size_t i = 0; i < offsets_count; ++i) { + trx.offsets_u64[i] = static_cast(src[i]); + } + } else { + throw std::invalid_argument("Unsupported offsets datatype: " + trx.offsets.dtype); + } + } + + if (offsets_count > 1) { + trx.lengths.resize(offsets_count - 1); + for (size_t i = 0; i + 1 < offsets_count; ++i) { + const uint64_t diff = trx.offsets_u64[i + 1] - trx.offsets_u64[i]; + if (diff > std::numeric_limits::max()) { + throw std::runtime_error("Offset difference exceeds uint32 range"); + } + trx.lengths[i] = static_cast(diff); + } + } + + return trx; +} + +void AnyTrxFile::save(const std::string &filename, zip_uint32_t compression_standard) { + const std::string ext = get_ext(filename); + if (ext.size() > 0 && (ext != "zip" && ext != "trx")) { + throw std::invalid_argument("Unsupported extension." + ext); + } + + if (offsets.empty()) { + throw std::runtime_error("Cannot save TRX without offsets data"); + } + if (offsets_u64.empty()) { + throw std::runtime_error("Cannot save TRX without decoded offsets"); + } + if (header["NB_STREAMLINES"].is_number()) { + const auto nb_streamlines = static_cast(header["NB_STREAMLINES"].int_value()); + if (offsets_u64.size() != nb_streamlines + 1) { + throw std::runtime_error("TRX offsets size does not match NB_STREAMLINES"); + } + } + if (header["NB_VERTICES"].is_number()) { + const auto nb_vertices = static_cast(header["NB_VERTICES"].int_value()); + const auto last = offsets_u64.back(); + if (last != nb_vertices) { + throw std::runtime_error("TRX offsets sentinel does not match NB_VERTICES"); + } + } + for (size_t i = 1; i < offsets_u64.size(); ++i) { + if (offsets_u64[i] < offsets_u64[i - 1]) { + throw std::runtime_error("TRX offsets must be monotonically increasing"); + } + } + if (!positions.empty()) { + const auto last = offsets_u64.back(); + if (last != static_cast(positions.rows)) { + throw std::runtime_error("TRX positions row count does not match offsets sentinel"); + } + } + + const std::string source_dir = + !_uncompressed_folder_handle.empty() ? _uncompressed_folder_handle : _backing_directory; + if (source_dir.empty()) { + throw std::runtime_error("TRX file has no backing directory to save from"); + } + + std::string tmp_dir = make_temp_dir("trx_runtime"); + copy_dir(source_dir, tmp_dir); + + { + const trx::fs::path header_path = trx::fs::path(tmp_dir) / "header.json"; + std::ofstream out_json(header_path); + if (!out_json.is_open()) { + throw std::runtime_error("Failed to write header.json to: " + header_path.string()); + } + out_json << header.dump() << std::endl; + } + + if (ext.size() > 0 && (ext == "zip" || ext == "trx")) { + int errorp; + zip_t *zf; + if ((zf = zip_open(filename.c_str(), ZIP_CREATE + ZIP_TRUNCATE, &errorp)) == nullptr) { + rm_dir(tmp_dir); + throw std::runtime_error("Could not open archive " + filename + ": " + strerror(errorp)); + } + zip_from_folder(zf, tmp_dir, tmp_dir, compression_standard); + if (zip_close(zf) != 0) { + rm_dir(tmp_dir); + throw std::runtime_error("Unable to close archive " + filename + ": " + zip_strerror(zf)); + } + } else { + std::error_code ec; + if (trx::fs::exists(filename, ec) && trx::fs::is_directory(filename, ec)) { + if (rm_dir(filename) != 0) { + rm_dir(tmp_dir); + throw std::runtime_error("Could not remove existing directory " + filename); + } + } + trx::fs::path dest_path(filename); + if (dest_path.has_parent_path()) { + std::error_code parent_ec; + trx::fs::create_directories(dest_path.parent_path(), parent_ec); + if (parent_ec) { + rm_dir(tmp_dir); + throw std::runtime_error("Could not create output parent directory: " + dest_path.parent_path().string()); + } + } + copy_dir(tmp_dir, filename); + ec.clear(); + if (!trx::fs::exists(filename, ec) || !trx::fs::is_directory(filename, ec)) { + rm_dir(tmp_dir); + throw std::runtime_error("Failed to create output directory: " + filename); + } + const trx::fs::path header_path = dest_path / "header.json"; + if (!trx::fs::exists(header_path)) { + rm_dir(tmp_dir); + throw std::runtime_error("Missing header.json in output directory: " + header_path.string()); + } + } + + rm_dir(tmp_dir); +} + void populate_fps(const string &name, std::map> &files_pointer_size) { const trx::fs::path root(name); std::error_code ec; @@ -188,15 +556,11 @@ void populate_fps(const string &name, std::map _split_ext_with_dimensionality(const std::string &filename) { - std::string base = path_basename(filename); - - const size_t num_splits = std::count(base.begin(), base.end(), '.'); - int dim = 0; - - if (num_splits != 1 && num_splits != 2) { - throw std::invalid_argument("Invalid filename"); - } - - const std::string ext = get_ext(base); - - base = base.substr(0, base.length() - ext.length() - 1); - - if (num_splits == 1) { - dim = 1; - } else { - const size_t pos = base.find_last_of('.'); - dim = std::stoi(base.substr(pos + 1, base.size())); - base = base.substr(0, pos); - } - - const bool is_valid = _is_dtype_valid(ext); - - if (!is_valid) { - // TODO: make formatted string and include provided extension name - throw std::invalid_argument("Unsupported file extension"); - } - - std::tuple output{base, dim, ext}; - - return output; -} - -bool _is_dtype_valid(const std::string &ext) { - if (ext == "bit") - return true; - if (std::find(trxmmap::dtypes.begin(), trxmmap::dtypes.end(), ext) != trxmmap::dtypes.end()) - return true; - return false; -} - json load_header(zip_t *zfolder) { if (zfolder == nullptr) { throw std::invalid_argument("Zip archive pointer is null"); @@ -380,23 +643,17 @@ void allocate_file(const std::string &path, std::size_t size) { } mio::shared_mmap_sink _create_memmap(std::string filename, - std::tuple &shape, + const std::tuple &shape, const std::string &mode, const std::string &dtype, long long offset) { static_cast(mode); - if (dtype == "bool") { - const std::string ext = "bit"; - filename.replace(filename.size() - 4, 3, ext); - filename.pop_back(); - } - const std::size_t filesize = static_cast(std::get<0>(shape)) * static_cast(std::get<1>(shape)) * - static_cast(_sizeof_dtype(dtype)); + static_cast(trx::detail::_sizeof_dtype(dtype)); // if file does not exist, create and allocate it - struct stat buffer{}; + struct stat buffer {}; if (stat(filename.c_str(), &buffer) != 0) { allocate_file(filename, filesize); } @@ -440,9 +697,10 @@ json assignHeader(const json &root) { return header; } -void get_reference_info(const std::string &reference, - const MatrixXf &affine, // NOLINT(misc-include-cleaner) - const RowVectorXi &dimensions) { // NOLINT(misc-use-internal-linkage,misc-include-cleaner) +void get_reference_info( + const std::string &reference, + const Eigen::MatrixXf &affine, // NOLINT(misc-include-cleaner) + const Eigen::RowVectorXi &dimensions) { // NOLINT(misc-use-internal-linkage,misc-include-cleaner) static_cast(affine); static_cast(dimensions); // TODO: find a library to use for nifti and trk (MRtrix??) @@ -470,51 +728,19 @@ void copy_dir(const string &src, const string &dst) { } ec.clear(); - for (trx::fs::recursive_directory_iterator it(src_path, ec), end; it != end; it.increment(ec)) { - if (ec) { - throw std::runtime_error("Failed to read directory: " + src_path.string()); - } - const trx::fs::path current = it->path(); - const trx::fs::path rel = current.lexically_relative(src_path); - const trx::fs::path target = dst_path / rel; - std::error_code entry_ec; - if (it->is_directory(entry_ec)) { - trx::fs::create_directories(target, entry_ec); - if (entry_ec) { - throw std::runtime_error("Could not create directory " + target.string()); - } - continue; - } - if (!it->is_regular_file(entry_ec)) { - continue; - } - copy_file(current.string(), target.string()); + const auto options = trx::fs::copy_options::recursive | trx::fs::copy_options::overwrite_existing | + trx::fs::copy_options::skip_symlinks; + trx::fs::copy(src_path, dst_path, options, ec); + if (ec) { + throw std::runtime_error("Failed to copy directory: " + ec.message()); } } void copy_file(const string &src, const string &dst) { - std::ifstream in(src, std::ios::binary); - if (!in.is_open()) { - throw std::runtime_error(std::string("Failed to open source file ") + src); - } - std::ofstream out(dst, std::ios::binary | std::ios::trunc); - if (!out.is_open()) { - throw std::runtime_error(std::string("Failed to open destination file ") + dst); - } - - std::array buffer{}; - while (in) { - in.read(buffer.data(), static_cast(buffer.size())); - const std::streamsize n = in.gcount(); - if (n > 0) { - out.write(buffer.data(), n); - if (!out) { - throw std::runtime_error(std::string("Error writing to file ") + dst); - } - } - } - if (!in.eof()) { - throw std::runtime_error(std::string("Error reading file ") + src); + std::error_code ec; + trx::fs::copy_file(src, dst, trx::fs::copy_options::overwrite_existing, ec); + if (ec) { + throw std::runtime_error(std::string("Failed to copy file ") + src + ": " + ec.message()); } } int rm_dir(const string &d) { @@ -723,4 +949,4 @@ std::string rm_root(const std::string &root, const std::string &path) { } return stripped; } -}; // namespace trxmmap \ No newline at end of file +}; // namespace trx \ No newline at end of file diff --git a/test_package/src/example.cpp b/test_package/src/example.cpp index 97bff04..6cedce8 100644 --- a/test_package/src/example.cpp +++ b/test_package/src/example.cpp @@ -2,7 +2,7 @@ int main() { // Basic construction and cleanup exercises the public API and linkage. - trxmmap::TrxFile file; + trx::TrxFile file; file.close(); return 0; } diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index c3175b6..55634e9 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -75,6 +75,12 @@ target_compile_features(test_mmap PRIVATE cxx_std_17) add_executable(test_io test_trx_io.cpp) target_link_libraries(test_io PRIVATE trx GTest::gtest_main) target_compile_features(test_io PRIVATE cxx_std_17) +if(TRX_ENABLE_NIFTI) + target_link_libraries(test_io PRIVATE trx-nifti) + target_compile_definitions(test_io PRIVATE TRX_ENABLE_NIFTI) + find_package(ZLIB REQUIRED) + target_link_libraries(test_io PRIVATE ZLIB::ZLIB) +endif() add_executable(test_streamlines_ops test_trx_streamlines_ops.cpp) target_link_libraries(test_streamlines_ops PRIVATE trx GTest::gtest_main) @@ -88,6 +94,10 @@ add_executable(test_trxfile test_trx_trxfile.cpp) target_link_libraries(test_trxfile PRIVATE trx GTest::gtest_main) target_compile_features(test_trxfile PRIVATE cxx_std_17) +add_executable(test_anytrxfile test_trx_anytrxfile.cpp) +target_link_libraries(test_anytrxfile PRIVATE trx GTest::gtest_main) +target_compile_features(test_anytrxfile PRIVATE cxx_std_17) + include(GoogleTest) gtest_discover_tests(test_mmap PROPERTIES ENVIRONMENT "TRX_TEST_DATA_DIR=${TRX_TEST_DATA_DIR}" @@ -104,3 +114,7 @@ gtest_discover_tests(test_filesystem) gtest_discover_tests(test_trxfile PROPERTIES ENVIRONMENT "TRX_TEST_DATA_DIR=${TRX_TEST_DATA_DIR}" ) + +gtest_discover_tests(test_anytrxfile PROPERTIES + ENVIRONMENT "TRX_TEST_DATA_DIR=${TRX_TEST_DATA_DIR}" +) diff --git a/tests/test_trx_anytrxfile.cpp b/tests/test_trx_anytrxfile.cpp new file mode 100644 index 0000000..6d66463 --- /dev/null +++ b/tests/test_trx_anytrxfile.cpp @@ -0,0 +1,641 @@ +#include +#include +#define private public +#include +#undef private + +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace trx; +namespace fs = std::filesystem; + +namespace { +std::string get_test_data_root() { + const auto *env = std::getenv("TRX_TEST_DATA_DIR"); // NOLINT(concurrency-mt-unsafe) + if (env == nullptr || std::string(env).empty()) { + return {}; + } + return std::string(env); +} + +fs::path resolve_gold_standard_dir(const std::string &root_dir) { + fs::path root(root_dir); + fs::path gs_dir = root / "gold_standard"; + if (fs::exists(gs_dir)) { + return gs_dir; + } + return root; +} + +fs::path require_gold_standard_dir() { + const auto root = get_test_data_root(); + if (root.empty()) { + throw std::runtime_error("TRX_TEST_DATA_DIR not set"); + } + const auto gs_dir = resolve_gold_standard_dir(root); + if (!fs::exists(gs_dir / "gs.trx")) { + throw std::runtime_error("Missing gold_standard gs.trx"); + } + if (!fs::exists(gs_dir / "gs_fldr.trx")) { + throw std::runtime_error("Missing gold_standard gs_fldr.trx"); + } + return gs_dir; +} + +fs::path make_temp_test_dir(const std::string &prefix) { + std::error_code ec; + auto base = fs::temp_directory_path(ec); + if (ec) { + throw std::runtime_error("Failed to get temp directory: " + ec.message()); + } + + thread_local std::mt19937_64 rng(std::random_device{}()); + std::uniform_int_distribution dist; + + for (int attempt = 0; attempt < 100; ++attempt) { + fs::path candidate = base / (prefix + "_" + std::to_string(dist(rng))); + std::error_code dir_ec; + if (fs::create_directory(candidate, dir_ec)) { + return candidate; + } + if (dir_ec && dir_ec != std::errc::file_exists) { + throw std::runtime_error("Failed to create temporary directory: " + dir_ec.message()); + } + } + throw std::runtime_error("Unable to create unique temporary directory"); +} + +fs::path copy_gold_standard_dir(const fs::path &gs_dir, const std::string &prefix, fs::path &temp_root) { + const fs::path source = gs_dir / "gs_fldr.trx"; + if (!fs::exists(source)) { + throw std::runtime_error("Missing gold_standard gs_fldr.trx"); + } + temp_root = make_temp_test_dir(prefix); + fs::path dest = temp_root / source.filename(); + std::error_code ec; + fs::copy(source, dest, fs::copy_options::recursive, ec); + if (ec) { + throw std::runtime_error("Failed to copy gold_standard directory: " + ec.message()); + } + return dest; +} + +fs::path find_file_with_prefix(const fs::path &dir, const std::string &prefix) { + std::error_code ec; + for (fs::directory_iterator it(dir, ec), end; it != end; it.increment(ec)) { + if (ec) { + break; + } + if (!it->is_regular_file(ec)) { + continue; + } + const std::string name = it->path().filename().string(); + if (name.rfind(prefix, 0) == 0) { + return it->path(); + } + } + throw std::runtime_error("Failed to find file with prefix: " + prefix); +} + +fs::path find_first_file_recursive(const fs::path &dir) { + std::error_code ec; + for (fs::recursive_directory_iterator it(dir, ec), end; it != end; it.increment(ec)) { + if (ec) { + break; + } + if (it->is_regular_file(ec)) { + return it->path(); + } + } + throw std::runtime_error("Failed to find file under: " + dir.string()); +} + +json read_header_file(const fs::path &dir) { + const fs::path header_path = dir / "header.json"; + std::ifstream in(header_path.string()); + if (!in.is_open()) { + throw std::runtime_error("Failed to open header.json"); + } + std::string jstream((std::istreambuf_iterator(in)), std::istreambuf_iterator()); + in.close(); + std::string err; + json header = json::parse(jstream, err); + if (!err.empty()) { + throw std::runtime_error("Failed to parse header.json: " + err); + } + return header; +} + +void write_header_file(const fs::path &dir, const json &header) { + const fs::path header_path = dir / "header.json"; + std::ofstream out(header_path.string()); + if (!out.is_open()) { + throw std::runtime_error("Failed to write header.json"); + } + out << header.dump() << '\n'; + out.close(); +} + +std::string pick_int_dtype_same_size(const std::string &ext) { + const int size = trx::detail::_sizeof_dtype(ext); + if (size == 8) { + return "int64"; + } + if (size == 4) { + return "int32"; + } + return "int16"; +} + +fs::path rename_with_new_dim(const fs::path &file_path, int new_dim) { + const std::string filename = file_path.filename().string(); + auto parsed = trx::detail::_split_ext_with_dimensionality(filename); + const std::string base = std::get<0>(parsed); + const std::string ext = std::get<2>(parsed); + std::string new_name = base + "." + std::to_string(new_dim) + "." + ext; + fs::path new_path = file_path.parent_path() / new_name; + fs::rename(file_path, new_path); + return new_path; +} + +fs::path rename_with_new_ext(const fs::path &file_path, const std::string &new_ext) { + const std::string filename = file_path.filename().string(); + const std::string old_ext = get_ext(filename); + const std::string new_name = filename.substr(0, filename.size() - old_ext.size()) + new_ext; + fs::path new_path = file_path.parent_path() / new_name; + fs::rename(file_path, new_path); + return new_path; +} + +void ensure_directory_exists(const fs::path &dir_path) { + std::error_code ec; + if (!fs::exists(dir_path, ec)) { + fs::create_directories(dir_path, ec); + } + if (ec) { + throw std::runtime_error("Failed to create directory: " + dir_path.string()); + } +} + +void write_zero_filled_file(const fs::path &file_path, const std::string &dtype, size_t count) { + const int dtype_size = trx::detail::_sizeof_dtype(dtype); + const size_t total_bytes = count * static_cast(dtype_size); + std::vector bytes(total_bytes, 0); + std::ofstream out(file_path.string(), std::ios::binary | std::ios::trunc); + if (!out.is_open()) { + throw std::runtime_error("Failed to write file: " + file_path.string()); + } + if (!bytes.empty()) { + out.write(bytes.data(), static_cast(bytes.size())); + } + out.close(); +} + +bool has_regular_file_recursive(const fs::path &dir_path) { + std::error_code ec; + for (fs::recursive_directory_iterator it(dir_path, ec), end; it != end; it.increment(ec)) { + if (ec) { + break; + } + if (it->is_regular_file(ec)) { + return true; + } + } + return false; +} + +void expect_basic_consistency(const AnyTrxFile &trx) { + ASSERT_TRUE(trx.header["NB_STREAMLINES"].is_number()); + ASSERT_TRUE(trx.header["NB_VERTICES"].is_number()); + + const auto nb_streamlines = static_cast(trx.header["NB_STREAMLINES"].int_value()); + const auto nb_vertices = static_cast(trx.header["NB_VERTICES"].int_value()); + + EXPECT_EQ(trx.num_streamlines(), nb_streamlines); + EXPECT_EQ(trx.num_vertices(), nb_vertices); + + ASSERT_EQ(trx.offsets_u64.size(), nb_streamlines + 1); + EXPECT_EQ(trx.offsets_u64.back(), static_cast(nb_vertices)); + + EXPECT_EQ(trx.positions.cols, 3); + EXPECT_EQ(trx.positions.rows, static_cast(nb_vertices)); + EXPECT_FALSE(trx.positions.empty()); + + const auto bytes = trx.positions.to_bytes(); + EXPECT_NE(bytes.data, nullptr); + EXPECT_GT(bytes.size, 0U); +} +} // namespace + +TEST(AnyTrxFile, LoadZipAndValidate) { + const auto gs_dir = require_gold_standard_dir(); + const fs::path gs_trx = gs_dir / "gs.trx"; + auto trx = load_any(gs_trx.string()); + + EXPECT_TRUE(trx.positions.dtype == "float16" || trx.positions.dtype == "float32" || trx.positions.dtype == "float64"); + expect_basic_consistency(trx); + + if (trx.positions.dtype == "float32") { + auto positions = trx.positions.as_matrix(); + EXPECT_EQ(positions.rows(), trx.positions.rows); + EXPECT_EQ(positions.cols(), trx.positions.cols); + } + trx.close(); +} + +TEST(AnyTrxFile, LoadDirectoryAndValidate) { + const auto gs_dir = require_gold_standard_dir(); + const fs::path gs_trx = gs_dir / "gs_fldr.trx"; + auto trx = load_any(gs_trx.string()); + + EXPECT_TRUE(trx.positions.dtype == "float16" || trx.positions.dtype == "float32" || trx.positions.dtype == "float64"); + expect_basic_consistency(trx); + trx.close(); +} + +TEST(AnyTrxFile, SaveUpdatesHeader) { + const auto gs_dir = require_gold_standard_dir(); + const fs::path gs_trx = gs_dir / "gs.trx"; + auto trx = load_any(gs_trx.string()); + + auto header_obj = trx.header.object_items(); + header_obj["COMMENT"] = "saved by anytrxfile test"; + trx.header = json(header_obj); + + const auto temp_dir = make_temp_test_dir("trx_any_save"); + const fs::path out_path = temp_dir / "saved_copy.trx"; + trx.save(out_path.string(), ZIP_CM_STORE); + trx.close(); + + auto reloaded = load_any(out_path.string()); + EXPECT_EQ(reloaded.header["COMMENT"].string_value(), "saved by anytrxfile test"); + reloaded.close(); +} + +TEST(AnyTrxFile, MissingHeaderCountsThrows) { + const auto gs_dir = require_gold_standard_dir(); + fs::path temp_root; + const fs::path corrupt_dir = copy_gold_standard_dir(gs_dir, "trx_any_missing_counts", temp_root); + + auto header = read_header_file(corrupt_dir); + auto header_obj = header.object_items(); + header_obj.erase("NB_VERTICES"); + write_header_file(corrupt_dir, json(header_obj)); + + EXPECT_THROW(load_any(corrupt_dir.string()), std::invalid_argument); + + std::error_code ec; + fs::remove_all(temp_root, ec); +} + +TEST(AnyTrxFile, WrongPositionsDimThrows) { + const auto gs_dir = require_gold_standard_dir(); + fs::path temp_root; + const fs::path corrupt_dir = copy_gold_standard_dir(gs_dir, "trx_any_pos_dim", temp_root); + + const fs::path positions = find_file_with_prefix(corrupt_dir, "positions"); + rename_with_new_dim(positions, 4); + + EXPECT_THROW(load_any(corrupt_dir.string()), std::invalid_argument); + + std::error_code ec; + fs::remove_all(temp_root, ec); +} + +TEST(AnyTrxFile, UnsupportedPositionsDtypeThrows) { + const auto gs_dir = require_gold_standard_dir(); + fs::path temp_root; + const fs::path corrupt_dir = copy_gold_standard_dir(gs_dir, "trx_any_pos_dtype", temp_root); + + const fs::path positions = find_file_with_prefix(corrupt_dir, "positions"); + const std::string ext = get_ext(positions.string()); + rename_with_new_ext(positions, pick_int_dtype_same_size(ext)); + + EXPECT_THROW(load_any(corrupt_dir.string()), std::invalid_argument); + + std::error_code ec; + fs::remove_all(temp_root, ec); +} + +TEST(AnyTrxFile, WrongOffsetsDimThrows) { + const auto gs_dir = require_gold_standard_dir(); + fs::path temp_root; + const fs::path corrupt_dir = copy_gold_standard_dir(gs_dir, "trx_any_offsets_dim", temp_root); + + const fs::path offsets = find_file_with_prefix(corrupt_dir, "offsets"); + rename_with_new_dim(offsets, 2); + + EXPECT_THROW(load_any(corrupt_dir.string()), std::invalid_argument); + + std::error_code ec; + fs::remove_all(temp_root, ec); +} + +TEST(AnyTrxFile, UnsupportedOffsetsDtypeThrows) { + const auto gs_dir = require_gold_standard_dir(); + fs::path temp_root; + const fs::path corrupt_dir = copy_gold_standard_dir(gs_dir, "trx_any_offsets_dtype", temp_root); + + const fs::path offsets = find_file_with_prefix(corrupt_dir, "offsets"); + const std::string ext = get_ext(offsets.string()); + rename_with_new_ext(offsets, pick_int_dtype_same_size(ext)); + + EXPECT_THROW(load_any(corrupt_dir.string()), std::invalid_argument); + + std::error_code ec; + fs::remove_all(temp_root, ec); +} + +TEST(AnyTrxFile, WrongDpsDimThrows) { + const auto gs_dir = require_gold_standard_dir(); + fs::path temp_root; + const fs::path corrupt_dir = copy_gold_standard_dir(gs_dir, "trx_any_dps_dim", temp_root); + + const fs::path dps_dir = corrupt_dir / "dps"; + if (!fs::exists(dps_dir) || !has_regular_file_recursive(dps_dir)) { + ensure_directory_exists(dps_dir); + const auto header = read_header_file(corrupt_dir); + const auto nb_streamlines = static_cast(header["NB_STREAMLINES"].int_value()); + const fs::path dps_file = dps_dir / "weight.float32"; + write_zero_filled_file(dps_file, "float32", nb_streamlines); + } + const fs::path dps_file = find_first_file_recursive(dps_dir); + rename_with_new_dim(dps_file, 2); + + EXPECT_THROW(load_any(corrupt_dir.string()), std::invalid_argument); + + std::error_code ec; + fs::remove_all(temp_root, ec); +} + +TEST(AnyTrxFile, WrongDpvDimThrows) { + const auto gs_dir = require_gold_standard_dir(); + fs::path temp_root; + const fs::path corrupt_dir = copy_gold_standard_dir(gs_dir, "trx_any_dpv_dim", temp_root); + + const fs::path dpv_dir = corrupt_dir / "dpv"; + if (!fs::exists(dpv_dir) || !has_regular_file_recursive(dpv_dir)) { + ensure_directory_exists(dpv_dir); + const auto header = read_header_file(corrupt_dir); + const auto nb_vertices = static_cast(header["NB_VERTICES"].int_value()); + const fs::path dpv_file = dpv_dir / "color.float32"; + write_zero_filled_file(dpv_file, "float32", nb_vertices); + } + const fs::path dpv_file = find_first_file_recursive(dpv_dir); + rename_with_new_dim(dpv_file, 2); + + EXPECT_THROW(load_any(corrupt_dir.string()), std::invalid_argument); + + std::error_code ec; + fs::remove_all(temp_root, ec); +} + +TEST(AnyTrxFile, WrongDpgDimThrows) { + const auto gs_dir = require_gold_standard_dir(); + fs::path temp_root; + const fs::path corrupt_dir = copy_gold_standard_dir(gs_dir, "trx_any_dpg_dim", temp_root); + + const fs::path dpg_dir = corrupt_dir / "dpg"; + if (!fs::exists(dpg_dir) || !has_regular_file_recursive(dpg_dir)) { + const fs::path group_dir = dpg_dir / "GroupA"; + ensure_directory_exists(group_dir); + const fs::path dpg_file = group_dir / "mean.float32"; + write_zero_filled_file(dpg_file, "float32", 1); + } + const fs::path dpg_file = find_first_file_recursive(dpg_dir); + rename_with_new_dim(dpg_file, 2); + + EXPECT_THROW(load_any(corrupt_dir.string()), std::invalid_argument); + + std::error_code ec; + fs::remove_all(temp_root, ec); +} + +TEST(AnyTrxFile, UnsupportedGroupDtypeThrows) { + const auto gs_dir = require_gold_standard_dir(); + fs::path temp_root; + const fs::path corrupt_dir = copy_gold_standard_dir(gs_dir, "trx_any_group_dtype", temp_root); + + const fs::path groups_dir = corrupt_dir / "groups"; + if (!fs::exists(groups_dir) || !has_regular_file_recursive(groups_dir)) { + ensure_directory_exists(groups_dir); + const auto header = read_header_file(corrupt_dir); + const auto nb_streamlines = static_cast(header["NB_STREAMLINES"].int_value()); + const fs::path group_file = groups_dir / "GroupA.uint32"; + write_zero_filled_file(group_file, "uint32", nb_streamlines); + } + const fs::path group_file = find_first_file_recursive(groups_dir); + rename_with_new_ext(group_file, "int32"); + + EXPECT_THROW(load_any(corrupt_dir.string()), std::invalid_argument); + + std::error_code ec; + fs::remove_all(temp_root, ec); +} + +TEST(AnyTrxFile, InvalidEntryThrows) { + const auto gs_dir = require_gold_standard_dir(); + fs::path temp_root; + const fs::path corrupt_dir = copy_gold_standard_dir(gs_dir, "trx_any_invalid_entry", temp_root); + + const fs::path bogus = corrupt_dir / "bogus.float32"; + std::ofstream out(bogus.string(), std::ios::binary); + float value = 1.0F; + std::array value_bytes{}; + std::memcpy(value_bytes.data(), &value, sizeof(value)); + out.write(value_bytes.data(), static_cast(value_bytes.size())); + out.close(); + + EXPECT_THROW(load_any(corrupt_dir.string()), std::invalid_argument); + + std::error_code ec; + fs::remove_all(temp_root, ec); +} + +TEST(AnyTrxFile, MissingEssentialDataThrows) { + const auto gs_dir = require_gold_standard_dir(); + fs::path temp_root; + const fs::path corrupt_dir = copy_gold_standard_dir(gs_dir, "trx_any_missing_essential", temp_root); + + const fs::path positions = find_file_with_prefix(corrupt_dir, "positions"); + std::error_code ec; + fs::remove(positions, ec); + + EXPECT_THROW(load_any(corrupt_dir.string()), std::invalid_argument); + + fs::remove_all(temp_root, ec); +} + +TEST(AnyTrxFile, OffsetsOverflowThrows) { + const auto gs_dir = require_gold_standard_dir(); + fs::path temp_root; + const fs::path corrupt_dir = copy_gold_standard_dir(gs_dir, "trx_any_offsets_overflow", temp_root); + + const auto header = read_header_file(corrupt_dir); + const int nb_streamlines = header["NB_STREAMLINES"].int_value(); + const fs::path offsets = find_file_with_prefix(corrupt_dir, "offsets"); + const std::string ext = get_ext(offsets.string()); + + fs::path offsets_path = offsets; + if (ext != "uint64") { + offsets_path = rename_with_new_ext(offsets, "uint64"); + } + + std::vector data(static_cast(nb_streamlines) + 1, 0); + if (data.size() >= 2) { + data[1] = static_cast(std::numeric_limits::max()) + 1; + for (size_t i = 2; i < data.size(); ++i) { + data[i] = data[i - 1] + 1; + } + } + std::ofstream out(offsets_path.string(), std::ios::binary | std::ios::trunc); + std::vector data_bytes(data.size() * sizeof(uint64_t)); + std::memcpy(data_bytes.data(), data.data(), data_bytes.size()); + out.write(data_bytes.data(), static_cast(data_bytes.size())); + out.close(); + + EXPECT_THROW(load_any(corrupt_dir.string()), std::runtime_error); + + std::error_code ec; + fs::remove_all(temp_root, ec); +} + +TEST(AnyTrxFile, SaveRejectsUnsupportedExtension) { + const auto gs_dir = require_gold_standard_dir(); + const fs::path gs_trx = gs_dir / "gs_fldr.trx"; + auto trx = load_any(gs_trx.string()); + + const auto temp_dir = make_temp_test_dir("trx_any_save_badext"); + const fs::path out_path = temp_dir / "bad.txt"; + EXPECT_THROW(trx.save(out_path.string(), ZIP_CM_STORE), std::invalid_argument); + trx.close(); + + std::error_code ec; + fs::remove_all(temp_dir, ec); +} + +TEST(AnyTrxFile, SaveRejectsMissingOffsets) { + const auto gs_dir = require_gold_standard_dir(); + const fs::path gs_trx = gs_dir / "gs_fldr.trx"; + auto trx = load_any(gs_trx.string()); + + trx.offsets = TypedArray(); + const auto temp_dir = make_temp_test_dir("trx_any_save_no_offsets"); + const fs::path out_path = temp_dir / "missing_offsets.trx"; + EXPECT_THROW(trx.save(out_path.string(), ZIP_CM_STORE), std::runtime_error); + trx.close(); + + std::error_code ec; + fs::remove_all(temp_dir, ec); +} + +TEST(AnyTrxFile, SaveRejectsMissingDecodedOffsets) { + const auto gs_dir = require_gold_standard_dir(); + const fs::path gs_trx = gs_dir / "gs_fldr.trx"; + auto trx = load_any(gs_trx.string()); + + trx.offsets_u64.clear(); + const auto temp_dir = make_temp_test_dir("trx_any_save_no_offsets_u64"); + const fs::path out_path = temp_dir / "missing_offsets_u64.trx"; + EXPECT_THROW(trx.save(out_path.string(), ZIP_CM_STORE), std::runtime_error); + trx.close(); + + std::error_code ec; + fs::remove_all(temp_dir, ec); +} + +TEST(AnyTrxFile, SaveRejectsStreamlineCountMismatch) { + const auto gs_dir = require_gold_standard_dir(); + const fs::path gs_trx = gs_dir / "gs_fldr.trx"; + auto trx = load_any(gs_trx.string()); + + auto header_obj = trx.header.object_items(); + header_obj["NB_STREAMLINES"] = header_obj["NB_STREAMLINES"].int_value() + 1; + trx.header = json(header_obj); + + const auto temp_dir = make_temp_test_dir("trx_any_save_bad_streamlines"); + const fs::path out_path = temp_dir / "bad_streamlines.trx"; + EXPECT_THROW(trx.save(out_path.string(), ZIP_CM_STORE), std::runtime_error); + trx.close(); + + std::error_code ec; + fs::remove_all(temp_dir, ec); +} + +TEST(AnyTrxFile, SaveRejectsVertexCountMismatch) { + const auto gs_dir = require_gold_standard_dir(); + const fs::path gs_trx = gs_dir / "gs_fldr.trx"; + auto trx = load_any(gs_trx.string()); + + auto header_obj = trx.header.object_items(); + header_obj["NB_VERTICES"] = static_cast(trx.offsets_u64.back() + 1); + trx.header = json(header_obj); + + const auto temp_dir = make_temp_test_dir("trx_any_save_bad_vertices"); + const fs::path out_path = temp_dir / "bad_vertices.trx"; + EXPECT_THROW(trx.save(out_path.string(), ZIP_CM_STORE), std::runtime_error); + trx.close(); + + std::error_code ec; + fs::remove_all(temp_dir, ec); +} + +TEST(AnyTrxFile, SaveRejectsNonMonotonicOffsets) { + const auto gs_dir = require_gold_standard_dir(); + const fs::path gs_trx = gs_dir / "gs_fldr.trx"; + auto trx = load_any(gs_trx.string()); + + ASSERT_GT(trx.offsets_u64.size(), 1U); + trx.offsets_u64[0] = 1; + trx.offsets_u64[1] = 0; + + const auto temp_dir = make_temp_test_dir("trx_any_save_non_mono"); + const fs::path out_path = temp_dir / "non_mono.trx"; + EXPECT_THROW(trx.save(out_path.string(), ZIP_CM_STORE), std::runtime_error); + trx.close(); + + std::error_code ec; + fs::remove_all(temp_dir, ec); +} + +TEST(AnyTrxFile, SaveRejectsPositionsRowMismatch) { + const auto gs_dir = require_gold_standard_dir(); + const fs::path gs_trx = gs_dir / "gs_fldr.trx"; + auto trx = load_any(gs_trx.string()); + + const uint64_t sentinel = trx.offsets_u64.back(); + ASSERT_GT(sentinel, 0U); + trx.positions.rows = static_cast(sentinel - 1); + + const auto temp_dir = make_temp_test_dir("trx_any_save_bad_positions"); + const fs::path out_path = temp_dir / "bad_positions.trx"; + EXPECT_THROW(trx.save(out_path.string(), ZIP_CM_STORE), std::runtime_error); + trx.close(); + + std::error_code ec; + fs::remove_all(temp_dir, ec); +} + +TEST(AnyTrxFile, SaveRejectsMissingBackingDirectory) { + const auto gs_dir = require_gold_standard_dir(); + const fs::path gs_trx = gs_dir / "gs_fldr.trx"; + auto trx = load_any(gs_trx.string()); + + trx._backing_directory.clear(); + trx._uncompressed_folder_handle.clear(); + + const auto temp_dir = make_temp_test_dir("trx_any_save_no_backing"); + const fs::path out_path = temp_dir / "no_backing.trx"; + EXPECT_THROW(trx.save(out_path.string(), ZIP_CM_STORE), std::runtime_error); + trx.close(); + + std::error_code ec; + fs::remove_all(temp_dir, ec); +} diff --git a/tests/test_trx_io.cpp b/tests/test_trx_io.cpp index d0ca947..46c492a 100644 --- a/tests/test_trx_io.cpp +++ b/tests/test_trx_io.cpp @@ -1,6 +1,10 @@ #include #include #include +#ifdef TRX_ENABLE_NIFTI +#include +#include +#endif #include #include @@ -9,6 +13,7 @@ #include #include #include +#include #include #include #include @@ -17,7 +22,7 @@ #include using namespace Eigen; -using namespace trxmmap; +using namespace trx; namespace fs = std::filesystem; namespace { @@ -142,12 +147,7 @@ void expect_allclose(const Matrix &actual, } } -template trxmmap::TrxFile
*load_trx(const fs::path &path) { - if (is_dir(path)) { - return trxmmap::load_from_directory
(path.string()); - } - return trxmmap::load_from_zip
(path.string()); -} +template trx::TrxReader
load_trx(const fs::path &path) { return trx::TrxReader
(path.string()); } class ScopedEnvVar { public: @@ -350,11 +350,11 @@ TEST(TrxFileIo, load_rasmm) { const std::vector inputs = {gs_dir / "gs.trx", gs_dir / "gs_fldr.trx"}; for (const auto &input : inputs) { ASSERT_TRUE(fs::exists(input)); - trxmmap::TrxFile *trx = load_trx(input); + auto reader = load_trx(input); + auto *trx = reader.get(); Matrix actual = trx->streamlines->_data; expect_allclose(actual, coords); trx->close(); - delete trx; } } @@ -367,30 +367,142 @@ TEST(TrxFileIo, multi_load_save_rasmm) { ASSERT_TRUE(fs::exists(input)); fs::path tmp_dir = make_temp_test_dir("trx_gs"); - trxmmap::TrxFile *trx = load_trx(input); + auto reader = load_trx(input); + auto *trx = reader.get(); const std::string input_str = normalize_path(input.string()); - const std::string basename = trxmmap::get_base("/", input_str); - const std::string ext = trxmmap::get_ext(input_str); + const std::string basename = trx::get_base("/", input_str); + const std::string ext = trx::get_ext(input_str); const std::string basename_no_ext = ext.empty() ? basename : basename.substr(0, basename.size() - ext.size() - 1); for (int i = 0; i < 3; ++i) { fs::path out_path = tmp_dir / (basename_no_ext + "_tmp" + std::to_string(i) + (ext.empty() ? "" : ("." + ext))); - trxmmap::save(*trx, out_path.string()); + trx->save(out_path.string()); trx->close(); - delete trx; - trx = load_trx(out_path); + reader = load_trx(out_path); + trx = reader.get(); } Matrix actual = trx->streamlines->_data; expect_allclose(actual, coords); trx->close(); - delete trx; std::error_code ec; fs::remove_all(tmp_dir, ec); } } +TEST(TrxFileIo, roundtrip_voxel_to_rasmm) { + const auto gs_dir = require_gold_standard_dir(); + const fs::path input = gs_dir / "gs.trx"; + ASSERT_TRUE(fs::exists(input)); + + auto reader = load_trx(input); + auto *trx = reader.get(); + + Eigen::Matrix4f affine; + affine << 1.0f, 0.1f, 0.2f, 10.0f, + -0.3f, 1.2f, 0.4f, -5.0f, + 0.5f, -0.6f, 0.9f, 2.5f, + 0.0f, 0.0f, 0.0f, 1.0f; + trx->set_voxel_to_rasmm(affine); + + fs::path tmp_dir = make_temp_test_dir("trx_affine"); + fs::path out_path = tmp_dir / "gs_affine.trx"; + trx->save(out_path.string()); + trx->close(); + + auto reader2 = load_trx(out_path); + auto *trx2 = reader2.get(); + const auto &vox = trx2->header["VOXEL_TO_RASMM"]; + ASSERT_TRUE(vox.is_array()); + const auto rows = vox.array_items(); + ASSERT_EQ(rows.size(), 4U); + for (size_t i = 0; i < 4; ++i) { + const auto cols = rows[i].array_items(); + ASSERT_EQ(cols.size(), 4U); + for (size_t j = 0; j < 4; ++j) { + EXPECT_FLOAT_EQ(cols[j].number_value(), affine(static_cast(i), static_cast(j))); + } + } + trx2->close(); + + std::error_code ec; + fs::remove_all(tmp_dir, ec); +} + +#ifdef TRX_ENABLE_NIFTI +fs::path gzip_copy(const fs::path &input, const fs::path &output) { + std::ifstream in(input, std::ios::binary); + if (!in.is_open()) { + throw std::runtime_error("Failed to open NIfTI for gzip: " + input.string()); + } + + gzFile out = gzopen(output.string().c_str(), "wb"); + if (!out) { + throw std::runtime_error("Failed to open gzip output: " + output.string()); + } + + std::array buffer{}; + while (in) { + in.read(buffer.data(), static_cast(buffer.size())); + const std::streamsize got = in.gcount(); + if (got <= 0) { + break; + } + const int written = gzwrite(out, buffer.data(), static_cast(got)); + if (written != got) { + gzclose(out); + throw std::runtime_error("Failed to write gzip output: " + output.string()); + } + } + + gzclose(out); + return output; +} + +TEST(TrxFileIo, nifti_voxel_to_rasmm_gs) { + const auto gs_dir = require_gold_standard_dir(); + const fs::path nifti_path = gs_dir / "gs.nii"; + ASSERT_TRUE(fs::exists(nifti_path)); + + Eigen::Matrix4f expected; + expected << 3.96961546e+00f, -2.45575607e-01f, 7.59612350e-03f, 1.20822277e+01f, + 4.91151214e-01f, 1.96961546e+00f, -1.22787803e-01f, 2.21644382e+01f, + 3.03844940e-02f, 2.45575607e-01f, 9.92403865e-01f, 3.79177742e+01f, + 0.00000000e+00f, 0.00000000e+00f, 0.00000000e+00f, 1.00000000e+00f; + + const Eigen::Matrix4f actual = trx::read_nifti_voxel_to_rasmm(nifti_path.string()); + for (int i = 0; i < 4; ++i) { + for (int j = 0; j < 4; ++j) { + EXPECT_NEAR(actual(i, j), expected(i, j), 1e-5f); + } + } +} + +TEST(TrxFileIo, nifti_voxel_to_rasmm_gs_gz_roundtrip) { + const auto gs_dir = require_gold_standard_dir(); + const fs::path nifti_path = gs_dir / "gs.nii"; + ASSERT_TRUE(fs::exists(nifti_path)); + + const Eigen::Matrix4f uncompressed = trx::read_nifti_voxel_to_rasmm(nifti_path.string()); + + fs::path tmp_dir = make_temp_test_dir("trx_gs_nifti_gz"); + fs::path gz_path = tmp_dir / "gs.nii.gz"; + gzip_copy(nifti_path, gz_path); + ASSERT_TRUE(fs::exists(gz_path)); + + const Eigen::Matrix4f gz_read = trx::read_nifti_voxel_to_rasmm(gz_path.string()); + for (int i = 0; i < 4; ++i) { + for (int j = 0; j < 4; ++j) { + EXPECT_NEAR(gz_read(i, j), uncompressed(i, j), 1e-5f); + } + } + + std::error_code ec; + fs::remove_all(tmp_dir, ec); +} +#endif + TEST(TrxFileIo, delete_tmp_gs_dir_rasmm) { const auto gs_dir = require_gold_standard_dir(); const auto coords = load_rasmm_coords(gs_dir / "gs_rasmm_space.txt"); @@ -398,7 +510,8 @@ TEST(TrxFileIo, delete_tmp_gs_dir_rasmm) { const std::vector inputs = {gs_dir / "gs.trx", gs_dir / "gs_fldr.trx"}; for (const auto &input : inputs) { ASSERT_TRUE(fs::exists(input)); - trxmmap::TrxFile *trx = load_trx(input); + auto reader = load_trx(input); + auto *trx = reader.get(); std::string tmp_dir = trx->_uncompressed_folder_handle; if (is_regular(input)) { @@ -419,13 +532,11 @@ TEST(TrxFileIo, delete_tmp_gs_dir_rasmm) { #endif } - delete trx; - - trx = load_trx(input); + reader = load_trx(input); + trx = reader.get(); Matrix actual2 = trx->streamlines->_data; expect_allclose(actual2, coords); trx->close(); - delete trx; } } @@ -434,7 +545,8 @@ TEST(TrxFileIo, close_tmp_files) { const fs::path input = gs_dir / "gs.trx"; ASSERT_TRUE(fs::exists(input)); - trxmmap::TrxFile *trx = load_trx(input); + auto reader = load_trx(input); + auto *trx = reader.get(); const std::string tmp_dir = trx->_uncompressed_folder_handle; ASSERT_FALSE(tmp_dir.empty()); ASSERT_TRUE(fs::exists(tmp_dir)); @@ -452,7 +564,6 @@ TEST(TrxFileIo, close_tmp_files) { } trx->close(); - delete trx; #if defined(_WIN32) || defined(_WIN64) // Windows can hold file handles briefly after close; avoid flaky removal assertions. @@ -479,7 +590,8 @@ TEST(TrxFileIo, change_tmp_dir) { { ScopedEnvVar env("TRX_TMPDIR", "use_working_dir"); - trxmmap::TrxFile *trx = load_trx(input); + auto reader = load_trx(input); + auto *trx = reader.get(); fs::path tmp_dir = trx->_uncompressed_folder_handle; fs::path parent = tmp_dir.parent_path(); fs::path expected = fs::path(get_current_working_dir()); @@ -489,18 +601,16 @@ TEST(TrxFileIo, change_tmp_dir) { } EXPECT_EQ(parent_norm, normalize_path(expected.lexically_normal())); trx->close(); - delete trx; } { ScopedEnvVar env("TRX_TMPDIR", home_env); - trxmmap::TrxFile *trx = load_trx(input); + auto trx = load_trx(input); fs::path tmp_dir = trx->_uncompressed_folder_handle; fs::path parent = tmp_dir.parent_path(); fs::path expected = fs::path(std::string(home_env)); EXPECT_EQ(normalize_path(parent.lexically_normal()), normalize_path(expected.lexically_normal())); trx->close(); - delete trx; } } @@ -518,7 +628,7 @@ TEST(TrxFileIo, complete_dir_from_trx) { const std::vector inputs = {gs_dir / "gs.trx", gs_dir / "gs_fldr.trx"}; for (const auto &input : inputs) { ASSERT_TRUE(fs::exists(input)); - trxmmap::TrxFile *trx = load_trx(input); + auto trx = load_trx(input); fs::path dir_to_check = trx->_uncompressed_folder_handle.empty() ? input : fs::path(trx->_uncompressed_folder_handle); @@ -538,7 +648,6 @@ TEST(TrxFileIo, complete_dir_from_trx) { EXPECT_EQ(file_paths, expected_content); trx->close(); - delete trx; } } @@ -578,14 +687,14 @@ TEST(TrxFileIo, complete_zip_from_trx) { } TEST(TrxFileIo, add_dps_from_text_success) { - trxmmap::TrxFile trx(4, 2); - set_streamline_lengths(trx.streamlines, {2, 2}); + trx::TrxFile trx(4, 2); + set_streamline_lengths(trx.streamlines.get(), {2, 2}); const fs::path tmp_dir = make_temp_test_dir("trx_dps_text"); const fs::path input_path = tmp_dir / "dps.txt"; write_text_file(input_path, "0.25 0.75"); - trxmmap::add_dps_from_text(trx, "weight", "float32", input_path.string()); + trx.add_dps_from_text("weight", "float32", input_path.string()); auto it = trx.data_per_streamline.find("weight"); ASSERT_NE(it, trx.data_per_streamline.end()); EXPECT_EQ(it->second->_matrix.rows(), 2); @@ -595,46 +704,45 @@ TEST(TrxFileIo, add_dps_from_text_success) { } TEST(TrxFileIo, add_dps_from_text_errors) { - trxmmap::TrxFile trx(4, 2); - set_streamline_lengths(trx.streamlines, {2, 2}); + trx::TrxFile trx(4, 2); + set_streamline_lengths(trx.streamlines.get(), {2, 2}); const fs::path tmp_dir = make_temp_test_dir("trx_dps_text_err"); const fs::path input_path = tmp_dir / "dps.txt"; write_text_file(input_path, "1.0"); - EXPECT_THROW(trxmmap::add_dps_from_text(trx, "", "float32", input_path.string()), std::invalid_argument); - EXPECT_THROW(trxmmap::add_dps_from_text(trx, "weight", "badtype", input_path.string()), std::invalid_argument); - EXPECT_THROW(trxmmap::add_dps_from_text(trx, "weight", "int32", input_path.string()), std::invalid_argument); + EXPECT_THROW(trx.add_dps_from_text("", "float32", input_path.string()), std::invalid_argument); + EXPECT_THROW(trx.add_dps_from_text("weight", "badtype", input_path.string()), std::invalid_argument); + EXPECT_THROW(trx.add_dps_from_text("weight", "int32", input_path.string()), std::invalid_argument); - EXPECT_THROW(trxmmap::add_dps_from_text(trx, "weight", "float32", (tmp_dir / "missing.txt").string()), - std::runtime_error); + EXPECT_THROW(trx.add_dps_from_text("weight", "float32", (tmp_dir / "missing.txt").string()), std::runtime_error); write_text_file(input_path, "1.0 abc"); - EXPECT_THROW(trxmmap::add_dps_from_text(trx, "weight", "float32", input_path.string()), std::runtime_error); + EXPECT_THROW(trx.add_dps_from_text("weight", "float32", input_path.string()), std::runtime_error); write_text_file(input_path, "1.0"); - EXPECT_THROW(trxmmap::add_dps_from_text(trx, "weight", "float32", input_path.string()), std::runtime_error); + EXPECT_THROW(trx.add_dps_from_text("weight", "float32", input_path.string()), std::runtime_error); - trxmmap::TrxFile empty; - EXPECT_THROW(trxmmap::add_dps_from_text(empty, "weight", "float32", input_path.string()), std::runtime_error); + trx::TrxFile empty; + EXPECT_THROW(empty.add_dps_from_text("weight", "float32", input_path.string()), std::runtime_error); } TEST(TrxFileIo, add_dpv_from_tsf_success) { ScopedLocale scoped_locale(std::locale::classic()); - trxmmap::TrxFile source_trx(4, 2); - set_streamline_lengths(source_trx.streamlines, {2, 2}); + trx::TrxFile source_trx(4, 2); + set_streamline_lengths(source_trx.streamlines.get(), {2, 2}); const fs::path tmp_dir = make_temp_test_dir("trx_dpv_tsf"); const fs::path input_path = tmp_dir / "dpv_text.tsf"; write_tsf_text_file(input_path, build_tsf_contents({{0.1, 0.2}, {0.3, 0.4}})); - trxmmap::add_dpv_from_tsf(source_trx, "signal", "float32", input_path.string()); + source_trx.add_dpv_from_tsf("signal", "float32", input_path.string()); const fs::path binary_path = tmp_dir / "dpv_binary.tsf"; - trxmmap::export_dpv_to_tsf(source_trx, "signal", binary_path.string(), "42"); + source_trx.export_dpv_to_tsf("signal", binary_path.string(), "42"); - trxmmap::TrxFile trx(4, 2); - set_streamline_lengths(trx.streamlines, {2, 2}); - trxmmap::add_dpv_from_tsf(trx, "signal", "float32", binary_path.string()); + trx::TrxFile trx(4, 2); + set_streamline_lengths(trx.streamlines.get(), {2, 2}); + trx.add_dpv_from_tsf("signal", "float32", binary_path.string()); auto it = trx.data_per_vertex.find("signal"); ASSERT_NE(it, trx.data_per_vertex.end()); EXPECT_EQ(it->second->_data.rows(), 4); @@ -648,57 +756,56 @@ TEST(TrxFileIo, add_dpv_from_tsf_success) { TEST(TrxFileIo, add_dpv_from_tsf_errors) { ScopedLocale scoped_locale(std::locale::classic()); - trxmmap::TrxFile trx(4, 2); - set_streamline_lengths(trx.streamlines, {2, 2}); + trx::TrxFile trx(4, 2); + set_streamline_lengths(trx.streamlines.get(), {2, 2}); const fs::path tmp_dir = make_temp_test_dir("trx_dpv_tsf_err"); const fs::path input_path = tmp_dir / "dpv.tsf"; write_tsf_text_file(input_path, build_tsf_contents({{0.1, 0.2}, {0.3}})); - EXPECT_THROW(trxmmap::add_dpv_from_tsf(trx, "", "float32", input_path.string()), std::invalid_argument); - EXPECT_THROW(trxmmap::add_dpv_from_tsf(trx, "signal", "badtype", input_path.string()), std::invalid_argument); - EXPECT_THROW(trxmmap::add_dpv_from_tsf(trx, "signal", "int32", input_path.string()), std::invalid_argument); + EXPECT_THROW(trx.add_dpv_from_tsf("", "float32", input_path.string()), std::invalid_argument); + EXPECT_THROW(trx.add_dpv_from_tsf("signal", "badtype", input_path.string()), std::invalid_argument); + EXPECT_THROW(trx.add_dpv_from_tsf("signal", "int32", input_path.string()), std::invalid_argument); - EXPECT_THROW(trxmmap::add_dpv_from_tsf(trx, "signal", "float32", (tmp_dir / "missing.tsf").string()), - std::runtime_error); + EXPECT_THROW(trx.add_dpv_from_tsf("signal", "float32", (tmp_dir / "missing.tsf").string()), std::runtime_error); write_tsf_text_file(input_path, "0.1 0.2 abc"); - EXPECT_THROW(trxmmap::add_dpv_from_tsf(trx, "signal", "float32", input_path.string()), std::runtime_error); + EXPECT_THROW(trx.add_dpv_from_tsf("signal", "float32", input_path.string()), std::runtime_error); write_tsf_text_file(input_path, build_tsf_contents({{0.1}, {0.2, 0.3}})); - EXPECT_THROW(trxmmap::add_dpv_from_tsf(trx, "signal", "float32", input_path.string()), std::runtime_error); + EXPECT_THROW(trx.add_dpv_from_tsf("signal", "float32", input_path.string()), std::runtime_error); write_tsf_text_file(input_path, build_tsf_contents({{0.1, 0.2}, {0.3}})); - EXPECT_THROW(trxmmap::add_dpv_from_tsf(trx, "signal", "float32", input_path.string()), std::runtime_error); + EXPECT_THROW(trx.add_dpv_from_tsf("signal", "float32", input_path.string()), std::runtime_error); write_text_file(input_path, "mrtrix track scalars\nfile: . 0\ndatatype: Float32LE\ntimestamp: 0\n0.1 0.2"); - EXPECT_THROW(trxmmap::add_dpv_from_tsf(trx, "signal", "float32", input_path.string()), std::runtime_error); + EXPECT_THROW(trx.add_dpv_from_tsf("signal", "float32", input_path.string()), std::runtime_error); - trxmmap::TrxFile empty; - EXPECT_THROW(trxmmap::add_dpv_from_tsf(empty, "signal", "float32", input_path.string()), std::runtime_error); + trx::TrxFile empty; + EXPECT_THROW(empty.add_dpv_from_tsf("signal", "float32", input_path.string()), std::runtime_error); - trxmmap::TrxFile no_dir(4, 2); - set_streamline_lengths(no_dir.streamlines, {2, 2}); + trx::TrxFile no_dir(4, 2); + set_streamline_lengths(no_dir.streamlines.get(), {2, 2}); // Intentional white-box access: there is no public API to construct a TrxFile // with valid streamlines but without an uncompressed folder. This test verifies // that add_dpv_from_tsf fails in that specific internal state. no_dir._uncompressed_folder_handle.clear(); - EXPECT_THROW(trxmmap::add_dpv_from_tsf(no_dir, "signal", "float32", input_path.string()), std::runtime_error); + EXPECT_THROW(no_dir.add_dpv_from_tsf("signal", "float32", input_path.string()), std::runtime_error); } TEST(TrxFileIo, export_dpv_to_tsf_success) { ScopedLocale scoped_locale(std::locale::classic()); - trxmmap::TrxFile trx(4, 2); - set_streamline_lengths(trx.streamlines, {2, 2}); + trx::TrxFile trx(4, 2); + set_streamline_lengths(trx.streamlines.get(), {2, 2}); const fs::path tmp_dir = make_temp_test_dir("trx_export_tsf"); const fs::path input_path = tmp_dir / "dpv_input.tsf"; write_tsf_text_file(input_path, build_tsf_contents({{0.1, 0.2}, {0.3, 0.4}})); - trxmmap::add_dpv_from_tsf(trx, "signal", "float32", input_path.string()); + trx.add_dpv_from_tsf("signal", "float32", input_path.string()); const fs::path output_path = tmp_dir / "signal.tsf"; const std::string timestamp = "1234.5678"; - trxmmap::export_dpv_to_tsf(trx, "signal", output_path.string(), timestamp); + trx.export_dpv_to_tsf("signal", output_path.string(), timestamp); const TsfHeader header = read_tsf_header(output_path); EXPECT_EQ(header.timestamp, timestamp); @@ -718,17 +825,17 @@ TEST(TrxFileIo, export_dpv_to_tsf_success) { } TEST(TrxFileIo, export_dpv_to_tsf_errors) { - trxmmap::TrxFile trx(4, 2); - set_streamline_lengths(trx.streamlines, {2, 2}); + trx::TrxFile trx(4, 2); + set_streamline_lengths(trx.streamlines.get(), {2, 2}); const fs::path tmp_dir = make_temp_test_dir("trx_export_tsf_err"); const fs::path output_path = tmp_dir / "signal.tsf"; - EXPECT_THROW(trxmmap::export_dpv_to_tsf(trx, "", output_path.string(), "1"), std::invalid_argument); - EXPECT_THROW(trxmmap::export_dpv_to_tsf(trx, "signal", output_path.string(), ""), std::invalid_argument); - EXPECT_THROW(trxmmap::export_dpv_to_tsf(trx, "signal", output_path.string(), "1", "int32"), std::invalid_argument); - EXPECT_THROW(trxmmap::export_dpv_to_tsf(trx, "missing", output_path.string(), "1"), std::runtime_error); + EXPECT_THROW(trx.export_dpv_to_tsf("", output_path.string(), "1"), std::invalid_argument); + EXPECT_THROW(trx.export_dpv_to_tsf("signal", output_path.string(), ""), std::invalid_argument); + EXPECT_THROW(trx.export_dpv_to_tsf("signal", output_path.string(), "1", "int32"), std::invalid_argument); + EXPECT_THROW(trx.export_dpv_to_tsf("missing", output_path.string(), "1"), std::runtime_error); - trxmmap::TrxFile empty; - EXPECT_THROW(trxmmap::export_dpv_to_tsf(empty, "signal", output_path.string(), "1"), std::runtime_error); + trx::TrxFile empty; + EXPECT_THROW(empty.export_dpv_to_tsf("signal", output_path.string(), "1"), std::runtime_error); } diff --git a/tests/test_trx_mmap.cpp b/tests/test_trx_mmap.cpp index daf3b6a..d4fb5ef 100644 --- a/tests/test_trx_mmap.cpp +++ b/tests/test_trx_mmap.cpp @@ -1,16 +1,19 @@ #include #include #include +#include #include #include #include #include -#include #include #include +#include using namespace Eigen; -using namespace trxmmap; +using ::json; +using trx::TrxFile; +using trx::TrxScalarType; namespace fs = std::filesystem; namespace { @@ -145,7 +148,7 @@ TestTrxFixture create_fixture() { Matrix positions(fixture.nb_vertices, 3); positions.setZero(); fs::path positions_path = trx_dir / "positions.3.float16"; - trxmmap::write_binary(positions_path.string(), positions); + trx::write_binary(positions_path.string(), positions); struct stat sb; if (stat(positions_path.string().c_str(), &sb) != 0) { throw std::runtime_error("Failed to stat positions file"); @@ -163,7 +166,7 @@ TestTrxFixture create_fixture() { offsets(fixture.nb_streamlines, 0) = static_cast(fixture.nb_vertices); fs::path offsets_path = trx_dir / "offsets.uint64"; - trxmmap::write_binary(offsets_path.string(), offsets); + trx::write_binary(offsets_path.string(), offsets); if (stat(offsets_path.string().c_str(), &sb) != 0) { throw std::runtime_error("Failed to stat offsets file"); } @@ -178,14 +181,14 @@ TestTrxFixture create_fixture() { if (zf == nullptr) { throw std::runtime_error("Failed to create trx zip file"); } - trxmmap::zip_from_folder(zf, trx_dir.string(), trx_dir.string(), ZIP_CM_STORE); + trx::zip_from_folder(zf, trx_dir.string(), trx_dir.string(), ZIP_CM_STORE); if (zip_close(zf) != 0) { throw std::runtime_error("Failed to close trx zip file"); } // Validate zip entry sizes int zip_err = 0; - zip_t *verify_zip = trxmmap::open_zip_for_read(fixture.path, zip_err); + zip_t *verify_zip = trx::open_zip_for_read(fixture.path, zip_err); if (verify_zip == nullptr) { throw std::runtime_error("Failed to reopen trx zip file"); } @@ -211,31 +214,93 @@ const TestTrxFixture &get_fixture() { } } // namespace +std::unique_ptr> create_small_trx() { + auto trx = std::make_unique>(8, 4); + + auto &positions = trx->streamlines->_data; + auto &offsets = trx->streamlines->_offsets; + auto &lengths = trx->streamlines->_lengths; + + lengths(0) = 2; + lengths(1) = 2; + lengths(2) = 3; + lengths(3) = 1; + + offsets(0, 0) = 0; + offsets(1, 0) = 2; + offsets(2, 0) = 4; + offsets(3, 0) = 7; + offsets(4, 0) = 8; + + positions(0, 0) = 0.0f; + positions(0, 1) = 0.0f; + positions(0, 2) = 0.0f; + positions(1, 0) = 1.0f; + positions(1, 1) = 1.0f; + positions(1, 2) = 1.0f; + + positions(2, 0) = 10.0f; + positions(2, 1) = 0.0f; + positions(2, 2) = 0.0f; + positions(3, 0) = 11.0f; + positions(3, 1) = 0.0f; + positions(3, 2) = 0.0f; + + positions(4, 0) = -5.0f; + positions(4, 1) = 5.0f; + positions(4, 2) = 0.0f; + positions(5, 0) = -5.0f; + positions(5, 1) = 6.0f; + positions(5, 2) = 0.0f; + positions(6, 0) = -5.0f; + positions(6, 1) = 7.0f; + positions(6, 2) = 0.0f; + + positions(7, 0) = 0.0f; + positions(7, 1) = 0.0f; + positions(7, 2) = 9.0f; + + std::vector dpv_values{100.0f, 101.0f, 102.0f, 103.0f, + 104.0f, 105.0f, 106.0f, 107.0f}; + trx->add_dpv_from_vector("dpv1", "float32", dpv_values); + + std::vector dps_values{1.0f, 2.0f, 3.0f, 4.0f}; + trx->add_dps_from_vector("dps1", "float32", dps_values); + + trx->add_group_from_indices("g1", std::vector{0, 2}); + trx->add_group_from_indices("g2", std::vector{1, 3}); + + std::vector dpg_values{0.5f, 1.5f}; + trx->add_dpg_from_vector("g1", "dpg1", "float32", dpg_values); + + return trx; +} + // TODO: Test null filenames. Maybe use MatrixBase instead of ArrayBase // TODO: try to update test case to use GTest parameterization // Mirrors trx/tests/test_memmap.py::test__generate_filename_from_data. TEST(TrxFileMemmap, __generate_filename_from_data) { - std::string filename = "mean_fa.bit"; + std::string filename = "mean_fa.uint8"; std::string output_fn; Matrix arr1; std::string exp_1 = "mean_fa.4.int16"; - output_fn = _generate_filename_from_data(arr1, filename); + output_fn = trx::_generate_filename_from_data(arr1, filename); EXPECT_STREQ(output_fn.c_str(), exp_1.c_str()); output_fn.clear(); Matrix arr2; std::string exp_2 = "mean_fa.4.float64"; - output_fn = _generate_filename_from_data(arr2, filename); + output_fn = trx::_generate_filename_from_data(arr2, filename); EXPECT_STREQ(output_fn.c_str(), exp_2.c_str()); output_fn.clear(); Matrix arr3; std::string exp_3 = "mean_fa.float64"; - output_fn = _generate_filename_from_data(arr3, filename); + output_fn = trx::_generate_filename_from_data(arr3, filename); EXPECT_STREQ(output_fn.c_str(), exp_3.c_str()); output_fn.clear(); } @@ -252,7 +317,7 @@ TEST(TrxFileMemmap, detect_positions_dtype_normalizes_slashes) { ASSERT_TRUE(out.is_open()); out.close(); - EXPECT_EQ(trxmmap::detect_positions_dtype(root_dir.string()), "float64"); + EXPECT_EQ(trx::detect_positions_dtype(root_dir.string()), "float64"); } TEST(TrxFileMemmap, detect_positions_scalar_type_directory) { @@ -271,18 +336,14 @@ TEST(TrxFileMemmap, detect_positions_scalar_type_directory) { const fs::path float32_dir = make_dir_with_positions("float32"); const fs::path float64_dir = make_dir_with_positions("float64"); - EXPECT_EQ(trxmmap::detect_positions_scalar_type(float16_dir.string(), TrxScalarType::Float64), - TrxScalarType::Float16); - EXPECT_EQ(trxmmap::detect_positions_scalar_type(float32_dir.string(), TrxScalarType::Float16), - TrxScalarType::Float32); - EXPECT_EQ(trxmmap::detect_positions_scalar_type(float64_dir.string(), TrxScalarType::Float32), - TrxScalarType::Float64); + EXPECT_EQ(trx::detect_positions_scalar_type(float16_dir.string(), TrxScalarType::Float64), TrxScalarType::Float16); + EXPECT_EQ(trx::detect_positions_scalar_type(float32_dir.string(), TrxScalarType::Float16), TrxScalarType::Float32); + EXPECT_EQ(trx::detect_positions_scalar_type(float64_dir.string(), TrxScalarType::Float32), TrxScalarType::Float64); } TEST(TrxFileMemmap, detect_positions_scalar_type_fallback) { const fs::path empty_dir = make_temp_test_dir("trx_scalar_empty"); - EXPECT_EQ(trxmmap::detect_positions_scalar_type(empty_dir.string(), TrxScalarType::Float16), - TrxScalarType::Float16); + EXPECT_EQ(trx::detect_positions_scalar_type(empty_dir.string(), TrxScalarType::Float16), TrxScalarType::Float16); const fs::path invalid_dir = make_temp_test_dir("trx_scalar_invalid"); const fs::path invalid_positions = invalid_dir / "positions.3.txt"; @@ -290,13 +351,12 @@ TEST(TrxFileMemmap, detect_positions_scalar_type_fallback) { ASSERT_TRUE(out.is_open()); out.close(); - EXPECT_THROW(trxmmap::detect_positions_scalar_type(invalid_dir.string(), TrxScalarType::Float64), - std::invalid_argument); + EXPECT_THROW(trx::detect_positions_scalar_type(invalid_dir.string(), TrxScalarType::Float64), std::invalid_argument); } TEST(TrxFileMemmap, detect_positions_scalar_type_missing_path) { const fs::path missing = fs::path(make_temp_test_dir("trx_scalar_missing")) / "nope"; - EXPECT_THROW(trxmmap::detect_positions_scalar_type(missing.string(), TrxScalarType::Float32), std::runtime_error); + EXPECT_THROW(trx::detect_positions_scalar_type(missing.string(), TrxScalarType::Float32), std::runtime_error); } TEST(TrxFileMemmap, open_zip_for_read_generic_fallback) { @@ -328,7 +388,7 @@ TEST(TrxFileMemmap, open_zip_for_read_generic_fallback) { } errorp = 0; - zip_t *fallback = trxmmap::open_zip_for_read(alt_path, errorp); + zip_t *fallback = trx::open_zip_for_read(alt_path, errorp); ASSERT_NE(fallback, nullptr); zip_close(fallback); #else @@ -341,12 +401,12 @@ TEST(TrxFileMemmap, __split_ext_with_dimensionality) { std::tuple output; const std::string fn1 = "mean_fa.float64"; std::tuple exp1("mean_fa", 1, "float64"); - output = _split_ext_with_dimensionality(fn1); + output = trx::detail::_split_ext_with_dimensionality(fn1); EXPECT_TRUE(output == exp1); const std::string fn2 = "mean_fa.5.int32"; std::tuple exp2("mean_fa", 5, "int32"); - output = _split_ext_with_dimensionality(fn2); + output = trx::detail::_split_ext_with_dimensionality(fn2); // std::cout << std::get<0>(output) << " TEST " << std::get<1>(output) << " " << // std::get<2>(output) << std::endl; EXPECT_TRUE(output == exp2); @@ -355,7 +415,7 @@ TEST(TrxFileMemmap, __split_ext_with_dimensionality) { EXPECT_THROW( { try { - output = _split_ext_with_dimensionality(fn3); + output = trx::detail::_split_ext_with_dimensionality(fn3); } catch (const std::invalid_argument &e) { EXPECT_STREQ("Invalid filename", e.what()); throw; @@ -367,7 +427,7 @@ TEST(TrxFileMemmap, __split_ext_with_dimensionality) { EXPECT_THROW( { try { - output = _split_ext_with_dimensionality(fn4); + output = trx::detail::_split_ext_with_dimensionality(fn4); } catch (const std::invalid_argument &e) { EXPECT_STREQ("Invalid filename", e.what()); throw; @@ -379,7 +439,7 @@ TEST(TrxFileMemmap, __split_ext_with_dimensionality) { EXPECT_THROW( { try { - output = _split_ext_with_dimensionality(fn5); + output = trx::detail::_split_ext_with_dimensionality(fn5); } catch (const std::invalid_argument &e) { EXPECT_STREQ("Unsupported file extension", e.what()); throw; @@ -391,131 +451,126 @@ TEST(TrxFileMemmap, __split_ext_with_dimensionality) { // Mirrors trx/tests/test_memmap.py::test__compute_lengths. TEST(TrxFileMemmap, __compute_lengths) { Matrix offsets{uint64_t{0}, uint64_t{1}, uint64_t{2}, uint64_t{3}, uint64_t{4}}; - Matrix lengths(trxmmap::_compute_lengths(offsets, 4)); + Matrix lengths(trx::detail::_compute_lengths(offsets, 4)); Matrix result{uint32_t{1}, uint32_t{1}, uint32_t{1}, uint32_t{1}}; EXPECT_EQ(lengths, result); Matrix offsets2{uint64_t{0}, uint64_t{1}, uint64_t{1}, uint64_t{3}, uint64_t{4}}; - Matrix lengths2(trxmmap::_compute_lengths(offsets2, 4)); + Matrix lengths2(trx::detail::_compute_lengths(offsets2, 4)); Matrix result2{uint32_t{1}, uint32_t{0}, uint32_t{2}, uint32_t{1}}; EXPECT_EQ(lengths2, result2); Matrix offsets3{uint64_t{0}, uint64_t{1}, uint64_t{2}, uint64_t{4}}; - Matrix lengths3(trxmmap::_compute_lengths(offsets3, 4)); + Matrix lengths3(trx::detail::_compute_lengths(offsets3, 4)); Matrix result3{uint32_t{1}, uint32_t{1}, uint32_t{2}}; EXPECT_EQ(lengths3, result3); Matrix offsets4; offsets4 << uint64_t{0}, uint64_t{2}; - Matrix lengths4(trxmmap::_compute_lengths(offsets4, 2)); + Matrix lengths4(trx::detail::_compute_lengths(offsets4, 2)); Matrix result4(uint32_t{2}); EXPECT_EQ(lengths4, result4); Matrix offsets5; - Matrix lengths5(trxmmap::_compute_lengths(offsets5, 2)); + Matrix lengths5(trx::detail::_compute_lengths(offsets5, 2)); EXPECT_EQ(lengths5.size(), 0); Matrix offsets6{int16_t(0), int16_t(1), int16_t(2), int16_t(3), int16_t(4)}; - Matrix lengths6(trxmmap::_compute_lengths(offsets6, 4)); + Matrix lengths6(trx::detail::_compute_lengths(offsets6, 4)); Matrix result6{uint32_t{1}, uint32_t{1}, uint32_t{1}, uint32_t{1}}; EXPECT_EQ(lengths6, result6); Matrix offsets7{int32_t(0), int32_t(1), int32_t(1), int32_t(3), int32_t(4)}; - Matrix lengths7(trxmmap::_compute_lengths(offsets7, 4)); + Matrix lengths7(trx::detail::_compute_lengths(offsets7, 4)); Matrix result7{uint32_t{1}, uint32_t{0}, uint32_t{2}, uint32_t{1}}; EXPECT_EQ(lengths7, result7); } // Mirrors trx/tests/test_memmap.py::test__is_dtype_valid. TEST(TrxFileMemmap, __is_dtype_valid) { - std::string ext = "bit"; - EXPECT_TRUE(_is_dtype_valid(ext)); - - std::string ext2 = "int16"; - EXPECT_TRUE(_is_dtype_valid(ext2)); + std::string ext1 = "int16"; + EXPECT_TRUE(trx::detail::_is_dtype_valid(ext1)); - std::string ext3 = "float32"; - EXPECT_TRUE(_is_dtype_valid(ext3)); + std::string ext2 = "float32"; + EXPECT_TRUE(trx::detail::_is_dtype_valid(ext2)); - std::string ext4 = "uint8"; - EXPECT_TRUE(_is_dtype_valid(ext4)); + std::string ext3 = "uint8"; + EXPECT_TRUE(trx::detail::_is_dtype_valid(ext3)); - std::string ext5 = "ushort"; - EXPECT_TRUE(_is_dtype_valid(ext5)); + std::string ext4 = "ushort"; + EXPECT_TRUE(trx::detail::_is_dtype_valid(ext4)); - std::string ext6 = "txt"; - EXPECT_FALSE(_is_dtype_valid(ext6)); + std::string ext5 = "txt"; + EXPECT_FALSE(trx::detail::_is_dtype_valid(ext5)); } // asserts C++ dtype alias behavior. TEST(TrxFileMemmap, __sizeof_dtype_ushort_alias) { - EXPECT_EQ(trxmmap::_sizeof_dtype("ushort"), sizeof(uint16_t)); - EXPECT_EQ(trxmmap::_sizeof_dtype("ushort"), trxmmap::_sizeof_dtype("uint16")); + EXPECT_EQ(trx::detail::_sizeof_dtype("ushort"), sizeof(uint16_t)); + EXPECT_EQ(trx::detail::_sizeof_dtype("ushort"), trx::detail::_sizeof_dtype("uint16")); } // asserts dtype size mapping and default. TEST(TrxFileMemmap, __sizeof_dtype_values) { - EXPECT_EQ(trxmmap::_sizeof_dtype("bit"), 1); - EXPECT_EQ(trxmmap::_sizeof_dtype("uint8"), sizeof(uint8_t)); - EXPECT_EQ(trxmmap::_sizeof_dtype("uint16"), sizeof(uint16_t)); - EXPECT_EQ(trxmmap::_sizeof_dtype("uint32"), sizeof(uint32_t)); - EXPECT_EQ(trxmmap::_sizeof_dtype("uint64"), sizeof(uint64_t)); - EXPECT_EQ(trxmmap::_sizeof_dtype("int8"), sizeof(int8_t)); - EXPECT_EQ(trxmmap::_sizeof_dtype("int16"), sizeof(int16_t)); - EXPECT_EQ(trxmmap::_sizeof_dtype("int32"), sizeof(int32_t)); - EXPECT_EQ(trxmmap::_sizeof_dtype("int64"), sizeof(int64_t)); - EXPECT_EQ(trxmmap::_sizeof_dtype("float32"), sizeof(float)); - EXPECT_EQ(trxmmap::_sizeof_dtype("float64"), sizeof(double)); - EXPECT_EQ(trxmmap::_sizeof_dtype("unknown"), sizeof(uint16_t)); + EXPECT_EQ(trx::detail::_sizeof_dtype("uint8"), sizeof(uint8_t)); + EXPECT_EQ(trx::detail::_sizeof_dtype("uint16"), sizeof(uint16_t)); + EXPECT_EQ(trx::detail::_sizeof_dtype("uint32"), sizeof(uint32_t)); + EXPECT_EQ(trx::detail::_sizeof_dtype("uint64"), sizeof(uint64_t)); + EXPECT_EQ(trx::detail::_sizeof_dtype("int8"), sizeof(int8_t)); + EXPECT_EQ(trx::detail::_sizeof_dtype("int16"), sizeof(int16_t)); + EXPECT_EQ(trx::detail::_sizeof_dtype("int32"), sizeof(int32_t)); + EXPECT_EQ(trx::detail::_sizeof_dtype("int64"), sizeof(int64_t)); + EXPECT_EQ(trx::detail::_sizeof_dtype("float32"), sizeof(float)); + EXPECT_EQ(trx::detail::_sizeof_dtype("float64"), sizeof(double)); + EXPECT_EQ(trx::detail::_sizeof_dtype("unknown"), sizeof(uint16_t)); } // asserts dtype code mapping. TEST(TrxFileMemmap, __get_dtype_codes) { - EXPECT_EQ(trxmmap::_get_dtype("b"), "bit"); - EXPECT_EQ(trxmmap::_get_dtype("h"), "uint8"); - EXPECT_EQ(trxmmap::_get_dtype("t"), "uint16"); - EXPECT_EQ(trxmmap::_get_dtype("j"), "uint32"); - EXPECT_EQ(trxmmap::_get_dtype("m"), "uint64"); - EXPECT_EQ(trxmmap::_get_dtype("y"), "uint64"); - EXPECT_EQ(trxmmap::_get_dtype("a"), "int8"); - EXPECT_EQ(trxmmap::_get_dtype("s"), "int16"); - EXPECT_EQ(trxmmap::_get_dtype("i"), "int32"); - EXPECT_EQ(trxmmap::_get_dtype("l"), "int64"); - EXPECT_EQ(trxmmap::_get_dtype("x"), "int64"); - EXPECT_EQ(trxmmap::_get_dtype("f"), "float32"); - EXPECT_EQ(trxmmap::_get_dtype("d"), "float64"); - EXPECT_EQ(trxmmap::_get_dtype("z"), "float16"); - EXPECT_EQ(trxmmap::_get_dtype("foo"), "float16"); + EXPECT_EQ(trx::detail::_get_dtype("h"), "uint8"); + EXPECT_EQ(trx::detail::_get_dtype("t"), "uint16"); + EXPECT_EQ(trx::detail::_get_dtype("j"), "uint32"); + EXPECT_EQ(trx::detail::_get_dtype("m"), "uint64"); + EXPECT_EQ(trx::detail::_get_dtype("y"), "uint64"); + EXPECT_EQ(trx::detail::_get_dtype("a"), "int8"); + EXPECT_EQ(trx::detail::_get_dtype("s"), "int16"); + EXPECT_EQ(trx::detail::_get_dtype("i"), "int32"); + EXPECT_EQ(trx::detail::_get_dtype("l"), "int64"); + EXPECT_EQ(trx::detail::_get_dtype("x"), "int64"); + EXPECT_EQ(trx::detail::_get_dtype("f"), "float32"); + EXPECT_EQ(trx::detail::_get_dtype("d"), "float64"); + EXPECT_EQ(trx::detail::_get_dtype("z"), "float16"); + EXPECT_EQ(trx::detail::_get_dtype("foo"), "float16"); } // Mirrors trx/tests/test_memmap.py::test__dichotomic_search. TEST(TrxFileMemmap, __dichotomic_search) { Matrix m{0, 1, 2, 3, 4}; - int result = trxmmap::_dichotomic_search(m); + int result = trx::detail::_dichotomic_search(m); EXPECT_EQ(result, 4); Matrix m2{0, 1, 0, 3, 4}; - int result2 = trxmmap::_dichotomic_search(m2); + int result2 = trx::detail::_dichotomic_search(m2); EXPECT_EQ(result2, 1); Matrix m3{0, 1, 2, 0, 4}; - int result3 = trxmmap::_dichotomic_search(m3); + int result3 = trx::detail::_dichotomic_search(m3); EXPECT_EQ(result3, 2); Matrix m4{0, 1, 2, 3, 4}; - int result4 = trxmmap::_dichotomic_search(m4, 1, 2); + int result4 = trx::detail::_dichotomic_search(m4, 1, 2); EXPECT_EQ(result4, 2); Matrix m5{0, 1, 2, 3, 4}; - int result5 = trxmmap::_dichotomic_search(m5, 3, 3); + int result5 = trx::detail::_dichotomic_search(m5, 3, 3); EXPECT_EQ(result5, 3); Matrix m6{0, 0, 0, 0, 0}; - int result6 = trxmmap::_dichotomic_search(m6, 3, 3); + int result6 = trx::detail::_dichotomic_search(m6, 3, 3); EXPECT_EQ(result6, -1); } @@ -528,7 +583,7 @@ TEST(TrxFileMemmap, __create_memmap) { std::tuple shape = std::make_tuple(3, 4); // Test 1: create file and allocate space assert that correct data is filled - mio::shared_mmap_sink empty_mmap = trxmmap::_create_memmap(path.string(), shape); + mio::shared_mmap_sink empty_mmap = trx::_create_memmap(path.string(), shape); Map> expected_m(reinterpret_cast(empty_mmap.data())); Matrix zero_filled{ {half(0), half(0), half(0), half(0)}, {half(0), half(0), half(0), half(0)}, {half(0), half(0), half(0), half(0)}}; @@ -540,7 +595,7 @@ TEST(TrxFileMemmap, __create_memmap) { expected_m(i) = half(i); } - mio::shared_mmap_sink filled_mmap = trxmmap::_create_memmap(path.string(), shape); + mio::shared_mmap_sink filled_mmap = trx::_create_memmap(path.string(), shape); Map> real_m(reinterpret_cast(filled_mmap.data()), std::get<0>(shape), std::get<1>(shape)); EXPECT_EQ(expected_m, real_m); @@ -555,7 +610,7 @@ TEST(TrxFileMemmap, __create_memmap_empty) { fs::path path = dir / "empty.float32"; std::tuple shape = std::make_tuple(0, 1); - mio::shared_mmap_sink empty_mmap = trxmmap::_create_memmap(path.string(), shape); + mio::shared_mmap_sink empty_mmap = trx::_create_memmap(path.string(), shape); struct stat sb; ASSERT_EQ(stat(path.string().c_str(), &sb), 0); @@ -570,8 +625,8 @@ TEST(TrxFileMemmap, __create_memmap_empty) { TEST(TrxFileMemmap, load_header) { const auto &fixture = get_fixture(); int errorp = 0; - zip_t *zf = trxmmap::open_zip_for_read(fixture.path, errorp); - json root = trxmmap::load_header(zf); + zip_t *zf = trx::open_zip_for_read(fixture.path, errorp); + json root = trx::load_header(zf); EXPECT_EQ(root, fixture.expected_header); EXPECT_EQ(root.dump(), fixture.expected_header.dump()); @@ -598,9 +653,9 @@ TEST(TrxFileMemmap, load_header) { // Mirrors trx/tests/test_memmap.py::test_load (small.trx via zip path). TEST(TrxFileMemmap, load_zip) { const auto &fixture = get_fixture(); - trxmmap::TrxFile *trx = trxmmap::load_from_zip(fixture.path); + trx::TrxReader reader(fixture.path); + auto *trx = reader.get(); EXPECT_GT(trx->streamlines->_data.size(), 0); - delete trx; } // Mirrors trx/tests/test_memmap.py::test_load for small.trx and small_compressed.trx. @@ -613,23 +668,23 @@ TEST(TrxFileMemmap, load_zip_test_data) { const fs::path small_trx = memmap_dir / "small.trx"; ASSERT_TRUE(fs::exists(small_trx)); - trxmmap::TrxFile *trx_small = trxmmap::load_from_zip(small_trx.string()); + trx::TrxReader reader_small(small_trx.string()); + auto *trx_small = reader_small.get(); EXPECT_GT(trx_small->streamlines->_data.size(), 0); - delete trx_small; const fs::path small_compressed = memmap_dir / "small_compressed.trx"; ASSERT_TRUE(fs::exists(small_compressed)); - trxmmap::TrxFile *trx_compressed = trxmmap::load_from_zip(small_compressed.string()); + trx::TrxReader reader_compressed(small_compressed.string()); + auto *trx_compressed = reader_compressed.get(); EXPECT_GT(trx_compressed->streamlines->_data.size(), 0); - delete trx_compressed; } // Mirrors trx/tests/test_memmap.py::test_load (small_fldr.trx via directory path). TEST(TrxFileMemmap, load_directory) { const auto &fixture = get_fixture(); - trxmmap::TrxFile *trx = trxmmap::load_from_directory(fixture.dir_path); + trx::TrxReader reader(fixture.dir_path); + auto *trx = reader.get(); EXPECT_GT(trx->streamlines->_data.size(), 0); - delete trx; } // Mirrors trx/tests/test_memmap.py::test_load_directory. @@ -642,9 +697,9 @@ TEST(TrxFileMemmap, load_directory_test_data) { const fs::path small_dir = memmap_dir / "small_fldr.trx"; ASSERT_TRUE(fs::exists(small_dir)); - trxmmap::TrxFile *trx = trxmmap::load_from_directory(small_dir.string()); + trx::TrxReader reader(small_dir.string()); + auto *trx = reader.get(); EXPECT_GT(trx->streamlines->_data.size(), 0); - delete trx; } // Mirrors trx/tests/test_memmap.py::test_load with missing path raising. @@ -656,12 +711,12 @@ TEST(TrxFileMemmap, load_missing_trx_throws) { const auto memmap_dir = resolve_memmap_test_data_dir(root); const fs::path missing_trx = memmap_dir / "dontexist.trx"; - EXPECT_THROW(trxmmap::load_from_zip(missing_trx.string()), std::runtime_error); + EXPECT_THROW(trx::TrxReader(missing_trx.string()), std::runtime_error); } // validates C++ TrxFile initialization. TEST(TrxFileMemmap, TrxFile) { - trxmmap::TrxFile *trx = new TrxFile(); + auto trx = std::make_unique>(); // expected header json::object expected_obj; @@ -680,15 +735,15 @@ TEST(TrxFileMemmap, TrxFile) { const auto &fixture = get_fixture(); int errorp = 0; - zip_t *zf = trxmmap::open_zip_for_read(fixture.path, errorp); - json root = trxmmap::load_header(zf); - TrxFile *root_init = new TrxFile(); + zip_t *zf = trx::open_zip_for_read(fixture.path, errorp); + json root = trx::load_header(zf); + auto root_init = std::make_unique>(); root_init->header = root; zip_close(zf); // TODO: test for now.. - trxmmap::TrxFile *trx_init = new TrxFile(fixture.nb_vertices, fixture.nb_streamlines, root_init); + auto trx_init = std::make_unique>(fixture.nb_vertices, fixture.nb_streamlines, root_init.get()); json::object init_as_obj; init_as_obj["DIMENSIONS"] = json::array{117, 151, 115}; init_as_obj["NB_STREAMLINES"] = fixture.nb_streamlines; @@ -705,47 +760,180 @@ TEST(TrxFileMemmap, TrxFile) { EXPECT_EQ(trx_init->streamlines->_data.size(), fixture.nb_vertices * 3); EXPECT_EQ(trx_init->streamlines->_offsets.size(), fixture.nb_streamlines + 1); EXPECT_EQ(trx_init->streamlines->_lengths.size(), fixture.nb_streamlines); - delete trx; - delete root_init; - delete trx_init; } // validates C++ deepcopy. TEST(TrxFileMemmap, deepcopy) { const auto &fixture = get_fixture(); - trxmmap::TrxFile *trx = trxmmap::load_from_zip(fixture.path); - trxmmap::TrxFile *copy = trx->deepcopy(); + trx::TrxReader reader(fixture.path); + auto *trx = reader.get(); + auto copy = trx->deepcopy(); EXPECT_EQ(trx->header, copy->header); EXPECT_EQ(trx->streamlines->_data, trx->streamlines->_data); EXPECT_EQ(trx->streamlines->_offsets, trx->streamlines->_offsets); EXPECT_EQ(trx->streamlines->_lengths, trx->streamlines->_lengths); - delete trx; - delete copy; } // Mirrors trx/tests/test_memmap.py::test_resize. TEST(TrxFileMemmap, resize) { const auto &fixture = get_fixture(); - trxmmap::TrxFile *trx = trxmmap::load_from_zip(fixture.path); + trx::TrxReader reader(fixture.path); + auto *trx = reader.get(); trx->resize(); trx->resize(10); - delete trx; } // exercises save paths in C++. TEST(TrxFileMemmap, save) { const auto &fixture = get_fixture(); - trxmmap::TrxFile *trx = trxmmap::load_from_zip(fixture.path); - trxmmap::save(*trx, (std::string) "testsave"); - trxmmap::save(*trx, (std::string) "testsave.trx"); + trx::TrxReader reader(fixture.path); + auto *trx = reader.get(); + trx->save("testsave"); + trx->save("testsave.trx"); - delete trx; - - // trxmmap::TrxFile *saved = trxmmap::load_from_zip("testsave.trx"); + // trx::TrxFile *saved = trx::load_from_zip("testsave.trx"); // EXPECT_EQ(saved->data_per_vertex["color_x.float16"]->_data, // trx->data_per_vertex["color_x.float16"]->_data); } +TEST(TrxFileMemmap, build_streamline_aabbs) { + auto trx = create_small_trx(); + auto aabbs = trx->build_streamline_aabbs(); + ASSERT_EQ(aabbs.size(), 4u); + + EXPECT_FLOAT_EQ(static_cast(aabbs[0][0]), 0.0f); + EXPECT_FLOAT_EQ(static_cast(aabbs[0][3]), 1.0f); + EXPECT_FLOAT_EQ(static_cast(aabbs[1][0]), 10.0f); + EXPECT_FLOAT_EQ(static_cast(aabbs[1][3]), 11.0f); + EXPECT_FLOAT_EQ(static_cast(aabbs[2][1]), 5.0f); + EXPECT_FLOAT_EQ(static_cast(aabbs[2][4]), 7.0f); + EXPECT_FLOAT_EQ(static_cast(aabbs[3][2]), 9.0f); + EXPECT_FLOAT_EQ(static_cast(aabbs[3][5]), 9.0f); + + trx->close(); +} + +TEST(TrxFileMemmap, query_aabb_filters_streamlines) { + auto trx = create_small_trx(); + std::array min_corner{9.0f, -1.0f, -1.0f}; + std::array max_corner{12.0f, 1.0f, 1.0f}; + + auto subset = trx->query_aabb(min_corner, max_corner); + EXPECT_EQ(subset->num_streamlines(), 1u); + EXPECT_EQ(subset->num_vertices(), 2u); + + auto dps_it = subset->data_per_streamline.find("dps1"); + ASSERT_NE(dps_it, subset->data_per_streamline.end()); + EXPECT_FLOAT_EQ(dps_it->second->_matrix(0, 0), 2.0f); + + auto dpv_it = subset->data_per_vertex.find("dpv1"); + ASSERT_NE(dpv_it, subset->data_per_vertex.end()); + EXPECT_FLOAT_EQ(dpv_it->second->_data(0, 0), 102.0f); + EXPECT_FLOAT_EQ(dpv_it->second->_data(1, 0), 103.0f); + + subset->close(); + trx->close(); +} + +TEST(TrxFileMemmap, query_aabb_rejects_bad_aabb_size) { + auto trx = create_small_trx(); + std::array min_corner{-1.0f, -1.0f, -1.0f}; + std::array max_corner{1.0f, 1.0f, 1.0f}; + + std::vector> bad_aabbs(1); + EXPECT_THROW(trx->query_aabb(min_corner, max_corner, &bad_aabbs), std::invalid_argument); + + trx->close(); +} + +TEST(TrxFileMemmap, subset_streamlines_basic) { + auto trx = create_small_trx(); + std::vector ids{2, 0, 2}; + + auto subset = trx->subset_streamlines(ids); + EXPECT_EQ(subset->num_streamlines(), 2u); + EXPECT_EQ(subset->num_vertices(), 5u); + + auto dps_it = subset->data_per_streamline.find("dps1"); + ASSERT_NE(dps_it, subset->data_per_streamline.end()); + EXPECT_FLOAT_EQ(dps_it->second->_matrix(0, 0), 3.0f); + EXPECT_FLOAT_EQ(dps_it->second->_matrix(1, 0), 1.0f); + + auto group_it = subset->groups.find("g1"); + ASSERT_NE(group_it, subset->groups.end()); + EXPECT_EQ(group_it->second->_matrix(0, 0), 1u); + EXPECT_EQ(group_it->second->_matrix(1, 0), 0u); + + auto dpg_it = subset->get_dpg("g1", "dpg1"); + ASSERT_NE(dpg_it, nullptr); + ASSERT_EQ(dpg_it->_matrix.rows(), 1); + ASSERT_EQ(dpg_it->_matrix.cols(), 2); + EXPECT_FLOAT_EQ(dpg_it->_matrix(0, 0), 0.5f); + EXPECT_FLOAT_EQ(dpg_it->_matrix(0, 1), 1.5f); + + subset->close(); + trx->close(); +} + +TEST(TrxFileMemmap, subset_streamlines_empty) { + auto trx = create_small_trx(); + std::vector ids; + auto subset = trx->subset_streamlines(ids); + EXPECT_EQ(subset->num_streamlines(), 0u); + EXPECT_EQ(subset->num_vertices(), 0u); + subset->close(); + trx->close(); +} + +TEST(TrxFileMemmap, subset_streamlines_out_of_range) { + auto trx = create_small_trx(); + std::vector ids{99}; + EXPECT_THROW(trx->subset_streamlines(ids), std::invalid_argument); + trx->close(); +} + +TEST(TrxFileMemmap, dpg_api_vector_and_matrix) { + auto trx = create_small_trx(); + + std::vector values{1.0f, 2.0f, 3.0f, 4.0f}; + trx->add_dpg_from_vector("g2", "dpg_vec", "float32", values, 2, 2); + + Matrix mat; + mat << 5.0f, 6.0f, 7.0f; + trx->add_dpg_from_matrix("g2", "dpg_mat", "float32", mat); + + auto fields = trx->list_dpg_fields("g2"); + EXPECT_EQ(fields.size(), 2u); + + auto dpg_vec = trx->get_dpg("g2", "dpg_vec"); + ASSERT_NE(dpg_vec, nullptr); + EXPECT_FLOAT_EQ(dpg_vec->_matrix(1, 1), 4.0f); + + auto dpg_mat = trx->get_dpg("g2", "dpg_mat"); + ASSERT_NE(dpg_mat, nullptr); + EXPECT_FLOAT_EQ(dpg_mat->_matrix(0, 2), 7.0f); + + trx->remove_dpg("g2", "dpg_vec"); + EXPECT_EQ(trx->get_dpg("g2", "dpg_vec"), nullptr); + + trx->remove_dpg_group("g2"); + EXPECT_TRUE(trx->list_dpg_fields("g2").empty()); + + trx->close(); +} + +TEST(TrxFileMemmap, dpg_api_invalid_inputs) { + auto trx = create_small_trx(); + + std::vector values{1.0f, 2.0f, 3.0f}; + EXPECT_THROW(trx->add_dpg_from_vector("", "dpg", "float32", values), std::invalid_argument); + EXPECT_THROW(trx->add_dpg_from_vector("g1", "", "float32", values), std::invalid_argument); + EXPECT_THROW(trx->add_dpg_from_vector("g1", "dpg", "int8", values), std::invalid_argument); + EXPECT_THROW(trx->add_dpg_from_vector("g1", "dpg", "float32", values, 2, 2), std::invalid_argument); + + trx->close(); +} + int main(int argc, char **argv) { // check_syntax off ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); diff --git a/tests/test_trx_streamlines_ops.cpp b/tests/test_trx_streamlines_ops.cpp index fc45a8a..ed57a49 100644 --- a/tests/test_trx_streamlines_ops.cpp +++ b/tests/test_trx_streamlines_ops.cpp @@ -60,8 +60,7 @@ TEST_P(StreamlinesOpsIntersectionTest, Intersection) { const std::vector streamlines_ori = {s1, s2}; const std::vector streamlines_new = make_streamlines_new(s2, noise); - auto result = - trxmmap::perform_streamlines_operation(trxmmap::intersection, {streamlines_new, streamlines_ori}, precision); + auto result = trx::perform_streamlines_operation(trx::intersection, {streamlines_new, streamlines_ori}, precision); const auto &indices = result.second; EXPECT_EQ(indices, expected); } @@ -90,8 +89,7 @@ TEST_P(StreamlinesOpsUnionTest, Union) { const std::vector streamlines_ori = {s1, s2}; const std::vector streamlines_new = make_streamlines_new(s2, noise); - auto result = - trxmmap::perform_streamlines_operation(trxmmap::union_maps, {streamlines_new, streamlines_ori}, precision); + auto result = trx::perform_streamlines_operation(trx::union_maps, {streamlines_new, streamlines_ori}, precision); EXPECT_EQ(result.first.size(), expected); } @@ -119,8 +117,7 @@ TEST_P(StreamlinesOpsDifferenceTest, Difference) { const std::vector streamlines_ori = {s1, s2}; const std::vector streamlines_new = make_streamlines_new(s2, noise); - auto result = - trxmmap::perform_streamlines_operation(trxmmap::difference, {streamlines_new, streamlines_ori}, precision); + auto result = trx::perform_streamlines_operation(trx::difference, {streamlines_new, streamlines_ori}, precision); EXPECT_EQ(result.first.size(), expected); } diff --git a/tests/test_trx_trxfile.cpp b/tests/test_trx_trxfile.cpp index 37e2cb2..9bb96d9 100644 --- a/tests/test_trx_trxfile.cpp +++ b/tests/test_trx_trxfile.cpp @@ -1,4 +1,5 @@ #include +#include #include #define private public @@ -13,12 +14,12 @@ #include using namespace Eigen; -using namespace trxmmap; +using namespace trx; namespace fs = std::filesystem; namespace { -template trxmmap::TrxFile
*load_trx_dir(const fs::path &path) { - return trxmmap::load_from_directory
(path.string()); +template trx::TrxReader
load_trx_dir(const fs::path &path) { + return trx::TrxReader
(path.string()); } fs::path make_temp_test_dir(const std::string &prefix) { @@ -69,59 +70,59 @@ fs::path create_float_trx_dir() { Matrix positions(4, 3); positions << 0.0f, 0.1f, 0.2f, 1.0f, 1.1f, 1.2f, 2.0f, 2.1f, 2.2f, 3.0f, 3.1f, 3.2f; - trxmmap::write_binary((root / "positions.3.float32").string(), positions); + trx::write_binary((root / "positions.3.float32").string(), positions); Matrix offsets(3, 1); offsets << 0, 2, 4; - trxmmap::write_binary((root / "offsets.uint32").string(), offsets); + trx::write_binary((root / "offsets.uint32").string(), offsets); fs::path dpv_dir = root / "dpv"; std::error_code ec; fs::create_directories(dpv_dir, ec); Matrix dpv(4, 1); dpv << 0.0f, 0.3f, 0.6f, 0.9f; - trxmmap::write_binary((dpv_dir / "color.float32").string(), dpv); + trx::write_binary((dpv_dir / "color.float32").string(), dpv); fs::path dps_dir = root / "dps"; fs::create_directories(dps_dir, ec); Matrix dps(2, 1); dps << 0.25f, 0.75f; - trxmmap::write_binary((dps_dir / "weight.float32").string(), dps); + trx::write_binary((dps_dir / "weight.float32").string(), dps); fs::path groups_dir = root / "groups"; fs::create_directories(groups_dir, ec); Matrix group_vals(2, 1); group_vals << 0, 1; - trxmmap::write_binary((groups_dir / "GroupA.uint32").string(), group_vals); + trx::write_binary((groups_dir / "GroupA.uint32").string(), group_vals); fs::path dpg_dir = root / "dpg" / "GroupA"; fs::create_directories(dpg_dir, ec); Matrix dpg(1, 1); dpg << 1.0f; - trxmmap::write_binary((dpg_dir / "mean.float32").string(), dpg); + trx::write_binary((dpg_dir / "mean.float32").string(), dpg); return root; } } // namespace TEST(TrxFileTpp, DeepcopyEmpty) { - trxmmap::TrxFile empty; - trxmmap::TrxFile *copy = empty.deepcopy(); + trx::TrxFile empty; + auto copy = empty.deepcopy(); EXPECT_EQ(copy->header, empty.header); if (copy->streamlines != nullptr) { EXPECT_EQ(copy->streamlines->_data.size(), 0); EXPECT_EQ(copy->streamlines->_offsets.size(), 0); EXPECT_EQ(copy->streamlines->_lengths.size(), 0); } - delete copy; } // Deepcopy preserves streamlines, dpv/dps, groups, and dpg shapes/content. TEST(TrxFileTpp, DeepcopyWithGroupsDpgDpvDps) { const fs::path data_dir = create_float_trx_dir(); - trxmmap::TrxFile *trx = load_trx_dir(data_dir); - trxmmap::TrxFile *copy = trx->deepcopy(); + auto reader = load_trx_dir(data_dir); + auto *trx = reader.get(); + auto copy = trx->deepcopy(); EXPECT_EQ(copy->header, trx->header); EXPECT_EQ(copy->streamlines->_data, trx->streamlines->_data); @@ -168,8 +169,6 @@ TEST(TrxFileTpp, DeepcopyWithGroupsDpgDpvDps) { trx->close(); copy->close(); - delete trx; - delete copy; std::error_code ec; fs::remove_all(data_dir, ec); @@ -178,10 +177,11 @@ TEST(TrxFileTpp, DeepcopyWithGroupsDpgDpvDps) { // _copy_fixed_arrays_from copies streamlines + dpv/dps into a preallocated target. TEST(TrxFileTpp, CopyFixedArraysFrom) { const fs::path data_dir = create_float_trx_dir(); - trxmmap::TrxFile *src = load_trx_dir(data_dir); + auto src_reader = load_trx_dir(data_dir); + auto *src = src_reader.get(); const int nb_vertices = src->header["NB_VERTICES"].int_value(); const int nb_streamlines = src->header["NB_STREAMLINES"].int_value(); - trxmmap::TrxFile *dst = new trxmmap::TrxFile(nb_vertices, nb_streamlines, src); + auto dst = std::make_unique>(nb_vertices, nb_streamlines, src); dst->_copy_fixed_arrays_from(src, 0, 0, nb_streamlines); @@ -205,22 +205,102 @@ TEST(TrxFileTpp, CopyFixedArraysFrom) { src->close(); dst->close(); - delete src; - delete dst; std::error_code ec; fs::remove_all(data_dir, ec); } +TEST(TrxFileTpp, AddGroupAndDpvFromVector) { + const int nb_vertices = 5; + const int nb_streamlines = 3; + trx::TrxFile trx(nb_vertices, nb_streamlines); + + trx.streamlines->_offsets(0, 0) = 0; + trx.streamlines->_offsets(1, 0) = 2; + trx.streamlines->_offsets(2, 0) = 4; + trx.streamlines->_offsets(3, 0) = 5; + trx.streamlines->_lengths(0) = 2; + trx.streamlines->_lengths(1) = 2; + trx.streamlines->_lengths(2) = 1; + + const std::vector group_indices = {0, 2}; + trx.add_group_from_indices("GroupA", group_indices); + ASSERT_EQ(trx.groups.size(), 1u); + auto group_it = trx.groups.find("GroupA"); + ASSERT_NE(group_it, trx.groups.end()); + EXPECT_EQ(group_it->second->_matrix.rows(), 2); + EXPECT_EQ(group_it->second->_matrix.cols(), 1); + EXPECT_EQ(group_it->second->_matrix(0, 0), 0u); + EXPECT_EQ(group_it->second->_matrix(1, 0), 2u); + + const std::vector dpv_values = {0.1f, 0.2f, 0.3f, 0.4f, 0.5f}; + trx.add_dpv_from_vector("quality", "float32", dpv_values); + ASSERT_EQ(trx.data_per_vertex.size(), 1u); + auto dpv_it = trx.data_per_vertex.find("quality"); + ASSERT_NE(dpv_it, trx.data_per_vertex.end()); + EXPECT_EQ(dpv_it->second->_data.rows(), nb_vertices); + EXPECT_EQ(dpv_it->second->_data.cols(), 1); + for (int i = 0; i < nb_vertices; ++i) { + EXPECT_FLOAT_EQ(dpv_it->second->_data(i, 0), dpv_values[static_cast(i)]); + } +} + +TEST(TrxFileTpp, TrxStreamFinalize) { + auto tmp_dir = make_temp_test_dir("trx_proto"); + const fs::path out_path = tmp_dir / "proto.trx"; + + TrxStream proto; + std::vector sl1 = {0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f}; + std::vector sl2 = {2.0f, 0.0f, 0.0f, 2.0f, 1.0f, 0.0f, 2.0f, 2.0f, 0.0f}; + proto.push_streamline(sl1); + proto.push_streamline(sl2); + proto.push_group_from_indices("GroupA", {0, 1}); + proto.push_dps_from_vector("weight", "float32", std::vector{0.5f, 1.5f}); + proto.push_dpv_from_vector("score", "float32", std::vector{1.0f, 2.0f, 3.0f, 4.0f, 5.0f}); + proto.finalize(out_path.string(), ZIP_CM_STORE); + + auto trx = load_any(out_path.string()); + EXPECT_EQ(trx.num_streamlines(), 2u); + EXPECT_EQ(trx.num_vertices(), 5u); + + auto grp_it = trx.groups.find("GroupA"); + ASSERT_NE(grp_it, trx.groups.end()); + auto grp_mat = grp_it->second.as_matrix(); + EXPECT_EQ(grp_mat.rows(), 2); + EXPECT_EQ(grp_mat.cols(), 1); + EXPECT_EQ(grp_mat(0, 0), 0u); + EXPECT_EQ(grp_mat(1, 0), 1u); + + auto dps_it = trx.data_per_streamline.find("weight"); + ASSERT_NE(dps_it, trx.data_per_streamline.end()); + auto dps_mat = dps_it->second.as_matrix(); + EXPECT_EQ(dps_mat.rows(), 2); + EXPECT_EQ(dps_mat.cols(), 1); + EXPECT_FLOAT_EQ(dps_mat(0, 0), 0.5f); + EXPECT_FLOAT_EQ(dps_mat(1, 0), 1.5f); + + auto dpv_it = trx.data_per_vertex.find("score"); + ASSERT_NE(dpv_it, trx.data_per_vertex.end()); + auto dpv_mat = dpv_it->second.as_matrix(); + EXPECT_EQ(dpv_mat.rows(), 5); + EXPECT_EQ(dpv_mat.cols(), 1); + EXPECT_FLOAT_EQ(dpv_mat(0, 0), 1.0f); + EXPECT_FLOAT_EQ(dpv_mat(4, 0), 5.0f); + + trx.close(); + std::error_code ec; + fs::remove_all(tmp_dir, ec); +} + // resize() with default arguments is a no-op when sizes already match. TEST(TrxFileTpp, ResizeNoChange) { const fs::path data_dir = create_float_trx_dir(); - trxmmap::TrxFile *trx = load_trx_dir(data_dir); + auto reader = load_trx_dir(data_dir); + auto *trx = reader.get(); json header_before = trx->header; trx->resize(); EXPECT_EQ(trx->header, header_before); trx->close(); - delete trx; std::error_code ec; fs::remove_all(data_dir, ec); @@ -230,7 +310,8 @@ TEST(TrxFileTpp, ResizeNoChange) { // dpv/dps/groups/dpg. TEST(TrxFileTpp, ResizeDeleteDpgCloses) { const fs::path data_dir = create_float_trx_dir(); - trxmmap::TrxFile *trx = load_trx_dir(data_dir); + auto reader = load_trx_dir(data_dir); + auto *trx = reader.get(); trx->resize(1, -1, true); EXPECT_EQ(trx->header["NB_STREAMLINES"].int_value(), 0); @@ -240,7 +321,7 @@ TEST(TrxFileTpp, ResizeDeleteDpgCloses) { EXPECT_EQ(trx->data_per_vertex.size(), 0u); EXPECT_EQ(trx->data_per_streamline.size(), 0u); - delete trx; + trx->close(); std::error_code ec; fs::remove_all(data_dir, ec); diff --git a/third_party/nibabel/LICENSE b/third_party/nibabel/LICENSE new file mode 100644 index 0000000..b53f13d --- /dev/null +++ b/third_party/nibabel/LICENSE @@ -0,0 +1,27 @@ +The MIT License + +Copyright (c) 2009-2019 Matthew Brett +Copyright (c) 2010-2013 Stephan Gerhard +Copyright (c) 2006-2014 Michael Hanke +Copyright (c) 2011 Christian Haselgrove +Copyright (c) 2010-2011 Jarrod Millman +Copyright (c) 2011-2019 Yaroslav Halchenko +Copyright (c) 2015-2019 Chris Markiewicz + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE.