From c4a62821d73fb5d5a994c8d13241e0b596f3cb0c Mon Sep 17 00:00:00 2001 From: Serge Koudoro Date: Fri, 16 Jan 2026 17:01:23 -0500 Subject: [PATCH] docs: large documentation update. Add examples, add specs, add scripts --- .github/workflows/docbuild.yml | 2 +- .gitignore | 3 + .spin/cmds.py | 13 +- docs/source/conf.py | 30 ++ docs/source/dev.rst | 5 + docs/source/index.rst | 36 +- docs/source/scripts.rst | 2 + docs/source/trx_specifications.rst | 232 +++++++++++++ examples/README.txt | 24 ++ examples/plot_dps_dpv.py | 202 ++++++++++++ examples/plot_groups.py | 200 +++++++++++ examples/plot_read_write_trx.py | 148 +++++++++ pyproject.toml | 13 +- trx/trx_file_memmap.py | 513 ++++++++++++++++++----------- trx/utils.py | 22 +- 15 files changed, 1234 insertions(+), 211 deletions(-) create mode 100644 docs/source/trx_specifications.rst create mode 100644 examples/README.txt create mode 100644 examples/plot_dps_dpv.py create mode 100644 examples/plot_groups.py create mode 100644 examples/plot_read_write_trx.py diff --git a/.github/workflows/docbuild.yml b/.github/workflows/docbuild.yml index 5109045..6cc89c6 100644 --- a/.github/workflows/docbuild.yml +++ b/.github/workflows/docbuild.yml @@ -35,7 +35,7 @@ jobs: - name: Build docs run: | cd docs - make html + SPHINXOPTS="-W" make html - name: Upload docs uses: actions/upload-artifact@v4 with: diff --git a/.gitignore b/.gitignore index 3987e02..4a201c7 100644 --- a/.gitignore +++ b/.gitignore @@ -134,7 +134,10 @@ dmypy.json .vscode/ tmp/ +auto_examples/ CLAUDE.md claude.md agents.md AGENTS.md +sg_execution_times.rst + diff --git a/.spin/cmds.py b/.spin/cmds.py index ea7b2c3..8b7873e 100644 --- a/.spin/cmds.py +++ b/.spin/cmds.py @@ -204,10 +204,19 @@ def docs(clean, open_browser): click.echo("Cleaning build directory...") build_dir = os.path.join(docs_dir, "_build") if os.path.exists(build_dir): - import shutil - shutil.rmtree(build_dir) + # Clean sphinx-gallery generated files + gallery_dir = os.path.join(docs_dir, "source", "auto_examples") + if os.path.exists(gallery_dir): + click.echo("Cleaning sphinx-gallery generated files...") + shutil.rmtree(gallery_dir) + + # Clean sphinx-gallery execution times file + sg_times = os.path.join(docs_dir, "source", "sg_execution_times.rst") + if os.path.exists(sg_times): + os.remove(sg_times) + click.echo("Building documentation...") cmd = ["make", "-C", docs_dir, "html"] result = run(cmd, capture=False, check=False) diff --git a/docs/source/conf.py b/docs/source/conf.py index 8b02e6e..a05264e 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -60,8 +60,19 @@ 'sphinx.ext.autosummary', 'autoapi.extension', 'numpydoc', + 'sphinx_gallery.gen_gallery', + 'sphinx_design', ] +# Suppress known deprecation warnings from dependencies +# astroid 4.x deprecation - will be fixed when sphinx-autoapi updates for astroid 5.x +import warnings +warnings.filterwarnings( + 'ignore', + message="importing .* from 'astroid' is deprecated", + category=DeprecationWarning +) + # Add any paths that contain templates here, relative to this directory. templates_path = ['_templates'] @@ -84,6 +95,10 @@ html_static_path = ['../_static'] html_logo = "../_static/trx_logo.png" +html_sidebars = { + "scripts": [], + "trx_specifications": [], +} html_theme_options = { "icon_links": [ @@ -106,9 +121,24 @@ }, "navbar_start": ["navbar-logo", "version-switcher"], "show_version_warning_banner": True, + # Show table of contents on each page (section navigation) + "secondary_sidebar_items": ["page-toc", "edit-this-page", "sourcelink"], + "show_toc_level": 2, } autoapi_type = 'python' autoapi_dirs = ['../../trx'] autoapi_ignore = ['*test*', '*version*'] + +# Sphinx gallery configuration +sphinx_gallery_conf = { + 'examples_dirs': '../../examples', + 'gallery_dirs': 'auto_examples', + 'within_subsection_order': 'NumberOfCodeLinesSortKey', + 'reference_url': { + 'numpy': 'https://numpy.org/doc/stable/', + 'nibabel': 'https://nipy.org/nibabel/', + }, + 'default_thumb_file': os.path.join(os.path.dirname(__file__), '..', '_static', 'trx_logo.png'), +} diff --git a/docs/source/dev.rst b/docs/source/dev.rst index 5aacdfd..00df196 100644 --- a/docs/source/dev.rst +++ b/docs/source/dev.rst @@ -3,6 +3,11 @@ Developer Guide This guide provides detailed information for developers working on TRX-Python. +.. toctree:: + :maxdepth: 1 + + contributing + Installation for Development ---------------------------- diff --git a/docs/source/index.rst b/docs/source/index.rst index 55ecf8e..714faad 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -11,8 +11,34 @@ exchange, interoperability, and state-of-the-art analyses, acting as a community-driven replacement for the myriad existing file formats. -Why TRX? +Getting Started ~~~~~~~~~~~~~~~ + +New to TRX? Start here: + +1. **Understand the format**: Read the :doc:`trx_specifications` to understand the TRX file structure +2. **Learn by example**: Follow our :doc:`auto_examples/index` to learn how to read, write, and manipulate TRX files +3. **Use the CLI tools**: Check out the :doc:`scripts` documentation for command-line operations + +.. grid:: 2 + + .. grid-item-card:: Tutorials + :link: auto_examples/index + :link-type: doc + + Learn how to work with TRX files through hands-on tutorials covering + reading/writing files, working with groups, and using metadata. + + .. grid-item-card:: TRX Specifications + :link: trx_specifications + :link-type: doc + + Complete technical specifications of the TRX file format including + header fields, array structures, and naming conventions. + + +Why TRX? +~~~~~~~~ File formats that store the results of computational tractography were typically developed within specific software packages. This approach has facilitated a myriad of applications, but this development approach has also generated @@ -44,13 +70,19 @@ Development of TRX is supported by `NIMH grant 1R01MH126699 `_ +- `TRX Python Implementation `_ diff --git a/examples/README.txt b/examples/README.txt new file mode 100644 index 0000000..9c7f85a --- /dev/null +++ b/examples/README.txt @@ -0,0 +1,24 @@ +Tutorials +========= + +These tutorials demonstrate how to use the trx-python library +for working with TRX tractography files. + +Getting Started +--------------- + +New to TRX? These tutorials will guide you through: + +1. **Reading and Writing TRX Files** - Load, inspect, and save TRX files +2. **Working with Groups** - Organize streamlines into anatomical bundles +3. **Data Per Vertex and Streamline** - Work with metadata + +Prerequisites +------------- + +To run these tutorials, you need: + +- trx-python installed (``pip install trx-python[all]``) +- An internet connection (for downloading test data on first run) + +Each tutorial can be run as a Python script or in a Jupyter notebook. diff --git a/examples/plot_dps_dpv.py b/examples/plot_dps_dpv.py new file mode 100644 index 0000000..7457bf1 --- /dev/null +++ b/examples/plot_dps_dpv.py @@ -0,0 +1,202 @@ +# -*- coding: utf-8 -*- +# sphinx_gallery_thumbnail_path = '../docs/_static/trx_logo.png' +""" +Data Per Vertex and Data Per Streamline +======================================== + +This tutorial demonstrates how to work with metadata in TRX files. +TRX supports two types of metadata: + +- **Data Per Vertex (dpv)**: Information attached to each point along streamlines +- **Data Per Streamline (dps)**: Information attached to entire streamlines + +By the end of this tutorial, you will know how to: + +- Access dpv and dps data in a TRX file +- Understand the data shapes and organization +- Use metadata for filtering and analysis +""" + +# %% +# Understanding DPV and DPS +# ------------------------- +# +# **Data Per Vertex (dpv):** +# +# - Attached to each individual point (vertex) in all streamlines +# - Shape: (NB_VERTICES, 1) for scalar data or (NB_VERTICES, N) for vector data +# - Common uses: FA values at each point, RGB colors, local orientations +# +# **Data Per Streamline (dps):** +# +# - Attached to entire streamlines (one value per streamline) +# - Shape: (NB_STREAMLINES, 1) for scalar data or (NB_STREAMLINES, N) for vector data +# - Common uses: bundle ID, mean FA, streamline length, tracking algorithm ID + +# %% +# Loading a TRX file with metadata +# -------------------------------- +# +# Let's load a TRX file and explore its metadata. + +import os + +import numpy as np + +from trx.fetcher import fetch_data, get_home, get_testing_files_dict +from trx.trx_file_memmap import load + +# Download test data +fetch_data(get_testing_files_dict(), keys="gold_standard.zip") +trx_home = get_home() +trx_path = os.path.join(trx_home, "gold_standard", "gs.trx") + +# Load the TRX file +trx = load(trx_path) + +print(f"Loaded TRX with {len(trx)} streamlines") +print(f"Total vertices: {trx.header['NB_VERTICES']}") + +# %% +# Exploring Data Per Vertex (dpv) +# ------------------------------- +# +# Let's see what dpv data is available. + +print("Data Per Vertex keys:", list(trx.data_per_vertex.keys())) + +# Examine each dpv field +for key in trx.data_per_vertex: + data = trx.data_per_vertex[key] + print(f"\n {key}:") + print(f" Shape: {data._data.shape}") + print(f" Dtype: {data._data.dtype}") + print(f" Sample values: {data._data[:3].flatten()}") + +# %% +# Accessing dpv for a specific streamline +# --------------------------------------- +# +# The dpv data is organized to match the streamlines. You can access +# the dpv values for a specific streamline using the same indices. + +if len(trx.data_per_vertex) > 0: + first_dpv_key = list(trx.data_per_vertex.keys())[0] + dpv_data = trx.data_per_vertex[first_dpv_key] + + # Get dpv values for the first streamline + first_streamline_dpv = dpv_data[0] + print(f"DPV '{first_dpv_key}' for first streamline:") + print(f" Shape: {first_streamline_dpv.shape}") + print(f" Values: {first_streamline_dpv.flatten()}") + +# %% +# Exploring Data Per Streamline (dps) +# ----------------------------------- +# +# Now let's examine the dps data. + +print("Data Per Streamline keys:", list(trx.data_per_streamline.keys())) + +# Examine each dps field +for key in trx.data_per_streamline: + data = trx.data_per_streamline[key] + print(f"\n {key}:") + print(f" Shape: {data.shape}") + print(f" Dtype: {data.dtype}") + print(f" First 5 values: {data[:5].flatten()}") + +# %% +# DPS for filtering streamlines +# ----------------------------- +# +# A common use case is filtering streamlines based on dps values. +# For example, selecting streamlines with high FA values. + +if len(trx.data_per_streamline) > 0: + # Use the first dps key for demonstration + first_dps_key = list(trx.data_per_streamline.keys())[0] + dps_data = trx.data_per_streamline[first_dps_key] + + # Calculate some statistics + print(f"\nStatistics for '{first_dps_key}':") + print(f" Min: {np.min(dps_data):.4f}") + print(f" Max: {np.max(dps_data):.4f}") + print(f" Mean: {np.mean(dps_data):.4f}") + print(f" Std: {np.std(dps_data):.4f}") + +# %% +# File structure for dpv and dps +# ------------------------------ +# +# In the TRX format, dpv and dps are stored in separate directories: +# +# .. code-block:: text +# +# my_tractogram.trx/ +# |-- dpv/ +# | |-- fa.float16 # FA values per vertex +# | |-- colors.3.uint8 # RGB colors (3 values per vertex) +# | +-- curvature.float32 # Curvature per vertex +# |-- dps/ +# | |-- bundle_id.uint8 # Bundle assignment per streamline +# | |-- length.uint16 # Length per streamline +# | +-- mean_fa.float32 # Mean FA per streamline +# +-- ... +# +# The filename format is: ``name.dtype`` or ``name.dimension.dtype`` + +# %% +# Working with multi-dimensional data +# ----------------------------------- +# +# Both dpv and dps can have multiple dimensions. For example, RGB colors +# have 3 values per vertex. + +print("\nDemonstrating multi-dimensional data:") + +# Check for any multi-dimensional dpv +for key in trx.data_per_vertex: + data = trx.data_per_vertex[key] + if len(data._data.shape) > 1 and data._data.shape[1] > 1: + print(f" {key}: {data._data.shape[1]}D data per vertex") + +# Check for any multi-dimensional dps +for key in trx.data_per_streamline: + data = trx.data_per_streamline[key] + if len(data.shape) > 1 and data.shape[1] > 1: + print(f" {key}: {data.shape[1]}D data per streamline") + +# %% +# Relationship between dpv and streamlines +# ---------------------------------------- +# +# It's important to understand how dpv data maps to individual streamlines. +# Each streamline's dpv values can be accessed using the streamline's +# vertex indices. + +# Get vertex counts for first few streamlines +print("\nVertex distribution for first 5 streamlines:") +for i in range(min(5, len(trx))): + streamline = trx.streamlines[i] + print(f" Streamline {i}: {len(streamline)} vertices") + +# Total vertices should match +total_from_streamlines = sum(len(trx.streamlines[i]) for i in range(len(trx))) +print(f"\nTotal vertices from streamlines: {total_from_streamlines}") +print(f"Total vertices in header: {trx.header['NB_VERTICES']}") + +# %% +# Summary +# ------- +# +# In this tutorial, you learned how to: +# +# - Access dpv data using ``trx.data_per_vertex[key]`` +# - Access dps data using ``trx.data_per_streamline[key]`` +# - Understand the shape conventions for scalar and vector data +# - Use metadata for statistical analysis +# - Understand the file structure for dpv and dps +# +# The TRX format's metadata system is designed for flexibility, allowing +# you to attach any kind of information to vertices or streamlines. diff --git a/examples/plot_groups.py b/examples/plot_groups.py new file mode 100644 index 0000000..cddf871 --- /dev/null +++ b/examples/plot_groups.py @@ -0,0 +1,200 @@ +# -*- coding: utf-8 -*- +# sphinx_gallery_thumbnail_path = '../docs/_static/trx_logo.png' +""" +Working with Groups +==================== + +This tutorial demonstrates how to work with groups in TRX files. +Groups allow you to organize streamlines into meaningful subsets, +such as anatomical bundles or clusters. + +By the end of this tutorial, you will know how to: + +- Access groups in a TRX file +- Extract streamlines belonging to a specific group +- Understand the relationship between groups and data_per_group (dpg) +- Work with overlapping groups +""" + +# %% +# What are Groups? +# ---------------- +# +# Groups in TRX files are collections of streamline indices. They enable: +# +# - **Sparse representation**: Only store indices instead of copying data +# - **Overlapping membership**: A streamline can belong to multiple groups +# - **Efficient access**: Quickly extract predefined subsets of streamlines +# +# Common use cases include anatomical bundles (e.g., Arcuate Fasciculus, +# Corpus Callosum), clustering results, or connectivity-based groupings. + +# %% +# Loading a TRX file with groups +# ------------------------------ +# +# Let's load a TRX file that contains group information. + +import os + +import numpy as np + +from trx.fetcher import fetch_data, get_home, get_testing_files_dict +from trx.trx_file_memmap import load + +# Download test data +fetch_data(get_testing_files_dict(), keys="gold_standard.zip") +trx_home = get_home() +trx_path = os.path.join(trx_home, "gold_standard", "gs.trx") + +# Load the TRX file +trx = load(trx_path) + +print(f"Loaded TRX with {len(trx)} streamlines") + +# %% +# Accessing groups +# ---------------- +# +# Groups are stored as a dictionary where keys are group names and values +# are numpy arrays of streamline indices. + +print(f"Available groups: {list(trx.groups.keys())}") + +# Check the number of groups +print(f"Number of groups: {len(trx.groups)}") + +# %% +# Let's examine the groups in more detail: + +for group_name, indices in trx.groups.items(): + print(f" {group_name}: {len(indices)} streamlines") + +# %% +# Extracting a group +# ------------------ +# +# You can extract all streamlines belonging to a specific group using +# the ``get_group()`` method. + +if len(trx.groups) > 0: + # Get the first group name + first_group = list(trx.groups.keys())[0] + + # Extract the group as a new TrxFile + group_trx = trx.get_group(first_group) + print(f"Extracted group '{first_group}' with {len(group_trx)} streamlines") + + # You can also access the raw indices + group_indices = trx.groups[first_group] + print(f"Raw indices (first 10): {group_indices[:10]}") +else: + print("No groups available in this file") + +# %% +# Using group indices directly +# ---------------------------- +# +# You can use group indices to select streamlines directly with the +# ``select()`` method. + +if len(trx.groups) > 0: + first_group = list(trx.groups.keys())[0] + indices = trx.groups[first_group] + + # Select streamlines using indices + selected = trx.select(indices[:5]) # Select first 5 from the group + print(f"Selected {len(selected)} streamlines from group '{first_group}'") + +# %% +# Data per group (dpg) +# -------------------- +# +# Groups can have associated metadata stored in ``data_per_group`` (dpg). +# This is useful for storing group-level statistics like mean FA, volume, +# or color codes. + +print(f"Data per group keys: {list(trx.data_per_group.keys())}") + +# Check what metadata is available for each group +for group_name in trx.data_per_group: + dpg_keys = list(trx.data_per_group[group_name].keys()) + print(f" {group_name}: {dpg_keys}") + +# %% +# Creating groups manually +# ------------------------ +# +# You can create groups by assigning indices to the groups dictionary. +# Here's an example of how groups work conceptually. + +# Example: Create conceptual groups for 10 streamlines +example_groups = { + "bundle_A": np.array([0, 1, 2, 3], dtype=np.uint32), + "bundle_B": np.array([4, 5, 6, 7, 8, 9], dtype=np.uint32), + "overlapping": np.array([3, 4, 5], dtype=np.uint32), # Overlaps with A and B +} + +print("Example groups:") +for name, indices in example_groups.items(): + print(f" {name}: streamlines {indices}") + +# Note: Streamline 3 is in both bundle_A and overlapping +# Note: Streamlines 4, 5 are in both bundle_B and overlapping +print("\nOverlapping groups are allowed in TRX!") + +# %% +# Group file structure +# -------------------- +# +# In the TRX file format, groups are stored as binary files in a ``groups/`` +# directory: +# +# .. code-block:: text +# +# my_tractogram.trx/ +# |-- groups/ +# | |-- AF_L.uint32 # Arcuate Fasciculus Left +# | |-- AF_R.uint32 # Arcuate Fasciculus Right +# | |-- CC.uint32 # Corpus Callosum +# | +-- CST_L.uint32 # Corticospinal Tract Left +# +-- ... +# +# Each file contains a flat array of streamline indices as uint32 values. + +# %% +# Filtering streamlines by group +# ------------------------------ +# +# A common workflow is to filter streamlines based on group membership +# and then analyze or visualize specific bundles. + +if len(trx.groups) > 0: + # Get all group names + group_names = list(trx.groups.keys()) + + # Report statistics for each group + print("Group statistics:") + for group_name in group_names: + group_trx = trx.get_group(group_name) + total_points = len(group_trx.streamlines._data) + avg_length = total_points / len(group_trx) if len(group_trx) > 0 else 0 + print(f" {group_name}:") + print(f" - Streamlines: {len(group_trx)}") + print(f" - Total points: {total_points}") + print(f" - Avg points per streamline: {avg_length:.1f}") + +# %% +# Summary +# ------- +# +# In this tutorial, you learned how to: +# +# - Access groups using ``trx.groups`` +# - Extract group streamlines using ``get_group()`` +# - Work with ``data_per_group`` (dpg) metadata +# - Understand that groups can overlap +# - Filter and analyze streamlines by group membership +# +# Groups are a powerful feature of the TRX format that enable efficient +# organization and retrieval of streamline subsets without data duplication. diff --git a/examples/plot_read_write_trx.py b/examples/plot_read_write_trx.py new file mode 100644 index 0000000..f9af07f --- /dev/null +++ b/examples/plot_read_write_trx.py @@ -0,0 +1,148 @@ +# -*- coding: utf-8 -*- +# sphinx_gallery_thumbnail_path = '../docs/_static/trx_logo.png' +""" +Reading and Writing TRX Files +============================== + +This tutorial demonstrates how to read and write TRX files using trx-python. +TRX is a tractography file format designed for efficient storage and access +of brain fiber tract streamline data. + +By the end of this tutorial, you will know how to: + +- Load a TRX file from disk +- Inspect the contents of a TRX file +- Access streamlines and metadata +- Save a TRX file to disk +- Create a TRX file from scratch +""" + +# %% +# Loading a TRX file +# ------------------ +# +# Let's start by loading an existing TRX file. First, we need to download +# some test data. + +import os +import tempfile + +from trx.fetcher import fetch_data, get_home, get_testing_files_dict +from trx.trx_file_memmap import load, save + +# Download test data +fetch_data(get_testing_files_dict(), keys="gold_standard.zip") +trx_home = get_home() +trx_path = os.path.join(trx_home, "gold_standard", "gs.trx") + +# Load the TRX file +trx = load(trx_path) + +print("TRX file loaded successfully!") + +# %% +# Inspecting TRX file contents +# ---------------------------- +# +# The TrxFile object has several key attributes that you can inspect. +# Let's look at what's inside our loaded file. + +# Print a summary of the TRX file +print(trx) + +# %% +# The header contains essential metadata about the tractogram: + +print("Header information:") +print(f" Number of streamlines: {trx.header['NB_STREAMLINES']}") +print(f" Number of vertices: {trx.header['NB_VERTICES']}") +print(f" Image dimensions: {trx.header['DIMENSIONS']}") +print(f" Voxel to RASMM affine:\n{trx.header['VOXEL_TO_RASMM']}") + +# %% +# Accessing streamlines +# --------------------- +# +# Streamlines are the core data in a TRX file. Each streamline is a sequence +# of 3D points representing a fiber tract in the brain. + +print(f"Number of streamlines: {len(trx)}") +print(f"Total number of vertices: {len(trx.streamlines._data)}") + +# Access the first streamline +first_streamline = trx.streamlines[0] +print(f"\nFirst streamline has {len(first_streamline)} points") +print(f"First 3 points of the first streamline:\n{first_streamline[:3]}") + +# %% +# Accessing metadata +# ------------------ +# +# TRX files can contain additional data per vertex (dpv) and per streamline (dps). + +print("Data per vertex (dpv) keys:", list(trx.data_per_vertex.keys())) +print("Data per streamline (dps) keys:", list(trx.data_per_streamline.keys())) +print("Groups:", list(trx.groups.keys())) + +# %% +# Selecting a subset of streamlines +# --------------------------------- +# +# You can easily select a subset of streamlines using indices or slicing. + +# Select first 5 streamlines +subset = trx[:5] +print(f"Subset has {len(subset)} streamlines") + +# Select specific streamlines by indices (ensure indices are valid) +max_idx = len(trx) - 1 +indices = [0, min(2, max_idx), min(5, max_idx)] +selected = trx.select(indices) +print(f"Selected {len(selected)} streamlines") + +# %% +# Saving a TRX file +# ----------------- +# +# You can save a TRX file back to disk. The file can be saved as a compressed +# or uncompressed zip archive, or as a directory. + +with tempfile.TemporaryDirectory() as tmpdir: + # Save as TRX file (zip archive) + output_path = os.path.join(tmpdir, "output.trx") + save(trx, output_path) + print(f"Saved TRX file to: {output_path}") + print(f"File size: {os.path.getsize(output_path)} bytes") + + # Reload to verify + reloaded = load(output_path) + print(f"Reloaded TRX has {len(reloaded)} streamlines") + +# %% +# Creating a TRX file from an existing one +# ---------------------------------------- +# +# A common workflow is to create a new TRX file based on an existing one, +# preserving the spatial reference information. + +# Create a deepcopy of the loaded TRX file +trx_copy = trx.deepcopy() + +print(f"Created copy with {len(trx_copy)} streamlines") +print(f"Header preserved: DIMENSIONS = {trx_copy.header['DIMENSIONS']}") + +# %% +# Summary +# ------- +# +# In this tutorial, you learned how to: +# +# - Load TRX files using ``load()`` +# - Inspect header information and streamline data +# - Access data per vertex (dpv) and data per streamline (dps) +# - Select subsets of streamlines +# - Save TRX files using ``save()`` +# - Create copies of TRX files using ``deepcopy()`` +# +# The TRX format is designed for memory efficiency through memory-mapping, +# making it suitable for large tractography datasets. diff --git a/pyproject.toml b/pyproject.toml index 46c5a1c..9b298eb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,11 +38,13 @@ dev = [ "setuptools_scm", ] doc = [ - "astroid >= 4.0.0", - "sphinx >= 9.0.0", - "pydata-sphinx-theme >= 0.16.1", - "sphinx-autoapi >= 3.0.0", + "matplotlib", "numpydoc", + "pydata-sphinx-theme >= 0.16.1", + "sphinx >= 8.2.0", + "sphinx-autoapi >= 3.4.0", + "sphinx-design", + "sphinx-gallery", ] style = [ "codespell", @@ -95,6 +97,9 @@ __version__ = "{version}" fallback_version = "0.0" local_scheme = "no-local-version" +[tool.codespell] +ignore-words-list = "astroid" + [tool.spin] package = "trx" diff --git a/trx/trx_file_memmap.py b/trx/trx_file_memmap.py index 5a70e82..741bc3d 100644 --- a/trx/trx_file_memmap.py +++ b/trx/trx_file_memmap.py @@ -34,16 +34,20 @@ def _append_last_offsets(nib_offsets: np.ndarray, nb_vertices: int) -> np.ndarray: - """Appends the last element of offsets from header information - - Keyword arguments: - nib_offsets -- np.ndarray - Array of offsets with the last element being the start of the last - streamline (nibabel convention) - nb_vertices -- int - Total number of vertices in the streamlines - Returns: - Offsets -- np.ndarray (VTK convention) + """Append the last element of offsets from header information. + + Parameters + ---------- + nib_offsets : np.ndarray + Array of offsets with the last element being the start of the last + streamline (nibabel convention). + nb_vertices : int + Total number of vertices in the streamlines. + + Returns + ------- + np.ndarray + Offsets array (VTK convention). """ def is_sorted(a): @@ -55,15 +59,19 @@ def is_sorted(a): def _generate_filename_from_data(arr: np.ndarray, filename: str) -> str: - """Determines the data type from array data and generates the appropriate - filename - - Keyword arguments: - arr -- a NumPy array (1-2D, otherwise ValueError raised) - filename -- the original filename - - Returns: - An updated filename + """Determine the data type from array data and generate the appropriate filename. + + Parameters + ---------- + arr : np.ndarray + A NumPy array (1-2D, otherwise ValueError raised). + filename : str + The original filename. + + Returns + ------- + str + An updated filename with appropriate extension. """ base, ext = os.path.splitext(filename) if ext: @@ -87,14 +95,17 @@ def _generate_filename_from_data(arr: np.ndarray, filename: str) -> str: def _split_ext_with_dimensionality(filename: str) -> Tuple[str, int, str]: - """Takes a filename and splits it into its components - - Keyword arguments: - filename -- Input filename + """Take a filename and split it into its components. - Returns: - tuple of strings (basename, dimension, extension) + Parameters + ---------- + filename : str + Input filename. + Returns + ------- + tuple + A tuple of (basename, dimension, extension). """ basename = os.path.basename(filename) split = basename.split(".") @@ -111,13 +122,17 @@ def _split_ext_with_dimensionality(filename: str) -> Tuple[str, int, str]: def _compute_lengths(offsets: np.ndarray) -> np.ndarray: - """Compute lengths from offsets + """Compute lengths from offsets. - Keyword arguments: - offsets -- An np.ndarray of offsets + Parameters + ---------- + offsets : np.ndarray + An array of offsets. - Returns: - lengths -- An np.ndarray of lengths + Returns + ------- + np.ndarray + An array of lengths. """ if len(offsets) > 0: last_elem_pos = _dichotomic_search(offsets) @@ -131,13 +146,17 @@ def _compute_lengths(offsets: np.ndarray) -> np.ndarray: def _is_dtype_valid(ext: str) -> bool: - """Verifies that filename extension is a valid datatype + """Verify that filename extension is a valid datatype. - Keyword arguments: - ext -- filename extension + Parameters + ---------- + ext : str + Filename extension. - Returns: - boolean representing if provided datatype is valid + Returns + ------- + bool + True if the provided datatype is valid, False otherwise. """ if ext.replace(".", "") == "bit": return True @@ -151,14 +170,22 @@ def _is_dtype_valid(ext: str) -> bool: def _dichotomic_search( x: np.ndarray, l_bound: Optional[int] = None, r_bound: Optional[int] = None ) -> int: - """Find where data of a contiguous array is actually ending - - Keyword arguments: - x -- np.ndarray of values - l_bound -- lower bound index for search - r_bound -- upper bound index for search - Returns: - index at which array value is 0 (if possible), otherwise returns -1""" + """Find where data of a contiguous array is actually ending. + + Parameters + ---------- + x : np.ndarray + Array of values. + l_bound : int, optional + Lower bound index for search. + r_bound : int, optional + Upper bound index for search. + + Returns + ------- + int + Index at which array value is 0 (if possible), otherwise returns -1. + """ if l_bound is None and r_bound is None: l_bound = 0 r_bound = len(x) - 1 @@ -183,19 +210,27 @@ def _create_memmap( offset: int = 0, order: str = "C", ) -> np.ndarray: - """Wrapper to support empty array as memmaps - - Keyword arguments: - filename -- filename where the empty memmap should be created - mode -- file open mode (see: np.memmap for options) - shape -- shape of memmapped np.ndarray - dtype -- datatype of memmapped np.ndarray - offset -- offset of the data within the file - order -- data representation on disk (C or Fortran) - - Returns: - mmapped np.ndarray or a zero-filled Numpy array if array has a shape of 0 - in the first dimension + """Wrap memmap creation to support empty arrays. + + Parameters + ---------- + filename : str + Filename where the empty memmap should be created. + mode : str, optional + File open mode (see np.memmap for options). Default is 'r'. + shape : tuple, optional + Shape of memmapped array. Default is (1,). + dtype : np.dtype, optional + Datatype of memmapped array. Default is np.float32. + offset : int, optional + Offset of the data within the file. Default is 0. + order : str, optional + Data representation on disk ('C' or 'F'). Default is 'C'. + + Returns + ------- + np.ndarray + Memory-mapped array or a zero-filled array if shape[0] is 0. """ if np.dtype(dtype) == bool: filename = filename.replace(".bool", ".bit") @@ -212,14 +247,19 @@ def _create_memmap( def load(input_obj: str, check_dpg: bool = True) -> Type["TrxFile"]: - """Load a TrxFile (compressed or not) - - Keyword arguments: - input_obj -- A directory name or filepath to the trx data - check_dpg -- Boolean denoting if group metadata should be checked - - Returns: - TrxFile object representing the read data + """Load a TrxFile (compressed or not). + + Parameters + ---------- + input_obj : str + A directory name or filepath to the TRX data. + check_dpg : bool, optional + Whether to check group metadata. Default is True. + + Returns + ------- + TrxFile + TrxFile object representing the read data. """ # TODO Check if 0 streamlines, then 0 vertices is expected (vice-versa) # TODO 4x4 affine matrices should contains values (no all-zeros) @@ -258,14 +298,19 @@ def load(input_obj: str, check_dpg: bool = True) -> Type["TrxFile"]: def load_from_zip(filename: str) -> Type["TrxFile"]: - """Load a TrxFile from a single zipfile. Note: does not work with - compressed zipfiles + """Load a TrxFile from a single zipfile. - Keyword arguments: - filename -- path of the zipped TrxFile + Note: Does not work with compressed zipfiles. - Returns: - TrxFile representing the read data + Parameters + ---------- + filename : str + Path of the zipped TrxFile. + + Returns + ------- + TrxFile + TrxFile representing the read data. """ with zipfile.ZipFile(filename, mode="r") as zf: with zf.open("header.json") as zf_header: @@ -307,13 +352,17 @@ def load_from_zip(filename: str) -> Type["TrxFile"]: def load_from_directory(directory: str) -> Type["TrxFile"]: - """Load a TrxFile from a folder containing memmaps + """Load a TrxFile from a folder containing memmaps. - Keyword arguments: - filename -- path of the zipped TrxFile + Parameters + ---------- + directory : str + Path of the directory containing TRX data. - Returns: - TrxFile representing the read data + Returns + ------- + TrxFile + TrxFile representing the read data. """ directory = os.path.abspath(directory) @@ -526,25 +575,33 @@ def concatenate( check_space_attributes: bool = True, preallocation: bool = False, ) -> "TrxFile": - """Concatenate multiple TrxFile together, support preallocation - - Keyword arguments: - trx_list -- A list containing TrxFiles to concatenate - delete_dpv -- Delete dpv keys that do not exist in all the provided - TrxFiles - delete_dps -- Delete dps keys that do not exist in all the provided - TrxFile - delete_groups -- Delete all the groups that currently exist in the - TrxFiles - check_space_attributes -- Verify that dimensions and size of data are - similar between all the TrxFiles - preallocation -- Preallocated TrxFile has already been generated and - is the first element in trx_list - (Note: delete_groups must be set to True as well) - - Returns: - TrxFile representing the concatenated data - + """Concatenate multiple TrxFile together, with support for preallocation. + + Parameters + ---------- + trx_list : list of TrxFile + A list containing TrxFiles to concatenate. + delete_dpv : bool, optional + Delete dpv keys that do not exist in all the provided TrxFiles. + Default is False. + delete_dps : bool, optional + Delete dps keys that do not exist in all the provided TrxFiles. + Default is False. + delete_groups : bool, optional + Delete all the groups that currently exist in the TrxFiles. + Default is False. + check_space_attributes : bool, optional + Verify that dimensions and size of data are similar between all + the TrxFiles. Default is True. + preallocation : bool, optional + Preallocated TrxFile has already been generated and is the first + element in trx_list. Note: delete_groups must be set to True as well. + Default is False. + + Returns + ------- + TrxFile + TrxFile representing the concatenated data. """ trx_list = _filter_empty_trx_files(trx_list) if len(trx_list) == 0: @@ -588,13 +645,17 @@ def concatenate( def save( trx: "TrxFile", filename: str, compression_standard: Any = zipfile.ZIP_STORED ) -> None: - """Save a TrxFile (compressed or not) - - Keyword arguments: - trx -- The TrxFile to save - filename -- The path to save the TrxFile to - compression_standard -- The compression standard to use, as defined by - the ZipFile library + """Save a TrxFile (compressed or not). + + Parameters + ---------- + trx : TrxFile + The TrxFile to save. + filename : str + The path to save the TrxFile to. + compression_standard : int, optional + The compression standard to use, as defined by the ZipFile library. + Default is zipfile.ZIP_STORED. """ _, ext = os.path.splitext(filename) if ext not in [".zip", ".trx", ""]: @@ -615,14 +676,17 @@ def save( def zip_from_folder( directory: str, filename: str, compression_standard: Any = zipfile.ZIP_STORED ) -> None: - """Utils function to zip on-disk memmaps - - Keyword arguments - directory -- The path to the on-disk memmap - filename -- The path where the zip file should be created - compression_standard -- The compression standard to use, as defined by - the ZipFile library - + """Zip on-disk memmaps into a single file. + + Parameters + ---------- + directory : str + The path to the on-disk memmap directory. + filename : str + The path where the zip file should be created. + compression_standard : int, optional + The compression standard to use, as defined by the ZipFile library. + Default is zipfile.ZIP_STORED. """ with zipfile.ZipFile(filename, mode="w", compression=compression_standard) as zf: for root, _, files in os.walk(directory): @@ -656,13 +720,18 @@ def __init__( None, ] = None, ) -> None: - """Initialize an empty TrxFile, support preallocation - - Keyword Arguments: - nb_vertices -- The number of vertices to use in the new TrxFile - nb_streamlines -- The number of streamlines in the new TrxFile - init_as -- A TrxFile to use as reference - reference -- A Nifti or Trk file/obj to use as reference + """Initialize an empty TrxFile with support for preallocation. + + Parameters + ---------- + nb_vertices : int, optional + The number of vertices to use in the new TrxFile. + nb_streamlines : int, optional + The number of streamlines in the new TrxFile. + init_as : TrxFile, optional + A TrxFile to use as reference. + reference : str, dict, Nifti1Image, TrkFile, Nifti1Header, optional + A Nifti or Trk file/obj to use as reference. """ if init_as is not None: affine = init_as.header["VOXEL_TO_RASMM"] @@ -779,10 +848,12 @@ def __deepcopy__(self) -> Type["TrxFile"]: return self.deepcopy() def deepcopy(self) -> Type["TrxFile"]: # noqa: C901 - """Create a deepcopy of the TrxFile + """Create a deepcopy of the TrxFile. Returns - A deepcopied TrxFile of the current TrxFile + ------- + TrxFile + A deepcopied TrxFile of the current TrxFile. """ tmp_dir = get_trx_tmp_dir() out_json = open(os.path.join(tmp_dir.name, "header.json"), "w") @@ -872,11 +943,13 @@ def deepcopy(self) -> Type["TrxFile"]: # noqa: C901 return copy_trx def _get_real_len(self) -> Tuple[int, int]: - """Get the real size of data (ignoring zeros of preallocation) + """Get the real size of data (ignoring zeros of preallocation). Returns - A tuple representing the index of the last streamline and the total - length of all the streamlines + ------- + tuple of int + A tuple (strs_end, pts_end) representing the index of the last + streamline and the total length of all the streamlines. """ if len(self.streamlines._lengths) == 0: return 0, 0 @@ -896,18 +969,24 @@ def _copy_fixed_arrays_from( pts_start: int = 0, nb_strs_to_copy: Optional[int] = None, ) -> Tuple[int, int]: - """Fill a TrxFile using another and start indexes (preallocation) - - Keyword arguments: - trx -- TrxFile to copy data from - strs_start -- The start index of the streamline - pts_start -- The start index of the point - nb_strs_to_copy -- The number of streamlines to copy. If not set - will copy all + """Fill a TrxFile using another and start indexes (preallocation). + + Parameters + ---------- + trx : TrxFile + TrxFile to copy data from. + strs_start : int, optional + The start index of the streamline. Default is 0. + pts_start : int, optional + The start index of the point. Default is 0. + nb_strs_to_copy : int, optional + The number of streamlines to copy. If not set, will copy all. Returns - A tuple representing the end of the copied streamlines and end of - copied points + ------- + tuple of int + A tuple (strs_end, pts_end) representing the end of the copied + streamlines and end of copied points. """ if nb_strs_to_copy is None: curr_strs_len, curr_pts_len = trx._get_real_len() @@ -954,17 +1033,23 @@ def _initialize_empty_trx( # noqa: C901 nb_vertices: int, init_as: Optional[Type["TrxFile"]] = None, ) -> Type["TrxFile"]: - """Create on-disk memmaps of a certain size (preallocation) - - Keyword arguments: - nb_streamlines -- The number of streamlines that the empty TrxFile - will be initialized with - nb_vertices -- The number of vertices that the empty TrxFile will - be initialized with - init_as -- A TrxFile to initialize the empty TrxFile with + """Create on-disk memmaps of a certain size (preallocation). + + Parameters + ---------- + nb_streamlines : int + The number of streamlines that the empty TrxFile will be + initialized with. + nb_vertices : int + The number of vertices that the empty TrxFile will be + initialized with. + init_as : TrxFile, optional + A TrxFile to initialize the empty TrxFile with. - Returns: - An empty TrxFile preallocated with a certain size + Returns + ------- + TrxFile + An empty TrxFile preallocated with a certain size. """ trx = TrxFile() tmp_dir = get_trx_tmp_dir() @@ -1078,18 +1163,24 @@ def _create_trx_from_pointer( # noqa: C901 root_zip: Optional[str] = None, root: Optional[str] = None, ) -> Type["TrxFile"]: - """After reading the structure of a zip/folder, create a TrxFile - - Keyword arguments: - header -- A TrxFile header dictionary which will be used for the - new TrxFile - dict_pointer_size -- A dictionary containing the filenames of all - the files within the TrxFile disk file/folder - root_zip -- The path of the ZipFile pointer - root -- The dirname of the ZipFile pointer - - Returns: - A TrxFile constructor from the pointer provided + """Create a TrxFile after reading the structure of a zip/folder. + + Parameters + ---------- + header : dict + A TrxFile header dictionary which will be used for the new TrxFile. + dict_pointer_size : dict + A dictionary containing the filenames of all the files within the + TrxFile disk file/folder. + root_zip : str, optional + The path of the ZipFile pointer. + root : str, optional + The dirname of the ZipFile pointer. + + Returns + ------- + TrxFile + A TrxFile constructed from the pointer provided. """ # TODO support empty positions, using optional tag? trx = TrxFile() @@ -1219,12 +1310,16 @@ def resize( # noqa: C901 nb_vertices: Optional[int] = None, delete_dpg: bool = False, ) -> None: - """Remove the unused portion of preallocated memmaps - - Keyword arguments: - nb_streamlines -- The number of streamlines to keep - nb_vertices -- The number of vertices to keep - delete_dpg -- Remove data_per_group when resizing + """Remove the unused portion of preallocated memmaps. + + Parameters + ---------- + nb_streamlines : int, optional + The number of streamlines to keep. + nb_vertices : int, optional + The number of vertices to keep. + delete_dpg : bool, optional + Remove data_per_group when resizing. Default is False. """ if not self._copy_safe: raise ValueError("Cannot resize a sliced datasets.") @@ -1331,10 +1426,12 @@ def resize( # noqa: C901 self.__dict__ = trx.__dict__ def get_dtype_dict(self): - """Get the dtype dictionary for the TrxFile + """Get the dtype dictionary for the TrxFile. Returns - A dictionary containing the dtype for each data element + ------- + dict + A dictionary containing the dtype for each data element. """ dtype_dict = { "positions": self.streamlines._data.dtype, @@ -1383,11 +1480,14 @@ def append(self, obj, extra_buffer: int = 0) -> None: self._append_trx(obj, extra_buffer=extra_buffer) def _append_trx(self, trx: Type["TrxFile"], extra_buffer: int = 0) -> None: - """Append a TrxFile to another (support buffer) - - Keyword arguments: - trx -- The TrxFile to append to the current TrxFile - extra_buffer -- The additional buffer space required to append data + """Append a TrxFile to another (with buffer support). + + Parameters + ---------- + trx : TrxFile + The TrxFile to append to the current TrxFile. + extra_buffer : int, optional + The additional buffer space required to append data. Default is 0. """ strs_end, pts_end = self._get_real_len() @@ -1407,30 +1507,42 @@ def _append_trx(self, trx: Type["TrxFile"], extra_buffer: int = 0) -> None: def get_group( self, key: str, keep_group: bool = True, copy_safe: bool = False ) -> Type["TrxFile"]: - """Get a particular group from the TrxFile + """Get a particular group from the TrxFile. - Keyword arguments: - key -- The group name to select - keep_group -- Make sure group exists in returned TrxFile - copy_safe -- Perform a deepcopy + Parameters + ---------- + key : str + The group name to select. + keep_group : bool, optional + Make sure group exists in returned TrxFile. Default is True. + copy_safe : bool, optional + Perform a deepcopy. Default is False. Returns - A TrxFile exclusively containing data from said group + ------- + TrxFile + A TrxFile exclusively containing data from said group. """ return self.select(self.groups[key], keep_group=keep_group, copy_safe=copy_safe) def select( self, indices: np.ndarray, keep_group: bool = True, copy_safe: bool = False ) -> Type["TrxFile"]: - """Get a subset of items, always vertices to the same memmaps + """Get a subset of items, always pointing to the same memmaps. - Keyword arguments: - indices -- The list of indices of elements to return - keep_group -- Ensure group is returned in output TrxFile - copy_safe -- Perform a deep-copy + Parameters + ---------- + indices : np.ndarray + The list of indices of elements to return. + keep_group : bool, optional + Ensure group is returned in output TrxFile. Default is True. + copy_safe : bool, optional + Perform a deep-copy. Default is False. - Returns: - A TrxFile containing data originating from the selected indices + Returns + ------- + TrxFile + A TrxFile containing data originating from the selected indices. """ indices = np.array(indices, dtype=np.uint32) @@ -1510,14 +1622,26 @@ def from_lazy_tractogram( chunk_size: int = 10000, dtype_dict: dict = None, ) -> Type["TrxFile"]: - """Append a TrxFile to another (support buffer) - - Keyword arguments: - trx -- The TrxFile to append to the current TrxFile - extra_buffer -- The buffer space between reallocation. - This number should be a number of streamlines. - Use 0 for no buffer. - chunk_size -- The number of streamlines to save at a time. + """Create a TrxFile from a LazyTractogram with buffer support. + + Parameters + ---------- + obj : LazyTractogram + The LazyTractogram to convert. + reference : object + Reference for spatial information. + extra_buffer : int, optional + The buffer space between reallocation. This number should be a + number of streamlines. Use 0 for no buffer. Default is 0. + chunk_size : int, optional + The number of streamlines to save at a time. Default is 10000. + dtype_dict : dict, optional + Dictionary specifying dtypes for positions, offsets, dpv, and dps. + + Returns + ------- + TrxFile + A TrxFile created from the LazyTractogram. """ if dtype_dict is None: dtype_dict = { @@ -1751,13 +1875,18 @@ def to_tractogram(self, resize=False): return tractogram def to_memory(self, resize: bool = False) -> Type["TrxFile"]: - """Convert a TrxFile to a RAM representation + """Convert a TrxFile to a RAM representation. - Keyword arguments: - resize -- Resize TrxFile when converting to RAM representation + Parameters + ---------- + resize : bool, optional + Resize TrxFile when converting to RAM representation. + Default is False. - Returns: - A non memory mapped TrxFile + Returns + ------- + TrxFile + A non memory-mapped TrxFile. """ if resize: self.resize() diff --git a/trx/utils.py b/trx/utils.py index 314b4c2..02f0869 100644 --- a/trx/utils.py +++ b/trx/utils.py @@ -17,14 +17,12 @@ def close_or_delete_mmap(obj): - """ - Close the memory-mapped file if it exists, otherwise set the object to None. + """Close the memory-mapped file if it exists, otherwise set the object to None. - Parameters: - ----------- + Parameters + ---------- obj : object The object that potentially has a memory-mapped file to be closed. - """ if hasattr(obj, "_mmap") and obj._mmap is not None: obj._mmap.close() @@ -373,13 +371,17 @@ def get_reverse_enum(space_str, origin_str): def convert_data_dict_to_tractogram(data): - """Convert a data from a lazy tractogram to a tractogram + """Convert data from a lazy tractogram to a tractogram. - Keyword arguments: - data -- The data dictionary to convert into a nibabel tractogram + Parameters + ---------- + data : dict + The data dictionary to convert into a nibabel tractogram. - Returns: - A Tractogram object + Returns + ------- + Tractogram + A Tractogram object. """ streamlines = ArraySequence(data["strs"]) streamlines._data = streamlines._data