diff --git a/.flake8 b/.flake8
new file mode 100644
index 0000000..f6b4ec2
--- /dev/null
+++ b/.flake8
@@ -0,0 +1,11 @@
+[flake8]
+max-line-length = 88
+max-complexity = 10
+exclude =
+ .git,
+ __pycache__,
+ .tox,
+ .eggs,
+ *.egg,
+ build,
+ dist
\ No newline at end of file
diff --git a/.github/workflows/codeformat.yml b/.github/workflows/codeformat.yml
new file mode 100644
index 0000000..03e3a21
--- /dev/null
+++ b/.github/workflows/codeformat.yml
@@ -0,0 +1,27 @@
+name: Code Format
+
+on:
+ push:
+ branches: [master]
+ pull_request:
+ branches: [master]
+
+permissions:
+ contents: read
+
+jobs:
+ pre-commit:
+ name: Pre-commit checks
+ runs-on: ubuntu-latest
+
+ steps:
+ - name: Check out repository
+ uses: actions/checkout@v4
+
+ - name: Set up Python
+ uses: actions/setup-python@v5
+ with:
+ python-version: "3.12"
+
+ - name: Install and run pre-commit hooks
+ uses: pre-commit/action@v3.0.1
diff --git a/.github/workflows/coverage.yml b/.github/workflows/coverage.yml
new file mode 100644
index 0000000..5738183
--- /dev/null
+++ b/.github/workflows/coverage.yml
@@ -0,0 +1,48 @@
+name: Coverage
+
+on:
+ push:
+ branches: [master]
+ pull_request:
+ branches: [master]
+
+permissions:
+ contents: read
+
+jobs:
+ coverage:
+ name: Code Coverage
+ runs-on: ubuntu-latest
+
+ steps:
+ - name: Check out repository
+ uses: actions/checkout@v4
+ with:
+ fetch-depth: 0
+
+ - name: Set up Python
+ uses: actions/setup-python@v5
+ with:
+ python-version: "3.12"
+ cache: pip
+ cache-dependency-path: pyproject.toml
+
+ - name: Install dependencies
+ run: |
+ python -m pip install --upgrade pip
+ python -m pip install -e .[test]
+
+ - name: Run tests with coverage
+ run: |
+ pytest trx/tests --cov=trx --cov-report=xml --cov-report=term-missing
+
+ - name: Upload coverage to Codecov
+ uses: codecov/codecov-action@v4
+ with:
+ files: ./coverage.xml
+ flags: unittests
+ name: codecov-trx
+ fail_ci_if_error: false
+ verbose: true
+ env:
+ CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
diff --git a/.github/workflows/docbuild.yml b/.github/workflows/docbuild.yml
index 757b4c9..5109045 100644
--- a/.github/workflows/docbuild.yml
+++ b/.github/workflows/docbuild.yml
@@ -17,12 +17,14 @@ jobs:
strategy:
fail-fast: false
matrix:
- python-version: ["3.9"]
+ python-version: ["3.13"]
steps:
- - uses: actions/checkout@v2
+ - uses: actions/checkout@v4
+ with:
+ fetch-depth: 0 # Fetch all history and tags for setuptools_scm
- name: Set up Python ${{ matrix.python-version }}
- uses: actions/setup-python@v2
+ uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install
@@ -35,13 +37,86 @@ jobs:
cd docs
make html
- name: Upload docs
- uses: actions/upload-artifact@v1
+ uses: actions/upload-artifact@v4
+ with:
+ name: docs
+ path: docs/_build/html
+
+ deploy-dev:
+ needs: build
+ runs-on: ubuntu-latest
+ if: github.event_name == 'push' && github.ref == 'refs/heads/master' && github.repository == 'tee-ar-ex/trx-python'
+ steps:
+ - uses: actions/checkout@v4
+ - uses: actions/download-artifact@v4
+ with:
+ name: docs
+ path: docs/_build/html
+ - name: Publish dev docs to Github Pages
+ uses: JamesIves/github-pages-deploy-action@v4
+ with:
+ branch: gh-pages
+ folder: docs/_build/html
+ target-folder: dev
+
+ deploy-release:
+ needs: build
+ runs-on: ubuntu-latest
+ if: startsWith(github.ref, 'refs/tags/') && github.repository == 'tee-ar-ex/trx-python'
+ steps:
+ - uses: actions/checkout@v4
+ - uses: actions/download-artifact@v4
with:
name: docs
path: docs/_build/html
- - name: Publish docs to Github Pages
- if: startsWith(github.event.ref, 'refs/tags')
+ - name: Get version from tag
+ id: get_version
+ run: echo "VERSION=${GITHUB_REF#refs/tags/}" >> $GITHUB_OUTPUT
+ - name: Fetch existing switcher.json from gh-pages
+ run: |
+ curl -sSL https://tee-ar-ex.github.io/trx-python/switcher.json -o switcher.json || echo '[]' > switcher.json
+ - name: Update switcher.json with new version
+ run: python tools/update_switcher.py switcher.json --version ${{ steps.get_version.outputs.VERSION }}
+ - name: Create root files (redirect + switcher.json)
+ run: |
+ mkdir -p root_files
+ cp switcher.json root_files/
+ cat > root_files/index.html << 'EOF'
+
+
+
+
+
+ trx-python - TRX File Format for Tractography
+
+
+
+
+
+
+
+ trx-python Documentation
+ Python implementation of the TRX file format for tractography data.
+ If you are not redirected automatically, visit the stable documentation.
+
+
+ EOF
+ - name: Publish root files (redirect + switcher.json)
+ uses: JamesIves/github-pages-deploy-action@v4
+ with:
+ branch: gh-pages
+ folder: root_files
+ target-folder: .
+ clean: false
+ - name: Publish release docs to Github Pages
+ uses: JamesIves/github-pages-deploy-action@v4
+ with:
+ branch: gh-pages
+ folder: docs/_build/html
+ target-folder: ${{ steps.get_version.outputs.VERSION }}
+ - name: Publish stable docs to Github Pages
uses: JamesIves/github-pages-deploy-action@v4
with:
branch: gh-pages
folder: docs/_build/html
+ target-folder: stable
diff --git a/.github/workflows/publish-to-test-pypi.yml b/.github/workflows/publish-to-test-pypi.yml
index 572d0d1..67918b9 100644
--- a/.github/workflows/publish-to-test-pypi.yml
+++ b/.github/workflows/publish-to-test-pypi.yml
@@ -1,39 +1,95 @@
name: Publish Python 🐍 distributions 📦 to PyPI and TestPyPI
-on: push
+on:
+ push:
+ branches:
+ - master
+ tags:
+ - '*'
+
jobs:
- build-n-publish:
- name: Build and publish Python 🐍 distributions 📦 to PyPI and TestPyPI
+ build:
+ name: Build distribution 📦
runs-on: ubuntu-latest
+ if: github.repository == 'tee-ar-ex/trx-python'
steps:
- - uses: actions/checkout@master
- - name: Set up Python 3.9
- uses: actions/setup-python@v1
+ - uses: actions/checkout@v4
+ with:
+ fetch-depth: 0 # Fetch all history and tags for setuptools_scm
+ - name: Set up Python 3.11
+ uses: actions/setup-python@v5
with:
- python-version: 3.9
+ python-version: "3.11"
- name: Install pypa/build
- run: >-
- python -m
- pip install
- build
- --user
+ run: python -m pip install build --user
- name: Build a binary wheel and a source tarball
- run: >-
- python -m
- build
- --sdist
- --wheel
- --outdir dist/
- .
+ run: python -m build --sdist --wheel --outdir dist/ .
+ - name: Upload distribution artifacts
+ uses: actions/upload-artifact@v4
+ with:
+ name: python-package-distributions
+ path: dist/
+
+ publish-to-testpypi:
+ name: Publish to TestPyPI
+ needs: build
+ runs-on: ubuntu-latest
+ environment:
+ name: testpypi
+ url: https://test.pypi.org/p/trx-python
+ permissions:
+ id-token: write
+ steps:
+ - name: Download distribution artifacts
+ uses: actions/download-artifact@v4
+ with:
+ name: python-package-distributions
+ path: dist/
- name: Publish distribution 📦 to Test PyPI
uses: pypa/gh-action-pypi-publish@release/v1
with:
- user: __token__
- password: ${{ secrets.test_pypi_password }}
- repository_url: https://test.pypi.org/legacy/
+ repository-url: https://test.pypi.org/legacy/
+
+ publish-to-pypi:
+ name: Publish to PyPI
+ needs: build
+ runs-on: ubuntu-latest
+ if: startsWith(github.ref, 'refs/tags/')
+ environment:
+ name: pypi
+ url: https://pypi.org/p/trx-python
+ permissions:
+ id-token: write
+ steps:
+ - name: Download distribution artifacts
+ uses: actions/download-artifact@v4
+ with:
+ name: python-package-distributions
+ path: dist/
- name: Publish distribution 📦 to PyPI
- if: startsWith(github.event.ref, 'refs/tags')
uses: pypa/gh-action-pypi-publish@release/v1
+
+ github-release:
+ name: Create GitHub Release
+ needs: publish-to-pypi
+ runs-on: ubuntu-latest
+ permissions:
+ contents: write
+ steps:
+ - uses: actions/checkout@v4
+ - name: Download distribution artifacts
+ uses: actions/download-artifact@v4
with:
- user: __token__
- password: ${{ secrets.pypi_password }}
+ name: python-package-distributions
+ path: dist/
+ - name: Get version from tag
+ id: get_version
+ run: echo "VERSION=${GITHUB_REF#refs/tags/}" >> $GITHUB_OUTPUT
+ - name: Create GitHub Release
+ env:
+ GH_TOKEN: ${{ github.token }}
+ run: |
+ gh release create "${{ steps.get_version.outputs.VERSION }}" \
+ --title "Release ${{ steps.get_version.outputs.VERSION }}" \
+ --generate-notes \
+ dist/*
diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml
deleted file mode 100644
index 822b770..0000000
--- a/.github/workflows/python-package.yml
+++ /dev/null
@@ -1,39 +0,0 @@
-name: Python package
-
-on:
- push:
- branches: [ master ]
- pull_request:
- branches: [ master ]
-
-jobs:
- build:
- runs-on: ubuntu-latest
- strategy:
- fail-fast: false
- matrix:
- python-version: ["3.8", "3.9", "3.10", "3.11"]
- os: [ubuntu-latest, windows-latest, macos-latest]
- steps:
- - uses: actions/checkout@v2
- - name: Set up Python ${{ matrix.python-version }}
- uses: actions/setup-python@v2
- with:
- python-version: ${{ matrix.python-version }}
- - name: Install the package
- run: |
- python -m pip install --upgrade pip
- python -m pip install -e .[test]
-
- - name: Report version
- run: |
- python setup.py --version
- - name: Lint with flake8
- run: |
- # stop the build if there are Python syntax errors or undefined names
- flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
- # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
- flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
- - name: Test with pytest
- run: |
- pytest trx/tests; pytest scripts/tests
diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml
new file mode 100644
index 0000000..e9f231e
--- /dev/null
+++ b/.github/workflows/test.yml
@@ -0,0 +1,39 @@
+name: Tests
+
+on:
+ push:
+ branches: [master]
+ pull_request:
+ branches: [master]
+
+permissions:
+ contents: read
+
+jobs:
+ test:
+ name: Python ${{ matrix.python-version }} • ${{ matrix.os }}
+ runs-on: ${{ matrix.os }}
+ strategy:
+ fail-fast: false
+ matrix:
+ python-version: ["3.11", "3.12", "3.13"]
+ os: [ubuntu-latest, windows-latest, macos-latest]
+ steps:
+ - uses: actions/checkout@v4
+ with:
+ fetch-depth: 0 # needed for setuptools_scm version detection
+
+ - name: Set up Python ${{ matrix.python-version }}
+ uses: actions/setup-python@v5
+ with:
+ python-version: ${{ matrix.python-version }}
+ cache: pip
+ cache-dependency-path: pyproject.toml
+
+ - name: Install dependencies
+ run: |
+ python -m pip install --upgrade pip
+ python -m pip install -e .[dev,test]
+
+ - name: Test
+ run: spin test
diff --git a/.gitignore b/.gitignore
index b5a6b76..3987e02 100644
--- a/.gitignore
+++ b/.gitignore
@@ -132,3 +132,9 @@ dmypy.json
# Pyre type checker
.pyre/
.vscode/
+
+tmp/
+CLAUDE.md
+claude.md
+agents.md
+AGENTS.md
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
new file mode 100644
index 0000000..898b98a
--- /dev/null
+++ b/.pre-commit-config.yaml
@@ -0,0 +1,19 @@
+default_language_version:
+ python: python3
+
+repos:
+ - repo: https://github.com/astral-sh/ruff-pre-commit
+ rev: v0.8.6
+ hooks:
+ # Run the linter
+ - id: ruff
+ args: [ --fix ]
+ # Run the formatter
+ - id: ruff-format
+ - repo: https://github.com/codespell-project/codespell
+ rev: v2.3.0
+ hooks:
+ - id: codespell
+ args: [--skip, "pyproject.toml,docs/_build/*,*.egg-info"]
+ additional_dependencies:
+ - tomli
diff --git a/.spin/__init__.py b/.spin/__init__.py
new file mode 100644
index 0000000..85155f8
--- /dev/null
+++ b/.spin/__init__.py
@@ -0,0 +1 @@
+# Spin commands package
diff --git a/.spin/cmds.py b/.spin/cmds.py
new file mode 100644
index 0000000..ea7b2c3
--- /dev/null
+++ b/.spin/cmds.py
@@ -0,0 +1,261 @@
+"""Custom spin commands for trx-python development."""
+
+import glob
+import os
+import shutil
+import subprocess
+import sys
+import tempfile
+
+import click
+
+UPSTREAM_URL = "https://github.com/tee-ar-ex/trx-python.git"
+UPSTREAM_NAME = "upstream"
+
+
+def run(cmd, check=True, capture=True):
+ """Run a shell command."""
+ result = subprocess.run(cmd, capture_output=capture, text=True, check=False)
+ if check and result.returncode != 0:
+ if capture:
+ click.echo(f"Error: {result.stderr}", err=True)
+ return None
+ return result.stdout.strip() if capture else result.returncode
+
+
+def get_remotes():
+ """Get dict of remote names to URLs."""
+ output = run(["git", "remote", "-v"])
+ if not output:
+ return {}
+ remotes = {}
+ for line in output.split("\n"):
+ if "(fetch)" in line:
+ parts = line.split()
+ remotes[parts[0]] = parts[1]
+ return remotes
+
+
+@click.command()
+def setup():
+ """Set up development environment (fetch tags from upstream).
+
+ This command configures your fork for development by:
+ 1. Adding the upstream remote if not present
+ 2. Fetching tags from upstream (required for correct version detection)
+
+ Run this once after cloning your fork.
+ """
+ click.echo("Setting up trx-python development environment...\n")
+
+ # Check if in git repo
+ if run(["git", "rev-parse", "--git-dir"], check=False) is None:
+ click.echo("Error: Not in a git repository", err=True)
+ sys.exit(1)
+
+ # Check/add upstream remote
+ remotes = get_remotes()
+ upstream_remote = None
+
+ for name, url in remotes.items():
+ if UPSTREAM_URL.rstrip(".git") in url.rstrip(".git"):
+ upstream_remote = name
+ click.echo(f"Found upstream remote: {name}")
+ break
+
+ if upstream_remote is None:
+ click.echo(f"Adding upstream remote: {UPSTREAM_URL}")
+ run(["git", "remote", "add", UPSTREAM_NAME, UPSTREAM_URL])
+ upstream_remote = UPSTREAM_NAME
+
+ # Fetch tags
+ click.echo(f"\nFetching tags from {upstream_remote}...")
+ run(["git", "fetch", upstream_remote, "--tags"], capture=False)
+
+ # Verify version
+ click.echo("\nVerifying version detection...")
+ try:
+ from setuptools_scm import get_version
+
+ version = get_version()
+ click.echo(f"Detected version: {version}")
+
+ # Check for suspicious version patterns
+ if version.startswith("0.0"):
+ click.echo(
+ "\nWarning: Version starts with 0.0 - tags may not be fetched.",
+ err=True,
+ )
+ sys.exit(1)
+ except ImportError:
+ click.echo("Note: Install setuptools_scm to verify version detection")
+
+ click.echo("\nSetup complete! You can now run:")
+ click.echo(" spin install # Install in development mode")
+ click.echo(" spin test # Run tests")
+
+
+@click.command()
+@click.option(
+ "-m",
+ "--match",
+ "pattern",
+ default=None,
+ help="Only run tests matching this pattern (passed to pytest -k)",
+)
+@click.option("-v", "--verbose", is_flag=True, default=False, help="Verbose output")
+@click.argument("pytest_args", nargs=-1)
+def test(pattern, verbose, pytest_args):
+ """Run tests using pytest.
+
+ Additional arguments are passed directly to pytest.
+
+ Examples:
+ spin test # Run all tests
+ spin test -m memmap # Run tests matching 'memmap'
+ spin test -v # Verbose output
+ spin test -- -x --tb=short # Pass args to pytest
+ """
+ cmd = ["pytest", "trx/tests"]
+
+ if pattern:
+ cmd.extend(["-k", pattern])
+
+ if verbose:
+ cmd.append("-v")
+
+ if pytest_args:
+ cmd.extend(pytest_args)
+
+ click.echo(f"Running: {' '.join(cmd)}\n")
+ sys.exit(run(cmd, capture=False, check=False))
+
+
+@click.command()
+@click.option(
+ "--fix", is_flag=True, default=False, help="Automatically fix issues where possible"
+)
+def lint(fix):
+ """Run linting checks using ruff and codespell.
+
+ Examples:
+ spin lint # Run ruff and codespell checks
+ spin lint --fix # Run ruff and auto-fix issues
+ """
+ click.echo("Running ruff linter...")
+ cmd = ["ruff", "check", "."]
+
+ if fix:
+ cmd.append("--fix")
+
+ result = run(cmd, capture=False, check=False)
+ if result != 0:
+ click.echo("\nLinting issues found!", err=True)
+ sys.exit(1)
+
+ click.echo("\nRunning ruff formatter check...")
+ cmd_format = ["ruff", "format", "--check", "."]
+ result = run(cmd_format, capture=False, check=False)
+ if result != 0:
+ click.echo("\nFormatting issues found!", err=True)
+ sys.exit(1)
+
+ click.echo("\nRunning codespell...")
+ cmd_spell = [
+ "codespell",
+ "--skip",
+ "*.pyc,.git,pyproject.toml,./docs/_build/*,*.egg-info,./build/*,./dist/*,./tmp/*",
+ "trx",
+ "docs/source",
+ ".spin",
+ ]
+ result = run(cmd_spell, capture=False, check=False)
+ if result != 0:
+ click.echo("\nSpelling issues found!", err=True)
+ sys.exit(1)
+
+ click.echo("\nAll checks passed!")
+
+
+@click.command()
+@click.option(
+ "--clean", is_flag=True, default=False, help="Clean build directory before building"
+)
+@click.option(
+ "--open",
+ "open_browser",
+ is_flag=True,
+ default=False,
+ help="Open documentation in browser after building",
+)
+def docs(clean, open_browser):
+ """Build documentation using Sphinx.
+
+ Examples:
+ spin docs # Build docs
+ spin docs --clean # Clean and rebuild
+ spin docs --open # Build and open in browser
+ """
+ import os
+
+ docs_dir = "docs"
+
+ if clean:
+ click.echo("Cleaning build directory...")
+ build_dir = os.path.join(docs_dir, "_build")
+ if os.path.exists(build_dir):
+ import shutil
+
+ shutil.rmtree(build_dir)
+
+ click.echo("Building documentation...")
+ cmd = ["make", "-C", docs_dir, "html"]
+ result = run(cmd, capture=False, check=False)
+
+ if result == 0:
+ index_path = os.path.abspath(
+ os.path.join(docs_dir, "_build", "html", "index.html")
+ )
+ click.echo("\nDocs built successfully!")
+ click.echo(f"Open: {index_path}")
+
+ if open_browser:
+ import webbrowser
+
+ webbrowser.open(f"file://{index_path}")
+
+ sys.exit(result)
+
+
+@click.command()
+def clean(): # noqa: C901
+ """Clean up temporary files and build artifacts."""
+ click.echo("Cleaning up temporary files...")
+
+ # Clean TRX temp directory
+ trx_tmp_dir = os.getenv("TRX_TMPDIR", tempfile.gettempdir())
+ if os.path.exists(trx_tmp_dir):
+ temp_files = glob.glob(os.path.join(trx_tmp_dir, "trx_*"))
+ for temp_dir in temp_files:
+ if os.path.isdir(temp_dir):
+ click.echo(f"Removing temporary directory: {temp_dir}")
+ shutil.rmtree(temp_dir)
+
+ # Clean build artifacts
+ for build_pattern in ["build", "dist", "*.egg-info"]:
+ for path in glob.glob(build_pattern):
+ if os.path.isdir(path):
+ click.echo(f"Removing build directory: {path}")
+ shutil.rmtree(path)
+ elif os.path.isfile(path):
+ click.echo(f"Removing build file: {path}")
+ os.remove(path)
+
+ # Clean Python cache
+ for cache_dir in ["**/__pycache__", "**/.pytest_cache"]:
+ for path in glob.glob(cache_dir, recursive=True):
+ if os.path.isdir(path):
+ click.echo(f"Removing cache directory: {path}")
+ shutil.rmtree(path)
+
+ click.echo("Cleanup complete!")
diff --git a/README.md b/README.md
index deca3d3..f86676a 100644
--- a/README.md
+++ b/README.md
@@ -1,29 +1,162 @@
# trx-python
-This is a Python implementation of the trx file-format for tractography data.
+[](https://github.com/tee-ar-ex/trx-python/actions/workflows/test.yml)
+[](https://github.com/tee-ar-ex/trx-python/actions/workflows/codeformat.yml)
+[](https://codecov.io/gh/tee-ar-ex/trx-python)
+[](https://badge.fury.io/py/trx-python)
-For details, please visit the documentation web-page at https://tee-ar-ex.github.io/trx-python/.
+A Python implementation of the TRX file format for tractography data.
-To install this, you can run:
+For details, please visit the [documentation](https://tee-ar-ex.github.io/trx-python/).
- pip install trx-python
+## Installation
-Or, to install from source:
+### From PyPI
- git clone https://github.com/tee-ar-ex/trx-python.git
- cd trx-python
- pip install .
+```bash
+pip install trx-python
+```
-### Temporary Directory
-The TRX file format uses memmaps to limit RAM usage. When dealing with large files this means several gigabytes could be required on disk (instead of RAM).
+### From Source
-By default, the temporary directory on Linux and MacOS is `/tmp` and on Windows it should be `C:\WINDOWS\Temp`.
+```bash
+git clone https://github.com/tee-ar-ex/trx-python.git
+cd trx-python
+pip install .
+```
-If you wish to change the directory add the following variable to your script or to your .bashrc or .bash_profile:
-`export TRX_TMPDIR=/WHERE/I/WANT/MY/TMP/DATA` (a)
-OR
-`export TRX_TMPDIR=use_working_dir` (b)
+## Quick Start
-The provided folder must already exists (a). `use_working_dir` will be the directory where the code is being executed from (b).
+### Loading and Saving Tractograms
-The temporary folders should be automatically cleaned. But, if the code crash unexpectedly, make sure the folders are deleted.
+```python
+from trx.io import load, save
+
+# Load a tractogram (supports .trx, .trk, .tck, .vtk, .fib, .dpy)
+trx = load("tractogram.trx")
+
+# Save to a different format
+save(trx, "output.trk")
+```
+
+### Command-Line Interface
+
+TRX-Python provides a unified CLI (`tff`) for common operations:
+
+```bash
+# Show all available commands
+tff --help
+
+# Convert between formats
+tff convert input.trk output.trx
+
+# Concatenate tractograms
+tff concatenate tract1.trx tract2.trx merged.trx
+
+# Validate a TRX file
+tff validate data.trx
+```
+
+Individual commands are also available for backward compatibility:
+
+```bash
+tff_convert_tractogram input.trk output.trx
+tff_concatenate_tractograms tract1.trx tract2.trx merged.trx
+tff_validate_trx data.trx
+```
+
+## Development
+
+We use [spin](https://github.com/scientific-python/spin) for development workflow.
+
+### First-Time Setup
+
+```bash
+# Clone the repository (or your fork)
+git clone https://github.com/tee-ar-ex/trx-python.git
+cd trx-python
+
+# Install with all dependencies
+pip install -e ".[all]"
+
+# Set up development environment (fetches upstream tags)
+spin setup
+```
+
+### Common Commands
+
+```bash
+spin setup # Set up development environment
+spin install # Install in editable mode
+spin test # Run all tests
+spin test -m memmap # Run tests matching pattern
+spin lint # Run linting (ruff)
+spin lint --fix # Auto-fix linting issues
+spin docs # Build documentation
+spin clean # Clean temporary files
+```
+
+Run `spin` without arguments to see all available commands.
+
+### Code Quality
+
+We use [ruff](https://docs.astral.sh/ruff/) for linting and formatting:
+
+```bash
+# Check for issues
+spin lint
+
+# Auto-fix issues
+spin lint --fix
+
+# Format code
+ruff format .
+```
+
+### Pre-commit Hooks
+
+```bash
+# Install hooks
+pre-commit install
+
+# Run on all files
+pre-commit run --all-files
+```
+
+## Temporary Directory
+
+The TRX file format uses memory-mapped files to limit RAM usage. When dealing with large files, several gigabytes may be required on disk.
+
+By default, temporary files are stored in:
+- Linux/macOS: `/tmp`
+- Windows: `C:\WINDOWS\Temp`
+
+To change the directory:
+
+```bash
+# Use a specific directory (must exist)
+export TRX_TMPDIR=/path/to/tmp
+
+# Use current working directory
+export TRX_TMPDIR=use_working_dir
+```
+
+Temporary folders are automatically cleaned, but if the code crashes unexpectedly, ensure folders are deleted manually.
+
+## Documentation
+
+Full documentation is available at https://tee-ar-ex.github.io/trx-python/
+
+To build locally:
+
+```bash
+spin docs --open
+```
+
+## Contributing
+
+We welcome contributions! Please see our [Contributing Guide](https://tee-ar-ex.github.io/trx-python/contributing.html) for details.
+
+## License
+
+BSD License - see [LICENSE](LICENSE) for details.
diff --git a/docs/_static/switcher.json b/docs/_static/switcher.json
new file mode 100644
index 0000000..ad36d9e
--- /dev/null
+++ b/docs/_static/switcher.json
@@ -0,0 +1,13 @@
+[
+ {
+ "name": "dev",
+ "version": "dev",
+ "url": "https://tee-ar-ex.github.io/trx-python/dev/"
+ },
+ {
+ "name": "stable",
+ "version": "stable",
+ "url": "https://tee-ar-ex.github.io/trx-python/stable/",
+ "preferred": true
+ }
+]
diff --git a/docs/source/conf.py b/docs/source/conf.py
index 5970660..8b02e6e 100644
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -13,13 +13,35 @@
# import os
# import sys
# sys.path.insert(0, os.path.abspath('.'))
+import os
from datetime import datetime as dt
+# -- Version information -----------------------------------------------------
+# Get version from environment variable (set by CI) or package
+version = os.environ.get('TRX_VERSION', None)
+if version is None:
+ try:
+ from trx import __version__
+ version = __version__
+ except ImportError:
+ version = "dev"
+
+# Normalize version for switcher matching
+# Remove .devX suffix for matching against switcher.json
+version_match = version.split('.dev')[0] if '.dev' in version else version
+if version_match == version and 'dev' not in version:
+ # This is a release version
+ pass
+else:
+ # Development version - match against "dev"
+ version_match = "dev"
+
# -- Project information -----------------------------------------------------
project = 'trx-python'
copyright = copyright = f'2021-{dt.now().year}, The TRX developers'
author = 'The TRX developers'
+release = version
# -- General configuration ---------------------------------------------------
@@ -59,6 +81,7 @@
# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
+html_static_path = ['../_static']
html_logo = "../_static/trx_logo.png"
@@ -69,12 +92,20 @@
"name": "GitHub",
# URL where the link will redirect
"url": "https://github.com/tee-ar-ex", # required
- # Icon class (if "type": "fontawesome"), or path to local image (if "type": "local")
+ # Icon class (if "type": "fontawesome"), or path to local image
+ # (if "type": "local")
"icon": "fab fa-github-square",
# The type of image to be used (see below for details)
"type": "fontawesome",
}
- ]
+ ],
+ # Version switcher configuration
+ "switcher": {
+ "json_url": "https://tee-ar-ex.github.io/trx-python/switcher.json",
+ "version_match": version_match,
+ },
+ "navbar_start": ["navbar-logo", "version-switcher"],
+ "show_version_warning_banner": True,
}
diff --git a/docs/source/contributing.rst b/docs/source/contributing.rst
new file mode 100644
index 0000000..0ab192a
--- /dev/null
+++ b/docs/source/contributing.rst
@@ -0,0 +1,209 @@
+Contributing to TRX-Python
+==========================
+
+We welcome contributions from the community! This guide will help you get started
+with contributing to the TRX-Python project.
+
+Ways to Contribute
+------------------
+
+There are many ways to contribute to TRX-Python:
+
+- **Report bugs**: If you find a bug, please open an issue on GitHub
+- **Suggest features**: Have an idea? Open an issue to discuss it
+- **Fix bugs**: Look for issues labeled "good first issue" or "help wanted"
+- **Write documentation**: Help improve our docs or add examples
+- **Write tests**: Increase test coverage
+- **Code review**: Review pull requests from other contributors
+
+Getting Started
+---------------
+
+1. **Fork the repository** on GitHub
+2. **Clone your fork**:
+
+ .. code-block:: bash
+
+ git clone https://github.com/YOUR_USERNAME/trx-python.git
+ cd trx-python
+
+3. **Set up development environment**:
+
+ .. code-block:: bash
+
+ pip install -e ".[all]"
+ spin setup
+
+ The ``spin setup`` command fetches version tags from upstream, which is
+ required for correct version detection.
+
+4. **Create a branch** for your changes:
+
+ .. code-block:: bash
+
+ git checkout -b my-feature-branch
+
+Making Changes
+--------------
+
+Development Workflow
+~~~~~~~~~~~~~~~~~~~~
+
+We use `spin `_ for development workflow:
+
+.. code-block:: bash
+
+ spin install # Install in editable mode
+ spin test # Run all tests
+ spin lint # Run linting (ruff)
+ spin docs # Build documentation
+
+Before Submitting
+~~~~~~~~~~~~~~~~~
+
+1. **Run tests** to ensure your changes don't break existing functionality:
+
+ .. code-block:: bash
+
+ spin test
+
+2. **Run linting** to ensure code style compliance:
+
+ .. code-block:: bash
+
+ spin lint
+
+ You can auto-fix many issues with:
+
+ .. code-block:: bash
+
+ spin lint --fix
+
+3. **Format your code** using ruff:
+
+ .. code-block:: bash
+
+ ruff format .
+
+4. **Write tests** for any new functionality
+
+5. **Update documentation** if needed
+
+Submitting a Pull Request
+-------------------------
+
+1. **Push your changes** to your fork:
+
+ .. code-block:: bash
+
+ git push origin my-feature-branch
+
+2. **Open a Pull Request** on GitHub against the ``master`` branch
+
+3. **Describe your changes** in the PR description:
+
+ - What does this PR do?
+ - Why is this change needed?
+ - How was it tested?
+
+4. **Wait for CI checks** to pass
+
+5. **Address review feedback** if requested
+
+Code Style
+----------
+
+We follow these conventions:
+
+- **PEP 8** style guide
+- **Line length**: 88 characters maximum
+- **Docstrings**: NumPy style format
+- **Type hints**: Encouraged but not required
+
+Example docstring:
+
+.. code-block:: python
+
+ def my_function(param1, param2):
+ """Short description of the function.
+
+ Parameters
+ ----------
+ param1 : int
+ Description of param1.
+ param2 : str
+ Description of param2.
+
+ Returns
+ -------
+ result : bool
+ Description of return value.
+
+ Examples
+ --------
+ >>> my_function(1, "test")
+ True
+ """
+ pass
+
+We use `ruff `_ for linting and formatting.
+Configuration is in ``ruff.toml``.
+
+Testing
+-------
+
+Tests are located in ``trx/tests/``. We use pytest for testing.
+
+Running Tests
+~~~~~~~~~~~~~
+
+.. code-block:: bash
+
+ # Run all tests
+ spin test
+
+ # Run tests matching a pattern
+ spin test -m memmap
+
+ # Run with verbose output
+ spin test -v
+
+ # Run a specific test file
+ pytest trx/tests/test_memmap.py
+
+Writing Tests
+~~~~~~~~~~~~~
+
+- Place tests in ``trx/tests/``
+- Name test files ``test_*.py``
+- Name test functions ``test_*``
+- Use pytest fixtures for common setup
+
+Documentation
+-------------
+
+Documentation is built with Sphinx and hosted on GitHub Pages.
+
+Building Docs
+~~~~~~~~~~~~~
+
+.. code-block:: bash
+
+ spin docs # Build documentation
+ spin docs --clean # Clean build
+ spin docs --open # Build and open in browser
+
+Writing Documentation
+~~~~~~~~~~~~~~~~~~~~~
+
+- Documentation source is in ``docs/source/``
+- Use reStructuredText format
+- API documentation is auto-generated from docstrings
+
+Getting Help
+------------
+
+- **GitHub Issues**: For bugs and feature requests
+- **GitHub Discussions**: For questions and discussions
+
+Thank you for contributing to TRX-Python!
diff --git a/docs/source/dev.rst b/docs/source/dev.rst
new file mode 100644
index 0000000..5aacdfd
--- /dev/null
+++ b/docs/source/dev.rst
@@ -0,0 +1,342 @@
+Developer Guide
+===============
+
+This guide provides detailed information for developers working on TRX-Python.
+
+Installation for Development
+----------------------------
+
+Prerequisites
+~~~~~~~~~~~~~
+
+- Python 3.11 or later (Python 3.12+ recommended)
+- Git
+- pip
+
+Setting Up Your Environment
+~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+1. **Clone the repository**:
+
+ .. code-block:: bash
+
+ # If you're a contributor, fork first then clone your fork
+ git clone https://github.com/YOUR_USERNAME/trx-python.git
+ cd trx-python
+
+2. **Install with all development dependencies**:
+
+ .. code-block:: bash
+
+ pip install -e ".[all]"
+
+ This installs:
+
+ - Core dependencies (numpy, nibabel, deepdiff, typer)
+ - Development tools (spin, setuptools_scm)
+ - Documentation tools (sphinx, numpydoc)
+ - Style tools (ruff, pre-commit)
+ - Testing tools (pytest, pytest-cov)
+
+3. **Set up the development environment**:
+
+ .. code-block:: bash
+
+ spin setup
+
+ This command:
+
+ - Adds upstream remote if missing
+ - Fetches version tags for correct ``setuptools_scm`` version detection
+
+Using Spin
+----------
+
+We use `spin `_ for development workflow.
+Spin provides a consistent interface for common development tasks.
+
+Available Commands
+~~~~~~~~~~~~~~~~~~
+
+Run ``spin`` without arguments to see all available commands:
+
+.. code-block:: bash
+
+ spin
+
+**Setup Commands:**
+
+.. code-block:: bash
+
+ spin setup # Configure development environment
+
+**Build Commands:**
+
+.. code-block:: bash
+
+ spin install # Install package in editable mode
+
+**Test Commands:**
+
+.. code-block:: bash
+
+ spin test # Run all tests
+ spin test -m NAME # Run tests matching pattern
+ spin test -v # Verbose output
+ spin lint # Run ruff linting
+ spin lint --fix # Auto-fix linting issues
+
+**Documentation Commands:**
+
+.. code-block:: bash
+
+ spin docs # Build documentation
+ spin docs --clean # Clean and rebuild
+ spin docs --open # Build and open in browser
+
+**Cleanup Commands:**
+
+.. code-block:: bash
+
+ spin clean # Remove temporary files and build artifacts
+
+Code Quality
+------------
+
+Linting with Ruff
+~~~~~~~~~~~~~~~~~
+
+We use `ruff `_ for linting and formatting.
+Configuration is in ``ruff.toml``.
+
+.. code-block:: bash
+
+ # Check for issues
+ spin lint
+
+ # Auto-fix issues
+ spin lint --fix
+
+ # Format code
+ ruff format .
+
+ # Check formatting without changes
+ ruff format --check .
+
+Pre-commit Hooks
+~~~~~~~~~~~~~~~~
+
+We recommend using pre-commit hooks to catch issues before committing:
+
+.. code-block:: bash
+
+ # Install pre-commit hooks
+ pre-commit install
+
+ # Run hooks manually on all files
+ pre-commit run --all-files
+
+The hooks run:
+
+- ``ruff`` - Linting with auto-fix
+- ``ruff-format`` - Code formatting
+- ``codespell`` - Spell checking
+
+Testing
+-------
+
+Running Tests
+~~~~~~~~~~~~~
+
+.. code-block:: bash
+
+ # Run all tests
+ spin test
+
+ # Run tests matching a pattern
+ spin test -m memmap
+
+ # Run with pytest directly
+ pytest trx/tests
+
+ # Run with coverage
+ pytest trx/tests --cov=trx --cov-report=term-missing
+
+Test Data
+~~~~~~~~~
+
+Test data is automatically downloaded from Figshare on first run.
+Data is cached in ``~/.tee_ar_ex/``.
+
+You can manually fetch test data:
+
+.. code-block:: python
+
+ from trx.fetcher import fetch_data, get_testing_files_dict
+ fetch_data(get_testing_files_dict())
+
+Writing Tests
+~~~~~~~~~~~~~
+
+- Tests go in ``trx/tests/``
+- Use pytest fixtures for setup/teardown
+- Use ``pytest.mark.skipif`` for conditional tests
+
+Example:
+
+.. code-block:: python
+
+ import pytest
+ import numpy as np
+ from numpy.testing import assert_array_equal
+
+ def test_my_function():
+ result = my_function(input_data)
+ expected = np.array([1, 2, 3])
+ assert_array_equal(result, expected)
+
+ @pytest.mark.skipif(not dipy_available, reason="Dipy required")
+ def test_with_dipy():
+ # Test that requires dipy
+ pass
+
+Documentation
+-------------
+
+Building Documentation
+~~~~~~~~~~~~~~~~~~~~~~
+
+.. code-block:: bash
+
+ # Build docs
+ spin docs
+
+ # Clean build
+ spin docs --clean
+
+ # Build and open in browser
+ spin docs --open
+
+Documentation is built with Sphinx and uses:
+
+- ``pydata-sphinx-theme`` for styling
+- ``sphinx-autoapi`` for API documentation
+- ``numpydoc`` for NumPy-style docstrings
+
+Writing Documentation
+~~~~~~~~~~~~~~~~~~~~~
+
+- Source files are in ``docs/source/``
+- Use reStructuredText format
+- API docs are auto-generated from docstrings
+
+NumPy Docstring Format
+~~~~~~~~~~~~~~~~~~~~~~
+
+All functions and classes should be documented using NumPy-style docstrings:
+
+.. code-block:: python
+
+ def load(filename, reference=None):
+ """Load a tractogram file.
+
+ Parameters
+ ----------
+ filename : str
+ Path to the tractogram file.
+ reference : str, optional
+ Path to reference anatomy for formats that require it.
+
+ Returns
+ -------
+ tractogram : TrxFile or StatefulTractogram
+ The loaded tractogram.
+
+ Raises
+ ------
+ ValueError
+ If the file format is not supported.
+
+ See Also
+ --------
+ save : Save a tractogram to file.
+
+ Examples
+ --------
+ >>> from trx.io import load
+ >>> trx = load("tractogram.trx")
+ """
+ pass
+
+Project Structure
+-----------------
+
+.. code-block:: text
+
+ trx-python/
+ ├── trx/ # Main package
+ │ ├── __init__.py
+ │ ├── cli.py # Command-line interface (Typer)
+ │ ├── fetcher.py # Test data fetching
+ │ ├── io.py # Unified I/O interface
+ │ ├── streamlines_ops.py # Streamline operations
+ │ ├── trx_file_memmap.py # Core TrxFile class
+ │ ├── utils.py # Utility functions
+ │ ├── viz.py # Visualization (optional)
+ │ ├── workflows.py # High-level workflows
+ │ └── tests/ # Test suite
+ ├── docs/ # Documentation
+ │ └── source/
+ ├── .github/ # GitHub Actions workflows
+ │ └── workflows/
+ ├── .spin/ # Spin configuration
+ │ └── cmds.py
+ ├── pyproject.toml # Project configuration
+ ├── ruff.toml # Ruff configuration
+ └── .pre-commit-config.yaml # Pre-commit hooks
+
+Continuous Integration
+----------------------
+
+GitHub Actions runs on every push and pull request:
+
+- **test.yml**: Runs tests on Python 3.11-3.13 across Linux, macOS, Windows
+- **codeformat.yml**: Checks code formatting with pre-commit/ruff
+- **coverage.yml**: Generates code coverage reports
+- **docbuild.yml**: Builds and deploys documentation
+
+Environment Variables
+---------------------
+
+TRX_TMPDIR
+~~~~~~~~~~
+
+Controls where temporary files are stored during memory-mapped operations.
+
+.. code-block:: bash
+
+ # Use a specific directory
+ export TRX_TMPDIR=/path/to/tmp
+
+ # Use current working directory
+ export TRX_TMPDIR=use_working_dir
+
+Default: System temp directory (``/tmp`` on Linux/macOS, ``C:\WINDOWS\Temp`` on Windows)
+
+Release Process
+---------------
+
+Releases are managed via GitHub:
+
+1. Update version in ``pyproject.toml`` if needed
+2. Create a GitHub release with appropriate tag
+3. CI automatically publishes to PyPI
+
+Version Detection
+~~~~~~~~~~~~~~~~~
+
+We use ``setuptools_scm`` for automatic version detection from git tags.
+This requires:
+
+- Proper git tags from upstream
+- Running ``spin setup`` after cloning a fork
diff --git a/docs/source/index.rst b/docs/source/index.rst
index 84be5d0..55ecf8e 100644
--- a/docs/source/index.rst
+++ b/docs/source/index.rst
@@ -42,8 +42,19 @@ Development of TRX is supported by `NIMH grant 1R01MH126699 --help # Show help for a specific command
+
+Available subcommands:
+
+- ``tff concatenate`` - Concatenate multiple tractograms
+- ``tff convert`` - Convert between tractography formats
+- ``tff convert-dsi`` - Fix DSI-Studio TRK files
+- ``tff generate`` - Generate TRX from raw data files
+- ``tff manipulate-dtype`` - Change array data types
+- ``tff compare`` - Simple tractogram comparison
+- ``tff validate`` - Validate and clean TRX files
+- ``tff verify-header`` - Check header compatibility
+- ``tff visualize`` - Visualize tractogram overlap
+
+Standalone Commands
+-------------------
+
+For backward compatibility, standalone commands are also available:
+
+tff_concatenate_tractograms
+~~~~~~~~~~~~~~~~~~~~~~~~~~~
+Concatenate multiple tractograms into a single output.
+
+- Supports ``trk``, ``tck``, ``vtk``, ``fib``, ``dpy``, and ``trx`` inputs.
+- Flags: ``--delete-dpv``, ``--delete-dps``, ``--delete-groups`` to drop mismatched metadata; ``--reference`` for formats needing an anatomy reference; ``-f`` to overwrite.
+
+.. code-block:: bash
+
+ # Using unified CLI
+ tff concatenate in1.trk in2.trk merged.trx
+
+ # Using standalone command
+ tff_concatenate_tractograms in1.trk in2.trk merged.trx
+
+tff_convert_dsi_studio
+~~~~~~~~~~~~~~~~~~~~~~
+Convert a DSI Studio ``.trk`` with accompanying ``.nii.gz`` reference into a cleaned ``.trk`` or TRX.
+
+.. code-block:: bash
+
+ # Using unified CLI
+ tff convert-dsi input.trk reference.nii.gz cleaned.trk
+
+ # Using standalone command
+ tff_convert_dsi_studio input.trk reference.nii.gz cleaned.trk
+
+tff_convert_tractogram
+~~~~~~~~~~~~~~~~~~~~~~
+General-purpose converter between ``trk``, ``tck``, ``vtk``, ``fib``, ``dpy``, and ``trx``.
+
+- Flags: ``--reference`` for formats needing a NIfTI, ``--positions-dtype``, ``--offsets-dtype``, ``-f`` to overwrite.
+
+.. code-block:: bash
+
+ # Using unified CLI
+ tff convert input.trk output.trx --positions-dtype float32 --offsets-dtype uint64
+
+ # Using standalone command
+ tff_convert_tractogram input.trk output.trx --positions-dtype float32 --offsets-dtype uint64
+
+tff_generate_trx_from_scratch
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+Build a TRX file from raw NumPy arrays or CSV streamline coordinates.
+
+- Flags: ``--positions``, ``--offsets``, ``--positions-dtype``, ``--offsets-dtype``, spatial options (``--space``, ``--origin``), and metadata loaders for dpv/dps/groups/dpg.
+
+.. code-block:: bash
+
+ # Using unified CLI
+ tff generate fa.nii.gz output.trx --positions positions.npy --offsets offsets.npy
+
+ # Using standalone command
+ tff_generate_trx_from_scratch fa.nii.gz output.trx --positions positions.npy --offsets offsets.npy
+
+tff_manipulate_datatype
+~~~~~~~~~~~~~~~~~~~~~~~
+Rewrite TRX datasets with new dtypes for positions/offsets/dpv/dps/dpg/groups.
+
+- Accepts per-field dtype arguments and overwrites with ``-f``.
+
+.. code-block:: bash
+
+ # Using unified CLI
+ tff manipulate-dtype input.trx output.trx --positions-dtype float16 --dpv color,uint8
+
+ # Using standalone command
+ tff_manipulate_datatype input.trx output.trx --positions-dtype float16 --dpv color,uint8
+
+tff_simple_compare
+~~~~~~~~~~~~~~~~~~
+Compare two tractograms for quick difference checks.
+
+.. code-block:: bash
+
+ # Using unified CLI
+ tff compare first.trk second.trk
+
+ # Using standalone command
+ tff_simple_compare first.trk second.trk
+
+tff_validate_trx
+~~~~~~~~~~~~~~~~
+Validate a TRX file for consistency and remove invalid streamlines.
+
+.. code-block:: bash
+
+ # Using unified CLI
+ tff validate data.trx --out cleaned.trx
+
+ # Using standalone command
+ tff_validate_trx data.trx --out cleaned.trx
+
+tff_verify_header_compatibility
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+Check whether tractogram headers are compatible for operations such as concatenation.
+
+.. code-block:: bash
+
+ # Using unified CLI
+ tff verify-header file1.trk file2.trk
+
+ # Using standalone command
+ tff_verify_header_compatibility file1.trk file2.trk
+
+tff_visualize_overlap
+~~~~~~~~~~~~~~~~~~~~~
+Visualize streamline overlap between tractograms (requires visualization dependencies).
+
+.. code-block:: bash
+
+ # Using unified CLI
+ tff visualize tractogram.trk reference.nii.gz
+
+ # Using standalone command
+ tff_visualize_overlap tractogram.trk reference.nii.gz
+
+Notes
+-----
+- Test datasets for examples can be fetched with ``python -m trx.fetcher`` helpers: ``fetch_data(get_testing_files_dict())`` downloads to ``$TRX_HOME`` (default ``~/.tee_ar_ex``).
+- All commands print detailed usage with ``--help``.
+- The unified ``tff`` CLI uses `Typer `_ for beautiful terminal output with colors and rich formatting.
diff --git a/pyproject.toml b/pyproject.toml
index 6d97948..46c5a1c 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,90 @@
[build-system]
-requires = ["setuptools >= 42.0", "wheel", "setuptools_scm[toml] >= 7"]
-build-backend = "setuptools.build_meta:__legacy__"
+requires = ["setuptools >= 64", "wheel", "setuptools_scm[toml] >= 7"]
+build-backend = "setuptools.build_meta"
+
+[project]
+name = "trx-python"
+dynamic = ["version"]
+description = "Experiments with new file format for tractography"
+readme = "README.md"
+license = {text = "BSD License"}
+requires-python = ">=3.11"
+authors = [
+ {name = "The TRX developers"}
+]
+classifiers = [
+ "Development Status :: 3 - Alpha",
+ "Environment :: Console",
+ "Intended Audience :: Science/Research",
+ "License :: OSI Approved :: BSD License",
+ "Operating System :: OS Independent",
+ "Programming Language :: Python",
+ "Programming Language :: Python :: 3",
+ "Programming Language :: Python :: 3.11",
+ "Programming Language :: Python :: 3.12",
+ "Programming Language :: Python :: 3.13",
+ "Topic :: Scientific/Engineering",
+]
+dependencies = [
+ "deepdiff",
+ "nibabel >= 5",
+ "numpy >= 1.22",
+ "typer >= 0.9",
+]
+
+[project.optional-dependencies]
+dev = [
+ "spin >= 0.13",
+ "setuptools_scm",
+]
+doc = [
+ "astroid >= 4.0.0",
+ "sphinx >= 9.0.0",
+ "pydata-sphinx-theme >= 0.16.1",
+ "sphinx-autoapi >= 3.0.0",
+ "numpydoc",
+]
+style = [
+ "codespell",
+ "pre-commit",
+ "ruff",
+]
+test = [
+ "psutil",
+ "pytest >= 7",
+ "pytest-console-scripts >= 0",
+ "pytest-cov",
+]
+all = [
+ "trx-python[dev]",
+ "trx-python[doc]",
+ "trx-python[style]",
+ "trx-python[test]",
+]
+
+[project.urls]
+Homepage = "https://github.com/tee-ar-ex/trx-python"
+Documentation = "https://tee-ar-ex.github.io/trx-python/"
+Repository = "https://github.com/tee-ar-ex/trx-python"
+
+[project.scripts]
+tff = "trx.cli:main"
+tff_concatenate_tractograms = "trx.cli:concatenate_tractograms_cmd"
+tff_convert_dsi_studio = "trx.cli:convert_dsi_cmd"
+tff_convert_tractogram = "trx.cli:convert_cmd"
+tff_generate_trx_from_scratch = "trx.cli:generate_cmd"
+tff_manipulate_datatype = "trx.cli:manipulate_dtype_cmd"
+tff_simple_compare = "trx.cli:compare_cmd"
+tff_validate_trx = "trx.cli:validate_cmd"
+tff_verify_header_compatibility = "trx.cli:verify_header_cmd"
+tff_visualize_overlap = "trx.cli:visualize_cmd"
+
+[tool.setuptools]
+packages = ["trx"]
+include-package-data = true
+
+[tool.setuptools.dynamic]
+version = {attr = "trx._version.__version__"}
[tool.setuptools_scm]
write_to = "trx/_version.py"
@@ -10,3 +94,13 @@ __version__ = "{version}"
"""
fallback_version = "0.0"
local_scheme = "no-local-version"
+
+[tool.spin]
+package = "trx"
+
+[tool.spin.commands]
+"Setup" = [".spin/cmds.py:setup"]
+"Build" = ["spin.cmds.pip.install"]
+"Test" = [".spin/cmds.py:test", ".spin/cmds.py:lint"]
+"Docs" = [".spin/cmds.py:docs"]
+"Clean" = [".spin/cmds.py:clean"]
diff --git a/ruff.toml b/ruff.toml
new file mode 100644
index 0000000..0cfb2a9
--- /dev/null
+++ b/ruff.toml
@@ -0,0 +1,41 @@
+target-version = "py312"
+
+line-length = 88
+force-exclude = true
+extend-exclude = [
+ "__pycache__",
+ "build",
+ "_version.py",
+ "docs/**",
+]
+
+[lint]
+select = [
+ "F", # Pyflakes
+ "E", # pycodestyle errors
+ "W", # pycodestyle warnings
+ "C", # mccabe complexity
+ "B", # flake8-bugbear
+ "I", # isort
+]
+ignore = [
+ "B905", # zip without explicit strict parameter
+ "C901", # too complex
+ "E203", # whitespace before ':'
+]
+
+[lint.extend-per-file-ignores]
+"trx/tests/**" = ["B011"]
+
+[lint.isort]
+case-sensitive = true
+combine-as-imports = true
+force-sort-within-sections = true
+known-first-party = ["trx"]
+no-sections = false
+order-by-type = true
+relative-imports-order = "closest-to-furthest"
+section-order = ["future", "standard-library", "third-party", "first-party", "local-folder"]
+
+[format]
+quote-style = "double"
diff --git a/scripts/tests/test_workflows.py b/scripts/tests/test_workflows.py
deleted file mode 100644
index 3ae5e95..0000000
--- a/scripts/tests/test_workflows.py
+++ /dev/null
@@ -1,322 +0,0 @@
-#! /usr/bin/env python3
-# -*- coding: utf-8 -*-
-
-from deepdiff import DeepDiff
-import os
-import tempfile
-
-import pytest
-import numpy as np
-from numpy.testing import assert_equal, assert_array_equal, assert_allclose
-try:
- from dipy.io.streamline import load_tractogram
- dipy_available = True
-except ImportError:
- dipy_available = False
-
-from trx.fetcher import (get_testing_files_dict,
- fetch_data, get_home)
-import trx.trx_file_memmap as tmm
-from trx.workflows import (convert_dsi_studio,
- convert_tractogram,
- manipulate_trx_datatype,
- generate_trx_from_scratch,
- validate_tractogram,)
-
-
-# If they already exist, this only takes 5 seconds (check md5sum)
-fetch_data(get_testing_files_dict(), keys=['DSI.zip', 'trx_from_scratch.zip'])
-
-
-def test_help_option_convert_dsi(script_runner):
- ret = script_runner.run('tff_convert_dsi_studio.py', '--help')
- assert ret.success
-
-
-def test_help_option_convert(script_runner):
- ret = script_runner.run('tff_convert_tractogram.py', '--help')
- assert ret.success
-
-
-def test_help_option_generate_trx_from_scratch(script_runner):
- ret = script_runner.run('tff_generate_trx_from_scratch.py', '--help')
- assert ret.success
-
-
-@pytest.mark.skipif(not dipy_available,
- reason='Dipy is not installed.')
-def test_execution_convert_dsi():
- with tempfile.TemporaryDirectory() as tmp_dir:
- in_trk = os.path.join(get_home(), 'DSI',
- 'CC.trk.gz')
- in_nii = os.path.join(get_home(), 'DSI',
- 'CC.nii.gz')
- exp_data = os.path.join(get_home(), 'DSI',
- 'CC_fix_data.npy')
- exp_offsets = os.path.join(get_home(), 'DSI',
- 'CC_fix_offsets.npy')
- out_fix_path = os.path.join(tmp_dir, 'fixed.trk')
- convert_dsi_studio(in_trk, in_nii, out_fix_path,
- remove_invalid=False,
- keep_invalid=True)
-
- data_fix = np.load(exp_data)
- offsets_fix = np.load(exp_offsets)
-
- sft = load_tractogram(out_fix_path, 'same')
- assert_equal(sft.streamlines._data, data_fix)
- assert_equal(sft.streamlines._offsets, offsets_fix)
-
-
-@pytest.mark.skipif(not dipy_available,
- reason='Dipy is not installed.')
-def test_execution_convert_to_trx():
- with tempfile.TemporaryDirectory() as tmp_dir:
- in_trk = os.path.join(get_home(), 'DSI',
- 'CC_fix.trk')
- exp_data = os.path.join(get_home(), 'DSI',
- 'CC_fix_data.npy')
- exp_offsets = os.path.join(get_home(), 'DSI',
- 'CC_fix_offsets.npy')
- out_trx_path = os.path.join(tmp_dir, 'CC_fix.trx')
- convert_tractogram(in_trk, out_trx_path, None)
-
- data_fix = np.load(exp_data)
- offsets_fix = np.load(exp_offsets)
-
- trx = tmm.load(out_trx_path)
- assert_equal(trx.streamlines._data.dtype, np.float32)
- assert_equal(trx.streamlines._offsets.dtype, np.uint32)
- assert_array_equal(trx.streamlines._data, data_fix)
- assert_array_equal(trx.streamlines._offsets, offsets_fix)
- trx.close()
-
-
-@pytest.mark.skipif(not dipy_available,
- reason='Dipy is not installed.')
-def test_execution_convert_from_trx():
- with tempfile.TemporaryDirectory() as tmp_dir:
- in_trk = os.path.join(get_home(), 'DSI',
- 'CC_fix.trk')
- in_nii = os.path.join(get_home(), 'DSI',
- 'CC.nii.gz')
- exp_data = os.path.join(get_home(), 'DSI',
- 'CC_fix_data.npy')
- exp_offsets = os.path.join(get_home(), 'DSI',
- 'CC_fix_offsets.npy')
-
- # Sequential conversions
- out_trx_path = os.path.join(tmp_dir, 'CC_fix.trx')
- out_trk_path = os.path.join(tmp_dir, 'CC_fix.trk')
- out_tck_path = os.path.join(tmp_dir, 'CC_fix.tck')
- convert_tractogram(in_trk, out_trx_path, None)
- convert_tractogram(out_trx_path, out_tck_path, None)
- convert_tractogram(out_trx_path, out_trk_path, None)
-
- data_fix = np.load(exp_data)
- offsets_fix = np.load(exp_offsets)
-
- sft = load_tractogram(out_trk_path, 'same')
- assert_equal(sft.streamlines._data, data_fix)
- assert_equal(sft.streamlines._offsets, offsets_fix)
-
- sft = load_tractogram(out_tck_path, in_nii)
- assert_equal(sft.streamlines._data, data_fix)
- assert_equal(sft.streamlines._offsets, offsets_fix)
-
-
-@pytest.mark.skipif(not dipy_available,
- reason='Dipy is not installed.')
-def test_execution_convert_dtype_p16_o64():
- with tempfile.TemporaryDirectory() as tmp_dir:
- in_trk = os.path.join(get_home(), 'DSI',
- 'CC_fix.trk')
- out_convert_path = os.path.join(tmp_dir, 'CC_fix_p16_o64.trx')
- convert_tractogram(in_trk, out_convert_path, None,
- pos_dtype='float16', offsets_dtype='uint64')
-
- trx = tmm.load(out_convert_path)
- assert_equal(trx.streamlines._data.dtype, np.float16)
- assert_equal(trx.streamlines._offsets.dtype, np.uint64)
- trx.close()
-
-
-@pytest.mark.skipif(not dipy_available,
- reason='Dipy is not installed.')
-def test_execution_convert_dtype_p64_o32():
- with tempfile.TemporaryDirectory() as tmp_dir:
- in_trk = os.path.join(get_home(), 'DSI',
- 'CC_fix.trk')
- out_convert_path = os.path.join(tmp_dir, 'CC_fix_p16_o64.trx')
- convert_tractogram(in_trk, out_convert_path, None,
- pos_dtype='float64', offsets_dtype='uint32')
-
- trx = tmm.load(out_convert_path)
- assert_equal(trx.streamlines._data.dtype, np.float64)
- assert_equal(trx.streamlines._offsets.dtype, np.uint32)
- trx.close()
-
-
-def test_execution_generate_trx_from_scratch():
- with tempfile.TemporaryDirectory() as tmp_dir:
- reference_fa = os.path.join(get_home(), 'trx_from_scratch',
- 'fa.nii.gz')
- raw_arr_dir = os.path.join(get_home(), 'trx_from_scratch',
- 'test_npy')
- expected_trx = os.path.join(get_home(), 'trx_from_scratch',
- 'expected.trx')
-
- dpv = [(os.path.join(raw_arr_dir, 'dpv_cx.npy'), 'uint8'),
- (os.path.join(raw_arr_dir, 'dpv_cy.npy'), 'uint8'),
- (os.path.join(raw_arr_dir, 'dpv_cz.npy'), 'uint8')]
- dps = [(os.path.join(raw_arr_dir, 'dps_algo.npy'), 'uint8'),
- (os.path.join(raw_arr_dir, 'dps_cw.npy'), 'float64')]
- dpg = [('g_AF_L', os.path.join(raw_arr_dir, 'dpg_AF_L_mean_fa.npy'), 'float32'),
- ('g_AF_R', os.path.join(raw_arr_dir, 'dpg_AF_R_mean_fa.npy'), 'float32'),
- ('g_AF_L', os.path.join(raw_arr_dir, 'dpg_AF_L_volume.npy'), 'float32')]
- groups = [(os.path.join(raw_arr_dir, 'g_AF_L.npy'), 'int32'),
- (os.path.join(raw_arr_dir, 'g_AF_R.npy'), 'int32'),
- (os.path.join(raw_arr_dir, 'g_CST_L.npy'), 'int32')]
-
- out_gen_path = os.path.join(tmp_dir, 'generated.trx')
- generate_trx_from_scratch(reference_fa, out_gen_path,
- positions=os.path.join(raw_arr_dir,
- 'positions.npy'),
- offsets=os.path.join(raw_arr_dir,
- 'offsets.npy'),
- positions_dtype='float16',
- offsets_dtype='uint64',
- space_str='rasmm', origin_str='nifti',
- verify_invalid=False, dpv=dpv, dps=dps,
- groups=groups, dpg=dpg)
- exp_trx = tmm.load(expected_trx)
- gen_trx = tmm.load(out_gen_path)
-
- assert DeepDiff(exp_trx.get_dtype_dict(),
- gen_trx.get_dtype_dict()) == {}
-
- assert_allclose(exp_trx.streamlines._data, gen_trx.streamlines._data,
- atol=0.1, rtol=0.1)
- assert_equal(exp_trx.streamlines._offsets,
- gen_trx.streamlines._offsets)
-
- for key in exp_trx.data_per_vertex.keys():
- assert_equal(exp_trx.data_per_vertex[key]._data,
- gen_trx.data_per_vertex[key]._data)
- assert_equal(exp_trx.data_per_vertex[key]._offsets,
- gen_trx.data_per_vertex[key]._offsets)
- for key in exp_trx.data_per_streamline.keys():
- assert_equal(exp_trx.data_per_streamline[key],
- gen_trx.data_per_streamline[key])
- for key in exp_trx.groups.keys():
- assert_equal(exp_trx.groups[key], gen_trx.groups[key])
-
- for group in exp_trx.groups.keys():
- if group in exp_trx.data_per_group:
- for key in exp_trx.data_per_group[group].keys():
- assert_equal(exp_trx.data_per_group[group][key],
- gen_trx.data_per_group[group][key])
- exp_trx.close()
- gen_trx.close()
-
-
-@pytest.mark.skipif(not dipy_available,
- reason='Dipy is not installed.')
-def test_execution_concatenate_validate_trx():
- with tempfile.TemporaryDirectory() as tmp_dir:
- trx1 = tmm.load(os.path.join(get_home(), 'gold_standard',
- 'gs.trx'))
- trx2 = tmm.load(os.path.join(get_home(), 'gold_standard',
- 'gs.trx'))
- # trx2.streamlines._data += 0.001
- trx = tmm.concatenate([trx1, trx2], preallocation=False)
-
- # Right size
- assert_equal(len(trx.streamlines), 2*len(trx1.streamlines))
-
- # Right data
- end_idx = trx1.header['NB_VERTICES']
- assert_allclose(
- trx.streamlines._data[:end_idx], trx1.streamlines._data)
- assert_allclose(
- trx.streamlines._data[end_idx:], trx2.streamlines._data)
-
- # Right data_per_*
- for key in trx.data_per_vertex.keys():
- assert_equal(trx.data_per_vertex[key]._data[:end_idx],
- trx1.data_per_vertex[key]._data)
- assert_equal(trx.data_per_vertex[key]._data[end_idx:],
- trx2.data_per_vertex[key]._data)
-
- end_idx = trx1.header['NB_STREAMLINES']
- for key in trx.data_per_streamline.keys():
- assert_equal(trx.data_per_streamline[key][:end_idx],
- trx1.data_per_streamline[key])
- assert_equal(trx.data_per_streamline[key][end_idx:],
- trx2.data_per_streamline[key])
-
- # Validate
- out_concat_path = os.path.join(tmp_dir, 'concat.trx')
- out_valid_path = os.path.join(tmp_dir, 'valid.trx')
- tmm.save(trx, out_concat_path)
- validate_tractogram(out_concat_path, None, out_valid_path,
- remove_identical_streamlines=True,
- precision=0)
- trx_val = tmm.load(out_valid_path)
-
- # # Right dtype and size
- assert DeepDiff(trx.get_dtype_dict(), trx_val.get_dtype_dict()) == {}
- assert_equal(len(trx1.streamlines), len(trx_val.streamlines))
-
- trx.close()
- trx1.close()
- trx2.close()
- trx_val.close()
-
-
-@pytest.mark.skipif(not dipy_available,
- reason='Dipy is not installed.')
-def test_execution_manipulate_trx_datatype():
- with tempfile.TemporaryDirectory() as tmp_dir:
- expected_trx = os.path.join(get_home(), 'trx_from_scratch',
- 'expected.trx')
- trx = tmm.load(expected_trx)
-
- expected_dtype = {'positions': np.dtype('float16'),
- 'offsets': np.dtype('uint64'),
- 'dpv': {'dpv_cx': np.dtype('uint8'),
- 'dpv_cy': np.dtype('uint8'),
- 'dpv_cz': np.dtype('uint8')},
- 'dps': {'dps_algo': np.dtype('uint8'),
- 'dps_cw': np.dtype('float64')},
- 'dpg': {'g_AF_L':
- {'dpg_AF_L_mean_fa': np.dtype('float32'),
- 'dpg_AF_L_volume': np.dtype('float32')},
- 'g_AF_R':
- {'dpg_AF_R_mean_fa': np.dtype('float32')}},
- 'groups': {'g_AF_L': np.dtype('int32'),
- 'g_AF_R': np.dtype('int32')}}
-
- assert DeepDiff(trx.get_dtype_dict(), expected_dtype) == {}
- trx.close()
-
- generated_dtype = {'positions': np.dtype('float32'),
- 'offsets': np.dtype('uint32'),
- 'dpv': {'dpv_cx': np.dtype('uint16'),
- 'dpv_cy': np.dtype('uint16'),
- 'dpv_cz': np.dtype('uint16')},
- 'dps': {'dps_algo': np.dtype('uint8'),
- 'dps_cw': np.dtype('float32')},
- 'dpg': {'g_AF_L':
- {'dpg_AF_L_mean_fa': np.dtype('float64'),
- 'dpg_AF_L_volume': np.dtype('float32')},
- 'g_AF_R':
- {'dpg_AF_R_mean_fa': np.dtype('float64')}},
- 'groups': {'g_AF_L': np.dtype('uint16'),
- 'g_AF_R': np.dtype('uint16')}}
-
- out_gen_path = os.path.join(tmp_dir, 'generated.trx')
- manipulate_trx_datatype(expected_trx, out_gen_path, generated_dtype)
- trx = tmm.load(out_gen_path)
- assert DeepDiff(trx.get_dtype_dict(), generated_dtype) == {}
- trx.close()
diff --git a/scripts/tff_concatenate_tractograms.py b/scripts/tff_concatenate_tractograms.py
deleted file mode 100755
index 30ac7a9..0000000
--- a/scripts/tff_concatenate_tractograms.py
+++ /dev/null
@@ -1,74 +0,0 @@
-#!/usr/bin/env python3
-# -*- coding: utf-8 -*-
-
-"""
-Concatenate multiple tractograms into one.
-
-If the data_per_point or data_per_streamline is not the same for all
-tractograms, the data must be deleted first.
-"""
-
-import argparse
-import os
-
-from trx.io import load, save
-from trx.trx_file_memmap import TrxFile, concatenate
-
-
-def _build_arg_parser():
- p = argparse.ArgumentParser(description=__doc__,
- formatter_class=argparse.RawTextHelpFormatter)
-
- p.add_argument('in_tractograms', nargs='+',
- help='Tractogram filename. Format must be one of \n'
- 'trk, tck, vtk, fib, dpy, trx.')
- p.add_argument('out_tractogram',
- help='Filename of the concatenated tractogram.')
-
- p.add_argument('--delete_dpv', action='store_true',
- help='Delete the dpv if it exists. '
- 'Required if not all input has the same metadata.')
- p.add_argument('--delete_dps', action='store_true',
- help='Delete the dps if it exists. '
- 'Required if not all input has the same metadata.')
- p.add_argument('--delete_groups', action='store_true',
- help='Delete the groups if it exists. '
- 'Required if not all input has the same metadata.')
- p.add_argument('--reference',
- help='Reference anatomy for tck/vtk/fib/dpy file\n'
- 'support (.nii or .nii.gz).')
- p.add_argument('-f', dest='overwrite', action='store_true',
- help='Force overwriting of the output files.')
-
- return p
-
-
-def main():
- parser = _build_arg_parser()
- args = parser.parse_args()
-
- if os.path.isfile(args.out_tractogram) and not args.overwrite:
- raise IOError('{} already exists, use -f to overwrite.'.format(
- args.out_tractogram))
-
- trx_list = []
- has_group = False
- for filename in args.in_tractograms:
- tractogram_obj = load(filename, args.reference)
-
- if not isinstance(tractogram_obj, TrxFile):
- tractogram_obj = TrxFile.from_sft(tractogram_obj)
- elif len(tractogram_obj.groups):
- has_group = True
- trx_list.append(tractogram_obj)
-
- trx = concatenate(trx_list, delete_dpv=args.delete_dpv,
- delete_dps=args.delete_dps,
- delete_groups=args.delete_groups or not has_group,
- check_space_attributes=True,
- preallocation=False)
- save(trx, args.out_tractogram)
-
-
-if __name__ == "__main__":
- main()
diff --git a/scripts/tff_convert_dsi_studio.py b/scripts/tff_convert_dsi_studio.py
deleted file mode 100644
index b99111f..0000000
--- a/scripts/tff_convert_dsi_studio.py
+++ /dev/null
@@ -1,63 +0,0 @@
-#!/usr/bin/env python3
-# -*- coding: utf-8 -*-
-
-"""
-This script is made to fix DSI-Studio TRK file (unknown space/convention) to
-make it compatible with TrackVis, MI-Brain, Dipy Horizon (Stateful Tractogram).
-
-The script either make it match with an anatomy from DSI-Studio.
-
-This script was tested on various datasets and worked on all of them. However,
-always verify the results and if a specific case does not work. Open an issue
-on the Scilpy GitHub repository.
-
-WARNING: This script is still experimental, DSI-Studio evolves quickly and
-results may vary depending on the data itself as well as DSI-studio version.
-"""
-
-import argparse
-import os
-
-from trx.workflows import convert_dsi_studio
-
-
-def _build_arg_parser():
- p = argparse.ArgumentParser(
- description=__doc__, formatter_class=argparse.RawTextHelpFormatter)
-
- p.add_argument('in_dsi_tractogram', metavar='IN_DSI_TRACTOGRAM',
- help='Path of the input tractogram file from DSI studio '
- '(.trk).')
- p.add_argument('in_dsi_fa', metavar='IN_DSI_FA',
- help='Path of the input FA from DSI Studio (.nii.gz).')
- p.add_argument('out_tractogram', metavar='OUT_TRACTOGRAM',
- help='Path of the output tractogram file.')
-
- invalid = p.add_mutually_exclusive_group()
- invalid.add_argument('--remove_invalid', action='store_true',
- help='Remove the streamlines landing out of the '
- 'bounding box.')
- invalid.add_argument('--keep_invalid', action='store_true',
- help='Keep the streamlines landing out of the '
- 'bounding box.')
- p.add_argument('-f', dest='overwrite', action='store_true',
- help='Force overwriting of the output files.')
-
- return p
-
-
-def main():
- parser = _build_arg_parser()
- args = parser.parse_args()
-
- if os.path.isfile(args.out_tractogram) and not args.overwrite:
- raise IOError('{} already exists, use -f to overwrite.'.format(
- args.out_tractogram))
-
- convert_dsi_studio(args.in_dsi_tractogram, args.in_dsi_fa,
- args.out_tractogram, remove_invalid=args.remove_invalid,
- keep_invalid=args.keep_invalid)
-
-
-if __name__ == "__main__":
- main()
diff --git a/scripts/tff_convert_tractogram.py b/scripts/tff_convert_tractogram.py
deleted file mode 100644
index cffee68..0000000
--- a/scripts/tff_convert_tractogram.py
+++ /dev/null
@@ -1,59 +0,0 @@
-#!/usr/bin/env python3
-# -*- coding: utf-8 -*-
-
-"""
-Conversion of '.tck', '.trk', '.fib', '.vtk', '.trx' and 'dpy' files using
-updated file format standard. TCK file always needs a reference file, a NIFTI,
-for conversion. The FIB file format is in fact a VTK, MITK Diffusion supports
-it.
-"""
-
-import argparse
-import os
-
-from trx.workflows import convert_tractogram
-
-
-def _build_arg_parser():
- p = argparse.ArgumentParser(description=__doc__,
- formatter_class=argparse.RawTextHelpFormatter)
-
- p.add_argument('in_tractogram', metavar='IN_TRACTOGRAM',
- help='Tractogram filename. Format must be one of \n'
- 'trk, tck, vtk, fib, dpy, trx.')
- p.add_argument('out_tractogram', metavar='OUT_TRACTOGRAM',
- help='Output filename. Format must be one of \n'
- 'trk, tck, vtk, fib, dpy, trx.')
-
- p.add_argument('--reference',
- help='Reference anatomy for tck/vtk/fib/dpy file\n'
- 'support (.nii or .nii.gz).')
-
- p2 = p.add_argument_group(title='Data type options')
- p2.add_argument('--positions_dtype', default='float32',
- choices=['float16', 'float32', 'float64'],
- help='Specify the datatype for positions for trx. [%(default)s]')
- p2.add_argument('--offsets_dtype', default='uint64',
- choices=['uint32', 'uint64'],
- help='Specify the datatype for offsets for trx. [%(default)s]')
- p.add_argument('-f', dest='overwrite', action='store_true',
- help='Force overwriting of the output files.')
-
- return p
-
-
-def main():
- parser = _build_arg_parser()
- args = parser.parse_args()
-
- if os.path.isfile(args.out_tractogram) and not args.overwrite:
- raise IOError('{} already exists, use -f to overwrite.'.format(
- args.out_tractogram))
-
- convert_tractogram(args.in_tractogram, args.out_tractogram, args.reference,
- pos_dtype=args.positions_dtype,
- offsets_dtype=args.offsets_dtype)
-
-
-if __name__ == "__main__":
- main()
diff --git a/scripts/tff_generate_trx_from_scratch.py b/scripts/tff_generate_trx_from_scratch.py
deleted file mode 100755
index 38fd1f9..0000000
--- a/scripts/tff_generate_trx_from_scratch.py
+++ /dev/null
@@ -1,140 +0,0 @@
-#!/usr/bin/env python3
-# -*- coding: utf-8 -*-
-
-"""
-Generate TRX file from a collection of CSV, TXT or NPY files by individually
-specifying positions, offsets, data_per_vertex, data_per_streamlines,
-groups and data_per_group. Each file must have its data type specified by the
-users.
-
-A reference file must be provided (NIFTI) and the option --verify_invalid will
-remove invalid streamlines (outside of the bounding box in VOX space).
-
-All dimensions (nbr_vertices and nbr_streamlines) and groups/dpg must match
-otherwise the script will (likely) crash.
-
-Each instance of --dps, --dpv, --groups require 2 arguments (FILE, DTYPE).
---dpg requires 3 arguments (GROUP, FILE, DTYPE).
-The choice of DTYPE are:
- - (u)int8, (u)int16, (u)int32, (u)int64
- - float16, float32, float64
- - bool
-
-Example command:
-tff_generate_trx_from_scratch.py fa.nii.gz generated.trx -f \
- --positions test_npy/positions.npy --positions_dtype float16 \
- --offsets test_npy/offsets.npy --offsets_dtype uint32 \
- --dpv test_npy/dpv_cx.npy uint8 \
- --dpv test_npy/dpv_cy.npy uint8 \
- --dpv test_npy/dpv_cz.npy uint8 \
- --dps test_npy/dps_algo.npy uint8 \
- --dps test_npy/dps_cw.npy float64 \
- --groups test_npy/g_AF_L.npy int32 \
- --groups test_npy/g_AF_R.npy int32 \
- --dpg g_AF_L test_npy/dpg_AF_L_mean_fa.npy float32 \
- --dpg g_AF_R test_npy/dpg_AF_R_mean_fa.npy float32 \
- --dpg g_AF_L test_npy/dpg_AF_L_volume.npy float32
-"""
-
-import argparse
-import os
-
-from trx.workflows import generate_trx_from_scratch
-
-
-def _build_arg_parser():
- p = argparse.ArgumentParser(description=__doc__,
- formatter_class=argparse.RawTextHelpFormatter)
- p.add_argument('reference',
- help='Reference anatomy for tck/vtk/fib/dpy file\n'
- 'support (.nii or .nii.gz).')
- p.add_argument('out_tractogram', metavar='OUT_TRACTOGRAM',
- help='Output filename. Format must be one of\n'
- 'trk, tck, vtk, fib, dpy, trx.')
-
- p1 = p.add_argument_group(title='Positions options')
- p1.add_argument('--positions', metavar='POSITIONS',
- help='Binary file containing the streamlines coordinates.'
- '\nMust be Nx3 (.npy)')
- p1.add_argument('--offsets', metavar='OFFSETS',
- help='Binary file containing the streamlines offsets (.npy)')
- p1.add_argument('--positions_csv', metavar='POSITIONS',
- help='CSV file containing the streamlines coordinates.'
- '\nRows for each streamlines organized as x1,y1,z1,\n'
- 'x2,y2,z2,...,xN,yN,zN')
- p1.add_argument('--space', choices=['RASMM', 'VOXMM', 'VOX'],
- default='RASMM',
- help='Space in which the coordinates are declared.'
- '[%(default)s]\nNon-default option requires Dipy.')
- p1.add_argument('--origin', choices=['NIFTI', 'TRACKVIS'],
- default='NIFTI',
- help='Origin in which the coordinates are declared. '
- '[%(default)s]\nNon-default option requires Dipy.')
- p2 = p.add_argument_group(title='Data type options')
- p2.add_argument('--positions_dtype', default='float32',
- choices=['float16', 'float32', 'float64'],
- help='Specify the datatype for positions for trx. '
- '[%(default)s]')
- p2.add_argument('--offsets_dtype', default='uint64',
- choices=['uint32', 'uint64'],
- help='Specify the datatype for offsets for trx. '
- '[%(default)s]')
-
- p3 = p.add_argument_group(title='Streamlines metadata options')
- p3.add_argument('--dpv', metavar=('FILE', 'DTYPE'), nargs=2,
- action='append',
- help='Binary file containing data_per_vertex.\n Must have'
- 'NB_VERTICES as first dimension (.npy)')
- p3.add_argument('--dps', metavar=('FILE', 'DTYPE'), nargs=2,
- action='append',
- help='Binary file containing data_per_vertex.\n Must have'
- 'NB_STREAMLINES as first dimension (.npy)')
- p3.add_argument('--groups', metavar=('FILE', 'DTYPE'), nargs=2,
- action='append',
- help='Binary file containing a sparse group (indices).\n '
- 'Indices should be lower than NB_STREAMLINES (.npy)')
- p3.add_argument('--dpg', metavar=('GROUP', 'FILE', 'DTYPE'), nargs=3,
- action='append',
- help='Binary file containing data_per_group.\n Must have'
- '(1,) as first dimension (.npy)')
-
- p.add_argument('--verify_invalid', action='store_true',
- help='Verify that the positions are all valid.\n'
- 'None outside of the bounding box in VOX space.\n'
- 'Requires Dipy (due to use of SFT).')
- p.add_argument('-f', dest='overwrite', action='store_true',
- help='Force overwriting of the output files.')
-
- return p
-
-
-def main():
- parser = _build_arg_parser()
- args = parser.parse_args()
-
- if os.path.isfile(args.out_tractogram) and not args.overwrite:
- raise IOError('{} already exists, use -f to overwrite.'.format(
- args.out_tractogram))
-
- if not args.positions and not args.positions_csv:
- parser.error('At least one positions options must be used.')
- if args.positions_csv and args.positions:
- parser.error('Cannot use both positions options.')
- if args.positions and args.offsets is None:
- parser.error('--offsets must be provided if --positions is used.')
- if args.offsets and args.positions is None:
- parser.error('--positions must be provided if --offsets is used.')
-
- generate_trx_from_scratch(args.reference, args.out_tractogram,
- positions_csv=args.positions_csv,
- positions=args.positions, offsets=args.offsets,
- positions_dtype=args.positions_dtype,
- offsets_dtype=args.offsets_dtype,
- space_str=args.space, origin_str=args.origin,
- verify_invalid=args.verify_invalid,
- dpv=args.dpv, dps=args.dps,
- groups=args.groups, dpg=args.dpg)
-
-
-if __name__ == "__main__":
- main()
diff --git a/scripts/tff_manipulate_datatype.py b/scripts/tff_manipulate_datatype.py
deleted file mode 100644
index 4c9c074..0000000
--- a/scripts/tff_manipulate_datatype.py
+++ /dev/null
@@ -1,103 +0,0 @@
-#!/usr/bin/env python3
-# -*- coding: utf-8 -*-
-
-"""
-Manipulate a TRX file internal array to change their data type.
-
-Each instance of --dps, --dpv, --groups require 2 arguments (FILE, DTYPE).
---dpg requires 3 arguments (GROUP, FILE, DTYPE).
-The choice of DTYPE are:
- - (u)int8, (u)int16, (u)int32, (u)int64
- - float16, float32, float64
- - bool
-
-Example command:
-tff_manipulate_datatype.py input.trx output.trx \
- --position float16 --offsets uint64 \
- --dpv color_x uint8 --dpv color_y uint8 --dpv color_z uint8 \
- --dpv fa float16 --dps algo uint8 --dps clusters_QB uint16 \
- --dps commit_colors uint8 --dps commit_weights float16 \
- --group CC uint64 --dpg CC mean_fa float64
-"""
-
-import argparse
-import os
-
-import numpy as np
-from trx.workflows import manipulate_trx_datatype
-
-
-def _build_arg_parser():
- p = argparse.ArgumentParser(description=__doc__,
- formatter_class=argparse.RawTextHelpFormatter)
- p.add_argument('in_tractogram',
- help='Input TRX file.')
- p.add_argument('out_tractogram',
- help='Output filename. Format must be one of\n'
- 'trk, tck, vtk, fib, dpy, trx.')
-
- p2 = p.add_argument_group(title='Data type options')
- p2.add_argument('--positions_dtype',
- choices=['float16', 'float32', 'float64'],
- help='Specify the datatype for positions for trx. '
- '[%(choices)s]')
- p2.add_argument('--offsets_dtype',
- choices=['uint32', 'uint64'],
- help='Specify the datatype for offsets for trx. '
- '[%(choices)s]')
-
- p3 = p.add_argument_group(title='Streamlines metadata options')
- p3.add_argument('--dpv', metavar=('NAME', 'DTYPE'), nargs=2,
- action='append',
- help='Specify the datatype for a specific data_per_vertex.')
- p3.add_argument('--dps', metavar=('NAME', 'DTYPE'), nargs=2,
- action='append',
- help='Specify the datatype for a specific data_per_streamline.')
- p3.add_argument('--groups', metavar=('NAME', 'DTYPE'), nargs=2,
- action='append',
- help='Specify the datatype for a specific group.')
- p3.add_argument('--dpg', metavar=('GROUP', 'NAME', 'DTYPE'), nargs=3,
- action='append',
- help='Specify the datatype for a specific data_per_group.')
- p.add_argument('-f', dest='overwrite', action='store_true',
- help='Force overwriting of the output files.')
-
- return p
-
-
-def main():
- parser = _build_arg_parser()
- args = parser.parse_args()
-
- if os.path.isfile(args.out_tractogram) and not args.overwrite:
- raise IOError('{} already exists, use -f to overwrite.'.format(
- args.out_tractogram))
-
- dtype_dict = {}
- if args.positions_dtype:
- dtype_dict['positions'] = np.dtype(args.positions_dtype)
- if args.offsets_dtype:
- dtype_dict['offsets'] = np.dtype(args.offsets_dtype)
- if args.dpv:
- dtype_dict['dpv'] = {}
- for name, dtype in args.dpv:
- dtype_dict['dpv'][name] = np.dtype(dtype)
- if args.dps:
- dtype_dict['dps'] = {}
- for name, dtype in args.dps:
- dtype_dict['dps'][name] = np.dtype(dtype)
- if args.groups:
- dtype_dict['groups'] = {}
- for name, dtype in args.groups:
- dtype_dict['groups'][name] = np.dtype(dtype)
- if args.dpg:
- dtype_dict['dpg'] = {}
- for group, name, dtype in args.dpg:
- dtype_dict['dpg'][group] = {name: np.dtype(dtype)}
-
- manipulate_trx_datatype(
- args.in_tractogram, args.out_tractogram, dtype_dict)
-
-
-if __name__ == "__main__":
- main()
diff --git a/scripts/tff_simple_compare.py b/scripts/tff_simple_compare.py
deleted file mode 100755
index 4e3e146..0000000
--- a/scripts/tff_simple_compare.py
+++ /dev/null
@@ -1,36 +0,0 @@
-#!/usr/bin/env python
-
-""" Simple comparison of tractogram by subtracting the coordinates' data.
-Does not account for shuffling of streamlines. Simple A-B operations.
-
-Differences below 1e^3 are expected for affine with large rotation/scaling.
-Difference below 1e^6 are expected for isotropic data with small rotation.
-"""
-
-import argparse
-
-from trx.workflows import tractogram_simple_compare
-
-
-def _build_arg_parser():
- p = argparse.ArgumentParser(description=__doc__,
- formatter_class=argparse.RawTextHelpFormatter)
-
- p.add_argument('in_tractograms', nargs=2, metavar='IN_TRACTOGRAM',
- help='Tractogram filename. Format must be one of \n'
- 'trk, tck, vtk, fib, dpy, trx.')
- p.add_argument('--reference', metavar='REFERENCE',
- help='Reference anatomy for tck/vtk/fib/dpy file\n'
- 'support (.nii or .nii.gz).')
- return p
-
-
-def main():
- parser = _build_arg_parser()
- args = parser.parse_args()
-
- tractogram_simple_compare(args.in_tractograms, args.reference)
-
-
-if __name__ == "__main__":
- main()
diff --git a/scripts/tff_validate_trx.py b/scripts/tff_validate_trx.py
deleted file mode 100755
index 03c808b..0000000
--- a/scripts/tff_validate_trx.py
+++ /dev/null
@@ -1,64 +0,0 @@
-#!/usr/bin/env python3
-# -*- coding: utf-8 -*-
-
-"""
-Validate TRX file.
-
-Removes streamlines that are out of the volume bounding box. In voxel space,
-no negative coordinate and no above volume dimension coordinate are possible.
-Any streamline that do not respect these two conditions are removed.
-
-Also removes streamlines with single or no point.
-The --remove_identical_streamlines option will remove identical streamlines.
-'identical' is defined as having the same number of points and the same
-points coordinates (to a specified precision, using a hash table).
-"""
-
-import argparse
-import os
-
-from trx.workflows import validate_tractogram
-
-
-def _build_arg_parser():
- p = argparse.ArgumentParser(description=__doc__,
- formatter_class=argparse.RawTextHelpFormatter)
-
- p.add_argument('in_tractogram',
- help='Tractogram filename. Format must be one of \n'
- 'trk, tck, vtk, fib, dpy, trx.')
- p.add_argument('--out_tractogram',
- help='Filename of the tractogram after removing invalid '
- 'streamlines.')
- p.add_argument('--remove_identical_streamlines', action='store_true',
- help='Remove identical streamlines from the set.')
- p.add_argument('--precision', type=int, default=1,
- help='Number of decimals to keep when hashing the points '
- 'of streamlines [%(default)s].')
-
- p.add_argument('--reference',
- help='Reference anatomy for tck/vtk/fib/dpy file\n'
- 'support (.nii or .nii.gz).')
- p.add_argument('-f', dest='overwrite', action='store_true',
- help='Force overwriting of the output files.')
-
- return p
-
-
-def main():
- parser = _build_arg_parser()
- args = parser.parse_args()
-
- if args.out_tractogram and os.path.isfile(args.out_tractogram) \
- and not args.overwrite:
- raise IOError('{} already exists, use -f to overwrite.'.format(
- args.out_tractogram))
-
- validate_tractogram(args.in_tractogram, reference=args.reference,
- out_tractogram=args.out_tractogram,
- remove_identical_streamlines=args.remove_identical_streamlines,
- precision=args.precision)
-
-
-if __name__ == "__main__":
- main()
diff --git a/scripts/tff_verify_header_compatibility.py b/scripts/tff_verify_header_compatibility.py
deleted file mode 100644
index 629edda..0000000
--- a/scripts/tff_verify_header_compatibility.py
+++ /dev/null
@@ -1,33 +0,0 @@
-#!/usr/bin/env python3
-# -*- coding: utf-8 -*-
-
-"""
-Will compare all input files against the first one for the compatibility
-of their spatial attributes.
-
-Spatial attributes are: affine, dimensions, voxel sizes and voxel order.
-"""
-
-import argparse
-from trx.workflows import verify_header_compatibility
-
-
-def _build_arg_parser():
- p = argparse.ArgumentParser(description=__doc__,
- formatter_class=argparse.RawTextHelpFormatter)
-
- p.add_argument('in_files', nargs='+',
- help='List of file to compare (trk, trx and nii).')
-
- return p
-
-
-def main():
- parser = _build_arg_parser()
- args = parser.parse_args()
-
- verify_header_compatibility(args.in_files)
-
-
-if __name__ == "__main__":
- main()
diff --git a/scripts/tff_visualize_overlap.py b/scripts/tff_visualize_overlap.py
deleted file mode 100644
index 5e530f8..0000000
--- a/scripts/tff_visualize_overlap.py
+++ /dev/null
@@ -1,39 +0,0 @@
-#!/usr/bin/env python
-
-"""
-Display a tractogram and its density map (computed from Dipy) in rasmm,
-voxmm and vox space with its bounding box.
-"""
-
-import argparse
-
-from trx.workflows import tractogram_visualize_overlap
-
-
-def _build_arg_parser():
- p = argparse.ArgumentParser(description=__doc__,
- formatter_class=argparse.RawTextHelpFormatter)
-
- p.add_argument('in_tractogram', metavar='IN_TRACTOGRAM',
- help='Tractogram filename. Format must be one of \n'
- 'trk, tck, vtk, fib, dpy, trx.')
- p.add_argument('reference',
- help='Reference anatomy for tck/vtk/fib/dpy file\n'
- 'support (nii or nii.gz).')
- p.add_argument('--remove_invalid', action='store_true',
- help='Removes invalid streamlines to avoid the density_map'
- 'function to crash.')
-
- return p
-
-
-def main():
- parser = _build_arg_parser()
- args = parser.parse_args()
-
- tractogram_visualize_overlap(args.in_tractogram, args.reference,
- args.remove_invalid)
-
-
-if __name__ == "__main__":
- main()
diff --git a/setup.cfg b/setup.cfg
deleted file mode 100644
index 13177b3..0000000
--- a/setup.cfg
+++ /dev/null
@@ -1,49 +0,0 @@
-[metadata]
-name = trx-python
-url = https://github.com/tee-ar-ex/trx-python
-classifiers =
- Development Status :: 3 - Alpha
- Environment :: Console
- Intended Audience :: Science/Research
- License :: OSI Approved :: BSD License
- Operating System :: OS Independent
- Programming Language :: Python
- Topic :: Scientific/Engineering
-
-license = BSD License
-description = Experiments with new file format for tractography
-long_description = file: README.md
-long_description_content_type = text/markdown
-platforms = OS Independent
-
-packages = find:
-include_package_data = True
-
-
-[options]
-python_requires = >=3.8
-setup_requires =
- packaging >= 19.0
- cython >= 0.29
-install_requires =
- setuptools_scm
- deepdiff
- nibabel >= 5
- numpy >= 1.22
-
-[options.extras_require]
-doc = astroid==2.15.8
- sphinx
- pydata-sphinx-theme
- sphinx-autoapi
- numpydoc
-
-test =
- flake8
- psutil
- pytest >= 7
- pytest-console-scripts >= 0
-
-all =
- %(doc)s
- %(test)s
diff --git a/setup.py b/setup.py
deleted file mode 100644
index 72e2115..0000000
--- a/setup.py
+++ /dev/null
@@ -1,33 +0,0 @@
-import glob
-import os.path as op
-import string
-from setuptools_scm import get_version
-from setuptools import setup
-
-
-def local_version(version):
- """
- Patch in a version that can be uploaded to test PyPI
- """
- scm_version = get_version()
- if "dev" in scm_version:
- gh_in_int = []
- for char in version.node:
- if char.isdigit():
- gh_in_int.append(str(char))
- else:
- gh_in_int.append(str(string.ascii_letters.find(char)))
- return "".join(gh_in_int)
- else:
- return ""
-
-
-opts = dict(use_scm_version={
- "root": ".", "relative_to": __file__,
- "write_to": op.join("trx", "version.py"),
- "local_scheme": local_version},
- scripts=glob.glob("scripts/*.py"))
-
-
-if __name__ == '__main__':
- setup(**opts)
diff --git a/tools/update_switcher.py b/tools/update_switcher.py
new file mode 100644
index 0000000..4a66988
--- /dev/null
+++ b/tools/update_switcher.py
@@ -0,0 +1,132 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+"""Update switcher.json for documentation version switching.
+
+This script maintains the version switcher JSON file used by pydata-sphinx-theme
+to enable users to switch between different documentation versions.
+"""
+
+import argparse
+import json
+from pathlib import Path
+import sys
+
+BASE_URL = "https://tee-ar-ex.github.io/trx-python"
+
+
+def load_switcher(path):
+ """Load existing switcher.json or return empty list."""
+ try:
+ with open(path, "r") as f:
+ return json.load(f)
+ except (FileNotFoundError, json.JSONDecodeError):
+ return []
+
+
+def save_switcher(path, versions):
+ """Save switcher.json with proper formatting."""
+ with open(path, "w") as f:
+ json.dump(versions, f, indent=4)
+ f.write("\n")
+
+
+def ensure_dev_entry(versions):
+ """Ensure dev entry exists in versions list."""
+ dev_exists = any(v.get("version") == "dev" for v in versions)
+ if not dev_exists:
+ versions.insert(0, {"name": "dev", "version": "dev", "url": f"{BASE_URL}/dev/"})
+ return versions
+
+
+def ensure_stable_entry(versions):
+ """Ensure stable entry exists with preferred flag."""
+ stable_idx = next(
+ (i for i, v in enumerate(versions) if v.get("version") == "stable"), None
+ )
+ if stable_idx is not None:
+ versions[stable_idx]["preferred"] = True
+ else:
+ versions.append(
+ {
+ "name": "stable",
+ "version": "stable",
+ "url": f"{BASE_URL}/stable/",
+ "preferred": True,
+ }
+ )
+ return versions
+
+
+def add_version(versions, version):
+ """Add a new version entry to the versions list.
+
+ Parameters
+ ----------
+ versions : list
+ List of version entries.
+ version : str
+ Version string to add (e.g., "0.5.0").
+
+ Returns
+ -------
+ list
+ Updated versions list.
+ """
+ # Remove 'preferred' from all existing entries
+ for v in versions:
+ v.pop("preferred", None)
+
+ # Check if this version already exists
+ version_exists = any(v.get("version") == version for v in versions)
+
+ if not version_exists:
+ new_entry = {
+ "name": version,
+ "version": version,
+ "url": f"{BASE_URL}/{version}/",
+ }
+ # Find dev entry index to insert after it
+ dev_idx = next(
+ (i for i, v in enumerate(versions) if v.get("version") == "dev"), -1
+ )
+ if dev_idx >= 0:
+ versions.insert(dev_idx + 1, new_entry)
+ else:
+ versions.insert(0, new_entry)
+
+ return versions
+
+
+def main():
+ """Main entry point."""
+ parser = argparse.ArgumentParser(
+ description="Update switcher.json for documentation version switching"
+ )
+ parser.add_argument("switcher_path", type=Path, help="Path to switcher.json file")
+ parser.add_argument("--version", type=str, help="New version to add (e.g., 0.5.0)")
+
+ args = parser.parse_args()
+
+ # Load existing versions
+ versions = load_switcher(args.switcher_path)
+
+ # Add new version if specified
+ if args.version:
+ versions = add_version(versions, args.version)
+
+ # Ensure required entries exist
+ versions = ensure_dev_entry(versions)
+ versions = ensure_stable_entry(versions)
+
+ # Save updated switcher.json
+ save_switcher(args.switcher_path, versions)
+
+ # Print result for CI logs
+ print(f"Updated {args.switcher_path}:")
+ print(json.dumps(versions, indent=4))
+
+ return 0
+
+
+if __name__ == "__main__":
+ sys.exit(main())
diff --git a/trx/cli.py b/trx/cli.py
new file mode 100644
index 0000000..cae5ed0
--- /dev/null
+++ b/trx/cli.py
@@ -0,0 +1,815 @@
+# -*- coding: utf-8 -*-
+"""
+TRX Command Line Interface.
+
+This module provides a unified CLI for all TRX file format operations using Typer.
+"""
+
+from pathlib import Path
+from typing import List, Optional
+
+import numpy as np
+import typer
+from typing_extensions import Annotated
+
+from trx.io import load, save
+from trx.trx_file_memmap import TrxFile, concatenate
+from trx.workflows import (
+ convert_dsi_studio,
+ convert_tractogram,
+ generate_trx_from_scratch,
+ manipulate_trx_datatype,
+ tractogram_simple_compare,
+ tractogram_visualize_overlap,
+ validate_tractogram,
+ verify_header_compatibility,
+)
+
+app = typer.Typer(
+ name="tff",
+ help="TRX File Format Tools - CLI for brain tractography data manipulation.",
+ add_completion=False,
+ rich_markup_mode="rich",
+)
+
+
+def _check_overwrite(filepath: Path, overwrite: bool) -> None:
+ """Check if file exists and raise error if overwrite is not enabled.
+
+ Parameters
+ ----------
+ filepath : Path
+ Path to the output file.
+ overwrite : bool
+ If True, allow overwriting existing files.
+
+ Raises
+ ------
+ typer.Exit
+ If file exists and overwrite is False.
+ """
+ if filepath.is_file() and not overwrite:
+ typer.echo(
+ typer.style(
+ f"Error: {filepath} already exists. Use --force to overwrite.",
+ fg=typer.colors.RED,
+ ),
+ err=True,
+ )
+ raise typer.Exit(code=1)
+
+
+@app.command("concatenate")
+def concatenate_tractograms(
+ in_tractograms: Annotated[
+ List[Path],
+ typer.Argument(
+ help="Input tractogram files. Format: trk, tck, vtk, fib, dpy, trx.",
+ ),
+ ],
+ out_tractogram: Annotated[
+ Path,
+ typer.Argument(help="Output filename for the concatenated tractogram."),
+ ],
+ delete_dpv: Annotated[
+ bool,
+ typer.Option(
+ "--delete-dpv",
+ help="Delete data_per_vertex if not all inputs have the same metadata.",
+ ),
+ ] = False,
+ delete_dps: Annotated[
+ bool,
+ typer.Option(
+ "--delete-dps",
+ help="Delete data_per_streamline if not all inputs have the same metadata.",
+ ),
+ ] = False,
+ delete_groups: Annotated[
+ bool,
+ typer.Option(
+ "--delete-groups",
+ help="Delete groups if not all inputs have the same metadata.",
+ ),
+ ] = False,
+ reference: Annotated[
+ Optional[Path],
+ typer.Option(
+ "--reference",
+ "-r",
+ help="Reference anatomy for tck/vtk/fib/dpy files (.nii or .nii.gz).",
+ ),
+ ] = None,
+ force: Annotated[
+ bool,
+ typer.Option("--force", "-f", help="Force overwriting of output files."),
+ ] = False,
+) -> None:
+ """Concatenate multiple tractograms into one.
+
+ If the data_per_point or data_per_streamline is not the same for all
+ tractograms, the data must be deleted first using the appropriate flags.
+ """
+ _check_overwrite(out_tractogram, force)
+
+ ref = str(reference) if reference else None
+
+ trx_list = []
+ has_group = False
+ for filename in in_tractograms:
+ tractogram_obj = load(str(filename), ref)
+
+ if not isinstance(tractogram_obj, TrxFile):
+ tractogram_obj = TrxFile.from_sft(tractogram_obj)
+ elif len(tractogram_obj.groups):
+ has_group = True
+ trx_list.append(tractogram_obj)
+
+ trx = concatenate(
+ trx_list,
+ delete_dpv=delete_dpv,
+ delete_dps=delete_dps,
+ delete_groups=delete_groups or not has_group,
+ check_space_attributes=True,
+ preallocation=False,
+ )
+ save(trx, str(out_tractogram))
+
+ typer.echo(
+ typer.style(
+ f"Successfully concatenated {len(in_tractograms)} tractograms "
+ f"to {out_tractogram}",
+ fg=typer.colors.GREEN,
+ )
+ )
+
+
+@app.command("convert")
+def convert(
+ in_tractogram: Annotated[
+ Path,
+ typer.Argument(help="Input tractogram. Format: trk, tck, vtk, fib, dpy, trx."),
+ ],
+ out_tractogram: Annotated[
+ Path,
+ typer.Argument(help="Output tractogram. Format: trk, tck, vtk, fib, dpy, trx."),
+ ],
+ reference: Annotated[
+ Optional[Path],
+ typer.Option(
+ "--reference",
+ "-r",
+ help="Reference anatomy for tck/vtk/fib/dpy files (.nii or .nii.gz).",
+ ),
+ ] = None,
+ positions_dtype: Annotated[
+ str,
+ typer.Option(
+ "--positions-dtype",
+ help="Datatype for positions in TRX output.",
+ ),
+ ] = "float32",
+ offsets_dtype: Annotated[
+ str,
+ typer.Option(
+ "--offsets-dtype",
+ help="Datatype for offsets in TRX output.",
+ ),
+ ] = "uint64",
+ force: Annotated[
+ bool,
+ typer.Option("--force", "-f", help="Force overwriting of output files."),
+ ] = False,
+) -> None:
+ """Convert tractograms between formats.
+
+ Supports conversion of .tck, .trk, .fib, .vtk, .trx and .dpy files.
+ TCK files always need a reference NIFTI file for conversion.
+ """
+ _check_overwrite(out_tractogram, force)
+
+ ref = str(reference) if reference else None
+ convert_tractogram(
+ str(in_tractogram),
+ str(out_tractogram),
+ ref,
+ pos_dtype=positions_dtype,
+ offsets_dtype=offsets_dtype,
+ )
+
+ typer.echo(
+ typer.style(
+ f"Successfully converted {in_tractogram} to {out_tractogram}",
+ fg=typer.colors.GREEN,
+ )
+ )
+
+
+@app.command("convert-dsi")
+def convert_dsi(
+ in_dsi_tractogram: Annotated[
+ Path,
+ typer.Argument(help="Input tractogram from DSI Studio (.trk)."),
+ ],
+ in_dsi_fa: Annotated[
+ Path,
+ typer.Argument(help="Input FA from DSI Studio (.nii.gz)."),
+ ],
+ out_tractogram: Annotated[
+ Path,
+ typer.Argument(help="Output tractogram file."),
+ ],
+ remove_invalid: Annotated[
+ bool,
+ typer.Option(
+ "--remove-invalid",
+ help="Remove streamlines landing out of the bounding box.",
+ ),
+ ] = False,
+ keep_invalid: Annotated[
+ bool,
+ typer.Option(
+ "--keep-invalid",
+ help="Keep streamlines landing out of the bounding box.",
+ ),
+ ] = False,
+ force: Annotated[
+ bool,
+ typer.Option("--force", "-f", help="Force overwriting of output files."),
+ ] = False,
+) -> None:
+ """Fix DSI-Studio TRK files for compatibility.
+
+ This script fixes DSI-Studio TRK files (unknown space/convention) to make
+ them compatible with TrackVis, MI-Brain, and Dipy Horizon.
+
+ [bold yellow]WARNING:[/bold yellow] This script is experimental. DSI-Studio evolves
+ quickly and results may vary depending on the data and DSI-Studio version.
+ """
+ _check_overwrite(out_tractogram, force)
+
+ if remove_invalid and keep_invalid:
+ typer.echo(
+ typer.style(
+ "Error: Cannot use both --remove-invalid and --keep-invalid.",
+ fg=typer.colors.RED,
+ ),
+ err=True,
+ )
+ raise typer.Exit(code=1)
+
+ convert_dsi_studio(
+ str(in_dsi_tractogram),
+ str(in_dsi_fa),
+ str(out_tractogram),
+ remove_invalid=remove_invalid,
+ keep_invalid=keep_invalid,
+ )
+
+ typer.echo(
+ typer.style(
+ f"Successfully converted DSI-Studio tractogram to {out_tractogram}",
+ fg=typer.colors.GREEN,
+ )
+ )
+
+
+@app.command("generate")
+def generate(
+ reference: Annotated[
+ Path,
+ typer.Argument(help="Reference anatomy (.nii or .nii.gz)."),
+ ],
+ out_tractogram: Annotated[
+ Path,
+ typer.Argument(help="Output tractogram. Format: trk, tck, vtk, fib, dpy, trx."),
+ ],
+ positions: Annotated[
+ Optional[Path],
+ typer.Option(
+ "--positions",
+ help="Binary file with streamline coordinates (Nx3 .npy).",
+ ),
+ ] = None,
+ offsets: Annotated[
+ Optional[Path],
+ typer.Option(
+ "--offsets",
+ help="Binary file with streamline offsets (.npy).",
+ ),
+ ] = None,
+ positions_csv: Annotated[
+ Optional[Path],
+ typer.Option(
+ "--positions-csv",
+ help="CSV file with streamline coordinates (x1,y1,z1,x2,y2,z2,...).",
+ ),
+ ] = None,
+ space: Annotated[
+ str,
+ typer.Option(
+ "--space",
+ help="Coordinate space. Non-default requires Dipy.",
+ ),
+ ] = "RASMM",
+ origin: Annotated[
+ str,
+ typer.Option(
+ "--origin",
+ help="Coordinate origin. Non-default requires Dipy.",
+ ),
+ ] = "NIFTI",
+ positions_dtype: Annotated[
+ str,
+ typer.Option("--positions-dtype", help="Datatype for positions."),
+ ] = "float32",
+ offsets_dtype: Annotated[
+ str,
+ typer.Option("--offsets-dtype", help="Datatype for offsets."),
+ ] = "uint64",
+ dpv: Annotated[
+ Optional[List[str]],
+ typer.Option(
+ "--dpv",
+ help="Data per vertex: FILE,DTYPE (e.g., color.npy,uint8).",
+ ),
+ ] = None,
+ dps: Annotated[
+ Optional[List[str]],
+ typer.Option(
+ "--dps",
+ help="Data per streamline: FILE,DTYPE (e.g., algo.npy,uint8).",
+ ),
+ ] = None,
+ groups: Annotated[
+ Optional[List[str]],
+ typer.Option(
+ "--groups",
+ help="Groups: FILE,DTYPE (e.g., AF_L.npy,int32).",
+ ),
+ ] = None,
+ dpg: Annotated[
+ Optional[List[str]],
+ typer.Option(
+ "--dpg",
+ help="Data per group: GROUP,FILE,DTYPE (e.g., AF_L,mean_fa.npy,float32).",
+ ),
+ ] = None,
+ verify_invalid: Annotated[
+ bool,
+ typer.Option(
+ "--verify-invalid",
+ help="Verify positions are valid (within bounding box). Requires Dipy.",
+ ),
+ ] = False,
+ force: Annotated[
+ bool,
+ typer.Option("--force", "-f", help="Force overwriting of output files."),
+ ] = False,
+) -> None:
+ """Generate TRX file from raw data files.
+
+ Create a TRX file from CSV, TXT, or NPY files by specifying positions,
+ offsets, data_per_vertex, data_per_streamlines, groups, and data_per_group.
+
+ Each --dpv, --dps, --groups option requires FILE,DTYPE format.
+ Each --dpg option requires GROUP,FILE,DTYPE format.
+
+ Valid DTYPEs: (u)int8, (u)int16, (u)int32, (u)int64, float16, float32, float64, bool
+ """
+ _check_overwrite(out_tractogram, force)
+
+ # Validate input combinations
+ if not positions and not positions_csv:
+ typer.echo(
+ typer.style(
+ "Error: At least one positions option must be provided "
+ "(--positions or --positions-csv).",
+ fg=typer.colors.RED,
+ ),
+ err=True,
+ )
+ raise typer.Exit(code=1)
+ if positions_csv and positions:
+ typer.echo(
+ typer.style(
+ "Error: Cannot use both --positions and --positions-csv.",
+ fg=typer.colors.RED,
+ ),
+ err=True,
+ )
+ raise typer.Exit(code=1)
+ if positions and offsets is None:
+ typer.echo(
+ typer.style(
+ "Error: --offsets must be provided if --positions is used.",
+ fg=typer.colors.RED,
+ ),
+ err=True,
+ )
+ raise typer.Exit(code=1)
+ if offsets and positions is None:
+ typer.echo(
+ typer.style(
+ "Error: --positions must be provided if --offsets is used.",
+ fg=typer.colors.RED,
+ ),
+ err=True,
+ )
+ raise typer.Exit(code=1)
+
+ # Parse comma-separated arguments to tuples
+ dpv_list = None
+ if dpv:
+ dpv_list = [tuple(item.split(",")) for item in dpv]
+
+ dps_list = None
+ if dps:
+ dps_list = [tuple(item.split(",")) for item in dps]
+
+ groups_list = None
+ if groups:
+ groups_list = [tuple(item.split(",")) for item in groups]
+
+ dpg_list = None
+ if dpg:
+ dpg_list = [tuple(item.split(",")) for item in dpg]
+
+ generate_trx_from_scratch(
+ str(reference),
+ str(out_tractogram),
+ positions_csv=str(positions_csv) if positions_csv else None,
+ positions=str(positions) if positions else None,
+ offsets=str(offsets) if offsets else None,
+ positions_dtype=positions_dtype,
+ offsets_dtype=offsets_dtype,
+ space_str=space,
+ origin_str=origin,
+ verify_invalid=verify_invalid,
+ dpv=dpv_list,
+ dps=dps_list,
+ groups=groups_list,
+ dpg=dpg_list,
+ )
+
+ typer.echo(
+ typer.style(
+ f"Successfully generated {out_tractogram}",
+ fg=typer.colors.GREEN,
+ )
+ )
+
+
+@app.command("manipulate-dtype")
+def manipulate_dtype(
+ in_tractogram: Annotated[
+ Path,
+ typer.Argument(help="Input TRX file."),
+ ],
+ out_tractogram: Annotated[
+ Path,
+ typer.Argument(help="Output tractogram file."),
+ ],
+ positions_dtype: Annotated[
+ Optional[str],
+ typer.Option(
+ "--positions-dtype",
+ help="Datatype for positions (float16, float32, float64).",
+ ),
+ ] = None,
+ offsets_dtype: Annotated[
+ Optional[str],
+ typer.Option(
+ "--offsets-dtype",
+ help="Datatype for offsets (uint32, uint64).",
+ ),
+ ] = None,
+ dpv: Annotated[
+ Optional[List[str]],
+ typer.Option(
+ "--dpv",
+ help="Data per vertex dtype: NAME,DTYPE (e.g., color_x,uint8).",
+ ),
+ ] = None,
+ dps: Annotated[
+ Optional[List[str]],
+ typer.Option(
+ "--dps",
+ help="Data per streamline dtype: NAME,DTYPE (e.g., algo,uint8).",
+ ),
+ ] = None,
+ groups: Annotated[
+ Optional[List[str]],
+ typer.Option(
+ "--groups",
+ help="Groups dtype: NAME,DTYPE (e.g., CC,uint64).",
+ ),
+ ] = None,
+ dpg: Annotated[
+ Optional[List[str]],
+ typer.Option(
+ "--dpg",
+ help="Data per group dtype: GROUP,NAME,DTYPE (e.g., CC,mean_fa,float64).",
+ ),
+ ] = None,
+ force: Annotated[
+ bool,
+ typer.Option("--force", "-f", help="Force overwriting of output files."),
+ ] = False,
+) -> None: # noqa: C901
+ """Manipulate TRX file internal array data types.
+
+ Change the data types of positions, offsets, data_per_vertex,
+ data_per_streamline, groups, and data_per_group arrays.
+
+ Valid DTYPEs: (u)int8, (u)int16, (u)int32, (u)int64, float16, float32, float64, bool
+ """
+ _check_overwrite(out_tractogram, force)
+
+ dtype_dict = {}
+ if positions_dtype:
+ dtype_dict["positions"] = np.dtype(positions_dtype)
+ if offsets_dtype:
+ dtype_dict["offsets"] = np.dtype(offsets_dtype)
+ if dpv:
+ dtype_dict["dpv"] = {}
+ for item in dpv:
+ name, dtype = item.split(",")
+ dtype_dict["dpv"][name] = np.dtype(dtype)
+ if dps:
+ dtype_dict["dps"] = {}
+ for item in dps:
+ name, dtype = item.split(",")
+ dtype_dict["dps"][name] = np.dtype(dtype)
+ if groups:
+ dtype_dict["groups"] = {}
+ for item in groups:
+ name, dtype = item.split(",")
+ dtype_dict["groups"][name] = np.dtype(dtype)
+ if dpg:
+ dtype_dict["dpg"] = {}
+ for item in dpg:
+ parts = item.split(",")
+ group, name, dtype = parts[0], parts[1], parts[2]
+ if group not in dtype_dict["dpg"]:
+ dtype_dict["dpg"][group] = {}
+ dtype_dict["dpg"][group][name] = np.dtype(dtype)
+
+ manipulate_trx_datatype(str(in_tractogram), str(out_tractogram), dtype_dict)
+
+ typer.echo(
+ typer.style(
+ f"Successfully manipulated datatypes and saved to {out_tractogram}",
+ fg=typer.colors.GREEN,
+ )
+ )
+
+
+@app.command("compare")
+def compare(
+ in_tractogram1: Annotated[
+ Path,
+ typer.Argument(help="First tractogram file."),
+ ],
+ in_tractogram2: Annotated[
+ Path,
+ typer.Argument(help="Second tractogram file."),
+ ],
+ reference: Annotated[
+ Optional[Path],
+ typer.Option(
+ "--reference",
+ "-r",
+ help="Reference anatomy for tck/vtk/fib/dpy files (.nii or .nii.gz).",
+ ),
+ ] = None,
+) -> None:
+ """Simple comparison of tractograms by subtracting coordinates.
+
+ Does not account for shuffling of streamlines. Simple A-B operations.
+
+ Differences below 1e-3 are expected for affines with large rotation/scaling.
+ Differences below 1e-6 are expected for isotropic data with small rotation.
+ """
+ ref = str(reference) if reference else None
+ tractogram_simple_compare([str(in_tractogram1), str(in_tractogram2)], ref)
+
+
+@app.command("validate")
+def validate(
+ in_tractogram: Annotated[
+ Path,
+ typer.Argument(help="Input tractogram. Format: trk, tck, vtk, fib, dpy, trx."),
+ ],
+ out_tractogram: Annotated[
+ Optional[Path],
+ typer.Option(
+ "--out",
+ "-o",
+ help="Output tractogram after removing invalid streamlines.",
+ ),
+ ] = None,
+ remove_identical: Annotated[
+ bool,
+ typer.Option(
+ "--remove-identical",
+ help="Remove identical streamlines from the set.",
+ ),
+ ] = False,
+ precision: Annotated[
+ int,
+ typer.Option(
+ "--precision",
+ "-p",
+ help="Number of decimals when hashing streamline points.",
+ ),
+ ] = 1,
+ reference: Annotated[
+ Optional[Path],
+ typer.Option(
+ "--reference",
+ "-r",
+ help="Reference anatomy for tck/vtk/fib/dpy files (.nii or .nii.gz).",
+ ),
+ ] = None,
+ force: Annotated[
+ bool,
+ typer.Option("--force", "-f", help="Force overwriting of output files."),
+ ] = False,
+) -> None:
+ """Validate TRX file and remove invalid streamlines.
+
+ Removes streamlines that are out of the volume bounding box (in voxel space,
+ no negative coordinates or coordinates above volume dimensions).
+
+ Also removes streamlines with single or no points.
+ Use --remove-identical to remove duplicate streamlines based on precision.
+ """
+ if out_tractogram:
+ _check_overwrite(out_tractogram, force)
+
+ ref = str(reference) if reference else None
+ out = str(out_tractogram) if out_tractogram else None
+
+ validate_tractogram(
+ str(in_tractogram),
+ reference=ref,
+ out_tractogram=out,
+ remove_identical_streamlines=remove_identical,
+ precision=precision,
+ )
+
+ if out_tractogram:
+ typer.echo(
+ typer.style(
+ f"Validation complete. Output saved to {out_tractogram}",
+ fg=typer.colors.GREEN,
+ )
+ )
+ else:
+ typer.echo(
+ typer.style(
+ "Validation complete.",
+ fg=typer.colors.GREEN,
+ )
+ )
+
+
+@app.command("verify-header")
+def verify_header(
+ in_files: Annotated[
+ List[Path],
+ typer.Argument(help="Files to compare (trk, trx, and nii)."),
+ ],
+) -> None:
+ """Compare spatial attributes of input files.
+
+ Compares all input files against the first one for compatibility of
+ spatial attributes: affine, dimensions, voxel sizes, and voxel order.
+ """
+ verify_header_compatibility([str(f) for f in in_files])
+
+
+@app.command("visualize")
+def visualize(
+ in_tractogram: Annotated[
+ Path,
+ typer.Argument(help="Input tractogram. Format: trk, tck, vtk, fib, dpy, trx."),
+ ],
+ reference: Annotated[
+ Path,
+ typer.Argument(help="Reference anatomy (.nii or .nii.gz)."),
+ ],
+ remove_invalid: Annotated[
+ bool,
+ typer.Option(
+ "--remove-invalid",
+ help="Remove invalid streamlines to avoid density_map crash.",
+ ),
+ ] = False,
+) -> None:
+ """Display tractogram and density map with bounding box.
+
+ Shows the tractogram and its density map (computed from Dipy) in
+ rasmm, voxmm, and vox space with its bounding box.
+ """
+ tractogram_visualize_overlap(
+ str(in_tractogram),
+ str(reference),
+ remove_invalid,
+ )
+
+
+def main():
+ """Entry point for the TRX CLI."""
+ app()
+
+
+# Standalone entry points for backward compatibility
+# These create individual Typer apps for each command
+
+
+def _create_standalone_app(command_func, name: str, help_text: str):
+ """Create a standalone Typer app for a single command.
+
+ Parameters
+ ----------
+ command_func : callable
+ The command function to wrap.
+ name : str
+ Name of the command.
+ help_text : str
+ Help text for the command.
+
+ Returns
+ -------
+ callable
+ Entry point function.
+ """
+ standalone = typer.Typer(
+ name=name,
+ help=help_text,
+ add_completion=False,
+ rich_markup_mode="rich",
+ )
+ standalone.command()(command_func)
+ return lambda: standalone()
+
+
+concatenate_tractograms_cmd = _create_standalone_app(
+ concatenate_tractograms,
+ "tff_concatenate_tractograms",
+ "Concatenate multiple tractograms into one.",
+)
+
+convert_dsi_cmd = _create_standalone_app(
+ convert_dsi,
+ "tff_convert_dsi_studio",
+ "Fix DSI-Studio TRK files for compatibility.",
+)
+
+convert_cmd = _create_standalone_app(
+ convert,
+ "tff_convert_tractogram",
+ "Convert tractograms between formats.",
+)
+
+generate_cmd = _create_standalone_app(
+ generate,
+ "tff_generate_trx_from_scratch",
+ "Generate TRX file from raw data files.",
+)
+
+manipulate_dtype_cmd = _create_standalone_app(
+ manipulate_dtype,
+ "tff_manipulate_datatype",
+ "Manipulate TRX file internal array data types.",
+)
+
+compare_cmd = _create_standalone_app(
+ compare,
+ "tff_simple_compare",
+ "Simple comparison of tractograms by subtracting coordinates.",
+)
+
+validate_cmd = _create_standalone_app(
+ validate,
+ "tff_validate_trx",
+ "Validate TRX file and remove invalid streamlines.",
+)
+
+verify_header_cmd = _create_standalone_app(
+ verify_header,
+ "tff_verify_header_compatibility",
+ "Compare spatial attributes of input files.",
+)
+
+visualize_cmd = _create_standalone_app(
+ visualize,
+ "tff_visualize_overlap",
+ "Display tractogram and density map with bounding box.",
+)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/trx/fetcher.py b/trx/fetcher.py
index 469fb8a..4aa3d1b 100644
--- a/trx/fetcher.py
+++ b/trx/fetcher.py
@@ -6,45 +6,78 @@
import shutil
import urllib.request
+TEST_DATA_REPO = "tee-ar-ex/trx-test-data"
+TEST_DATA_TAG = "v0.1.0"
+# GitHub release API entrypoint for metadata (asset list, sizes, etc.).
+TEST_DATA_API_URL = (
+ f"https://api.github.com/repos/{TEST_DATA_REPO}/releases/tags/{TEST_DATA_TAG}"
+)
+# Direct download base for release assets.
+TEST_DATA_BASE_URL = (
+ f"https://github.com/{TEST_DATA_REPO}/releases/download/{TEST_DATA_TAG}"
+)
+
def get_home():
- """ Set a user-writeable file-system location to put files """
- if 'TRX_HOME' in os.environ:
- trx_home = os.environ['TRX_HOME']
+ """Set a user-writeable file-system location to put files"""
+ if "TRX_HOME" in os.environ:
+ trx_home = os.environ["TRX_HOME"]
else:
- trx_home = os.path.join(os.path.expanduser('~'), '.tee_ar_ex')
+ trx_home = os.path.join(os.path.expanduser("~"), ".tee_ar_ex")
return trx_home
def get_testing_files_dict():
- """ Get dictionary linking zip file to their Figshare URL & MD5SUM """
+ """Get dictionary linking zip file to their GitHub release URL & checksums.
+
+ Assets are hosted under the v0.1.0 release of tee-ar-ex/trx-test-data.
+ If URLs change, check TEST_DATA_API_URL to discover the latest asset
+ locations.
+ """
return {
- 'DSI.zip':
- ('https://figshare.com/ndownloader/files/37624154',
- 'b847f053fc694d55d935c0be0e5268f7'), # V1 (27.09.2022)
- 'memmap_test_data.zip':
- ('https://figshare.com/ndownloader/files/37624148',
- '03f7651a0f9e3eeabee9aed0ad5f69e1'), # V2 (27.09.2022)
- 'trx_from_scratch.zip':
- ('https://figshare.com/ndownloader/files/37624151',
- 'd9f220a095ce7f027772fcd9451a2ee5'), # V2 (27.09.2022)
- 'gold_standard.zip':
- ('https://figshare.com/ndownloader/files/38146098',
- '57e3f9951fe77245684ede8688af3ae8') # V1 (8.11.2022)
+ "DSI.zip": (
+ f"{TEST_DATA_BASE_URL}/DSI.zip",
+ "b847f053fc694d55d935c0be0e5268f7", # md5
+ "1b09ce8b4b47b2600336c558fdba7051218296e8440e737364f2c4b8ebae666c",
+ ),
+ "memmap_test_data.zip": (
+ f"{TEST_DATA_BASE_URL}/memmap_test_data.zip",
+ "03f7651a0f9e3eeabee9aed0ad5f69e1", # md5
+ "98ba89d7a9a7baa2d37956a0a591dce9bb4581bd01296ad5a596706ee90a52ef",
+ ),
+ "trx_from_scratch.zip": (
+ f"{TEST_DATA_BASE_URL}/trx_from_scratch.zip",
+ "d9f220a095ce7f027772fcd9451a2ee5", # md5
+ "f98ab6da6a6065527fde4b0b6aa40f07583e925d952182e9bbd0febd55c0f6b2",
+ ),
+ "gold_standard.zip": (
+ f"{TEST_DATA_BASE_URL}/gold_standard.zip",
+ "57e3f9951fe77245684ede8688af3ae8", # md5
+ "35a0b633560cc2b0d8ecda885aa72d06385499e0cd1ca11a956b0904c3358f01",
+ ),
}
def md5sum(filename):
- """ Compute one md5 checksum for a file """
+ """Compute one md5 checksum for a file"""
h = hashlib.md5()
- with open(filename, 'rb') as f:
- for chunk in iter(lambda: f.read(128 * h.block_size), b''):
+ with open(filename, "rb") as f:
+ for chunk in iter(lambda: f.read(128 * h.block_size), b""):
+ h.update(chunk)
+ return h.hexdigest()
+
+
+def sha256sum(filename):
+ """Compute one sha256 checksum for a file"""
+ h = hashlib.sha256()
+ with open(filename, "rb") as f:
+ for chunk in iter(lambda: f.read(128 * h.block_size), b""):
h.update(chunk)
return h.hexdigest()
-def fetch_data(files_dict, keys=None):
- """ Downloads files to folder and checks their md5 checksums
+def fetch_data(files_dict, keys=None): # noqa: C901
+ """Downloads files to folder and checks their md5 checksums
Parameters
----------
@@ -71,23 +104,33 @@ def fetch_data(files_dict, keys=None):
keys = [keys]
for f in keys:
- url, expected_md5 = files_dict[f]
+ file_entry = files_dict[f]
+ if len(file_entry) == 2:
+ url, expected_md5 = file_entry
+ expected_sha = None
+ else:
+ url, expected_md5, expected_sha = file_entry
full_path = os.path.join(trx_home, f)
- logging.info('Downloading {} to {}'.format(f, trx_home))
+ logging.info("Downloading {} to {}".format(f, trx_home))
if not os.path.exists(full_path):
urllib.request.urlretrieve(url, full_path)
actual_md5 = md5sum(full_path)
if expected_md5 != actual_md5:
raise ValueError(
- f'Md5sum for {f} does not match. '
- 'Please remove the file to download it again: ' +
- full_path
+ f"Md5sum for {f} does not match. "
+ "Please remove the file to download it again: " + full_path
)
- if f.endswith('.zip'):
+ if expected_sha is not None:
+ actual_sha = sha256sum(full_path)
+ if expected_sha != actual_sha:
+ raise ValueError(
+ f"SHA256 for {f} does not match. "
+ "Please remove the file to download it again: " + full_path
+ )
+
+ if f.endswith(".zip"):
dst_dir = os.path.join(trx_home, f[:-4])
- shutil.unpack_archive(full_path,
- extract_dir=dst_dir,
- format='zip')
+ shutil.unpack_archive(full_path, extract_dir=dst_dir, format="zip")
diff --git a/trx/io.py b/trx/io.py
index e292778..86c8156 100644
--- a/trx/io.py
+++ b/trx/io.py
@@ -1,12 +1,13 @@
# -*- coding: utf-8 -*-
-import os
import logging
-import tempfile
+import os
import sys
+import tempfile
try:
- import dipy
+ import dipy # noqa: F401
+
dipy_available = True
except ImportError:
dipy_available = False
@@ -15,58 +16,61 @@
def get_trx_tmp_dir():
- if os.getenv('TRX_TMPDIR') is not None:
- if os.getenv('TRX_TMPDIR') == 'use_working_dir':
+ if os.getenv("TRX_TMPDIR") is not None:
+ if os.getenv("TRX_TMPDIR") == "use_working_dir":
trx_tmp_dir = os.getcwd()
else:
- trx_tmp_dir = os.getenv('TRX_TMPDIR')
+ trx_tmp_dir = os.getenv("TRX_TMPDIR")
else:
trx_tmp_dir = tempfile.gettempdir()
if sys.version_info[1] >= 10:
- return tempfile.TemporaryDirectory(dir=trx_tmp_dir, prefix='trx_',
- ignore_cleanup_errors=True)
+ return tempfile.TemporaryDirectory(
+ dir=trx_tmp_dir, prefix="trx_", ignore_cleanup_errors=True
+ )
else:
- return tempfile.TemporaryDirectory(dir=trx_tmp_dir, prefix='trx_')
+ return tempfile.TemporaryDirectory(dir=trx_tmp_dir, prefix="trx_")
-def load_sft_with_reference(filepath, reference=None,
- bbox_check=True):
+def load_sft_with_reference(filepath, reference=None, bbox_check=True):
if not dipy_available:
- logging.error('Dipy library is missing, cannot use functions related '
- 'to the StatefulTractogram.')
+ logging.error(
+ "Dipy library is missing, cannot use functions related "
+ "to the StatefulTractogram."
+ )
return None
from dipy.io.streamline import load_tractogram
# Force the usage of --reference for all file formats without an header
_, ext = os.path.splitext(filepath)
- if ext == '.trk':
- if reference is not None and reference != 'same':
- logging.warning('Reference is discarded for this file format '
- '{}.'.format(filepath))
- sft = load_tractogram(filepath, 'same',
- bbox_valid_check=bbox_check)
- elif ext in ['.tck', '.fib', '.vtk', '.dpy']:
- if reference is None or reference == 'same':
- raise IOError('--reference is required for this file format '
- '{}.'.format(filepath))
+ if ext == ".trk":
+ if reference is not None and reference != "same":
+ logging.warning(
+ "Reference is discarded for this file format {}.".format(filepath)
+ )
+ sft = load_tractogram(filepath, "same", bbox_valid_check=bbox_check)
+ elif ext in [".tck", ".fib", ".vtk", ".dpy"]:
+ if reference is None or reference == "same":
+ raise IOError(
+ "--reference is required for this file format {}.".format(filepath)
+ )
else:
- sft = load_tractogram(filepath, reference,
- bbox_valid_check=bbox_check)
+ sft = load_tractogram(filepath, reference, bbox_valid_check=bbox_check)
else:
- raise IOError('{} is an unsupported file format'.format(filepath))
+ raise IOError("{} is an unsupported file format".format(filepath))
return sft
def load(tractogram_filename, reference):
import trx.trx_file_memmap as tmm
+
in_ext = split_name_with_gz(tractogram_filename)[1]
- if in_ext != '.trx' and not os.path.isdir(tractogram_filename):
- tractogram_obj = load_sft_with_reference(tractogram_filename,
- reference,
- bbox_check=False)
+ if in_ext != ".trx" and not os.path.isdir(tractogram_filename):
+ tractogram_obj = load_sft_with_reference(
+ tractogram_filename, reference, bbox_check=False
+ )
else:
tractogram_obj = tmm.load(tractogram_filename)
@@ -75,20 +79,24 @@ def load(tractogram_filename, reference):
def save(tractogram_obj, tractogram_filename, bbox_valid_check=False):
if not dipy_available:
- logging.error('Dipy library is missing, cannot use functions related '
- 'to the StatefulTractogram.')
+ logging.error(
+ "Dipy library is missing, cannot use functions related "
+ "to the StatefulTractogram."
+ )
return None
from dipy.io.stateful_tractogram import StatefulTractogram
from dipy.io.streamline import save_tractogram
+
import trx.trx_file_memmap as tmm
out_ext = split_name_with_gz(tractogram_filename)[1]
- if out_ext != '.trx':
+ if out_ext != ".trx":
if not isinstance(tractogram_obj, StatefulTractogram):
tractogram_obj = tractogram_obj.to_sft()
- save_tractogram(tractogram_obj, tractogram_filename,
- bbox_valid_check=bbox_valid_check)
+ save_tractogram(
+ tractogram_obj, tractogram_filename, bbox_valid_check=bbox_valid_check
+ )
else:
if not isinstance(tractogram_obj, tmm.TrxFile):
tractogram_obj = tmm.TrxFile.from_sft(tractogram_obj)
diff --git a/trx/streamlines_ops.py b/trx/streamlines_ops.py
index d8aa12a..007aaba 100644
--- a/trx/streamlines_ops.py
+++ b/trx/streamlines_ops.py
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
-import itertools
from functools import reduce
+import itertools
import numpy as np
@@ -89,7 +89,7 @@ def hash_streamlines(streamlines, start_index=0, precision=None):
def perform_streamlines_operation(operation, streamlines, precision=0):
- """Peforms an operation on a list of list of streamlines
+ """Performs an operation on a list of list of streamlines
Given a list of list of streamlines, this function applies the operation
to the first two lists of streamlines. The result in then used recursively
@@ -123,8 +123,7 @@ def perform_streamlines_operation(operation, streamlines, precision=0):
# Hash the streamlines using the desired precision.
indices = np.cumsum([0] + [len(s) for s in streamlines[:-1]])
- hashes = [hash_streamlines(s, i, precision) for
- s, i in zip(streamlines, indices)]
+ hashes = [hash_streamlines(s, i, precision) for s, i in zip(streamlines, indices)]
# Perform the operation on the hashes and get the output streamlines.
to_keep = reduce(operation, hashes)
diff --git a/trx/tests/test_cli.py b/trx/tests/test_cli.py
new file mode 100644
index 0000000..5860a26
--- /dev/null
+++ b/trx/tests/test_cli.py
@@ -0,0 +1,420 @@
+# -*- coding: utf-8 -*-
+"""Tests for CLI commands and workflow functions."""
+
+import os
+import tempfile
+
+from deepdiff import DeepDiff
+import numpy as np
+from numpy.testing import assert_allclose, assert_array_equal, assert_equal
+import pytest
+
+try:
+ from dipy.io.streamline import load_tractogram
+
+ dipy_available = True
+except ImportError:
+ dipy_available = False
+
+from trx.fetcher import fetch_data, get_home, get_testing_files_dict
+import trx.trx_file_memmap as tmm
+from trx.workflows import (
+ convert_dsi_studio,
+ convert_tractogram,
+ generate_trx_from_scratch,
+ manipulate_trx_datatype,
+ validate_tractogram,
+)
+
+# If they already exist, this only takes 5 seconds (check md5sum)
+fetch_data(get_testing_files_dict(), keys=["DSI.zip", "trx_from_scratch.zip"])
+
+
+# Tests for standalone CLI commands (tff_* commands)
+class TestStandaloneCommands:
+ """Tests for standalone CLI commands."""
+
+ def test_help_option_convert_dsi(self, script_runner):
+ ret = script_runner.run(["tff_convert_dsi_studio", "--help"])
+ assert ret.success
+
+ def test_help_option_convert(self, script_runner):
+ ret = script_runner.run(["tff_convert_tractogram", "--help"])
+ assert ret.success
+
+ def test_help_option_generate_trx_from_scratch(self, script_runner):
+ ret = script_runner.run(["tff_generate_trx_from_scratch", "--help"])
+ assert ret.success
+
+ def test_help_option_concatenate(self, script_runner):
+ ret = script_runner.run(["tff_concatenate_tractograms", "--help"])
+ assert ret.success
+
+ def test_help_option_manipulate(self, script_runner):
+ ret = script_runner.run(["tff_manipulate_datatype", "--help"])
+ assert ret.success
+
+ def test_help_option_compare(self, script_runner):
+ ret = script_runner.run(["tff_simple_compare", "--help"])
+ assert ret.success
+
+ def test_help_option_validate(self, script_runner):
+ ret = script_runner.run(["tff_validate_trx", "--help"])
+ assert ret.success
+
+ def test_help_option_verify_header(self, script_runner):
+ ret = script_runner.run(["tff_verify_header_compatibility", "--help"])
+ assert ret.success
+
+ def test_help_option_visualize(self, script_runner):
+ ret = script_runner.run(["tff_visualize_overlap", "--help"])
+ assert ret.success
+
+
+# Tests for unified tff CLI
+class TestUnifiedCLI:
+ """Tests for the unified tff CLI."""
+
+ def test_tff_help(self, script_runner):
+ ret = script_runner.run(["tff", "--help"])
+ assert ret.success
+
+ def test_tff_concatenate_help(self, script_runner):
+ ret = script_runner.run(["tff", "concatenate", "--help"])
+ assert ret.success
+
+ def test_tff_convert_help(self, script_runner):
+ ret = script_runner.run(["tff", "convert", "--help"])
+ assert ret.success
+
+ def test_tff_convert_dsi_help(self, script_runner):
+ ret = script_runner.run(["tff", "convert-dsi", "--help"])
+ assert ret.success
+
+ def test_tff_generate_help(self, script_runner):
+ ret = script_runner.run(["tff", "generate", "--help"])
+ assert ret.success
+
+ def test_tff_manipulate_dtype_help(self, script_runner):
+ ret = script_runner.run(["tff", "manipulate-dtype", "--help"])
+ assert ret.success
+
+ def test_tff_compare_help(self, script_runner):
+ ret = script_runner.run(["tff", "compare", "--help"])
+ assert ret.success
+
+ def test_tff_validate_help(self, script_runner):
+ ret = script_runner.run(["tff", "validate", "--help"])
+ assert ret.success
+
+ def test_tff_verify_header_help(self, script_runner):
+ ret = script_runner.run(["tff", "verify-header", "--help"])
+ assert ret.success
+
+ def test_tff_visualize_help(self, script_runner):
+ ret = script_runner.run(["tff", "visualize", "--help"])
+ assert ret.success
+
+
+# Tests for workflow functions
+class TestWorkflowFunctions:
+ """Tests for workflow functions."""
+
+ @pytest.mark.skipif(not dipy_available, reason="Dipy is not installed.")
+ def test_execution_convert_dsi(self):
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ in_trk = os.path.join(get_home(), "DSI", "CC.trk.gz")
+ in_nii = os.path.join(get_home(), "DSI", "CC.nii.gz")
+ exp_data = os.path.join(get_home(), "DSI", "CC_fix_data.npy")
+ exp_offsets = os.path.join(get_home(), "DSI", "CC_fix_offsets.npy")
+ out_fix_path = os.path.join(tmp_dir, "fixed.trk")
+ convert_dsi_studio(
+ in_trk, in_nii, out_fix_path, remove_invalid=False, keep_invalid=True
+ )
+
+ data_fix = np.load(exp_data)
+ offsets_fix = np.load(exp_offsets)
+
+ sft = load_tractogram(out_fix_path, "same")
+ assert_equal(sft.streamlines._data, data_fix)
+ assert_equal(sft.streamlines._offsets, offsets_fix)
+
+ @pytest.mark.skipif(not dipy_available, reason="Dipy is not installed.")
+ def test_execution_convert_to_trx(self):
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ in_trk = os.path.join(get_home(), "DSI", "CC_fix.trk")
+ exp_data = os.path.join(get_home(), "DSI", "CC_fix_data.npy")
+ exp_offsets = os.path.join(get_home(), "DSI", "CC_fix_offsets.npy")
+ out_trx_path = os.path.join(tmp_dir, "CC_fix.trx")
+ convert_tractogram(in_trk, out_trx_path, None)
+
+ data_fix = np.load(exp_data)
+ offsets_fix = np.load(exp_offsets)
+
+ trx = tmm.load(out_trx_path)
+ assert_equal(trx.streamlines._data.dtype, np.float32)
+ assert_equal(trx.streamlines._offsets.dtype, np.uint32)
+ assert_array_equal(trx.streamlines._data, data_fix)
+ assert_array_equal(trx.streamlines._offsets, offsets_fix)
+ trx.close()
+
+ @pytest.mark.skipif(not dipy_available, reason="Dipy is not installed.")
+ def test_execution_convert_from_trx(self):
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ in_trk = os.path.join(get_home(), "DSI", "CC_fix.trk")
+ in_nii = os.path.join(get_home(), "DSI", "CC.nii.gz")
+ exp_data = os.path.join(get_home(), "DSI", "CC_fix_data.npy")
+ exp_offsets = os.path.join(get_home(), "DSI", "CC_fix_offsets.npy")
+
+ # Sequential conversions
+ out_trx_path = os.path.join(tmp_dir, "CC_fix.trx")
+ out_trk_path = os.path.join(tmp_dir, "CC_fix.trk")
+ out_tck_path = os.path.join(tmp_dir, "CC_fix.tck")
+ convert_tractogram(in_trk, out_trx_path, None)
+ convert_tractogram(out_trx_path, out_tck_path, None)
+ convert_tractogram(out_trx_path, out_trk_path, None)
+
+ data_fix = np.load(exp_data)
+ offsets_fix = np.load(exp_offsets)
+
+ sft = load_tractogram(out_trk_path, "same")
+ assert_equal(sft.streamlines._data, data_fix)
+ assert_equal(sft.streamlines._offsets, offsets_fix)
+
+ sft = load_tractogram(out_tck_path, in_nii)
+ assert_equal(sft.streamlines._data, data_fix)
+ assert_equal(sft.streamlines._offsets, offsets_fix)
+
+ @pytest.mark.skipif(not dipy_available, reason="Dipy is not installed.")
+ def test_execution_convert_dtype_p16_o64(self):
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ in_trk = os.path.join(get_home(), "DSI", "CC_fix.trk")
+ out_convert_path = os.path.join(tmp_dir, "CC_fix_p16_o64.trx")
+ convert_tractogram(
+ in_trk,
+ out_convert_path,
+ None,
+ pos_dtype="float16",
+ offsets_dtype="uint64",
+ )
+
+ trx = tmm.load(out_convert_path)
+ assert_equal(trx.streamlines._data.dtype, np.float16)
+ assert_equal(trx.streamlines._offsets.dtype, np.uint64)
+ trx.close()
+
+ @pytest.mark.skipif(not dipy_available, reason="Dipy is not installed.")
+ def test_execution_convert_dtype_p64_o32(self):
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ in_trk = os.path.join(get_home(), "DSI", "CC_fix.trk")
+ out_convert_path = os.path.join(tmp_dir, "CC_fix_p16_o64.trx")
+ convert_tractogram(
+ in_trk,
+ out_convert_path,
+ None,
+ pos_dtype="float64",
+ offsets_dtype="uint32",
+ )
+
+ trx = tmm.load(out_convert_path)
+ assert_equal(trx.streamlines._data.dtype, np.float64)
+ assert_equal(trx.streamlines._offsets.dtype, np.uint32)
+ trx.close()
+
+ def test_execution_generate_trx_from_scratch(self):
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ reference_fa = os.path.join(get_home(), "trx_from_scratch", "fa.nii.gz")
+ raw_arr_dir = os.path.join(get_home(), "trx_from_scratch", "test_npy")
+ expected_trx = os.path.join(get_home(), "trx_from_scratch", "expected.trx")
+
+ dpv = [
+ (os.path.join(raw_arr_dir, "dpv_cx.npy"), "uint8"),
+ (os.path.join(raw_arr_dir, "dpv_cy.npy"), "uint8"),
+ (os.path.join(raw_arr_dir, "dpv_cz.npy"), "uint8"),
+ ]
+ dps = [
+ (os.path.join(raw_arr_dir, "dps_algo.npy"), "uint8"),
+ (os.path.join(raw_arr_dir, "dps_cw.npy"), "float64"),
+ ]
+ dpg = [
+ (
+ "g_AF_L",
+ os.path.join(raw_arr_dir, "dpg_AF_L_mean_fa.npy"),
+ "float32",
+ ),
+ (
+ "g_AF_R",
+ os.path.join(raw_arr_dir, "dpg_AF_R_mean_fa.npy"),
+ "float32",
+ ),
+ ("g_AF_L", os.path.join(raw_arr_dir, "dpg_AF_L_volume.npy"), "float32"),
+ ]
+ groups = [
+ (os.path.join(raw_arr_dir, "g_AF_L.npy"), "int32"),
+ (os.path.join(raw_arr_dir, "g_AF_R.npy"), "int32"),
+ (os.path.join(raw_arr_dir, "g_CST_L.npy"), "int32"),
+ ]
+
+ out_gen_path = os.path.join(tmp_dir, "generated.trx")
+ generate_trx_from_scratch(
+ reference_fa,
+ out_gen_path,
+ positions=os.path.join(raw_arr_dir, "positions.npy"),
+ offsets=os.path.join(raw_arr_dir, "offsets.npy"),
+ positions_dtype="float16",
+ offsets_dtype="uint64",
+ space_str="rasmm",
+ origin_str="nifti",
+ verify_invalid=False,
+ dpv=dpv,
+ dps=dps,
+ groups=groups,
+ dpg=dpg,
+ )
+ exp_trx = tmm.load(expected_trx)
+ gen_trx = tmm.load(out_gen_path)
+
+ assert DeepDiff(exp_trx.get_dtype_dict(), gen_trx.get_dtype_dict()) == {}
+
+ assert_allclose(
+ exp_trx.streamlines._data, gen_trx.streamlines._data, atol=0.1, rtol=0.1
+ )
+ assert_equal(exp_trx.streamlines._offsets, gen_trx.streamlines._offsets)
+
+ for key in exp_trx.data_per_vertex.keys():
+ assert_equal(
+ exp_trx.data_per_vertex[key]._data,
+ gen_trx.data_per_vertex[key]._data,
+ )
+ assert_equal(
+ exp_trx.data_per_vertex[key]._offsets,
+ gen_trx.data_per_vertex[key]._offsets,
+ )
+ for key in exp_trx.data_per_streamline.keys():
+ assert_equal(
+ exp_trx.data_per_streamline[key], gen_trx.data_per_streamline[key]
+ )
+ for key in exp_trx.groups.keys():
+ assert_equal(exp_trx.groups[key], gen_trx.groups[key])
+
+ for group in exp_trx.groups.keys():
+ if group in exp_trx.data_per_group:
+ for key in exp_trx.data_per_group[group].keys():
+ assert_equal(
+ exp_trx.data_per_group[group][key],
+ gen_trx.data_per_group[group][key],
+ )
+ exp_trx.close()
+ gen_trx.close()
+
+ @pytest.mark.skipif(not dipy_available, reason="Dipy is not installed.")
+ def test_execution_concatenate_validate_trx(self):
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ trx1 = tmm.load(os.path.join(get_home(), "gold_standard", "gs.trx"))
+ trx2 = tmm.load(os.path.join(get_home(), "gold_standard", "gs.trx"))
+ trx = tmm.concatenate([trx1, trx2], preallocation=False)
+
+ # Right size
+ assert_equal(len(trx.streamlines), 2 * len(trx1.streamlines))
+
+ # Right data
+ end_idx = trx1.header["NB_VERTICES"]
+ assert_allclose(trx.streamlines._data[:end_idx], trx1.streamlines._data)
+ assert_allclose(trx.streamlines._data[end_idx:], trx2.streamlines._data)
+
+ # Right data_per_*
+ for key in trx.data_per_vertex.keys():
+ assert_equal(
+ trx.data_per_vertex[key]._data[:end_idx],
+ trx1.data_per_vertex[key]._data,
+ )
+ assert_equal(
+ trx.data_per_vertex[key]._data[end_idx:],
+ trx2.data_per_vertex[key]._data,
+ )
+
+ end_idx = trx1.header["NB_STREAMLINES"]
+ for key in trx.data_per_streamline.keys():
+ assert_equal(
+ trx.data_per_streamline[key][:end_idx],
+ trx1.data_per_streamline[key],
+ )
+ assert_equal(
+ trx.data_per_streamline[key][end_idx:],
+ trx2.data_per_streamline[key],
+ )
+
+ # Validate
+ out_concat_path = os.path.join(tmp_dir, "concat.trx")
+ out_valid_path = os.path.join(tmp_dir, "valid.trx")
+ tmm.save(trx, out_concat_path)
+ validate_tractogram(
+ out_concat_path,
+ None,
+ out_valid_path,
+ remove_identical_streamlines=True,
+ precision=0,
+ )
+ trx_val = tmm.load(out_valid_path)
+
+ # Right dtype and size
+ assert DeepDiff(trx.get_dtype_dict(), trx_val.get_dtype_dict()) == {}
+ assert_equal(len(trx1.streamlines), len(trx_val.streamlines))
+
+ trx.close()
+ trx1.close()
+ trx2.close()
+ trx_val.close()
+
+ @pytest.mark.skipif(not dipy_available, reason="Dipy is not installed.")
+ def test_execution_manipulate_trx_datatype(self):
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ expected_trx = os.path.join(get_home(), "trx_from_scratch", "expected.trx")
+ trx = tmm.load(expected_trx)
+
+ expected_dtype = {
+ "positions": np.dtype("float16"),
+ "offsets": np.dtype("uint64"),
+ "dpv": {
+ "dpv_cx": np.dtype("uint8"),
+ "dpv_cy": np.dtype("uint8"),
+ "dpv_cz": np.dtype("uint8"),
+ },
+ "dps": {"dps_algo": np.dtype("uint8"), "dps_cw": np.dtype("float64")},
+ "dpg": {
+ "g_AF_L": {
+ "dpg_AF_L_mean_fa": np.dtype("float32"),
+ "dpg_AF_L_volume": np.dtype("float32"),
+ },
+ "g_AF_R": {"dpg_AF_R_mean_fa": np.dtype("float32")},
+ },
+ "groups": {"g_AF_L": np.dtype("int32"), "g_AF_R": np.dtype("int32")},
+ }
+
+ assert DeepDiff(trx.get_dtype_dict(), expected_dtype) == {}
+ trx.close()
+
+ generated_dtype = {
+ "positions": np.dtype("float32"),
+ "offsets": np.dtype("uint32"),
+ "dpv": {
+ "dpv_cx": np.dtype("uint16"),
+ "dpv_cy": np.dtype("uint16"),
+ "dpv_cz": np.dtype("uint16"),
+ },
+ "dps": {"dps_algo": np.dtype("uint8"), "dps_cw": np.dtype("float32")},
+ "dpg": {
+ "g_AF_L": {
+ "dpg_AF_L_mean_fa": np.dtype("float64"),
+ "dpg_AF_L_volume": np.dtype("float32"),
+ },
+ "g_AF_R": {"dpg_AF_R_mean_fa": np.dtype("float64")},
+ },
+ "groups": {"g_AF_L": np.dtype("uint16"), "g_AF_R": np.dtype("uint16")},
+ }
+
+ out_gen_path = os.path.join(tmp_dir, "generated.trx")
+ manipulate_trx_datatype(expected_trx, out_gen_path, generated_dtype)
+ trx = tmm.load(out_gen_path)
+ assert DeepDiff(trx.get_dtype_dict(), generated_dtype) == {}
+ trx.close()
diff --git a/trx/tests/test_io.py b/trx/tests/test_io.py
index 33d02bf..c5e823d 100644
--- a/trx/tests/test_io.py
+++ b/trx/tests/test_io.py
@@ -2,69 +2,64 @@
from copy import deepcopy
import os
-import psutil
from tempfile import TemporaryDirectory
import zipfile
-import pytest
import numpy as np
from numpy.testing import assert_allclose
+import psutil
+import pytest
try:
- from dipy.io.streamline import save_tractogram, load_tractogram
+ from dipy.io.streamline import load_tractogram, save_tractogram
+
dipy_available = True
except ImportError:
dipy_available = False
+from trx.fetcher import fetch_data, get_home, get_testing_files_dict
+from trx.io import load, save
import trx.trx_file_memmap as tmm
from trx.trx_file_memmap import TrxFile
-from trx.io import load, save, get_trx_tmp_dir
-from trx.fetcher import (get_testing_files_dict,
- fetch_data, get_home)
-
-fetch_data(get_testing_files_dict(), keys=['gold_standard.zip'])
+fetch_data(get_testing_files_dict(), keys=["gold_standard.zip"])
-@pytest.mark.parametrize("path", [("gs.trk"), ("gs.tck"),
- ("gs.vtk")])
-@pytest.mark.skipif(not dipy_available, reason='Dipy is not installed.')
+@pytest.mark.parametrize("path", [("gs.trk"), ("gs.tck"), ("gs.vtk")])
+@pytest.mark.skipif(not dipy_available, reason="Dipy is not installed.")
def test_seq_ops_sft(path):
with TemporaryDirectory() as tmp_dir:
- gs_dir = os.path.join(get_home(), 'gold_standard')
+ gs_dir = os.path.join(get_home(), "gold_standard")
path = os.path.join(tmp_dir, path)
- obj = load(os.path.join(gs_dir, 'gs.trx'),
- os.path.join(gs_dir, 'gs.nii'))
+ obj = load(os.path.join(gs_dir, "gs.trx"), os.path.join(gs_dir, "gs.nii"))
sft_1 = obj.to_sft()
save_tractogram(sft_1, path)
obj.close()
- save_tractogram(sft_1, os.path.join(tmp_dir, 'tmp.trk'))
+ save_tractogram(sft_1, os.path.join(tmp_dir, "tmp.trk"))
- sft_2 = load_tractogram(os.path.join(tmp_dir, 'tmp.trk'), 'same')
+ _ = load_tractogram(os.path.join(tmp_dir, "tmp.trk"), "same")
def test_seq_ops_trx():
with TemporaryDirectory() as tmp_dir:
- gs_dir = os.path.join(get_home(), 'gold_standard')
- path = os.path.join(gs_dir, 'gs.trx')
+ gs_dir = os.path.join(get_home(), "gold_standard")
+ path = os.path.join(gs_dir, "gs.trx")
trx_1 = tmm.load(path)
- tmm.save(trx_1, os.path.join(tmp_dir, 'tmp.trx'))
+ tmm.save(trx_1, os.path.join(tmp_dir, "tmp.trx"))
trx_1.close()
- trx_2 = tmm.load(os.path.join(tmp_dir, 'tmp.trx'))
+ trx_2 = tmm.load(os.path.join(tmp_dir, "tmp.trx"))
trx_2.close()
-@pytest.mark.parametrize("path", [("gs.trx"), ("gs.trk"), ("gs.tck"),
- ("gs.vtk")])
-@pytest.mark.skipif(not dipy_available, reason='Dipy is not installed.')
+@pytest.mark.parametrize("path", [("gs.trx"), ("gs.trk"), ("gs.tck"), ("gs.vtk")])
+@pytest.mark.skipif(not dipy_available, reason="Dipy is not installed.")
def test_load_vox(path):
- gs_dir = os.path.join(get_home(), 'gold_standard')
+ gs_dir = os.path.join(get_home(), "gold_standard")
path = os.path.join(gs_dir, path)
- coord = np.loadtxt(os.path.join(get_home(), 'gold_standard',
- 'gs_vox_space.txt'))
- obj = load(path, os.path.join(gs_dir, 'gs.nii'))
+ coord = np.loadtxt(os.path.join(get_home(), "gold_standard", "gs_vox_space.txt"))
+ obj = load(path, os.path.join(gs_dir, "gs.nii"))
sft = obj.to_sft() if isinstance(obj, TrxFile) else obj
sft.to_vox()
@@ -74,15 +69,13 @@ def test_load_vox(path):
obj.close()
-@pytest.mark.parametrize("path", [("gs.trx"), ("gs.trk"), ("gs.tck"),
- ("gs.vtk")])
-@pytest.mark.skipif(not dipy_available, reason='Dipy is not installed.')
+@pytest.mark.parametrize("path", [("gs.trx"), ("gs.trk"), ("gs.tck"), ("gs.vtk")])
+@pytest.mark.skipif(not dipy_available, reason="Dipy is not installed.")
def test_load_voxmm(path):
- gs_dir = os.path.join(get_home(), 'gold_standard')
+ gs_dir = os.path.join(get_home(), "gold_standard")
path = os.path.join(gs_dir, path)
- coord = np.loadtxt(os.path.join(get_home(), 'gold_standard',
- 'gs_voxmm_space.txt'))
- obj = load(path, os.path.join(gs_dir, 'gs.nii'))
+ coord = np.loadtxt(os.path.join(get_home(), "gold_standard", "gs_voxmm_space.txt"))
+ obj = load(path, os.path.join(gs_dir, "gs.nii"))
sft = obj.to_sft() if isinstance(obj, TrxFile) else obj
sft.to_voxmm()
@@ -93,25 +86,25 @@ def test_load_voxmm(path):
@pytest.mark.parametrize("path", [("gs.trk"), ("gs.trx"), ("gs_fldr.trx")])
-@pytest.mark.skipif(not dipy_available, reason='Dipy is not installed.')
+@pytest.mark.skipif(not dipy_available, reason="Dipy is not installed.")
def test_multi_load_save_rasmm(path):
with TemporaryDirectory() as tmp_gs_dir:
- gs_dir = os.path.join(get_home(), 'gold_standard')
+ gs_dir = os.path.join(get_home(), "gold_standard")
basename, ext = os.path.splitext(path)
path = os.path.join(gs_dir, path)
- coord = np.loadtxt(os.path.join(get_home(), 'gold_standard',
- 'gs_rasmm_space.txt'))
+ coord = np.loadtxt(
+ os.path.join(get_home(), "gold_standard", "gs_rasmm_space.txt")
+ )
- obj = load(path, os.path.join(gs_dir, 'gs.nii'))
+ obj = load(path, os.path.join(gs_dir, "gs.nii"))
for i in range(3):
- out_path = os.path.join(
- tmp_gs_dir, '{}_tmp{}_{}'.format(basename, i, ext))
+ out_path = os.path.join(tmp_gs_dir, "{}_tmp{}_{}".format(basename, i, ext))
save(obj, out_path)
if isinstance(obj, TrxFile):
obj.close()
- obj = load(out_path, os.path.join(gs_dir, 'gs.nii'))
+ obj = load(out_path, os.path.join(gs_dir, "gs.nii"))
assert_allclose(obj.streamlines._data, coord, rtol=1e-04, atol=1e-06)
if isinstance(obj, TrxFile):
@@ -119,9 +112,9 @@ def test_multi_load_save_rasmm(path):
@pytest.mark.parametrize("path", [("gs.trx"), ("gs_fldr.trx")])
-@pytest.mark.skipif(not dipy_available, reason='Dipy is not installed.')
+@pytest.mark.skipif(not dipy_available, reason="Dipy is not installed.")
def test_delete_tmp_gs_dir(path):
- gs_dir = os.path.join(get_home(), 'gold_standard')
+ gs_dir = os.path.join(get_home(), "gold_standard")
path = os.path.join(gs_dir, path)
trx1 = tmm.load(path)
@@ -131,10 +124,12 @@ def test_delete_tmp_gs_dir(path):
sft = trx1.to_sft()
trx1.close()
- coord_rasmm = np.loadtxt(os.path.join(get_home(), 'gold_standard',
- 'gs_rasmm_space.txt'))
- coord_vox = np.loadtxt(os.path.join(get_home(), 'gold_standard',
- 'gs_vox_space.txt'))
+ coord_rasmm = np.loadtxt(
+ os.path.join(get_home(), "gold_standard", "gs_rasmm_space.txt")
+ )
+ coord_vox = np.loadtxt(
+ os.path.join(get_home(), "gold_standard", "gs_vox_space.txt")
+ )
# The folder trx representation does not need tmp files
if os.path.isfile(path):
@@ -144,33 +139,38 @@ def test_delete_tmp_gs_dir(path):
# Reloading the TRX and checking its data, then closing
trx2 = tmm.load(path)
- assert_allclose(trx2.streamlines._data,
- sft.streamlines._data, rtol=1e-04, atol=1e-06)
+ assert_allclose(
+ trx2.streamlines._data, sft.streamlines._data, rtol=1e-04, atol=1e-06
+ )
trx2.close()
sft.to_vox()
assert_allclose(sft.streamlines._data, coord_vox, rtol=1e-04, atol=1e-06)
trx3 = tmm.load(path)
- assert_allclose(trx3.streamlines._data,
- coord_rasmm, rtol=1e-04, atol=1e-06)
+ assert_allclose(trx3.streamlines._data, coord_rasmm, rtol=1e-04, atol=1e-06)
trx3.close()
@pytest.mark.parametrize("path", [("gs.trx")])
-@pytest.mark.skipif(not dipy_available, reason='Dipy is not installed.')
+@pytest.mark.skipif(not dipy_available, reason="Dipy is not installed.")
def test_close_tmp_files(path):
- gs_dir = os.path.join(get_home(), 'gold_standard')
+ gs_dir = os.path.join(get_home(), "gold_standard")
path = os.path.join(gs_dir, path)
trx = tmm.load(path)
process = psutil.Process(os.getpid())
open_files = process.open_files()
- expected_content = ['offsets.uint32', 'positions.3.float32',
- 'header.json', 'random_coord.3.float32',
- 'color_y.float32', 'color_x.float32',
- 'color_z.float32']
+ expected_content = [
+ "offsets.uint32",
+ "positions.3.float32",
+ "header.json",
+ "random_coord.3.float32",
+ "color_y.float32",
+ "color_x.float32",
+ "color_z.float32",
+ ]
count = 0
for open_file in open_files:
@@ -192,18 +192,18 @@ def test_close_tmp_files(path):
@pytest.mark.parametrize("tmp_path", [("~"), ("use_working_dir")])
def test_change_tmp_dir(tmp_path):
- gs_dir = os.path.join(get_home(), 'gold_standard')
- path = os.path.join(gs_dir, 'gs.trx')
+ gs_dir = os.path.join(get_home(), "gold_standard")
+ path = os.path.join(gs_dir, "gs.trx")
- if tmp_path == 'use_working_dir':
- os.environ['TRX_TMPDIR'] = 'use_working_dir'
+ if tmp_path == "use_working_dir":
+ os.environ["TRX_TMPDIR"] = "use_working_dir"
else:
- os.environ['TRX_TMPDIR'] = os.path.expanduser(tmp_path)
+ os.environ["TRX_TMPDIR"] = os.path.expanduser(tmp_path)
trx = tmm.load(path)
tmp_gs_dir = deepcopy(trx._uncompressed_folder_handle.name)
- if tmp_path == 'use_working_dir':
+ if tmp_path == "use_working_dir":
assert os.path.dirname(tmp_gs_dir) == os.getcwd()
else:
assert os.path.dirname(tmp_gs_dir) == os.path.expanduser(tmp_path)
@@ -214,7 +214,7 @@ def test_change_tmp_dir(tmp_path):
@pytest.mark.parametrize("path", [("gs.trx"), ("gs_fldr.trx")])
def test_complete_dir_from_trx(path):
- gs_dir = os.path.join(get_home(), 'gold_standard')
+ gs_dir = os.path.join(get_home(), "gold_standard")
path = os.path.join(gs_dir, path)
trx = tmm.load(path)
@@ -227,25 +227,35 @@ def test_complete_dir_from_trx(path):
for dirpath, _, filenames in os.walk(dir_to_check):
for filename in filenames:
full_path = os.path.join(dirpath, filename)
- cut_path = full_path.split(dir_to_check)[1][1:].replace('\\', '/')
+ cut_path = full_path.split(dir_to_check)[1][1:].replace("\\", "/")
file_paths.append(cut_path)
- expected_content = ['offsets.uint32', 'positions.3.float32',
- 'header.json', 'dps/random_coord.3.float32',
- 'dpv/color_y.float32', 'dpv/color_x.float32',
- 'dpv/color_z.float32']
+ expected_content = [
+ "offsets.uint32",
+ "positions.3.float32",
+ "header.json",
+ "dps/random_coord.3.float32",
+ "dpv/color_y.float32",
+ "dpv/color_x.float32",
+ "dpv/color_z.float32",
+ ]
assert set(file_paths) == set(expected_content)
def test_complete_zip_from_trx():
- gs_dir = os.path.join(get_home(), 'gold_standard')
- path = os.path.join(gs_dir, 'gs.trx')
+ gs_dir = os.path.join(get_home(), "gold_standard")
+ path = os.path.join(gs_dir, "gs.trx")
with zipfile.ZipFile(path, mode="r") as zf:
zip_file_list = zf.namelist()
- expected_content = ['offsets.uint32', 'positions.3.float32',
- 'header.json', 'dps/random_coord.3.float32',
- 'dpv/color_y.float32', 'dpv/color_x.float32',
- 'dpv/color_z.float32']
+ expected_content = [
+ "offsets.uint32",
+ "positions.3.float32",
+ "header.json",
+ "dps/random_coord.3.float32",
+ "dpv/color_y.float32",
+ "dpv/color_x.float32",
+ "dpv/color_z.float32",
+ ]
assert set(zip_file_list) == set(expected_content)
diff --git a/trx/tests/test_memmap.py b/trx/tests/test_memmap.py
index a671c08..529600c 100644
--- a/trx/tests/test_memmap.py
+++ b/trx/tests/test_memmap.py
@@ -2,24 +2,23 @@
import os
-from nibabel.streamlines.tests.test_tractogram import make_dummy_streamline
from nibabel.streamlines import LazyTractogram
+from nibabel.streamlines.tests.test_tractogram import make_dummy_streamline
import numpy as np
import pytest
try:
- import dipy
+ import dipy # noqa: F401
+
dipy_available = True
except ImportError:
dipy_available = False
+from trx.fetcher import fetch_data, get_home, get_testing_files_dict
from trx.io import get_trx_tmp_dir
import trx.trx_file_memmap as tmm
-from trx.fetcher import (get_testing_files_dict,
- fetch_data, get_home)
-
-fetch_data(get_testing_files_dict(), keys=['memmap_test_data.zip'])
+fetch_data(get_testing_files_dict(), keys=["memmap_test_data.zip"])
tmp_dir = get_trx_tmp_dir()
@@ -36,11 +35,9 @@
def test__generate_filename_from_data(
arr, expected, value_error, filename="mean_fa.bit"
):
-
if value_error:
with pytest.raises(ValueError):
- new_fn = tmm._generate_filename_from_data(arr=arr,
- filename=filename)
+ new_fn = tmm._generate_filename_from_data(arr=arr, filename=filename)
assert new_fn is None
else:
new_fn = tmm._generate_filename_from_data(arr=arr, filename=filename)
@@ -55,8 +52,7 @@ def test__generate_filename_from_data(
("mean_fa", None, True),
("mean_fa.5.4.int32", None, True),
pytest.param(
- "mean_fa.fa", None, True, marks=pytest.mark.xfail,
- id="invalid extension"
+ "mean_fa.fa", None, True, marks=pytest.mark.xfail, id="invalid extension"
),
],
)
@@ -72,16 +68,18 @@ def test__split_ext_with_dimensionality(filename, expected, value_error):
"offsets,nb_vertices,expected",
[
(np.array(range(5), dtype=np.int16), 4, np.array([1, 1, 1, 1, 0])),
- (np.array([0, 1, 1, 3, 4], dtype=np.int32),
- 4, np.array([1, 0, 2, 1, 0])),
+ (np.array([0, 1, 1, 3, 4], dtype=np.int32), 4, np.array([1, 0, 2, 1, 0])),
(np.array(range(4), dtype=np.uint64), 4, np.array([1, 1, 1, 1])),
- pytest.param(np.array([0, 1, 0, 3, 4], dtype=np.int16), 4,
- np.array([1, 3, 0, 1, 0]), marks=pytest.mark.xfail,
- id="offsets not sorted"),
+ pytest.param(
+ np.array([0, 1, 0, 3, 4], dtype=np.int16),
+ 4,
+ np.array([1, 3, 0, 1, 0]),
+ marks=pytest.mark.xfail,
+ id="offsets not sorted",
+ ),
],
)
def test__compute_lengths(offsets, nb_vertices, expected):
-
offsets = tmm._append_last_offsets(offsets, nb_vertices)
lengths = tmm._compute_lengths(offsets=offsets)
assert np.array_equal(lengths, expected)
@@ -120,8 +118,7 @@ def test__dichotomic_search(arr, l_bound, r_bound, expected):
@pytest.mark.parametrize(
"basename, create, expected",
[
- ("offsets.int16", True, np.array(range(12), dtype=np.int16).reshape((
- 3, 4))),
+ ("offsets.int16", True, np.array(range(12), dtype=np.int16).reshape((3, 4))),
("offsets.float32", False, None),
],
)
@@ -132,18 +129,15 @@ def test__create_memmap(basename, create, expected):
filename = os.path.join(dirname, basename)
fp = np.memmap(filename, dtype=np.int16, mode="w+", shape=(3, 4))
fp[:] = expected[:]
- mmarr = tmm._create_memmap(filename=filename, shape=(3, 4),
- dtype=np.int16)
+ mmarr = tmm._create_memmap(filename=filename, shape=(3, 4), dtype=np.int16)
assert np.array_equal(mmarr, expected)
else:
with get_trx_tmp_dir() as dirname:
filename = os.path.join(dirname, basename)
- mmarr = tmm._create_memmap(filename=filename, shape=(0,),
- dtype=np.int16)
+ mmarr = tmm._create_memmap(filename=filename, shape=(0,), dtype=np.int16)
assert os.path.isfile(filename)
- assert np.array_equal(mmarr, np.zeros(
- shape=(0,), dtype=np.float32))
+ assert np.array_equal(mmarr, np.zeros(shape=(0,), dtype=np.float32))
# need dpg test with missing keys
@@ -157,7 +151,7 @@ def test__create_memmap(basename, create, expected):
],
)
def test_load(path, check_dpg, value_error):
- path = os.path.join(get_home(), 'memmap_test_data', path)
+ path = os.path.join(get_home(), "memmap_test_data", path)
# Need to perhaps improve test
if value_error:
with pytest.raises(ValueError):
@@ -165,25 +159,24 @@ def test_load(path, check_dpg, value_error):
tmm.load(input_obj=path, check_dpg=check_dpg), tmm.TrxFile
)
else:
- assert isinstance(tmm.load(input_obj=path, check_dpg=check_dpg),
- tmm.TrxFile)
+ assert isinstance(tmm.load(input_obj=path, check_dpg=check_dpg), tmm.TrxFile)
@pytest.mark.parametrize("path", [("small.trx")])
def test_load_zip(path):
- path = os.path.join(get_home(), 'memmap_test_data', path)
+ path = os.path.join(get_home(), "memmap_test_data", path)
assert isinstance(tmm.load_from_zip(path), tmm.TrxFile)
@pytest.mark.parametrize("path", [("small_fldr.trx")])
def test_load_directory(path):
- path = os.path.join(get_home(), 'memmap_test_data', path)
+ path = os.path.join(get_home(), "memmap_test_data", path)
assert isinstance(tmm.load_from_directory(path), tmm.TrxFile)
@pytest.mark.parametrize("path", [("small.trx")])
def test_concatenate(path):
- path = os.path.join(get_home(), 'memmap_test_data', path)
+ path = os.path.join(get_home(), "memmap_test_data", path)
trx1 = tmm.load(path)
trx2 = tmm.load(path)
concat = tmm.concatenate([trx1, trx2])
@@ -196,10 +189,9 @@ def test_concatenate(path):
@pytest.mark.parametrize("path", [("small.trx")])
def test_resize(path):
- path = os.path.join(get_home(), 'memmap_test_data', path)
+ path = os.path.join(get_home(), "memmap_test_data", path)
trx1 = tmm.load(path)
- concat = tmm.TrxFile(nb_vertices=1000000, nb_streamlines=10000,
- init_as=trx1)
+ concat = tmm.TrxFile(nb_vertices=1000000, nb_streamlines=10000, init_as=trx1)
tmm.concatenate([concat, trx1], preallocation=True, delete_groups=True)
concat.resize()
@@ -209,18 +201,11 @@ def test_resize(path):
concat.close()
-@pytest.mark.parametrize(
- "path, buffer",
- [
- ("small.trx", 10000),
- ("small.trx", 0)
- ]
-)
+@pytest.mark.parametrize("path, buffer", [("small.trx", 10000), ("small.trx", 0)])
def test_append(path, buffer):
- path = os.path.join(get_home(), 'memmap_test_data', path)
+ path = os.path.join(get_home(), "memmap_test_data", path)
trx1 = tmm.load(path)
- concat = tmm.TrxFile(nb_vertices=1, nb_streamlines=1,
- init_as=trx1)
+ concat = tmm.TrxFile(nb_vertices=1, nb_streamlines=1, init_as=trx1)
concat.append(trx1, extra_buffer=buffer)
if buffer > 0:
@@ -234,7 +219,7 @@ def test_append(path, buffer):
@pytest.mark.parametrize("path, buffer", [("small.trx", 10000)])
@pytest.mark.skipif(not dipy_available, reason="Dipy is not installed")
def test_append_StatefulTractogram(path, buffer):
- path = os.path.join(get_home(), 'memmap_test_data', path)
+ path = os.path.join(get_home(), "memmap_test_data", path)
trx = tmm.load(path)
obj = trx.to_sft()
concat = tmm.TrxFile(nb_vertices=1, nb_streamlines=1, init_as=trx)
@@ -250,7 +235,7 @@ def test_append_StatefulTractogram(path, buffer):
@pytest.mark.parametrize("path, buffer", [("small.trx", 10000)])
def test_append_Tractogram(path, buffer):
- path = os.path.join(get_home(), 'memmap_test_data', path)
+ path = os.path.join(get_home(), "memmap_test_data", path)
trx = tmm.load(path)
obj = trx.to_tractogram()
concat = tmm.TrxFile(nb_vertices=1, nb_streamlines=1, init_as=trx)
@@ -264,12 +249,17 @@ def test_append_Tractogram(path, buffer):
concat.close()
-@pytest.mark.parametrize("path, size, buffer", [("small.trx", 50, 10000),
- ("small.trx", 0, 10000),
- ("small.trx", 25000, 10000),
- ("small.trx", 50, 0),
- ("small.trx", 0, 0),
- ("small.trx", 25000, 10000)])
+@pytest.mark.parametrize(
+ "path, size, buffer",
+ [
+ ("small.trx", 50, 10000),
+ ("small.trx", 0, 10000),
+ ("small.trx", 25000, 10000),
+ ("small.trx", 50, 0),
+ ("small.trx", 0, 0),
+ ("small.trx", 25000, 10000),
+ ],
+)
def test_from_lazy_tractogram(path, size, buffer):
_ = np.random.RandomState(1776)
streamlines = []
@@ -281,32 +271,36 @@ def test_from_lazy_tractogram(path, size, buffer):
data = make_dummy_streamline(i)
streamline, data_per_point, data_for_streamline = data
streamlines.append(streamline)
- fa.append(data_per_point['fa'].astype(np.float16))
- commit_weights.append(
- data_for_streamline['mean_curvature'].astype(np.float32))
- clusters_QB.append(
- data_for_streamline['mean_torsion'].astype(np.uint16))
-
- def streamlines_func(): return (e for e in streamlines)
- data_per_point_func = {'fa': lambda: (e for e in fa)}
+ fa.append(data_per_point["fa"].astype(np.float16))
+ commit_weights.append(data_for_streamline["mean_curvature"].astype(np.float32))
+ clusters_QB.append(data_for_streamline["mean_torsion"].astype(np.uint16))
+
+ def streamlines_func():
+ return (e for e in streamlines)
+
+ data_per_point_func = {"fa": lambda: (e for e in fa)}
data_per_streamline_func = {
- 'commit_weights': lambda: (e for e in commit_weights),
- 'clusters_QB': lambda: (e for e in clusters_QB)}
-
- obj = LazyTractogram(streamlines_func,
- data_per_streamline_func,
- data_per_point_func,
- affine_to_rasmm=np.eye(4))
-
- dtype_dict = {'positions': np.float32, 'offsets': np.uint32,
- 'dpv': {'fa': np.float16},
- 'dps': {'commit_weights': np.float32,
- 'clusters_QB': np.uint16}}
- path = os.path.join(get_home(), 'memmap_test_data', path)
- trx = tmm.TrxFile.from_lazy_tractogram(obj, reference=path,
- extra_buffer=buffer,
- chunk_size=1000,
- dtype_dict=dtype_dict)
+ "commit_weights": lambda: (e for e in commit_weights),
+ "clusters_QB": lambda: (e for e in clusters_QB),
+ }
+
+ obj = LazyTractogram(
+ streamlines_func,
+ data_per_streamline_func,
+ data_per_point_func,
+ affine_to_rasmm=np.eye(4),
+ )
+
+ dtype_dict = {
+ "positions": np.float32,
+ "offsets": np.uint32,
+ "dpv": {"fa": np.float16},
+ "dps": {"commit_weights": np.float32, "clusters_QB": np.uint16},
+ }
+ path = os.path.join(get_home(), "memmap_test_data", path)
+ trx = tmm.TrxFile.from_lazy_tractogram(
+ obj, reference=path, extra_buffer=buffer, chunk_size=1000, dtype_dict=dtype_dict
+ )
assert len(trx) == len(gen_range)
diff --git a/trx/tests/test_streamlines_ops.py b/trx/tests/test_streamlines_ops.py
index 3cd2678..268606b 100644
--- a/trx/tests/test_streamlines_ops.py
+++ b/trx/tests/test_streamlines_ops.py
@@ -3,11 +3,17 @@
import numpy as np
import pytest
-from trx.streamlines_ops import (perform_streamlines_operation,
- intersection, union, difference)
+from trx.streamlines_ops import (
+ difference,
+ intersection,
+ perform_streamlines_operation,
+ union,
+)
-streamlines_ori = [np.ones(90).reshape((30, 3)),
- np.arange(90).reshape((30, 3)) + 0.3333]
+streamlines_ori = [
+ np.ones(90).reshape((30, 3)),
+ np.arange(90).reshape((30, 3)) + 0.3333,
+]
@pytest.mark.parametrize(
@@ -27,16 +33,15 @@ def test_intersection(precision, noise, expected):
streamlines_new = []
for i in range(5):
if i < 4:
- streamlines_new.append(streamlines_ori[1] +
- np.random.random((30, 3)))
+ streamlines_new.append(streamlines_ori[1] + np.random.random((30, 3)))
else:
- streamlines_new.append(streamlines_ori[1] +
- noise * np.random.random((30, 3)))
+ streamlines_new.append(
+ streamlines_ori[1] + noise * np.random.random((30, 3))
+ )
# print(streamlines_new)
- _, indices_uniq = perform_streamlines_operation(intersection,
- [streamlines_new,
- streamlines_ori],
- precision=precision)
+ _, indices_uniq = perform_streamlines_operation(
+ intersection, [streamlines_new, streamlines_ori], precision=precision
+ )
indices_uniq = indices_uniq.tolist()
assert indices_uniq == expected
@@ -58,16 +63,15 @@ def test_union(precision, noise, expected):
streamlines_new = []
for i in range(5):
if i < 4:
- streamlines_new.append(streamlines_ori[1] +
- np.random.random((30, 3)))
+ streamlines_new.append(streamlines_ori[1] + np.random.random((30, 3)))
else:
- streamlines_new.append(streamlines_ori[1] +
- noise * np.random.random((30, 3)))
+ streamlines_new.append(
+ streamlines_ori[1] + noise * np.random.random((30, 3))
+ )
- unique_streamlines, _ = perform_streamlines_operation(union,
- [streamlines_new,
- streamlines_ori],
- precision=precision)
+ unique_streamlines, _ = perform_streamlines_operation(
+ union, [streamlines_new, streamlines_ori], precision=precision
+ )
assert len(unique_streamlines) == expected
@@ -88,14 +92,13 @@ def test_difference(precision, noise, expected):
streamlines_new = []
for i in range(5):
if i < 4:
- streamlines_new.append(streamlines_ori[1] +
- np.random.random((30, 3)))
+ streamlines_new.append(streamlines_ori[1] + np.random.random((30, 3)))
else:
- streamlines_new.append(streamlines_ori[1] +
- noise * np.random.random((30, 3)))
+ streamlines_new.append(
+ streamlines_ori[1] + noise * np.random.random((30, 3))
+ )
- unique_streamlines, _ = perform_streamlines_operation(difference,
- [streamlines_new,
- streamlines_ori],
- precision=precision)
+ unique_streamlines, _ = perform_streamlines_operation(
+ difference, [streamlines_new, streamlines_ori], precision=precision
+ )
assert len(unique_streamlines) == expected
diff --git a/trx/trx_file_memmap.py b/trx/trx_file_memmap.py
index c617ebb..5a70e82 100644
--- a/trx/trx_file_memmap.py
+++ b/trx/trx_file_memmap.py
@@ -5,7 +5,7 @@
import logging
import os
import shutil
-from typing import Any, List, Tuple, Type, Union, Optional
+from typing import Any, List, Optional, Tuple, Type, Union
import zipfile
import nibabel as nib
@@ -13,18 +13,21 @@
from nibabel.nifti1 import Nifti1Header, Nifti1Image
from nibabel.orientations import aff2axcodes
from nibabel.streamlines.array_sequence import ArraySequence
+from nibabel.streamlines.tractogram import LazyTractogram, Tractogram
from nibabel.streamlines.trk import TrkFile
-from nibabel.streamlines.tractogram import Tractogram, LazyTractogram
import numpy as np
from trx.io import get_trx_tmp_dir
-from trx.utils import (append_generator_to_dict,
- close_or_delete_mmap,
- convert_data_dict_to_tractogram,
- get_reference_info_wrapper)
+from trx.utils import (
+ append_generator_to_dict,
+ close_or_delete_mmap,
+ convert_data_dict_to_tractogram,
+ get_reference_info_wrapper,
+)
try:
- import dipy
+ import dipy # noqa: F401
+
dipy_available = True
except ImportError:
dipy_available = False
@@ -42,9 +45,12 @@ def _append_last_offsets(nib_offsets: np.ndarray, nb_vertices: int) -> np.ndarra
Returns:
Offsets -- np.ndarray (VTK convention)
"""
- def is_sorted(a): return np.all(a[:-1] <= a[1:])
+
+ def is_sorted(a):
+ return np.all(a[:-1] <= a[1:])
+
if not is_sorted(nib_offsets):
- raise ValueError('Offsets must be sorted values.')
+ raise ValueError("Offsets must be sorted values.")
return np.append(nib_offsets, nb_vertices).astype(nib_offsets.dtype)
@@ -64,7 +70,7 @@ def _generate_filename_from_data(arr: np.ndarray, filename: str) -> str:
logging.warning("Will overwrite provided extension if needed.")
dtype = arr.dtype
- dtype = "bit" if dtype == bool else dtype.name
+ dtype = "bit" if dtype is np.dtype(bool) else dtype.name
if arr.ndim == 1:
new_filename = "{}.{}".format(base, dtype)
@@ -196,8 +202,7 @@ def _create_memmap(
if shape[0]:
return np.memmap(
- filename, mode=mode, offset=offset, shape=shape, dtype=dtype,
- order=order
+ filename, mode=mode, offset=offset, shape=shape, dtype=dtype, order=order
)
else:
if not os.path.isfile(filename):
@@ -233,8 +238,7 @@ def load(input_obj: str, check_dpg: bool = True) -> Type["TrxFile"]:
trx = load_from_directory(tmp_dir.name)
trx._uncompressed_folder_handle = tmp_dir
logging.info(
- "File was compressed, call the close() function before "
- "exiting."
+ "File was compressed, call the close() function before exiting."
)
else:
trx = load_from_zip(input_obj)
@@ -248,8 +252,7 @@ def load(input_obj: str, check_dpg: bool = True) -> Type["TrxFile"]:
for dpg in trx.data_per_group.keys():
if dpg not in trx.groups.keys():
raise ValueError(
- "An undeclared group ({}) has " "data_per_group.".format(
- dpg)
+ "An undeclared group ({}) has data_per_group.".format(dpg)
)
return trx
@@ -270,8 +273,7 @@ def load_from_zip(filename: str) -> Type["TrxFile"]:
header["VOXEL_TO_RASMM"] = np.reshape(
header["VOXEL_TO_RASMM"], (4, 4)
).astype(np.float32)
- header["DIMENSIONS"] = np.array(
- header["DIMENSIONS"], dtype=np.uint16)
+ header["DIMENSIONS"] = np.array(header["DIMENSIONS"], dtype=np.uint16)
files_pointer_size = {}
for zip_info in zf.filelist:
@@ -282,8 +284,7 @@ def load_from_zip(filename: str) -> Type["TrxFile"]:
if not _is_dtype_valid(ext):
continue
- raise ValueError(
- "The dtype {} is not supported".format(elem_filename))
+ raise ValueError("The dtype {} is not supported".format(elem_filename))
if ext == ".bit":
ext = ".bool"
@@ -318,11 +319,12 @@ def load_from_directory(directory: str) -> Type["TrxFile"]:
directory = os.path.abspath(directory)
with open(os.path.join(directory, "header.json")) as header:
header = json.load(header)
- header["VOXEL_TO_RASMM"] = np.reshape(header["VOXEL_TO_RASMM"],
- (4, 4)).astype(np.float32)
+ header["VOXEL_TO_RASMM"] = np.reshape(header["VOXEL_TO_RASMM"], (4, 4)).astype(
+ np.float32
+ )
header["DIMENSIONS"] = np.array(header["DIMENSIONS"], dtype=np.uint16)
files_pointer_size = {}
- for root, dirs, files in os.walk(directory):
+ for root, _dirs, files in os.walk(directory):
for name in files:
elem_filename = os.path.join(root, name)
_, ext = os.path.splitext(elem_filename)
@@ -346,78 +348,51 @@ def load_from_directory(directory: str) -> Type["TrxFile"]:
else:
raise ValueError("Wrong size or datatype")
- return TrxFile._create_trx_from_pointer(header, files_pointer_size,
- root=directory)
-
+ return TrxFile._create_trx_from_pointer(header, files_pointer_size, root=directory)
-def concatenate(
- trx_list: List["TrxFile"],
- delete_dpv: bool = False,
- delete_dps: bool = False,
- delete_groups: bool = False,
- check_space_attributes: bool = True,
- preallocation: bool = False,
-) -> "TrxFile":
- """Concatenate multiple TrxFile together, support preallocation
-
- Keyword arguments:
- trx_list -- A list containing TrxFiles to concatenate
- delete_dpv -- Delete dpv keys that do not exist in all the provided
- TrxFiles
- delete_dps -- Delete dps keys that do not exist in all the provided
- TrxFile
- delete_groups -- Delete all the groups that currently exist in the
- TrxFiles
- check_space_attributes -- Verify that dimensions and size of data are
- similar between all the TrxFiles
- preallocation -- Preallocated TrxFile has already been generated and
- is the first element in trx_list
- (Note: delete_groups must be set to True as well)
- Returns:
- TrxFile representing the concatenated data
+def _filter_empty_trx_files(trx_list: List["TrxFile"]) -> List["TrxFile"]:
+ """Remove empty TrxFiles from the list."""
+ return [curr_trx for curr_trx in trx_list if curr_trx.header["NB_STREAMLINES"] > 0]
- """
- trx_list = [
- curr_trx for curr_trx in trx_list if curr_trx.header["NB_STREAMLINES"] > 0
- ]
- if len(trx_list) == 0:
- logging.warning("Inputs of concatenation were empty.")
- return TrxFile()
- ref_trx = trx_list[0]
+def _get_all_data_keys(trx_list: List["TrxFile"]) -> Tuple[set, set]:
+ """Get all dps and dpv keys from the TrxFile list."""
all_dps = []
all_dpv = []
for curr_trx in trx_list:
all_dps.extend(list(curr_trx.data_per_streamline.keys()))
all_dpv.extend(list(curr_trx.data_per_vertex.keys()))
- all_dps, all_dpv = set(all_dps), set(all_dpv)
+ return set(all_dps), set(all_dpv)
- if check_space_attributes:
- for curr_trx in trx_list[1:]:
- if not np.allclose(
- ref_trx.header["VOXEL_TO_RASMM"], curr_trx.header["VOXEL_TO_RASMM"]
- ) or not np.array_equal(
- ref_trx.header["DIMENSIONS"], curr_trx.header["DIMENSIONS"]
- ):
- raise ValueError("Wrong space attributes.")
- if preallocation and not delete_groups:
- raise ValueError(
- "Groups are variables, cannot be handled with " "preallocation"
- )
+def _check_space_attributes(trx_list: List["TrxFile"]) -> None:
+ """Verify that space attributes are consistent across TrxFiles."""
+ ref_trx = trx_list[0]
+ for curr_trx in trx_list[1:]:
+ if not np.allclose(
+ ref_trx.header["VOXEL_TO_RASMM"], curr_trx.header["VOXEL_TO_RASMM"]
+ ) or not np.array_equal(
+ ref_trx.header["DIMENSIONS"], curr_trx.header["DIMENSIONS"]
+ ):
+ raise ValueError("Wrong space attributes.")
- # Verifying the validity of fixed-size arrays, coherence between inputs
+
+def _verify_dpv_coherence(
+ trx_list: List["TrxFile"], all_dpv: set, ref_trx: "TrxFile", delete_dpv: bool
+) -> None:
+ """Verify dpv coherence across TrxFiles."""
for curr_trx in trx_list:
for key in all_dpv:
- if key not in ref_trx.data_per_vertex.keys() \
- or key not in curr_trx.data_per_vertex.keys():
+ if (
+ key not in ref_trx.data_per_vertex.keys()
+ or key not in curr_trx.data_per_vertex.keys()
+ ):
if not delete_dpv:
logging.debug(
"{} dpv key does not exist in all TrxFile.".format(key)
)
- raise ValueError(
- "TrxFile must be sharing identical dpv " "keys.")
+ raise ValueError("TrxFile must be sharing identical dpv keys.")
elif (
ref_trx.data_per_vertex[key]._data.dtype
!= curr_trx.data_per_vertex[key]._data.dtype
@@ -428,17 +403,22 @@ def concatenate(
)
raise ValueError("Shared dpv key, has different dtype.")
+
+def _verify_dps_coherence(
+ trx_list: List["TrxFile"], all_dps: set, ref_trx: "TrxFile", delete_dps: bool
+) -> None:
+ """Verify dps coherence across TrxFiles."""
for curr_trx in trx_list:
for key in all_dps:
- if key not in ref_trx.data_per_streamline.keys() \
- or key not in curr_trx.data_per_streamline.keys():
+ if (
+ key not in ref_trx.data_per_streamline.keys()
+ or key not in curr_trx.data_per_streamline.keys()
+ ):
if not delete_dps:
logging.debug(
- "{} dps key does not exist in all " "TrxFile.".format(
- key)
+ "{} dps key does not exist in all TrxFile.".format(key)
)
- raise ValueError(
- "TrxFile must be sharing identical dps " "keys.")
+ raise ValueError("TrxFile must be sharing identical dps keys.")
elif (
ref_trx.data_per_streamline[key].dtype
!= curr_trx.data_per_streamline[key].dtype
@@ -449,79 +429,156 @@ def concatenate(
)
raise ValueError("Shared dps key, has different dtype.")
+
+def _compute_groups_info(trx_list: List["TrxFile"]) -> Tuple[dict, dict]:
+ """Compute group length and dtype information."""
all_groups_len = {}
all_groups_dtype = {}
- # Variable-size arrays do not have to exist in all TrxFile
- if not delete_groups:
- for trx_1 in trx_list:
- for group_key in trx_1.groups.keys():
- # Concatenating groups together
- if group_key in all_groups_len:
- all_groups_len[group_key] += len(trx_1.groups[group_key])
- else:
- all_groups_len[group_key] = len(trx_1.groups[group_key])
- if (
- group_key in all_groups_dtype
- and trx_1.groups[group_key].dtype != all_groups_dtype[group_key]
- ):
- raise ValueError("Shared group key, has different dtype.")
- else:
- all_groups_dtype[group_key] = trx_1.groups[group_key].dtype
- # Once the checks are done, actually concatenate
- to_concat_list = trx_list[1:] if preallocation else trx_list
- if not preallocation:
- nb_vertices = 0
- nb_streamlines = 0
- for curr_trx in to_concat_list:
- curr_strs_len, curr_pts_len = curr_trx._get_real_len()
- nb_streamlines += curr_strs_len
- nb_vertices += curr_pts_len
-
- new_trx = TrxFile(
- nb_vertices=nb_vertices, nb_streamlines=nb_streamlines,
- init_as=ref_trx
+ for trx_1 in trx_list:
+ for group_key in trx_1.groups.keys():
+ if group_key in all_groups_len:
+ all_groups_len[group_key] += len(trx_1.groups[group_key])
+ else:
+ all_groups_len[group_key] = len(trx_1.groups[group_key])
+
+ if (
+ group_key in all_groups_dtype
+ and trx_1.groups[group_key].dtype != all_groups_dtype[group_key]
+ ):
+ raise ValueError("Shared group key, has different dtype.")
+ else:
+ all_groups_dtype[group_key] = trx_1.groups[group_key].dtype
+
+ return all_groups_len, all_groups_dtype
+
+
+def _create_new_trx_for_concatenation(
+ trx_list: List["TrxFile"],
+ ref_trx: "TrxFile",
+ delete_dps: bool,
+ delete_dpv: bool,
+ delete_groups: bool,
+) -> "TrxFile":
+ """Create a new TrxFile for concatenation."""
+ nb_vertices = 0
+ nb_streamlines = 0
+ for curr_trx in trx_list:
+ curr_strs_len, curr_pts_len = curr_trx._get_real_len()
+ nb_streamlines += curr_strs_len
+ nb_vertices += curr_pts_len
+
+ new_trx = TrxFile(
+ nb_vertices=nb_vertices, nb_streamlines=nb_streamlines, init_as=ref_trx
+ )
+ if delete_dps:
+ new_trx.data_per_streamline = {}
+ if delete_dpv:
+ new_trx.data_per_vertex = {}
+ if delete_groups:
+ new_trx.groups = {}
+
+ return new_trx
+
+
+def _setup_groups_for_concatenation(
+ new_trx: "TrxFile",
+ trx_list: List["TrxFile"],
+ all_groups_len: dict,
+ all_groups_dtype: dict,
+ delete_groups: bool,
+) -> None:
+ """Setup groups in the new TrxFile for concatenation."""
+ if delete_groups:
+ return
+
+ tmp_dir = new_trx._uncompressed_folder_handle.name
+
+ for group_key in all_groups_len.keys():
+ if not os.path.isdir(os.path.join(tmp_dir, "groups/")):
+ os.mkdir(os.path.join(tmp_dir, "groups/"))
+
+ dtype = all_groups_dtype[group_key]
+ group_filename = os.path.join(
+ tmp_dir, "groups/{}.{}".format(group_key, dtype.name)
)
- if delete_dps:
- new_trx.data_per_streamline = {}
- if delete_dpv:
- new_trx.data_per_vertex = {}
- if delete_groups:
- new_trx.groups = {}
-
- tmp_dir = new_trx._uncompressed_folder_handle.name
-
- # When memory is allocated on the spot, groups and data_per_group can
- # be concatenated together
- for group_key in all_groups_len.keys():
- if not os.path.isdir(os.path.join(tmp_dir, "groups/")):
- os.mkdir(os.path.join(tmp_dir, "groups/"))
- dtype = all_groups_dtype[group_key]
- group_filename = os.path.join(
- tmp_dir, "groups/" "{}.{}".format(group_key, dtype.name)
- )
- group_len = all_groups_len[group_key]
- new_trx.groups[group_key] = _create_memmap(
- group_filename, mode="w+", shape=(group_len,), dtype=dtype
+ group_len = all_groups_len[group_key]
+ new_trx.groups[group_key] = _create_memmap(
+ group_filename, mode="w+", shape=(group_len,), dtype=dtype
+ )
+
+ pos = 0
+ count = 0
+ for curr_trx in trx_list:
+ curr_len = len(curr_trx.groups[group_key])
+ new_trx.groups[group_key][pos : pos + curr_len] = (
+ curr_trx.groups[group_key] + count
)
- if delete_groups:
- continue
- pos = 0
- count = 0
- for curr_trx in trx_list:
- curr_len = len(curr_trx.groups[group_key])
- new_trx.groups[group_key][pos: pos + curr_len] = \
- curr_trx.groups[group_key] + count
- pos += curr_len
- count += curr_trx.header["NB_STREAMLINES"]
+ pos += curr_len
+ count += curr_trx.header["NB_STREAMLINES"]
+
+def concatenate(
+ trx_list: List["TrxFile"],
+ delete_dpv: bool = False,
+ delete_dps: bool = False,
+ delete_groups: bool = False,
+ check_space_attributes: bool = True,
+ preallocation: bool = False,
+) -> "TrxFile":
+ """Concatenate multiple TrxFile together, support preallocation
+
+ Keyword arguments:
+ trx_list -- A list containing TrxFiles to concatenate
+ delete_dpv -- Delete dpv keys that do not exist in all the provided
+ TrxFiles
+ delete_dps -- Delete dps keys that do not exist in all the provided
+ TrxFile
+ delete_groups -- Delete all the groups that currently exist in the
+ TrxFiles
+ check_space_attributes -- Verify that dimensions and size of data are
+ similar between all the TrxFiles
+ preallocation -- Preallocated TrxFile has already been generated and
+ is the first element in trx_list
+ (Note: delete_groups must be set to True as well)
+
+ Returns:
+ TrxFile representing the concatenated data
+
+ """
+ trx_list = _filter_empty_trx_files(trx_list)
+ if len(trx_list) == 0:
+ logging.warning("Inputs of concatenation were empty.")
+ return TrxFile()
+
+ ref_trx = trx_list[0]
+ all_dps, all_dpv = _get_all_data_keys(trx_list)
+
+ if check_space_attributes:
+ _check_space_attributes(trx_list)
+
+ if preallocation and not delete_groups:
+ raise ValueError("Groups are variables, cannot be handled with preallocation")
+
+ _verify_dpv_coherence(trx_list, all_dpv, ref_trx, delete_dpv)
+ _verify_dps_coherence(trx_list, all_dps, ref_trx, delete_dps)
+
+ all_groups_len, all_groups_dtype = _compute_groups_info(trx_list)
+
+ to_concat_list = trx_list[1:] if preallocation else trx_list
+ if not preallocation:
+ new_trx = _create_new_trx_for_concatenation(
+ to_concat_list, ref_trx, delete_dps, delete_dpv, delete_groups
+ )
+ _setup_groups_for_concatenation(
+ new_trx, trx_list, all_groups_len, all_groups_dtype, delete_groups
+ )
strs_end, pts_end = 0, 0
else:
new_trx = ref_trx
strs_end, pts_end = new_trx._get_real_len()
for curr_trx in to_concat_list:
- # Copy the TrxFile fixed-size info (the right chunk)
strs_end, pts_end = new_trx._copy_fixed_arrays_from(
curr_trx, strs_start=strs_end, pts_start=pts_end
)
@@ -623,10 +680,9 @@ def __init__(
if nb_vertices is None and nb_streamlines is None:
if init_as is not None:
raise ValueError(
- "Cant use init_as without declaring "
- "nb_vertices AND nb_streamlines"
+ "Can't use init_as without declaring nb_vertices AND nb_streamlines"
)
- logging.debug("Intializing empty TrxFile.")
+ logging.debug("Initializing empty TrxFile.")
self.header = {}
# Using the new format default type
tmp_strs = ArraySequence()
@@ -645,16 +701,16 @@ def __init__(
elif nb_vertices is not None and nb_streamlines is not None:
logging.debug(
- "Preallocating TrxFile with size {} streamlines"
- "and {} vertices.".format(nb_streamlines, nb_vertices)
+ "Preallocating TrxFile with size {} streamlinesand {} vertices.".format(
+ nb_streamlines, nb_vertices
+ )
)
trx = self._initialize_empty_trx(
nb_streamlines, nb_vertices, init_as=init_as
)
self.__dict__ = trx.__dict__
else:
- raise ValueError(
- "You must declare both nb_vertices AND " "NB_STREAMLINES")
+ raise ValueError("You must declare both nb_vertices AND NB_STREAMLINES")
self.header["VOXEL_TO_RASMM"] = affine
self.header["DIMENSIONS"] = dimensions
@@ -670,13 +726,11 @@ def __str__(self) -> str:
vox_order = "".join(aff2axcodes(affine))
text = "VOXEL_TO_RASMM: \n{}".format(
- np.array2string(affine, formatter={
- "float_kind": lambda x: "%.6f" % x})
+ np.array2string(affine, formatter={"float_kind": lambda x: "%.6f" % x})
)
text += "\nDIMENSIONS: {}".format(np.array2string(dimensions))
text += "\nVOX_SIZES: {}".format(
- np.array2string(vox_sizes, formatter={
- "float_kind": lambda x: "%.2f" % x})
+ np.array2string(vox_sizes, formatter={"float_kind": lambda x: "%.2f" % x})
)
text += "\nVOX_ORDER: {}".format(vox_order)
@@ -690,8 +744,7 @@ def __str__(self) -> str:
text += "\nstreamline_count: {}".format(strs_len)
text += "\nvertex_count: {}".format(pts_len)
- text += "\ndata_per_vertex keys: {}".format(
- list(self.data_per_vertex.keys()))
+ text += "\ndata_per_vertex keys: {}".format(list(self.data_per_vertex.keys()))
text += "\ndata_per_streamline keys: {}".format(
list(self.data_per_streamline.keys())
)
@@ -718,14 +771,14 @@ def __getitem__(self, key) -> Any:
key += len(self)
key = [key]
elif isinstance(key, slice):
- key = [ii for ii in range(*key.indices(len(self)))]
+ key = list(range(*key.indices(len(self))))
return self.select(key, keep_group=False)
def __deepcopy__(self) -> Type["TrxFile"]:
return self.deepcopy()
- def deepcopy(self) -> Type["TrxFile"]:
+ def deepcopy(self) -> Type["TrxFile"]: # noqa: C901
"""Create a deepcopy of the TrxFile
Returns
@@ -740,7 +793,7 @@ def deepcopy(self) -> Type["TrxFile"]:
if not isinstance(tmp_header["DIMENSIONS"], list):
tmp_header["DIMENSIONS"] = tmp_header["DIMENSIONS"].tolist()
- # tofile() alway write in C-order
+ # tofile() always write in C-order
if not self._copy_safe:
to_dump = self.streamlines.copy()._data
tmp_header["NB_STREAMLINES"] = len(self.streamlines)
@@ -756,11 +809,13 @@ def deepcopy(self) -> Type["TrxFile"]:
to_dump.tofile(positions_filename)
if not self._copy_safe:
- to_dump = _append_last_offsets(self.streamlines.copy()._offsets,
- self.header["NB_VERTICES"])
+ to_dump = _append_last_offsets(
+ self.streamlines.copy()._offsets, self.header["NB_VERTICES"]
+ )
else:
- to_dump = _append_last_offsets(self.streamlines._offsets,
- self.header["NB_VERTICES"])
+ to_dump = _append_last_offsets(
+ self.streamlines._offsets, self.header["NB_VERTICES"]
+ )
offsets_filename = _generate_filename_from_data(
self.streamlines._offsets, os.path.join(tmp_dir.name, "offsets")
)
@@ -807,8 +862,7 @@ def deepcopy(self) -> Type["TrxFile"]:
os.mkdir(os.path.join(tmp_dir.name, "dpg/", group_key))
to_dump = self.data_per_group[group_key][dpg_key]
dpg_filename = _generate_filename_from_data(
- to_dump, os.path.join(
- tmp_dir.name, "dpg/", group_key, dpg_key)
+ to_dump, os.path.join(tmp_dir.name, "dpg/", group_key, dpg_key)
)
to_dump.tofile(dpg_filename)
@@ -869,32 +923,37 @@ def _copy_fixed_arrays_from(
return strs_start, pts_start
# Mandatory arrays
- self.streamlines._data[pts_start:pts_end] = \
- trx.streamlines._data[0:curr_pts_len]
- self.streamlines._offsets[strs_start:strs_end] = \
- (trx.streamlines._offsets[0:curr_strs_len] + pts_start)
- self.streamlines._lengths[strs_start:strs_end] = \
- trx.streamlines._lengths[0:curr_strs_len]
+ self.streamlines._data[pts_start:pts_end] = trx.streamlines._data[
+ 0:curr_pts_len
+ ]
+ self.streamlines._offsets[strs_start:strs_end] = (
+ trx.streamlines._offsets[0:curr_strs_len] + pts_start
+ )
+ self.streamlines._lengths[strs_start:strs_end] = trx.streamlines._lengths[
+ 0:curr_strs_len
+ ]
# Optional fixed-sized arrays
for dpv_key in self.data_per_vertex.keys():
- self.data_per_vertex[dpv_key]._data[
- pts_start:pts_end
- ] = trx.data_per_vertex[dpv_key]._data[0:curr_pts_len]
+ self.data_per_vertex[dpv_key]._data[pts_start:pts_end] = (
+ trx.data_per_vertex[dpv_key]._data[0:curr_pts_len]
+ )
self.data_per_vertex[dpv_key]._offsets = self.streamlines._offsets
self.data_per_vertex[dpv_key]._lengths = self.streamlines._lengths
for dps_key in self.data_per_streamline.keys():
- self.data_per_streamline[dps_key][
- strs_start:strs_end
- ] = trx.data_per_streamline[dps_key][0:curr_strs_len]
+ self.data_per_streamline[dps_key][strs_start:strs_end] = (
+ trx.data_per_streamline[dps_key][0:curr_strs_len]
+ )
return strs_end, pts_end
@staticmethod
- def _initialize_empty_trx(
- nb_streamlines: int, nb_vertices: int,
- init_as: Optional[Type["TrxFile"]] = None) -> Type["TrxFile"]:
+ def _initialize_empty_trx( # noqa: C901
+ nb_streamlines: int,
+ nb_vertices: int,
+ init_as: Optional[Type["TrxFile"]] = None,
+ ) -> Type["TrxFile"]:
"""Create on-disk memmaps of a certain size (preallocation)
Keyword arguments:
@@ -926,29 +985,24 @@ def _initialize_empty_trx(
lengths_dtype = np.dtype(np.uint32)
logging.debug(
- "Initializing positions with dtype: {}".format(
- positions_dtype.name)
+ "Initializing positions with dtype: {}".format(positions_dtype.name)
)
- logging.debug(
- "Initializing offsets with dtype: {}".format(offsets_dtype.name))
- logging.debug(
- "Initializing lengths with dtype: {}".format(lengths_dtype.name))
+ logging.debug("Initializing offsets with dtype: {}".format(offsets_dtype.name))
+ logging.debug("Initializing lengths with dtype: {}".format(lengths_dtype.name))
# A TrxFile without init_as only contain the essential arrays
positions_filename = os.path.join(
tmp_dir.name, "positions.3.{}".format(positions_dtype.name)
)
trx.streamlines._data = _create_memmap(
- positions_filename, mode="w+", shape=(nb_vertices, 3),
- dtype=positions_dtype
+ positions_filename, mode="w+", shape=(nb_vertices, 3), dtype=positions_dtype
)
offsets_filename = os.path.join(
tmp_dir.name, "offsets.{}".format(offsets_dtype.name)
)
trx.streamlines._offsets = _create_memmap(
- offsets_filename, mode="w+", shape=(nb_streamlines,),
- dtype=offsets_dtype
+ offsets_filename, mode="w+", shape=(nb_streamlines,), dtype=offsets_dtype
)
trx.streamlines._lengths = np.zeros(
shape=(nb_streamlines,), dtype=lengths_dtype
@@ -966,23 +1020,20 @@ def _initialize_empty_trx(
tmp_as = init_as.data_per_vertex[dpv_key]._data
if tmp_as.ndim == 1:
dpv_filename = os.path.join(
- tmp_dir.name, "dpv/" "{}.{}".format(
- dpv_key, dtype.name)
+ tmp_dir.name, "dpv/{}.{}".format(dpv_key, dtype.name)
)
shape = (nb_vertices, 1)
elif tmp_as.ndim == 2:
dim = tmp_as.shape[-1]
shape = (nb_vertices, dim)
dpv_filename = os.path.join(
- tmp_dir.name, "dpv/" "{}.{}.{}".format(
- dpv_key, dim, dtype.name)
+ tmp_dir.name, "dpv/{}.{}.{}".format(dpv_key, dim, dtype.name)
)
else:
raise ValueError("Invalid dimensionality.")
logging.debug(
- "Initializing {} (dpv) with dtype: "
- "{}".format(dpv_key, dtype.name)
+ "Initializing {} (dpv) with dtype: {}".format(dpv_key, dtype.name)
)
trx.data_per_vertex[dpv_key] = ArraySequence()
trx.data_per_vertex[dpv_key]._data = _create_memmap(
@@ -996,23 +1047,22 @@ def _initialize_empty_trx(
tmp_as = init_as.data_per_streamline[dps_key]
if tmp_as.ndim == 1:
dps_filename = os.path.join(
- tmp_dir.name, "dps/" "{}.{}".format(
- dps_key, dtype.name)
+ tmp_dir.name, "dps/{}.{}".format(dps_key, dtype.name)
)
shape = (nb_streamlines,)
elif tmp_as.ndim == 2:
dim = tmp_as.shape[-1]
shape = (nb_streamlines, dim)
dps_filename = os.path.join(
- tmp_dir.name, "dps/" "{}.{}.{}".format(
- dps_key, dim, dtype.name)
+ tmp_dir.name, "dps/{}.{}.{}".format(dps_key, dim, dtype.name)
)
else:
raise ValueError("Invalid dimensionality.")
logging.debug(
- "Initializing {} (dps) with and dtype: "
- "{}".format(dps_key, dtype.name)
+ "Initializing {} (dps) with and dtype: {}".format(
+ dps_key, dtype.name
+ )
)
trx.data_per_streamline[dps_key] = _create_memmap(
dps_filename, mode="w+", shape=shape, dtype=dtype
@@ -1022,7 +1072,7 @@ def _initialize_empty_trx(
return trx
- def _create_trx_from_pointer(
+ def _create_trx_from_pointer( # noqa: C901
header: dict,
dict_pointer_size: dict,
root_zip: Optional[str] = None,
@@ -1039,7 +1089,7 @@ def _create_trx_from_pointer(
root -- The dirname of the ZipFile pointer
Returns:
- A TrxFile constructer from the pointer provided
+ A TrxFile constructor from the pointer provided
"""
# TODO support empty positions, using optional tag?
trx = TrxFile()
@@ -1059,15 +1109,19 @@ def _create_trx_from_pointer(
if root is not None:
# This is for Unix
- if os.name != 'nt' and folder.startswith(root.rstrip("/")):
+ if os.name != "nt" and folder.startswith(root.rstrip("/")):
folder = folder.replace(root, "").lstrip("/")
# These three are for Windows
- elif os.path.isdir(folder) and os.path.basename(folder) in ['dpv', 'dps', 'groups']:
+ elif os.path.isdir(folder) and os.path.basename(folder) in [
+ "dpv",
+ "dps",
+ "groups",
+ ]:
folder = os.path.basename(folder)
- elif os.path.basename(os.path.dirname(folder)) == 'dpg':
- folder = os.path.join('dpg', os.path.basename(folder))
+ elif os.path.basename(os.path.dirname(folder)) == "dpg":
+ folder = os.path.join("dpg", os.path.basename(folder))
else:
- folder = ''
+ folder = ""
# Parse/walk the directory tree
if base == "positions" and folder == "":
@@ -1081,13 +1135,13 @@ def _create_trx_from_pointer(
dtype=ext[1:],
)
elif base == "offsets" and folder == "":
- if size != trx.header["NB_STREAMLINES"]+1 or dim != 1:
+ if size != trx.header["NB_STREAMLINES"] + 1 or dim != 1:
raise ValueError("Wrong offsets size/dimensionality.")
offsets = _create_memmap(
filename,
mode="r+",
offset=mem_adress,
- shape=(trx.header["NB_STREAMLINES"]+1,),
+ shape=(trx.header["NB_STREAMLINES"] + 1,),
dtype=ext[1:],
)
if offsets[-1] != 0:
@@ -1102,8 +1156,7 @@ def _create_trx_from_pointer(
shape = (trx.header["NB_STREAMLINES"], int(nb_scalar))
trx.data_per_streamline[base] = _create_memmap(
- filename, mode="r+", offset=mem_adress, shape=shape,
- dtype=ext[1:]
+ filename, mode="r+", offset=mem_adress, shape=shape, dtype=ext[1:]
)
elif folder == "dpv":
nb_scalar = size / trx.header["NB_VERTICES"]
@@ -1113,8 +1166,7 @@ def _create_trx_from_pointer(
shape = (trx.header["NB_VERTICES"], int(nb_scalar))
trx.data_per_vertex[base] = _create_memmap(
- filename, mode="r+", offset=mem_adress, shape=shape,
- dtype=ext[1:]
+ filename, mode="r+", offset=mem_adress, shape=shape, dtype=ext[1:]
)
elif folder.startswith("dpg"):
if int(size) != dim:
@@ -1128,8 +1180,7 @@ def _create_trx_from_pointer(
if sub_folder not in trx.data_per_group:
trx.data_per_group[sub_folder] = {}
trx.data_per_group[sub_folder][data_name] = _create_memmap(
- filename, mode="r+", offset=mem_adress, shape=shape,
- dtype=ext[1:]
+ filename, mode="r+", offset=mem_adress, shape=shape, dtype=ext[1:]
)
elif folder == "groups":
# Groups are simply indices, nothing else
@@ -1139,13 +1190,11 @@ def _create_trx_from_pointer(
else:
shape = (int(size),)
trx.groups[base] = _create_memmap(
- filename, mode="r+", offset=mem_adress, shape=shape,
- dtype=ext[1:]
+ filename, mode="r+", offset=mem_adress, shape=shape, dtype=ext[1:]
)
else:
logging.error(
- "{} is not part of a valid structure.".format(
- elem_filename)
+ "{} is not part of a valid structure.".format(elem_filename)
)
# All essential array must be declared
@@ -1164,13 +1213,13 @@ def _create_trx_from_pointer(
trx.data_per_vertex[dpv_key]._lengths = lengths
return trx
- def resize(
+ def resize( # noqa: C901
self,
nb_streamlines: Optional[int] = None,
nb_vertices: Optional[int] = None,
delete_dpg: bool = False,
) -> None:
- """Remove the ununsed portion of preallocated memmaps
+ """Remove the unused portion of preallocated memmaps
Keyword arguments:
nb_streamlines -- The number of streamlines to keep
@@ -1185,8 +1234,7 @@ def resize(
if nb_streamlines is not None and nb_streamlines < strs_end:
strs_end = nb_streamlines
logging.info(
- "Resizing (down) memmaps, less streamlines than it "
- "actually contains."
+ "Resizing (down) memmaps, less streamlines than it actually contains."
)
if nb_vertices is None:
@@ -1207,8 +1255,7 @@ def resize(
logging.debug("TrxFile of the right size, no resizing.")
return
- trx = self._initialize_empty_trx(
- nb_streamlines, nb_vertices, init_as=self)
+ trx = self._initialize_empty_trx(nb_streamlines, nb_vertices, init_as=self)
logging.info(
"Resizing streamlines from size {} to {}".format(
@@ -1245,8 +1292,7 @@ def resize(
group_name, mode="w+", shape=(len(tmp),), dtype=group_dtype
)
logging.debug(
- "{} group went from {} items to {}".format(
- group_key, ori_len, len(tmp))
+ "{} group went from {} items to {}".format(group_key, ori_len, len(tmp))
)
trx.groups[group_key][:] = tmp
@@ -1277,8 +1323,9 @@ def resize(
dpg_filename, mode="w+", shape=shape, dtype=dpg_dtype
)
- trx.data_per_group[group_key][dpg_key][:] = \
- self.data_per_group[group_key][dpg_key]
+ trx.data_per_group[group_key][dpg_key][:] = self.data_per_group[
+ group_key
+ ][dpg_key]
self.close()
self.__dict__ = trx.__dict__
@@ -1289,23 +1336,29 @@ def get_dtype_dict(self):
Returns
A dictionary containing the dtype for each data element
"""
- dtype_dict = {"positions": self.streamlines._data.dtype,
- "offsets": self.streamlines._offsets.dtype,
- "dpv": {}, "dps": {}, "dpg": {}, "groups": {}}
+ dtype_dict = {
+ "positions": self.streamlines._data.dtype,
+ "offsets": self.streamlines._offsets.dtype,
+ "dpv": {},
+ "dps": {},
+ "dpg": {},
+ "groups": {},
+ }
for key in self.data_per_vertex.keys():
- dtype_dict['dpv'][key] = self.data_per_vertex[key]._data.dtype
+ dtype_dict["dpv"][key] = self.data_per_vertex[key]._data.dtype
for key in self.data_per_streamline.keys():
- dtype_dict['dps'][key] = self.data_per_streamline[key].dtype
+ dtype_dict["dps"][key] = self.data_per_streamline[key].dtype
for group_key in self.data_per_group.keys():
- dtype_dict['groups'][group_key] = self.groups[group_key].dtype
+ dtype_dict["groups"][group_key] = self.groups[group_key].dtype
for group_key in self.data_per_group.keys():
- dtype_dict['dpg'][group_key] = {}
+ dtype_dict["dpg"][group_key] = {}
for dpg_key in self.data_per_group[group_key].keys():
- dtype_dict['dpg'][group_key][dpg_key] = \
- self.data_per_group[group_key][dpg_key].dtype
+ dtype_dict["dpg"][group_key][dpg_key] = self.data_per_group[group_key][
+ dpg_key
+ ].dtype
return dtype_dict
@@ -1314,20 +1367,22 @@ def append(self, obj, extra_buffer: int = 0) -> None:
if dipy_available:
from dipy.io.stateful_tractogram import StatefulTractogram
- if not isinstance(obj, (TrxFile, Tractogram)) \
- and (dipy_available and not isinstance(obj, StatefulTractogram)):
+ if not isinstance(obj, (TrxFile, Tractogram)) and (
+ dipy_available and not isinstance(obj, StatefulTractogram)
+ ):
raise TypeError(
- "{} is not a supported object type for appending.".format(type(obj)))
+ "{} is not a supported object type for appending.".format(type(obj))
+ )
elif isinstance(obj, Tractogram):
- obj = self.from_tractogram(obj, reference=self.header,
- dtype_dict=curr_dtype_dict)
+ obj = self.from_tractogram(
+ obj, reference=self.header, dtype_dict=curr_dtype_dict
+ )
elif dipy_available and isinstance(obj, StatefulTractogram):
obj = self.from_sft(obj, dtype_dict=curr_dtype_dict)
self._append_trx(obj, extra_buffer=extra_buffer)
- def _append_trx(self, trx: Type["TrxFile"],
- extra_buffer: int = 0) -> None:
+ def _append_trx(self, trx: Type["TrxFile"], extra_buffer: int = 0) -> None:
"""Append a TrxFile to another (support buffer)
Keyword arguments:
@@ -1398,14 +1453,12 @@ def select(
lengths_dtype
)
new_trx.header["NB_VERTICES"] = len(new_trx.streamlines._data)
- new_trx.header["NB_STREAMLINES"] = len(
- new_trx.streamlines._lengths)
+ new_trx.header["NB_STREAMLINES"] = len(new_trx.streamlines._lengths)
return new_trx.deepcopy() if copy_safe else new_trx
new_trx.streamlines = (
- self.streamlines[indices].copy(
- ) if copy_safe else self.streamlines[indices]
+ self.streamlines[indices].copy() if copy_safe else self.streamlines[indices]
)
for dpv_key in self.data_per_vertex.keys():
new_trx.data_per_vertex[dpv_key] = (
@@ -1423,14 +1476,12 @@ def select(
# Not keeping group is equivalent to the [] operator
if keep_group:
- logging.warning(
- "Keeping dpg despite affecting the group " "items.")
+ logging.warning("Keeping dpg despite affecting the group items.")
for group_key in self.groups.keys():
# Keep the group indices even when fancy slicing
index = np.argsort(indices)
sorted_x = indices[index]
- sorted_index = np.searchsorted(
- sorted_x, self.groups[group_key])
+ sorted_index = np.searchsorted(sorted_x, self.groups[group_key])
yindex = np.take(index, sorted_index, mode="clip")
mask = indices[yindex] != self.groups[group_key]
intersect = yindex[~mask]
@@ -1443,22 +1494,22 @@ def select(
for dpg_key in self.data_per_group[group_key].keys():
if group_key not in new_trx.data_per_group:
new_trx.data_per_group[group_key] = {}
- new_trx.data_per_group[group_key][
- dpg_key
- ] = self.data_per_group[group_key][dpg_key]
+ new_trx.data_per_group[group_key][dpg_key] = (
+ self.data_per_group[group_key][dpg_key]
+ )
new_trx.header["NB_VERTICES"] = len(new_trx.streamlines._data)
new_trx.header["NB_STREAMLINES"] = len(new_trx.streamlines._lengths)
return new_trx.deepcopy() if copy_safe else new_trx
@staticmethod
- def from_lazy_tractogram(obj: ["LazyTractogram"], reference,
- extra_buffer: int = 0,
- chunk_size: int = 10000,
- dtype_dict: dict = {'positions': np.float32,
- 'offsets': np.uint32,
- 'dpv': {}, 'dps': {}}) \
- -> Type["TrxFile"]:
+ def from_lazy_tractogram(
+ obj: ["LazyTractogram"],
+ reference,
+ extra_buffer: int = 0,
+ chunk_size: int = 10000,
+ dtype_dict: dict = None,
+ ) -> Type["TrxFile"]:
"""Append a TrxFile to another (support buffer)
Keyword arguments:
@@ -1468,8 +1519,15 @@ def from_lazy_tractogram(obj: ["LazyTractogram"], reference,
Use 0 for no buffer.
chunk_size -- The number of streamlines to save at a time.
"""
-
- data = {'strs': [], 'dpv': {}, 'dps': {}}
+ if dtype_dict is None:
+ dtype_dict = {
+ "positions": np.float32,
+ "offsets": np.uint32,
+ "dpv": {},
+ "dps": {},
+ }
+
+ data = {"strs": [], "dpv": {}, "dps": {}}
concat = None
count = 0
iterator = iter(obj)
@@ -1484,60 +1542,70 @@ def from_lazy_tractogram(obj: ["LazyTractogram"], reference,
if len(obj.streamlines) == 0:
concat = TrxFile()
else:
- concat = TrxFile.from_tractogram(obj,
- reference=reference,
- dtype_dict=dtype_dict)
+ concat = TrxFile.from_tractogram(
+ obj, reference=reference, dtype_dict=dtype_dict
+ )
elif len(obj.streamlines) > 0:
- curr_obj = TrxFile.from_tractogram(obj,
- reference=reference,
- dtype_dict=dtype_dict)
+ curr_obj = TrxFile.from_tractogram(
+ obj, reference=reference, dtype_dict=dtype_dict
+ )
concat.append(curr_obj)
break
append_generator_to_dict(i, data)
else:
obj = convert_data_dict_to_tractogram(data)
if concat is None:
- concat = TrxFile.from_tractogram(obj,
- reference=reference,
- dtype_dict=dtype_dict)
+ concat = TrxFile.from_tractogram(
+ obj, reference=reference, dtype_dict=dtype_dict
+ )
else:
- curr_obj = TrxFile.from_tractogram(obj,
- reference=reference,
- dtype_dict=dtype_dict)
+ curr_obj = TrxFile.from_tractogram(
+ obj, reference=reference, dtype_dict=dtype_dict
+ )
concat.append(curr_obj, extra_buffer=extra_buffer)
- data = {'strs': [], 'dpv': {}, 'dps': {}}
+ data = {"strs": [], "dpv": {}, "dps": {}}
count = 0
concat.resize()
return concat
@staticmethod
- def from_sft(sft, dtype_dict={}):
+ def from_sft(sft, dtype_dict=None):
"""Generate a valid TrxFile from a StatefulTractogram"""
+ if dtype_dict is None:
+ dtype_dict = {}
if len(sft.dtype_dict) > 0:
dtype_dict = sft.dtype_dict
- if 'dpp' in dtype_dict:
- dtype_dict['dpv'] = dtype_dict.pop('dpp')
+ if "dpp" in dtype_dict:
+ dtype_dict["dpv"] = dtype_dict.pop("dpp")
elif len(dtype_dict) == 0:
- dtype_dict = {'positions': np.float32, 'offsets': np.uint32,
- 'dpv': {}, 'dps': {}}
+ dtype_dict = {
+ "positions": np.float32,
+ "offsets": np.uint32,
+ "dpv": {},
+ "dps": {},
+ }
- positions_dtype = dtype_dict['positions']
- offsets_dtype = dtype_dict['offsets']
+ positions_dtype = dtype_dict["positions"]
+ offsets_dtype = dtype_dict["offsets"]
if not np.issubdtype(positions_dtype, np.floating):
logging.warning(
"Casting positions as {}, considering using a floating point "
- "dtype.".format(positions_dtype))
+ "dtype.".format(positions_dtype)
+ )
if not np.issubdtype(offsets_dtype, np.integer):
logging.warning(
- "Casting offsets as {}, considering using a integer "
- "dtype.".format(offsets_dtype))
+ "Casting offsets as {}, considering using a integer dtype.".format(
+ offsets_dtype
+ )
+ )
- trx = TrxFile(nb_vertices=len(sft.streamlines._data),
- nb_streamlines=len(sft.streamlines))
+ trx = TrxFile(
+ nb_vertices=len(sft.streamlines._data), nb_streamlines=len(sft.streamlines)
+ )
trx.header = {
"DIMENSIONS": sft.dimensions.tolist(),
"VOXEL_TO_RASMM": sft.affine.tolist(),
@@ -1555,24 +1623,26 @@ def from_sft(sft, dtype_dict={}):
tmp_streamlines = deepcopy(sft.streamlines)
# Cast the int64 of Nibabel to uint32
- tmp_streamlines._offsets = tmp_streamlines._offsets.astype(
- offsets_dtype)
+ tmp_streamlines._offsets = tmp_streamlines._offsets.astype(offsets_dtype)
tmp_streamlines._data = tmp_streamlines._data.astype(positions_dtype)
trx.streamlines = tmp_streamlines
for key in sft.data_per_point:
- dtype_to_use = dtype_dict['dpv'][key] if key in dtype_dict['dpv'] \
- else np.float32
- trx.data_per_vertex[key] = \
- sft.data_per_point[key]
- trx.data_per_vertex[key]._data = \
- sft.data_per_point[key]._data.astype(dtype_to_use)
+ dtype_to_use = (
+ dtype_dict["dpv"][key] if key in dtype_dict["dpv"] else np.float32
+ )
+ trx.data_per_vertex[key] = sft.data_per_point[key]
+ trx.data_per_vertex[key]._data = sft.data_per_point[key]._data.astype(
+ dtype_to_use
+ )
for key in sft.data_per_streamline:
- dtype_to_use = dtype_dict['dps'][key] if key in dtype_dict['dps'] \
- else np.float32
+ dtype_to_use = (
+ dtype_dict["dps"][key] if key in dtype_dict["dps"] else np.float32
+ )
trx.data_per_streamline[key] = sft.data_per_streamline[key].astype(
- dtype_to_use)
+ dtype_to_use
+ )
# For safety and for RAM, convert the whole object to memmaps
tmp_dir = get_trx_tmp_dir()
@@ -1588,26 +1658,37 @@ def from_sft(sft, dtype_dict={}):
return trx
@staticmethod
- def from_tractogram(tractogram, reference,
- dtype_dict={'positions': np.float32,
- 'offsets': np.uint32,
- 'dpv': {}, 'dps': {}}):
+ def from_tractogram(
+ tractogram,
+ reference,
+ dtype_dict=None,
+ ):
"""Generate a valid TrxFile from a Nibabel Tractogram"""
-
- positions_dtype = dtype_dict['positions'] if 'positions' in dtype_dict \
- else np.float32
- offsets_dtype = dtype_dict['offsets'] if 'offsets' in dtype_dict \
- else np.uint32
+ if dtype_dict is None:
+ dtype_dict = {
+ "positions": np.float32,
+ "offsets": np.uint32,
+ "dpv": {},
+ "dps": {},
+ }
+
+ positions_dtype = (
+ dtype_dict["positions"] if "positions" in dtype_dict else np.float32
+ )
+ offsets_dtype = dtype_dict["offsets"] if "offsets" in dtype_dict else np.uint32
if not np.issubdtype(positions_dtype, np.floating):
logging.warning(
"Casting positions as {}, considering using a floating point "
- "dtype.".format(positions_dtype))
+ "dtype.".format(positions_dtype)
+ )
if not np.issubdtype(offsets_dtype, np.integer):
logging.warning(
- "Casting offsets as {}, considering using a integer "
- "dtype.".format(offsets_dtype))
+ "Casting offsets as {}, considering using a integer dtype.".format(
+ offsets_dtype
+ )
+ )
trx = TrxFile(
nb_vertices=len(tractogram.streamlines._data),
@@ -1625,24 +1706,26 @@ def from_tractogram(tractogram, reference,
tmp_streamlines = deepcopy(tractogram.streamlines)
# Cast the int64 of Nibabel to uint32
- tmp_streamlines._offsets = tmp_streamlines._offsets.astype(
- offsets_dtype)
+ tmp_streamlines._offsets = tmp_streamlines._offsets.astype(offsets_dtype)
tmp_streamlines._data = tmp_streamlines._data.astype(positions_dtype)
trx.streamlines = tmp_streamlines
for key in tractogram.data_per_point:
- dtype_to_use = dtype_dict['dpv'][key] if key in dtype_dict['dpv'] \
- else np.float32
- trx.data_per_vertex[key] = \
- tractogram.data_per_point[key]
- trx.data_per_vertex[key]._data = \
- tractogram.data_per_point[key]._data.astype(dtype_to_use)
+ dtype_to_use = (
+ dtype_dict["dpv"][key] if key in dtype_dict["dpv"] else np.float32
+ )
+ trx.data_per_vertex[key] = tractogram.data_per_point[key]
+ trx.data_per_vertex[key]._data = tractogram.data_per_point[
+ key
+ ]._data.astype(dtype_to_use)
for key in tractogram.data_per_streamline:
- dtype_to_use = dtype_dict['dps'][key] if key in dtype_dict['dps'] \
- else np.float32
- trx.data_per_streamline[key] = \
- tractogram.data_per_streamline[key].astype(dtype_to_use)
+ dtype_to_use = (
+ dtype_dict["dps"][key] if key in dtype_dict["dps"] else np.float32
+ )
+ trx.data_per_streamline[key] = tractogram.data_per_streamline[key].astype(
+ dtype_to_use
+ )
# For safety and for RAM, convert the whole object to memmaps
tmp_dir = get_trx_tmp_dir()
@@ -1687,8 +1770,7 @@ def to_memory(self, resize: bool = False) -> Type["TrxFile"]:
trx_obj.data_per_vertex[key] = deepcopy(self.data_per_vertex[key])
for key in self.data_per_streamline:
- trx_obj.data_per_streamline[key] = deepcopy(
- self.data_per_streamline[key])
+ trx_obj.data_per_streamline[key] = deepcopy(self.data_per_streamline[key])
for key in self.groups:
trx_obj.groups[key] = deepcopy(self.groups[key])
@@ -1701,10 +1783,11 @@ def to_memory(self, resize: bool = False) -> Type["TrxFile"]:
def to_sft(self, resize=False):
"""Convert a TrxFile to a valid StatefulTractogram (in RAM)"""
try:
- from dipy.io.stateful_tractogram import StatefulTractogram, Space
+ from dipy.io.stateful_tractogram import Space, StatefulTractogram
except ImportError:
- logging.error('Dipy library is missing, cannot convert to '
- 'StatefulTractogram.')
+ logging.error(
+ "Dipy library is missing, cannot convert to StatefulTractogram."
+ )
return None
affine = np.array(self.header["VOXEL_TO_RASMM"], dtype=np.float32)
@@ -1723,8 +1806,8 @@ def to_sft(self, resize=False):
data_per_streamline=deepcopy(self.data_per_streamline),
)
tmp_dict = self.get_dtype_dict()
- if 'dpv' in tmp_dict:
- tmp_dict['dpp'] = tmp_dict.pop('dpv')
+ if "dpv" in tmp_dict:
+ tmp_dict["dpp"] = tmp_dict.pop("dpv")
sft.dtype_dict = self.get_dtype_dict()
return sft
@@ -1751,7 +1834,9 @@ def close(self) -> None:
try:
self._uncompressed_folder_handle.cleanup()
except PermissionError:
- logging.error("Windows PermissionError, temporary directory {}"
- "was not deleted!".format(self._uncompressed_folder_handle.name))
+ logging.error(
+ "Windows PermissionError, temporary directory %s was not deleted!",
+ self._uncompressed_folder_handle.name,
+ )
self.__init__()
- logging.debug("Deleted memmaps and intialized empty TrxFile.")
+ logging.debug("Deleted memmaps and initialized empty TrxFile.")
diff --git a/trx/utils.py b/trx/utils.py
index b58a2a2..314b4c2 100644
--- a/trx/utils.py
+++ b/trx/utils.py
@@ -1,16 +1,16 @@
# -*- coding: utf-8 -*-
-from nibabel.streamlines.tractogram import TractogramItem
-from nibabel.streamlines.tractogram import Tractogram
-from nibabel.streamlines.array_sequence import ArraySequence
-import os
import logging
+import os
import nibabel as nib
+from nibabel.streamlines.array_sequence import ArraySequence
+from nibabel.streamlines.tractogram import Tractogram, TractogramItem
import numpy as np
try:
import dipy
+
dipy_available = True
except ImportError:
dipy_available = False
@@ -26,7 +26,7 @@ def close_or_delete_mmap(obj):
The object that potentially has a memory-mapped file to be closed.
"""
- if hasattr(obj, '_mmap') and obj._mmap is not None:
+ if hasattr(obj, "_mmap") and obj._mmap is not None:
obj._mmap.close()
elif isinstance(obj, ArraySequence):
close_or_delete_mmap(obj._data)
@@ -35,7 +35,7 @@ def close_or_delete_mmap(obj):
elif isinstance(obj, np.memmap):
del obj
else:
- logging.debug('Object to be close or deleted must be np.memmap')
+ logging.debug("Object to be close or deleted must be np.memmap")
def split_name_with_gz(filename):
@@ -66,8 +66,8 @@ def split_name_with_gz(filename):
return base, ext
-def get_reference_info_wrapper(reference):
- """ Will compare the spatial attribute of 2 references.
+def get_reference_info_wrapper(reference): # noqa: C901
+ """Will compare the spatial attribute of 2 references.
Parameters
----------
@@ -77,25 +77,26 @@ def get_reference_info_wrapper(reference):
Returns
-------
output : tuple
- - affine ndarray (4,4), np.float32, tranformation of VOX to RASMM
+ - affine ndarray (4,4), np.float32, transformation of VOX to RASMM
- dimensions ndarray (3,), int16, volume shape for each axis
- voxel_sizes ndarray (3,), float32, size of voxel for each axis
- voxel_order, string, Typically 'RAS' or 'LPS'
"""
from trx import trx_file_memmap
+
is_nifti = False
is_trk = False
is_sft = False
is_trx = False
if isinstance(reference, str):
_, ext = split_name_with_gz(reference)
- if ext in ['.nii', '.nii.gz']:
+ if ext in [".nii", ".nii.gz"]:
header = nib.load(reference).header
is_nifti = True
- elif ext == '.trk':
+ elif ext == ".trk":
header = nib.streamlines.load(reference, lazy_load=True).header
is_trk = True
- elif ext == '.trx':
+ elif ext == ".trx":
header = trx_file_memmap.load(reference).header
is_trx = True
elif isinstance(reference, trx_file_memmap.TrxFile):
@@ -110,53 +111,56 @@ def get_reference_info_wrapper(reference):
elif isinstance(reference, nib.nifti1.Nifti1Header):
header = reference
is_nifti = True
- elif isinstance(reference, dict) and 'magic_number' in reference:
+ elif isinstance(reference, dict) and "magic_number" in reference:
header = reference
is_trk = True
- elif isinstance(reference, dict) and 'NB_VERTICES' in reference:
+ elif isinstance(reference, dict) and "NB_VERTICES" in reference:
header = reference
is_trx = True
- elif dipy_available and \
- isinstance(reference, dipy.io.stateful_tractogram.StatefulTractogram):
+ elif dipy_available and isinstance(
+ reference, dipy.io.stateful_tractogram.StatefulTractogram
+ ):
is_sft = True
if is_nifti:
affine = header.get_best_affine()
- dimensions = header['dim'][1:4]
- voxel_sizes = header['pixdim'][1:4]
+ dimensions = header["dim"][1:4]
+ voxel_sizes = header["pixdim"][1:4]
if not affine[0:3, 0:3].any():
raise ValueError(
- 'Invalid affine, contains only zeros.'
- 'Cannot determine voxel order from transformation')
- voxel_order = ''.join(nib.aff2axcodes(affine))
+ "Invalid affine, contains only zeros."
+ "Cannot determine voxel order from transformation"
+ )
+ voxel_order = "".join(nib.aff2axcodes(affine))
elif is_trk:
- affine = header['voxel_to_rasmm']
- dimensions = header['dimensions']
- voxel_sizes = header['voxel_sizes']
- voxel_order = header['voxel_order']
+ affine = header["voxel_to_rasmm"]
+ dimensions = header["dimensions"]
+ voxel_sizes = header["voxel_sizes"]
+ voxel_order = header["voxel_order"]
elif is_sft:
affine, dimensions, voxel_sizes, voxel_order = reference.space_attributes
elif is_trx:
- affine = header['VOXEL_TO_RASMM']
- dimensions = header['DIMENSIONS']
+ affine = header["VOXEL_TO_RASMM"]
+ dimensions = header["DIMENSIONS"]
voxel_sizes = nib.affines.voxel_sizes(affine)
- voxel_order = ''.join(nib.aff2axcodes(affine))
+ voxel_order = "".join(nib.aff2axcodes(affine))
else:
- raise TypeError('Input reference is not one of the supported format')
+ raise TypeError("Input reference is not one of the supported format")
if isinstance(voxel_order, np.bytes_):
- voxel_order = voxel_order.decode('utf-8')
+ voxel_order = voxel_order.decode("utf-8")
if dipy_available:
from dipy.io.utils import is_reference_info_valid
+
is_reference_info_valid(affine, dimensions, voxel_sizes, voxel_order)
return affine, dimensions, voxel_sizes, voxel_order
def is_header_compatible(reference_1, reference_2):
- """ Will compare the spatial attribute of 2 references.
+ """Will compare the spatial attribute of 2 references.
Parameters
----------
@@ -173,25 +177,27 @@ def is_header_compatible(reference_1, reference_2):
"""
affine_1, dimensions_1, voxel_sizes_1, voxel_order_1 = get_reference_info_wrapper(
- reference_1)
+ reference_1
+ )
affine_2, dimensions_2, voxel_sizes_2, voxel_order_2 = get_reference_info_wrapper(
- reference_2)
+ reference_2
+ )
identical_header = True
if not np.allclose(affine_1, affine_2, rtol=1e-03, atol=1e-03):
- logging.error('Affine not equal')
+ logging.error("Affine not equal")
identical_header = False
if not np.array_equal(dimensions_1, dimensions_2):
- logging.error('Dimensions not equal')
+ logging.error("Dimensions not equal")
identical_header = False
if not np.allclose(voxel_sizes_1, voxel_sizes_2, rtol=1e-03, atol=1e-03):
- logging.error('Voxel_size not equal')
+ logging.error("Voxel_size not equal")
identical_header = False
if voxel_order_1 != voxel_order_2:
- logging.error('Voxel_order not equal')
+ logging.error("Voxel_order not equal")
identical_header = False
return identical_header
@@ -211,11 +217,11 @@ def get_axis_shift_vector(flip_axes):
Possible values are -1, 1
"""
shift_vector = np.zeros(3)
- if 'x' in flip_axes:
+ if "x" in flip_axes:
shift_vector[0] = -1.0
- if 'y' in flip_axes:
+ if "y" in flip_axes:
shift_vector[1] = -1.0
- if 'z' in flip_axes:
+ if "z" in flip_axes:
shift_vector[2] = -1.0
return shift_vector
@@ -235,11 +241,11 @@ def get_axis_flip_vector(flip_axes):
Possible values are -1, 1
"""
flip_vector = np.ones(3)
- if 'x' in flip_axes:
+ if "x" in flip_axes:
flip_vector[0] = -1.0
- if 'y' in flip_axes:
+ if "y" in flip_axes:
flip_vector[1] = -1.0
- if 'z' in flip_axes:
+ if "z" in flip_axes:
flip_vector[2] = -1.0
return flip_vector
@@ -266,7 +272,7 @@ def get_shift_vector(sft):
def flip_sft(sft, flip_axes):
- """ Flip the streamlines in the StatefulTractogram according to the
+ """Flip the streamlines in the StatefulTractogram according to the
flip_axes. Uses the spatial information to flip according to the center
of the grid.
@@ -283,8 +289,10 @@ def flip_sft(sft, flip_axes):
StatefulTractogram with flipped axes
"""
if not dipy_available:
- logging.error('Dipy library is missing, cannot use functions related '
- 'to the StatefulTractogram.')
+ logging.error(
+ "Dipy library is missing, cannot use functions related "
+ "to the StatefulTractogram."
+ )
return None
flip_vector = get_axis_flip_vector(flip_axes)
@@ -298,14 +306,18 @@ def flip_sft(sft, flip_axes):
flipped_streamlines.append(mod_streamline)
from dipy.io.stateful_tractogram import StatefulTractogram
- new_sft = StatefulTractogram.from_sft(flipped_streamlines, sft,
- data_per_point=sft.data_per_point,
- data_per_streamline=sft.data_per_streamline)
+
+ new_sft = StatefulTractogram.from_sft(
+ flipped_streamlines,
+ sft,
+ data_per_point=sft.data_per_point,
+ data_per_streamline=sft.data_per_streamline,
+ )
return new_sft
def load_matrix_in_any_format(filepath):
- """ Load a matrix from a txt file OR a npy file.
+ """Load a matrix from a txt file OR a npy file.
Parameters
----------
@@ -317,18 +329,18 @@ def load_matrix_in_any_format(filepath):
The matrix.
"""
_, ext = os.path.splitext(filepath)
- if ext == '.txt':
+ if ext == ".txt":
data = np.loadtxt(filepath)
- elif ext == '.npy':
+ elif ext == ".npy":
data = np.load(filepath)
else:
- raise ValueError('Extension {} is not supported'.format(ext))
+ raise ValueError("Extension {} is not supported".format(ext))
return data
def get_reverse_enum(space_str, origin_str):
- """ Convert string representation to enums for the StatefulTractogram.
+ """Convert string representation to enums for the StatefulTractogram.
Parameters
----------
@@ -342,14 +354,17 @@ def get_reverse_enum(space_str, origin_str):
Space and Origin as Enums.
"""
if not dipy_available:
- logging.error('Dipy library is missing, cannot use functions related '
- 'to the StatefulTractogram.')
+ logging.error(
+ "Dipy library is missing, cannot use functions related "
+ "to the StatefulTractogram."
+ )
return None
- from dipy.io.stateful_tractogram import Space, Origin
- origin = Origin.NIFTI if origin_str.lower() == 'nifti' else Origin.TRACKVIS
- if space_str.lower() == 'rasmm':
+ from dipy.io.stateful_tractogram import Origin, Space
+
+ origin = Origin.NIFTI if origin_str.lower() == "nifti" else Origin.TRACKVIS
+ if space_str.lower() == "rasmm":
space = Space.RASMM
- elif space_str.lower() == 'voxmm':
+ elif space_str.lower() == "voxmm":
space = Space.VOXMM
else:
space = Space.VOX
@@ -358,7 +373,7 @@ def get_reverse_enum(space_str, origin_str):
def convert_data_dict_to_tractogram(data):
- """ Convert a data from a lazy tractogram to a tractogram
+ """Convert a data from a lazy tractogram to a tractogram
Keyword arguments:
data -- The data dictionary to convert into a nibabel tractogram
@@ -366,49 +381,50 @@ def convert_data_dict_to_tractogram(data):
Returns:
A Tractogram object
"""
- streamlines = ArraySequence(data['strs'])
+ streamlines = ArraySequence(data["strs"])
streamlines._data = streamlines._data
- for key in data['dps']:
- shape = (len(streamlines), len(data['dps'][key]) // len(streamlines))
- data['dps'][key] = np.array(data['dps'][key]).reshape(shape)
+ for key in data["dps"]:
+ shape = (len(streamlines), len(data["dps"][key]) // len(streamlines))
+ data["dps"][key] = np.array(data["dps"][key]).reshape(shape)
- for key in data['dpv']:
- shape = (len(streamlines._data), len(
- data['dpv'][key]) // len(streamlines._data))
- data['dpv'][key] = np.array(data['dpv'][key]).reshape(shape)
+ for key in data["dpv"]:
+ shape = (
+ len(streamlines._data),
+ len(data["dpv"][key]) // len(streamlines._data),
+ )
+ data["dpv"][key] = np.array(data["dpv"][key]).reshape(shape)
tmp_arr = ArraySequence()
- tmp_arr._data = data['dpv'][key]
+ tmp_arr._data = data["dpv"][key]
tmp_arr._offsets = streamlines._offsets
tmp_arr._lengths = streamlines._lengths
- data['dpv'][key] = tmp_arr
+ data["dpv"][key] = tmp_arr
- obj = Tractogram(streamlines, data_per_point=data['dpv'],
- data_per_streamline=data['dps'])
+ obj = Tractogram(
+ streamlines, data_per_point=data["dpv"], data_per_streamline=data["dps"]
+ )
return obj
def append_generator_to_dict(gen, data):
if isinstance(gen, TractogramItem):
- data['strs'].append(gen.streamline.tolist())
+ data["strs"].append(gen.streamline.tolist())
for key in gen.data_for_points:
- if key not in data['dpv']:
- data['dpv'][key] = np.array([])
- data['dpv'][key] = np.append(
- data['dpv'][key], gen.data_for_points[key])
+ if key not in data["dpv"]:
+ data["dpv"][key] = np.array([])
+ data["dpv"][key] = np.append(data["dpv"][key], gen.data_for_points[key])
for key in gen.data_for_streamline:
- if key not in data['dps']:
- data['dps'][key] = np.array([])
- data['dps'][key] = np.append(
- data['dps'][key], gen.data_for_streamline[key])
+ if key not in data["dps"]:
+ data["dps"][key] = np.array([])
+ data["dps"][key] = np.append(data["dps"][key], gen.data_for_streamline[key])
else:
- data['strs'].append(gen.tolist())
+ data["strs"].append(gen.tolist())
-def verify_trx_dtype(trx, dict_dtype):
- """ Verify if the dtype of the data in the trx is the same as the one in
+def verify_trx_dtype(trx, dict_dtype): # noqa: C901
+ """Verify if the dtype of the data in the trx is the same as the one in
the dict.
Parameters
@@ -424,38 +440,48 @@ def verify_trx_dtype(trx, dict_dtype):
"""
identical = True
for key in dict_dtype:
- if key == 'positions':
+ if key == "positions":
if trx.streamlines._data.dtype != dict_dtype[key]:
- logging.warning('Positions dtype is different')
+ logging.warning("Positions dtype is different")
identical = False
- elif key == 'offsets':
+ elif key == "offsets":
if trx.streamlines._offsets.dtype != dict_dtype[key]:
- logging.warning('Offsets dtype is different')
+ logging.warning("Offsets dtype is different")
identical = False
- elif key == 'dpv':
+ elif key == "dpv":
for key_dpv in dict_dtype[key]:
if trx.data_per_vertex[key_dpv]._data.dtype != dict_dtype[key][key_dpv]:
logging.warning(
- 'Data per vertex ({}) dtype is different'.format(key_dpv))
+ "Data per vertex ({}) dtype is different".format(key_dpv)
+ )
identical = False
- elif key == 'dps':
+ elif key == "dps":
for key_dps in dict_dtype[key]:
if trx.data_per_streamline[key_dps].dtype != dict_dtype[key][key_dps]:
logging.warning(
- 'Data per streamline ({}) dtype is different'.format(key_dps))
+ "Data per streamline ({}) dtype is different".format(key_dps)
+ )
identical = False
- elif key == 'dpg':
+ elif key == "dpg":
for key_group in dict_dtype[key]:
for key_dpg in dict_dtype[key][key_group]:
- if trx.data_per_point[key_group][key_dpg].dtype != dict_dtype[key][key_group][key_dpg]:
+ if (
+ trx.data_per_point[key_group][key_dpg].dtype
+ != dict_dtype[key][key_group][key_dpg]
+ ):
logging.warning(
- 'Data per group ({}) dtype is different'.format(key_dpg))
+ "Data per group ({}) dtype is different".format(key_dpg)
+ )
identical = False
- elif key == 'groups':
+ elif key == "groups":
for key_group in dict_dtype[key]:
- if trx.data_per_point[key_group]._data.dtype != dict_dtype[key][key_group]:
+ if (
+ trx.data_per_point[key_group]._data.dtype
+ != dict_dtype[key][key_group]
+ ):
logging.warning(
- 'Data per group ({}) dtype is different'.format(key_group))
+ "Data per group ({}) dtype is different".format(key_group)
+ )
identical = False
return identical
diff --git a/trx/viz.py b/trx/viz.py
index d5afb9b..c995a84 100644
--- a/trx/viz.py
+++ b/trx/viz.py
@@ -4,36 +4,52 @@
import logging
import numpy as np
+
try:
- from dipy.viz import window, actor, colormap
- from fury.utils import get_bounds
+ from dipy.viz import actor, colormap, window
import fury.utils as ut_vtk
+ from fury.utils import get_bounds
import vtk
+
fury_available = True
except ImportError:
fury_available = False
-def display(volume, volume_affine=None, streamlines=None, title='FURY',
- display_bounds=True):
+def display(
+ volume, volume_affine=None, streamlines=None, title="FURY", display_bounds=True
+):
if not fury_available:
- logging.error('Fury library is missing, visualization functions '
- 'are not available.')
+ logging.error(
+ "Fury library is missing, visualization functions are not available."
+ )
return None
volume = volume.astype(float)
scene = window.Scene()
- scene.background((1., 0.5, 0.))
+ scene.background((1.0, 0.5, 0.0))
# Show the X/Y/Z plane intersecting, mid-slices
- slicer_actor_1 = actor.slicer(volume, affine=volume_affine,
- value_range=(volume.min(), volume.max()),
- interpolation='nearest', opacity=0.8)
- slicer_actor_2 = actor.slicer(volume, affine=volume_affine,
- value_range=(volume.min(), volume.max()),
- interpolation='nearest', opacity=0.8)
- slicer_actor_3 = actor.slicer(volume, affine=volume_affine,
- value_range=(volume.min(), volume.max()),
- interpolation='nearest', opacity=0.8)
+ slicer_actor_1 = actor.slicer(
+ volume,
+ affine=volume_affine,
+ value_range=(volume.min(), volume.max()),
+ interpolation="nearest",
+ opacity=0.8,
+ )
+ slicer_actor_2 = actor.slicer(
+ volume,
+ affine=volume_affine,
+ value_range=(volume.min(), volume.max()),
+ interpolation="nearest",
+ opacity=0.8,
+ )
+ slicer_actor_3 = actor.slicer(
+ volume,
+ affine=volume_affine,
+ value_range=(volume.min(), volume.max()),
+ interpolation="nearest",
+ opacity=0.8,
+ )
slicer_actor_1.display(y=volume.shape[1] // 2)
slicer_actor_2.display(x=volume.shape[0] // 2)
slicer_actor_3.display(z=volume.shape[2] // 2)
@@ -55,33 +71,40 @@ def display(volume, volume_affine=None, streamlines=None, title='FURY',
# Show each corner's coordinates
corners = itertools.product(bounds[0:2], bounds[2:4], bounds[4:6])
for corner in corners:
- text_actor = actor.text_3d('{}, {}, {}'.format(
- *corner), corner, font_size=6, justification='center')
+ text_actor = actor.text_3d(
+ "{}, {}, {}".format(*corner),
+ corner,
+ font_size=6,
+ justification="center",
+ )
scene.add(text_actor)
# Show the X/Y/Z dimensions
- text_actor_x = actor.text_3d('{}'.format(np.abs(bounds[0]-bounds[1])),
- ((bounds[0]+bounds[1])/2,
- bounds[2],
- bounds[4]),
- font_size=10, justification='center')
- text_actor_y = actor.text_3d('{}'.format(np.abs(bounds[2]-bounds[3])),
- (bounds[0],
- (bounds[2]+bounds[3])/2,
- bounds[4]),
- font_size=10, justification='center')
- text_actor_z = actor.text_3d('{}'.format(np.abs(bounds[4]-bounds[5])),
- (bounds[0],
- bounds[2],
- (bounds[4]+bounds[5])/2),
- font_size=10, justification='center')
+ text_actor_x = actor.text_3d(
+ "{}".format(np.abs(bounds[0] - bounds[1])),
+ ((bounds[0] + bounds[1]) / 2, bounds[2], bounds[4]),
+ font_size=10,
+ justification="center",
+ )
+ text_actor_y = actor.text_3d(
+ "{}".format(np.abs(bounds[2] - bounds[3])),
+ (bounds[0], (bounds[2] + bounds[3]) / 2, bounds[4]),
+ font_size=10,
+ justification="center",
+ )
+ text_actor_z = actor.text_3d(
+ "{}".format(np.abs(bounds[4] - bounds[5])),
+ (bounds[0], bounds[2], (bounds[4] + bounds[5]) / 2),
+ font_size=10,
+ justification="center",
+ )
scene.add(text_actor_x)
scene.add(text_actor_y)
scene.add(text_actor_z)
if streamlines is not None:
- streamlines_actor = actor.line(streamlines,
- colormap.line_colors(streamlines),
- opacity=0.25)
+ streamlines_actor = actor.line(
+ streamlines, colormap.line_colors(streamlines), opacity=0.25
+ )
scene.add(streamlines_actor)
window.show(scene, title=title, size=(800, 800))
diff --git a/trx/workflows.py b/trx/workflows.py
index eaa85c5..6eb7921 100644
--- a/trx/workflows.py
+++ b/trx/workflows.py
@@ -11,54 +11,66 @@
import nibabel as nib
from nibabel.streamlines.array_sequence import ArraySequence
import numpy as np
+
try:
- import dipy
+ import dipy # noqa: F401
+
dipy_available = True
except ImportError:
dipy_available = False
from trx.io import get_trx_tmp_dir, load, load_sft_with_reference, save
-from trx.streamlines_ops import perform_streamlines_operation, intersection
+from trx.streamlines_ops import intersection, perform_streamlines_operation
import trx.trx_file_memmap as tmm
+from trx.utils import (
+ flip_sft,
+ get_axis_shift_vector,
+ get_reference_info_wrapper,
+ get_reverse_enum,
+ is_header_compatible,
+ load_matrix_in_any_format,
+ split_name_with_gz,
+)
from trx.viz import display
-from trx.utils import (flip_sft, is_header_compatible,
- get_axis_shift_vector,
- get_reference_info_wrapper,
- get_reverse_enum,
- load_matrix_in_any_format,
- split_name_with_gz)
-def convert_dsi_studio(in_dsi_tractogram, in_dsi_fa, out_tractogram,
- remove_invalid=True, keep_invalid=False):
+def convert_dsi_studio(
+ in_dsi_tractogram,
+ in_dsi_fa,
+ out_tractogram,
+ remove_invalid=True,
+ keep_invalid=False,
+):
if not dipy_available:
- logging.error('Dipy library is missing, scripts are not available.')
+ logging.error("Dipy library is missing, scripts are not available.")
return None
- from dipy.io.stateful_tractogram import StatefulTractogram, Space
- from dipy.io.streamline import save_tractogram, load_tractogram
+ from dipy.io.stateful_tractogram import Space, StatefulTractogram
+ from dipy.io.streamline import load_tractogram, save_tractogram
in_ext = split_name_with_gz(in_dsi_tractogram)[1]
out_ext = split_name_with_gz(out_tractogram)[1]
- if in_ext == '.trk.gz':
- with gzip.open(in_dsi_tractogram, 'rb') as f_in:
- with open('tmp.trk', 'wb') as f_out:
+ if in_ext == ".trk.gz":
+ with gzip.open(in_dsi_tractogram, "rb") as f_in:
+ with open("tmp.trk", "wb") as f_out:
f_out.writelines(f_in)
- sft = load_tractogram('tmp.trk', 'same',
- bbox_valid_check=False)
- os.remove('tmp.trk')
- elif in_ext == '.trk':
- sft = load_tractogram(in_dsi_tractogram, 'same',
- bbox_valid_check=False)
+ sft = load_tractogram("tmp.trk", "same", bbox_valid_check=False)
+ os.remove("tmp.trk")
+ elif in_ext == ".trk":
+ sft = load_tractogram(in_dsi_tractogram, "same", bbox_valid_check=False)
else:
- raise IOError('{} is not currently supported.'.format(in_ext))
+ raise IOError("{} is not currently supported.".format(in_ext))
sft.to_vox()
- sft_fix = StatefulTractogram(sft.streamlines, in_dsi_fa, Space.VOXMM,
- data_per_point=sft.data_per_point,
- data_per_streamline=sft.data_per_streamline)
+ sft_fix = StatefulTractogram(
+ sft.streamlines,
+ in_dsi_fa,
+ Space.VOXMM,
+ data_per_point=sft.data_per_point,
+ data_per_streamline=sft.data_per_streamline,
+ )
sft_fix.to_vox()
- flip_axis = ['x', 'y']
+ flip_axis = ["x", "y"]
sft_fix.streamlines._data -= get_axis_shift_vector(flip_axis)
sft_flip = flip_sft(sft_fix, flip_axis)
@@ -68,18 +80,22 @@ def convert_dsi_studio(in_dsi_tractogram, in_dsi_fa, out_tractogram,
if remove_invalid:
sft_flip.remove_invalid_streamlines()
- if out_ext != '.trx':
- save_tractogram(sft_flip, out_tractogram,
- bbox_valid_check=not keep_invalid)
+ if out_ext != ".trx":
+ save_tractogram(sft_flip, out_tractogram, bbox_valid_check=not keep_invalid)
else:
trx = tmm.TrxFile.from_sft(sft_flip)
tmm.save(trx, out_tractogram)
-def convert_tractogram(in_tractogram, out_tractogram, reference,
- pos_dtype='float32', offsets_dtype='uint32'):
+def convert_tractogram( # noqa: C901
+ in_tractogram,
+ out_tractogram,
+ reference,
+ pos_dtype="float32",
+ offsets_dtype="uint32",
+):
if not dipy_available:
- logging.error('Dipy library is missing, scripts are not available.')
+ logging.error("Dipy library is missing, scripts are not available.")
return None
from dipy.io.streamline import save_tractogram
@@ -87,40 +103,39 @@ def convert_tractogram(in_tractogram, out_tractogram, reference,
out_ext = split_name_with_gz(out_tractogram)[1]
if in_ext == out_ext:
- raise IOError('Input and output cannot be of the same file format.')
+ raise IOError("Input and output cannot be of the same file format.")
- if in_ext != '.trx':
- sft = load_sft_with_reference(in_tractogram, reference,
- bbox_check=False)
+ if in_ext != ".trx":
+ sft = load_sft_with_reference(in_tractogram, reference, bbox_check=False)
else:
trx = tmm.load(in_tractogram)
sft = trx.to_sft()
trx.close()
- if out_ext != '.trx':
- if out_ext == '.vtk':
+ if out_ext != ".trx":
+ if out_ext == ".vtk":
if sft.streamlines._data.dtype.name != pos_dtype:
sft.streamlines._data = sft.streamlines._data.astype(pos_dtype)
- if offsets_dtype == 'uint64' or offsets_dtype == 'uint32':
+ if offsets_dtype == "uint64" or offsets_dtype == "uint32":
offsets_dtype = offsets_dtype[1:]
if sft.streamlines._offsets.dtype.name != offsets_dtype:
sft.streamlines._offsets = sft.streamlines._offsets.astype(
- offsets_dtype)
+ offsets_dtype
+ )
save_tractogram(sft, out_tractogram, bbox_valid_check=False)
else:
trx = tmm.TrxFile.from_sft(sft)
if trx.streamlines._data.dtype.name != pos_dtype:
trx.streamlines._data = trx.streamlines._data.astype(pos_dtype)
if trx.streamlines._offsets.dtype.name != offsets_dtype:
- trx.streamlines._offsets = trx.streamlines._offsets.astype(
- offsets_dtype)
+ trx.streamlines._offsets = trx.streamlines._offsets.astype(offsets_dtype)
tmm.save(trx, out_tractogram)
trx.close()
def tractogram_simple_compare(in_tractograms, reference):
if not dipy_available:
- logging.error('Dipy library is missing, scripts are not available.')
+ logging.error("Dipy library is missing, scripts are not available.")
return
from dipy.io.stateful_tractogram import StatefulTractogram
@@ -138,56 +153,62 @@ def tractogram_simple_compare(in_tractograms, reference):
else:
sft_2 = tractogram_obj
- if np.allclose(sft_1.streamlines._data, sft_2.streamlines._data,
- atol=0.001):
- print('Matching tractograms in rasmm!')
+ if np.allclose(sft_1.streamlines._data, sft_2.streamlines._data, atol=0.001):
+ print("Matching tractograms in rasmm!")
else:
- print('Average difference in rasmm of {}'.format(np.average(
- sft_1.streamlines._data - sft_2.streamlines._data, axis=0)))
+ print(
+ "Average difference in rasmm of {}".format(
+ np.average(sft_1.streamlines._data - sft_2.streamlines._data, axis=0)
+ )
+ )
sft_1.to_voxmm()
sft_2.to_voxmm()
- if np.allclose(sft_1.streamlines._data, sft_2.streamlines._data,
- atol=0.001):
- print('Matching tractograms in voxmm!')
+ if np.allclose(sft_1.streamlines._data, sft_2.streamlines._data, atol=0.001):
+ print("Matching tractograms in voxmm!")
else:
- print('Average difference in voxmm of {}'.format(np.average(
- sft_1.streamlines._data - sft_2.streamlines._data, axis=0)))
+ print(
+ "Average difference in voxmm of {}".format(
+ np.average(sft_1.streamlines._data - sft_2.streamlines._data, axis=0)
+ )
+ )
sft_1.to_vox()
sft_2.to_vox()
- if np.allclose(sft_1.streamlines._data, sft_2.streamlines._data,
- atol=0.001):
- print('Matching tractograms in vox!')
+ if np.allclose(sft_1.streamlines._data, sft_2.streamlines._data, atol=0.001):
+ print("Matching tractograms in vox!")
else:
- print('Average difference in vox of {}'.format(np.average(
- sft_1.streamlines._data - sft_2.streamlines._data, axis=0)))
+ print(
+ "Average difference in vox of {}".format(
+ np.average(sft_1.streamlines._data - sft_2.streamlines._data, axis=0)
+ )
+ )
def verify_header_compatibility(in_files):
if not dipy_available:
- logging.error('Dipy library is missing, scripts are not available.')
+ logging.error("Dipy library is missing, scripts are not available.")
return
all_valid = True
for filepath in in_files:
if not os.path.isfile(filepath):
- print('{} does not exist'.format(filepath))
+ print("{} does not exist".format(filepath))
_, in_extension = split_name_with_gz(filepath)
- if in_extension not in ['.trk', '.nii', '.nii.gz', '.trx']:
- raise IOError('{} does not have a supported extension'.format(
- filepath))
+ if in_extension not in [".trk", ".nii", ".nii.gz", ".trx"]:
+ raise IOError("{} does not have a supported extension".format(filepath))
if not is_header_compatible(in_files[0], filepath):
- print('{} and {} do not have compatible header.'.format(
- in_files[0], filepath))
+ print(
+ "{} and {} do not have compatible header.".format(in_files[0], filepath)
+ )
all_valid = False
if all_valid:
- print('All input files have compatible headers.')
+ print("All input files have compatible headers.")
def tractogram_visualize_overlap(in_tractogram, reference, remove_invalid=True):
if not dipy_available:
- logging.error('Dipy library is missing, scripts are not available.')
+ logging.error("Dipy library is missing, scripts are not available.")
return None
from dipy.io.stateful_tractogram import StatefulTractogram
from dipy.tracking.streamline import set_number_of_points
@@ -210,8 +231,12 @@ def tractogram_visualize_overlap(in_tractogram, reference, remove_invalid=True):
# Approach (1)
density_1 = density_map(sft.streamlines, sft.affine, sft.dimensions)
img = nib.load(reference)
- display(img.get_fdata(), volume_affine=img.affine,
- streamlines=sft.streamlines, title='RASMM')
+ display(
+ img.get_fdata(),
+ volume_affine=img.affine,
+ streamlines=sft.streamlines,
+ title="RASMM",
+ )
# Approach (2)
sft.to_vox()
@@ -219,25 +244,36 @@ def tractogram_visualize_overlap(in_tractogram, reference, remove_invalid=True):
# Small difference due to casting of the affine as float32 or float64
diff = density_1 - density_2
- print('Total difference of {} voxels with total value of {}'.format(
- np.count_nonzero(diff), np.sum(np.abs(diff))))
+ print(
+ "Total difference of {} voxels with total value of {}".format(
+ np.count_nonzero(diff), np.sum(np.abs(diff))
+ )
+ )
- display(img.get_fdata(), streamlines=sft.streamlines, title='VOX')
+ display(img.get_fdata(), streamlines=sft.streamlines, title="VOX")
# Try VOXMM
sft.to_voxmm()
affine = np.eye(4)
affine[0:3, 0:3] *= sft.voxel_sizes
- display(img.get_fdata(), volume_affine=affine,
- streamlines=sft.streamlines, title='VOXMM')
-
-
-def validate_tractogram(in_tractogram, reference, out_tractogram,
- remove_identical_streamlines=True, precision=1):
-
+ display(
+ img.get_fdata(),
+ volume_affine=affine,
+ streamlines=sft.streamlines,
+ title="VOXMM",
+ )
+
+
+def validate_tractogram(
+ in_tractogram,
+ reference,
+ out_tractogram,
+ remove_identical_streamlines=True,
+ precision=1,
+):
if not dipy_available:
- logging.error('Dipy library is missing, scripts are not available.')
+ logging.error("Dipy library is missing, scripts are not available.")
return None
from dipy.io.stateful_tractogram import StatefulTractogram
@@ -255,34 +291,42 @@ def validate_tractogram(in_tractogram, reference, out_tractogram,
invalid_coord_ind, _ = sft.remove_invalid_streamlines()
tot_remove += len(invalid_coord_ind)
- logging.warning('Removed {} streamlines with invalid coordinates.'.format(
- len(invalid_coord_ind)))
+ logging.warning(
+ "Removed {} streamlines with invalid coordinates.".format(
+ len(invalid_coord_ind)
+ )
+ )
indices = [i for i in range(len(sft)) if len(sft.streamlines[i]) <= 1]
- tot_remove = + len(indices)
- logging.warning('Removed {} invalid streamlines (1 or 0 points).'.format(
- len(indices)))
+ tot_remove = +len(indices)
+ logging.warning(
+ "Removed {} invalid streamlines (1 or 0 points).".format(len(indices))
+ )
for i in np.setdiff1d(range(len(sft)), indices):
- norm = np.linalg.norm(np.diff(sft.streamlines[i],
- axis=0), axis=1)
+ norm = np.linalg.norm(np.diff(sft.streamlines[i], axis=0), axis=1)
if (norm < 0.001).any():
indices.append(i)
indices_val = np.setdiff1d(range(len(sft)), indices).astype(np.uint32)
- logging.warning('Removed {} invalid streamlines (overlapping points).'.format(
- ori_len - len(indices_val)))
+ logging.warning(
+ "Removed {} invalid streamlines (overlapping points).".format(
+ ori_len - len(indices_val)
+ )
+ )
tot_remove += ori_len - len(indices_val)
if remove_identical_streamlines:
- _, indices_uniq = perform_streamlines_operation(intersection,
- [sft.streamlines],
- precision=precision)
- indices_final = np.intersect1d(
- indices_val, indices_uniq).astype(np.uint32)
- logging.warning('Removed {} overlapping streamlines.'.format(
- ori_len - len(indices_final) - tot_remove))
+ _, indices_uniq = perform_streamlines_operation(
+ intersection, [sft.streamlines], precision=precision
+ )
+ indices_final = np.intersect1d(indices_val, indices_uniq).astype(np.uint32)
+ logging.warning(
+ "Removed {} overlapping streamlines.".format(
+ ori_len - len(indices_final) - tot_remove
+ )
+ )
indices_final = np.intersect1d(indices_val, indices_uniq)
else:
@@ -297,197 +341,261 @@ def validate_tractogram(in_tractogram, reference, out_tractogram,
dps = {}
for key in sft.data_per_streamline.keys():
dps[key] = sft.data_per_streamline[key][indices_final]
- new_sft = StatefulTractogram.from_sft(streamlines, sft,
- data_per_point=dpp,
- data_per_streamline=dps)
+ new_sft = StatefulTractogram.from_sft(
+ streamlines, sft, data_per_point=dpp, data_per_streamline=dps
+ )
new_sft.dtype_dict = ori_dtype
save(new_sft, out_tractogram)
-def generate_trx_from_scratch(reference, out_tractogram, positions_csv=False,
- positions=False, offsets=False,
- positions_dtype='float32', offsets_dtype='uint64',
- space_str='rasmm', origin_str='nifti',
- verify_invalid=True, dpv=[], dps=[],
- groups=[], dpg=[]):
+def _load_streamlines_from_csv(positions_csv):
+ """Load streamlines from CSV file."""
+ with open(positions_csv, newline="") as f:
+ reader = csv.reader(f)
+ data = list(reader)
+ data = [np.reshape(i, (len(i) // 3, 3)).astype(float) for i in data]
+ return ArraySequence(data)
+
+
+def _load_streamlines_from_arrays(positions, offsets):
+ """Load streamlines from position and offset arrays."""
+ positions = load_matrix_in_any_format(positions)
+ offsets = load_matrix_in_any_format(offsets)
+ lengths = tmm._compute_lengths(offsets)
+ streamlines = ArraySequence()
+ streamlines._data = positions
+ streamlines._offsets = deepcopy(offsets)
+ streamlines._lengths = lengths
+ return streamlines, offsets
+
+
+def _apply_spatial_transforms(
+ streamlines, reference, space_str, origin_str, verify_invalid, offsets
+):
+ """Apply spatial transforms and verify streamlines."""
+ if not dipy_available:
+ logging.error(
+ "Dipy library is missing, advanced options "
+ "related to spatial transforms and invalid "
+ "streamlines are not available."
+ )
+ return None
+
+ from dipy.io.stateful_tractogram import StatefulTractogram
+
+ space, origin = get_reverse_enum(space_str, origin_str)
+ sft = StatefulTractogram(streamlines, reference, space, origin)
+ if verify_invalid:
+ rem, _ = sft.remove_invalid_streamlines()
+ print(
+ "{} streamlines were removed becaused they were invalid.".format(len(rem))
+ )
+ sft.to_rasmm()
+ sft.to_center()
+ streamlines = sft.streamlines
+ streamlines._offsets = offsets
+ return streamlines
+
+
+def _write_header(tmp_dir_name, reference, streamlines):
+ """Write header file."""
+ affine, dimensions, _, _ = get_reference_info_wrapper(reference)
+ header = {
+ "DIMENSIONS": dimensions.tolist(),
+ "VOXEL_TO_RASMM": affine.tolist(),
+ "NB_VERTICES": len(streamlines._data),
+ "NB_STREAMLINES": len(streamlines) - 1,
+ }
+
+ if header["NB_STREAMLINES"] <= 1:
+ raise IOError("To use this script, you need at least 2streamlines.")
+
+ with open(os.path.join(tmp_dir_name, "header.json"), "w") as out_json:
+ json.dump(header, out_json)
+
+
+def _write_streamline_data(tmp_dir_name, streamlines, positions_dtype, offsets_dtype):
+ """Write streamline position and offset data."""
+ curr_filename = os.path.join(tmp_dir_name, "positions.3.{}".format(positions_dtype))
+ streamlines._data.astype(positions_dtype).tofile(curr_filename)
+
+ curr_filename = os.path.join(tmp_dir_name, "offsets.{}".format(offsets_dtype))
+ streamlines._offsets.astype(offsets_dtype).tofile(curr_filename)
+
+
+def _normalize_dtype(dtype_str):
+ """Normalize dtype string format."""
+ return "bit" if dtype_str == "bool" else dtype_str
+
+
+def _write_data_array(tmp_dir_name, subdir_name, args, is_dpg=False):
+ """Write data array to file."""
+ if is_dpg:
+ os.makedirs(os.path.join(tmp_dir_name, "dpg", args[0]), exist_ok=True)
+ curr_arr = load_matrix_in_any_format(args[1]).astype(args[2])
+ basename = os.path.basename(os.path.splitext(args[1])[0])
+ dtype_str = _normalize_dtype(args[1]) if args[1] != "bool" else "bit"
+ dtype = args[2]
+ else:
+ os.makedirs(os.path.join(tmp_dir_name, subdir_name), exist_ok=True)
+ curr_arr = np.squeeze(load_matrix_in_any_format(args[0]).astype(args[1]))
+ basename = os.path.basename(os.path.splitext(args[0])[0])
+ dtype_str = _normalize_dtype(args[1])
+ dtype = dtype_str
+
+ if curr_arr.ndim > 2:
+ raise IOError("Maximum of 2 dimensions for dpv/dps/dpg.")
+
+ if curr_arr.shape == (1, 1):
+ curr_arr = curr_arr.reshape((1,))
+
+ dim = "" if curr_arr.ndim == 1 else "{}.".format(curr_arr.shape[-1])
+
+ if is_dpg:
+ curr_filename = os.path.join(
+ tmp_dir_name, "dpg", args[0], "{}.{}{}".format(basename, dim, dtype)
+ )
+ else:
+ curr_filename = os.path.join(
+ tmp_dir_name, subdir_name, "{}.{}{}".format(basename, dim, dtype)
+ )
+
+ curr_arr.tofile(curr_filename)
+
+
+def generate_trx_from_scratch( # noqa: C901
+ reference,
+ out_tractogram,
+ positions_csv=False,
+ positions=False,
+ offsets=False,
+ positions_dtype="float32",
+ offsets_dtype="uint64",
+ space_str="rasmm",
+ origin_str="nifti",
+ verify_invalid=True,
+ dpv=None,
+ dps=None,
+ groups=None,
+ dpg=None,
+):
+ """Generate TRX file from scratch using various input formats."""
+ if dpv is None:
+ dpv = []
+ if dps is None:
+ dps = []
+ if groups is None:
+ groups = []
+ if dpg is None:
+ dpg = []
+
with get_trx_tmp_dir() as tmp_dir_name:
if positions_csv:
- with open(positions_csv, newline='') as f:
- reader = csv.reader(f)
- data = list(reader)
- data = [np.reshape(i, (len(i) // 3, 3)).astype(float)
- for i in data]
- streamlines = ArraySequence(data)
+ streamlines = _load_streamlines_from_csv(positions_csv)
+ offsets = None
else:
- positions = load_matrix_in_any_format(positions)
- offsets = load_matrix_in_any_format(offsets)
- lengths = tmm._compute_lengths(offsets)
- streamlines = ArraySequence()
- streamlines._data = positions
- streamlines._offsets = deepcopy(offsets)
- streamlines._lengths = lengths
-
- if space_str.lower() != 'rasmm' or origin_str.lower() != 'nifti' or \
- verify_invalid:
- if not dipy_available:
- logging.error('Dipy library is missing, advanced options '
- 'related to spatial transforms and invalid '
- 'streamlines are not available.')
+ streamlines, offsets = _load_streamlines_from_arrays(positions, offsets)
+
+ if (
+ space_str.lower() != "rasmm"
+ or origin_str.lower() != "nifti"
+ or verify_invalid
+ ):
+ streamlines = _apply_spatial_transforms(
+ streamlines, reference, space_str, origin_str, verify_invalid, offsets
+ )
+ if streamlines is None:
return
- from dipy.io.stateful_tractogram import StatefulTractogram
-
- space, origin = get_reverse_enum(space_str, origin_str)
- sft = StatefulTractogram(streamlines, reference, space, origin)
- if verify_invalid:
- rem, _ = sft.remove_invalid_streamlines()
- print('{} streamlines were removed becaused they were '
- 'invalid.'.format(len(rem)))
- sft.to_rasmm()
- sft.to_center()
- streamlines = sft.streamlines
- streamlines._offsets = offsets
-
- affine, dimensions, _, _ = get_reference_info_wrapper(reference)
- header = {
- "DIMENSIONS": dimensions.tolist(),
- "VOXEL_TO_RASMM": affine.tolist(),
- "NB_VERTICES": len(streamlines._data),
- "NB_STREAMLINES": len(streamlines)-1,
- }
-
- if header['NB_STREAMLINES'] <= 1:
- raise IOError('To use this script, you need at least 2'
- 'streamlines.')
-
- with open(os.path.join(tmp_dir_name, "header.json"), "w") as out_json:
- json.dump(header, out_json)
-
- curr_filename = os.path.join(tmp_dir_name, 'positions.3.{}'.format(
- positions_dtype))
- streamlines._data.astype(positions_dtype).tofile(
- curr_filename)
- curr_filename = os.path.join(tmp_dir_name, 'offsets.{}'.format(
- offsets_dtype))
- streamlines._offsets.astype(offsets_dtype).tofile(
- curr_filename)
+
+ _write_header(tmp_dir_name, reference, streamlines)
+ _write_streamline_data(
+ tmp_dir_name, streamlines, positions_dtype, offsets_dtype
+ )
if dpv:
- os.mkdir(os.path.join(tmp_dir_name, 'dpv'))
for arg in dpv:
- curr_arr = np.squeeze(load_matrix_in_any_format(arg[0]).astype(
- arg[1]))
- if arg[1] == 'bool':
- arg[1] = 'bit'
- if curr_arr.ndim > 2:
- raise IOError('Maximum of 2 dimensions for dpv/dps/dpg.')
- dim = '' if curr_arr.ndim == 1 else '{}.'.format(
- curr_arr.shape[-1])
- curr_filename = os.path.join(tmp_dir_name, 'dpv', '{}.{}{}'.format(
- os.path.basename(os.path.splitext(arg[0])[0]), dim, arg[1]))
- curr_arr.tofile(curr_filename)
+ _write_data_array(tmp_dir_name, "dpv", arg)
if dps:
- os.mkdir(os.path.join(tmp_dir_name, 'dps'))
for arg in dps:
- curr_arr = np.squeeze(load_matrix_in_any_format(arg[0]).astype(
- arg[1]))
- if arg[1] == 'bool':
- arg[1] = 'bit'
- if curr_arr.ndim > 2:
- raise IOError('Maximum of 2 dimensions for dpv/dps/dpg.')
- dim = '' if curr_arr.ndim == 1 else '{}.'.format(
- curr_arr.shape[-1])
- curr_filename = os.path.join(tmp_dir_name, 'dps', '{}.{}{}'.format(
- os.path.basename(os.path.splitext(arg[0])[0]), dim, arg[1]))
- curr_arr.tofile(curr_filename)
+ _write_data_array(tmp_dir_name, "dps", arg)
+
if groups:
- os.mkdir(os.path.join(tmp_dir_name, 'groups'))
for arg in groups:
- curr_arr = load_matrix_in_any_format(arg[0]).astype(arg[1])
- if arg[1] == 'bool':
- arg[1] = 'bit'
- if curr_arr.ndim > 2:
- raise IOError('Maximum of 2 dimensions for dpv/dps/dpg.')
- dim = '' if curr_arr.ndim == 1 else '{}.'.format(
- curr_arr.shape[-1])
- curr_filename = os.path.join(tmp_dir_name, 'groups', '{}.{}{}'.format(
- os.path.basename(os.path.splitext(arg[0])[0]), dim, arg[1]))
- curr_arr.tofile(curr_filename)
+ _write_data_array(tmp_dir_name, "groups", arg)
if dpg:
- os.mkdir(os.path.join(tmp_dir_name, 'dpg'))
for arg in dpg:
- if not os.path.isdir(os.path.join(tmp_dir_name, 'dpg', arg[0])):
- os.mkdir(os.path.join(tmp_dir_name, 'dpg', arg[0]))
- curr_arr = load_matrix_in_any_format(arg[1]).astype(arg[2])
- if arg[1] == 'bool':
- arg[1] = 'bit'
- if curr_arr.ndim > 2:
- raise IOError('Maximum of 2 dimensions for dpv/dps/dpg.')
- if curr_arr.shape == (1, 1):
- curr_arr = curr_arr.reshape((1,))
- dim = '' if curr_arr.ndim == 1 else '{}.'.format(
- curr_arr.shape[-1])
- curr_filename = os.path.join(tmp_dir_name, 'dpg', arg[0], '{}.{}{}'.format(
- os.path.basename(os.path.splitext(arg[1])[0]), dim, arg[2]))
- curr_arr.tofile(curr_filename)
+ _write_data_array(tmp_dir_name, "dpg", arg, is_dpg=True)
trx = tmm.load(tmp_dir_name)
tmm.save(trx, out_tractogram)
trx.close()
-def manipulate_trx_datatype(in_filename, out_filename, dict_dtype):
+def manipulate_trx_datatype(in_filename, out_filename, dict_dtype): # noqa: C901
trx = tmm.load(in_filename)
# For each key in dict_dtype, we create a new memmap with the new dtype
# and we copy the data from the old memmap to the new one.
for key in dict_dtype:
- if key == 'positions':
- tmp_mm = np.memmap(tempfile.NamedTemporaryFile(),
- dtype=dict_dtype[key],
- mode='w+',
- shape=trx.streamlines._data.shape)
+ if key == "positions":
+ tmp_mm = np.memmap(
+ tempfile.NamedTemporaryFile(),
+ dtype=dict_dtype[key],
+ mode="w+",
+ shape=trx.streamlines._data.shape,
+ )
tmp_mm[:] = trx.streamlines._data[:]
trx.streamlines._data = tmp_mm
- elif key == 'offsets':
- tmp_mm = np.memmap(tempfile.NamedTemporaryFile(),
- dtype=dict_dtype[key],
- mode='w+',
- shape=trx.streamlines._offsets.shape)
+ elif key == "offsets":
+ tmp_mm = np.memmap(
+ tempfile.NamedTemporaryFile(),
+ dtype=dict_dtype[key],
+ mode="w+",
+ shape=trx.streamlines._offsets.shape,
+ )
tmp_mm[:] = trx.streamlines._offsets[:]
trx.streamlines._offsets = tmp_mm
- elif key == 'dpv':
+ elif key == "dpv":
for key_dpv in dict_dtype[key]:
- tmp_mm = np.memmap(tempfile.NamedTemporaryFile(),
- dtype=dict_dtype[key][key_dpv],
- mode='w+',
- shape=trx.data_per_vertex[key_dpv]._data.shape)
+ tmp_mm = np.memmap(
+ tempfile.NamedTemporaryFile(),
+ dtype=dict_dtype[key][key_dpv],
+ mode="w+",
+ shape=trx.data_per_vertex[key_dpv]._data.shape,
+ )
tmp_mm[:] = trx.data_per_vertex[key_dpv]._data[:]
trx.data_per_vertex[key_dpv]._data = tmp_mm
- elif key == 'dps':
+ elif key == "dps":
for key_dps in dict_dtype[key]:
- tmp_mm = np.memmap(tempfile.NamedTemporaryFile(),
- dtype=dict_dtype[key][key_dps],
- mode='w+',
- shape=trx.data_per_streamline[key_dps].shape)
+ tmp_mm = np.memmap(
+ tempfile.NamedTemporaryFile(),
+ dtype=dict_dtype[key][key_dps],
+ mode="w+",
+ shape=trx.data_per_streamline[key_dps].shape,
+ )
tmp_mm[:] = trx.data_per_streamline[key_dps][:]
trx.data_per_streamline[key_dps] = tmp_mm
- elif key == 'dpg':
+ elif key == "dpg":
for key_group in dict_dtype[key]:
for key_dpg in dict_dtype[key][key_group]:
- tmp_mm = np.memmap(tempfile.NamedTemporaryFile(),
- dtype=dict_dtype[key][key_group][key_dpg],
- mode='w+',
- shape=trx.data_per_group[key_group][key_dpg].shape)
+ tmp_mm = np.memmap(
+ tempfile.NamedTemporaryFile(),
+ dtype=dict_dtype[key][key_group][key_dpg],
+ mode="w+",
+ shape=trx.data_per_group[key_group][key_dpg].shape,
+ )
tmp_mm[:] = trx.data_per_group[key_group][key_dpg][:]
trx.data_per_group[key_group][key_dpg] = tmp_mm
- elif key == 'groups':
+ elif key == "groups":
for key_group in dict_dtype[key]:
- tmp_mm = np.memmap(tempfile.NamedTemporaryFile(),
- dtype=dict_dtype[key][key_group],
- mode='w+',
- shape=trx.groups[key_group].shape)
+ tmp_mm = np.memmap(
+ tempfile.NamedTemporaryFile(),
+ dtype=dict_dtype[key][key_group],
+ mode="w+",
+ shape=trx.groups[key_group].shape,
+ )
tmp_mm[:] = trx.groups[key_group][:]
trx.groups[key_group] = tmp_mm