From 0408812988a2766c0a36a41e3172970ac52b8461 Mon Sep 17 00:00:00 2001 From: Serge Koudoro Date: Thu, 15 Jan 2026 10:48:04 -0500 Subject: [PATCH 1/7] ci: remove old python version(3.8-3.10) and add new ones (3.11-3.14) --- .github/workflows/python-package.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 822b770..7fdbcce 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -12,7 +12,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.8", "3.9", "3.10", "3.11"] + python-version: ["3.11", "3.12", "3.13", "3.14"] os: [ubuntu-latest, windows-latest, macos-latest] steps: - uses: actions/checkout@v2 From 0c1386f8bb318628387447eb5f604b239844e82a Mon Sep 17 00:00:00 2001 From: Serge Koudoro Date: Thu, 15 Jan 2026 10:48:19 -0500 Subject: [PATCH 2/7] RF: update gitignore --- .gitignore | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.gitignore b/.gitignore index b5a6b76..aac37e3 100644 --- a/.gitignore +++ b/.gitignore @@ -132,3 +132,7 @@ dmypy.json # Pyre type checker .pyre/ .vscode/ + +tmp/ +CLAUDE.md +claude.md From 6339adff3eeabab3f88c3e83c9c10335210677c8 Mon Sep 17 00:00:00 2001 From: Serge Koudoro Date: Thu, 15 Jan 2026 10:49:05 -0500 Subject: [PATCH 3/7] ci: refresh publish to pypi --- .github/workflows/publish-to-test-pypi.yml | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/.github/workflows/publish-to-test-pypi.yml b/.github/workflows/publish-to-test-pypi.yml index 572d0d1..0cfa0bb 100644 --- a/.github/workflows/publish-to-test-pypi.yml +++ b/.github/workflows/publish-to-test-pypi.yml @@ -1,14 +1,21 @@ name: Publish Python 🐍 distributions 📦 to PyPI and TestPyPI -on: push +on: + push: + branches: + - master + tags: + - '*' + jobs: build-n-publish: name: Build and publish Python 🐍 distributions 📦 to PyPI and TestPyPI runs-on: ubuntu-latest + if: github.repository == 'tee-ar-ex/trx-python' steps: - - uses: actions/checkout@master + - uses: actions/checkout@v4 - name: Set up Python 3.9 - uses: actions/setup-python@v1 + uses: actions/setup-python@v5 with: python-version: 3.9 - name: Install pypa/build @@ -30,7 +37,7 @@ jobs: with: user: __token__ password: ${{ secrets.test_pypi_password }} - repository_url: https://test.pypi.org/legacy/ + repository-url: https://test.pypi.org/legacy/ - name: Publish distribution 📦 to PyPI if: startsWith(github.event.ref, 'refs/tags') uses: pypa/gh-action-pypi-publish@release/v1 From 091bc12db1a0c4cdfa50fd6149f7cbfeec693880 Mon Sep 17 00:00:00 2001 From: Serge Koudoro Date: Thu, 15 Jan 2026 10:49:35 -0500 Subject: [PATCH 4/7] doc: large improvement of doc generation. manage version --- .github/workflows/docbuild.yml | 85 ++++++++++++++++++-- docs/_static/switcher.json | 13 +++ docs/source/conf.py | 32 +++++++- setup.cfg | 2 +- tools/update_switcher.py | 143 +++++++++++++++++++++++++++++++++ 5 files changed, 267 insertions(+), 8 deletions(-) create mode 100644 docs/_static/switcher.json create mode 100644 tools/update_switcher.py diff --git a/.github/workflows/docbuild.yml b/.github/workflows/docbuild.yml index 757b4c9..59cbf2d 100644 --- a/.github/workflows/docbuild.yml +++ b/.github/workflows/docbuild.yml @@ -17,12 +17,12 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.9"] + python-version: ["3.13"] steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: Install @@ -35,13 +35,86 @@ jobs: cd docs make html - name: Upload docs - uses: actions/upload-artifact@v1 + uses: actions/upload-artifact@v4 with: name: docs path: docs/_build/html - - name: Publish docs to Github Pages - if: startsWith(github.event.ref, 'refs/tags') + + deploy-dev: + needs: build + runs-on: ubuntu-latest + if: github.event_name == 'push' && github.ref == 'refs/heads/master' && github.repository == 'tee-ar-ex/trx-python' + steps: + - uses: actions/checkout@v4 + - uses: actions/download-artifact@v4 + with: + name: docs + path: docs/_build/html + - name: Publish dev docs to Github Pages + uses: JamesIves/github-pages-deploy-action@v4 + with: + branch: gh-pages + folder: docs/_build/html + target-folder: dev + + deploy-release: + needs: build + runs-on: ubuntu-latest + if: startsWith(github.ref, 'refs/tags/') && github.repository == 'tee-ar-ex/trx-python' + steps: + - uses: actions/checkout@v4 + - uses: actions/download-artifact@v4 + with: + name: docs + path: docs/_build/html + - name: Get version from tag + id: get_version + run: echo "VERSION=${GITHUB_REF#refs/tags/}" >> $GITHUB_OUTPUT + - name: Fetch existing switcher.json from gh-pages + run: | + curl -sSL https://tee-ar-ex.github.io/trx-python/switcher.json -o switcher.json || echo '[]' > switcher.json + - name: Update switcher.json with new version + run: python tools/update_switcher.py switcher.json --version ${{ steps.get_version.outputs.VERSION }} + - name: Create root files (redirect + switcher.json) + run: | + mkdir -p root_files + cp switcher.json root_files/ + cat > root_files/index.html << 'EOF' + + + + + + trx-python - TRX File Format for Tractography + + + + + + + +

trx-python Documentation

+

Python implementation of the TRX file format for tractography data.

+

If you are not redirected automatically, visit the stable documentation.

+ + + EOF + - name: Publish root files (redirect + switcher.json) + uses: JamesIves/github-pages-deploy-action@v4 + with: + branch: gh-pages + folder: root_files + target-folder: . + clean: false + - name: Publish release docs to Github Pages + uses: JamesIves/github-pages-deploy-action@v4 + with: + branch: gh-pages + folder: docs/_build/html + target-folder: ${{ steps.get_version.outputs.VERSION }} + - name: Publish stable docs to Github Pages uses: JamesIves/github-pages-deploy-action@v4 with: branch: gh-pages folder: docs/_build/html + target-folder: stable diff --git a/docs/_static/switcher.json b/docs/_static/switcher.json new file mode 100644 index 0000000..ad36d9e --- /dev/null +++ b/docs/_static/switcher.json @@ -0,0 +1,13 @@ +[ + { + "name": "dev", + "version": "dev", + "url": "https://tee-ar-ex.github.io/trx-python/dev/" + }, + { + "name": "stable", + "version": "stable", + "url": "https://tee-ar-ex.github.io/trx-python/stable/", + "preferred": true + } +] diff --git a/docs/source/conf.py b/docs/source/conf.py index 5970660..eb70011 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -13,13 +13,35 @@ # import os # import sys # sys.path.insert(0, os.path.abspath('.')) +import os from datetime import datetime as dt +# -- Version information ----------------------------------------------------- +# Get version from environment variable (set by CI) or package +version = os.environ.get('TRX_VERSION', None) +if version is None: + try: + from trx import __version__ + version = __version__ + except ImportError: + version = "dev" + +# Normalize version for switcher matching +# Remove .devX suffix for matching against switcher.json +version_match = version.split('.dev')[0] if '.dev' in version else version +if version_match == version and 'dev' not in version: + # This is a release version + pass +else: + # Development version - match against "dev" + version_match = "dev" + # -- Project information ----------------------------------------------------- project = 'trx-python' copyright = copyright = f'2021-{dt.now().year}, The TRX developers' author = 'The TRX developers' +release = version # -- General configuration --------------------------------------------------- @@ -59,6 +81,7 @@ # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". +html_static_path = ['../_static'] html_logo = "../_static/trx_logo.png" @@ -74,7 +97,14 @@ # The type of image to be used (see below for details) "type": "fontawesome", } - ] + ], + # Version switcher configuration + "switcher": { + "json_url": "https://tee-ar-ex.github.io/trx-python/switcher.json", + "version_match": version_match, + }, + "navbar_start": ["navbar-logo", "version-switcher"], + "show_version_warning_banner": True, } diff --git a/setup.cfg b/setup.cfg index 13177b3..b9e7279 100644 --- a/setup.cfg +++ b/setup.cfg @@ -34,7 +34,7 @@ install_requires = [options.extras_require] doc = astroid==2.15.8 sphinx - pydata-sphinx-theme + pydata-sphinx-theme >= 0.16.1 sphinx-autoapi numpydoc diff --git a/tools/update_switcher.py b/tools/update_switcher.py new file mode 100644 index 0000000..6bc5ad6 --- /dev/null +++ b/tools/update_switcher.py @@ -0,0 +1,143 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +"""Update switcher.json for documentation version switching. + +This script maintains the version switcher JSON file used by pydata-sphinx-theme +to enable users to switch between different documentation versions. +""" +import argparse +import json +import sys +from pathlib import Path + +BASE_URL = "https://tee-ar-ex.github.io/trx-python" + + +def load_switcher(path): + """Load existing switcher.json or return empty list.""" + try: + with open(path, 'r') as f: + return json.load(f) + except (FileNotFoundError, json.JSONDecodeError): + return [] + + +def save_switcher(path, versions): + """Save switcher.json with proper formatting.""" + with open(path, 'w') as f: + json.dump(versions, f, indent=4) + f.write('\n') + + +def ensure_dev_entry(versions): + """Ensure dev entry exists in versions list.""" + dev_exists = any(v.get('version') == 'dev' for v in versions) + if not dev_exists: + versions.insert(0, { + "name": "dev", + "version": "dev", + "url": f"{BASE_URL}/dev/" + }) + return versions + + +def ensure_stable_entry(versions): + """Ensure stable entry exists with preferred flag.""" + stable_idx = next( + (i for i, v in enumerate(versions) if v.get('version') == 'stable'), + None + ) + if stable_idx is not None: + versions[stable_idx]['preferred'] = True + else: + versions.append({ + "name": "stable", + "version": "stable", + "url": f"{BASE_URL}/stable/", + "preferred": True + }) + return versions + + +def add_version(versions, version): + """Add a new version entry to the versions list. + + Parameters + ---------- + versions : list + List of version entries. + version : str + Version string to add (e.g., "0.5.0"). + + Returns + ------- + list + Updated versions list. + """ + # Remove 'preferred' from all existing entries + for v in versions: + v.pop('preferred', None) + + # Check if this version already exists + version_exists = any(v.get('version') == version for v in versions) + + if not version_exists: + new_entry = { + "name": version, + "version": version, + "url": f"{BASE_URL}/{version}/" + } + # Find dev entry index to insert after it + dev_idx = next( + (i for i, v in enumerate(versions) if v.get('version') == 'dev'), + -1 + ) + if dev_idx >= 0: + versions.insert(dev_idx + 1, new_entry) + else: + versions.insert(0, new_entry) + + return versions + + +def main(): + """Main entry point.""" + parser = argparse.ArgumentParser( + description='Update switcher.json for documentation version switching' + ) + parser.add_argument( + 'switcher_path', + type=Path, + help='Path to switcher.json file' + ) + parser.add_argument( + '--version', + type=str, + help='New version to add (e.g., 0.5.0)' + ) + + args = parser.parse_args() + + # Load existing versions + versions = load_switcher(args.switcher_path) + + # Add new version if specified + if args.version: + versions = add_version(versions, args.version) + + # Ensure required entries exist + versions = ensure_dev_entry(versions) + versions = ensure_stable_entry(versions) + + # Save updated switcher.json + save_switcher(args.switcher_path, versions) + + # Print result for CI logs + print(f"Updated {args.switcher_path}:") + print(json.dumps(versions, indent=4)) + + return 0 + + +if __name__ == '__main__': + sys.exit(main()) From a7c07a20c3832e2937196e6c9362653b458f1b0d Mon Sep 17 00:00:00 2001 From: Serge Koudoro Date: Thu, 15 Jan 2026 11:17:43 -0500 Subject: [PATCH 5/7] RF: spin integration, removing setup.cfg and setup.py --- .flake8 | 11 + .github/workflows/docbuild.yml | 2 + .github/workflows/publish-to-test-pypi.yml | 87 ++++-- .github/workflows/python-package.yml | 39 --- .github/workflows/test.yml | 42 +++ .gitignore | 2 + .spin/cmds.py | 210 ++++++++++++++ README.md | 24 ++ docs/source/conf.py | 3 +- pyproject.toml | 91 +++++- scripts/tff_manipulate_datatype.py | 2 +- setup.cfg | 49 ---- setup.py | 33 --- trx/io.py | 2 +- trx/tests/test_io.py | 4 +- trx/tests/test_memmap.py | 2 +- trx/trx_file_memmap.py | 310 ++++++++++++--------- trx/utils.py | 14 +- trx/workflows.py | 275 ++++++++++-------- 19 files changed, 799 insertions(+), 403 deletions(-) create mode 100644 .flake8 delete mode 100644 .github/workflows/python-package.yml create mode 100644 .github/workflows/test.yml create mode 100644 .spin/cmds.py delete mode 100644 setup.cfg delete mode 100644 setup.py diff --git a/.flake8 b/.flake8 new file mode 100644 index 0000000..f6b4ec2 --- /dev/null +++ b/.flake8 @@ -0,0 +1,11 @@ +[flake8] +max-line-length = 88 +max-complexity = 10 +exclude = + .git, + __pycache__, + .tox, + .eggs, + *.egg, + build, + dist \ No newline at end of file diff --git a/.github/workflows/docbuild.yml b/.github/workflows/docbuild.yml index 59cbf2d..5109045 100644 --- a/.github/workflows/docbuild.yml +++ b/.github/workflows/docbuild.yml @@ -21,6 +21,8 @@ jobs: steps: - uses: actions/checkout@v4 + with: + fetch-depth: 0 # Fetch all history and tags for setuptools_scm - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v5 with: diff --git a/.github/workflows/publish-to-test-pypi.yml b/.github/workflows/publish-to-test-pypi.yml index 0cfa0bb..3b3c454 100644 --- a/.github/workflows/publish-to-test-pypi.yml +++ b/.github/workflows/publish-to-test-pypi.yml @@ -8,39 +8,88 @@ on: - '*' jobs: - build-n-publish: - name: Build and publish Python 🐍 distributions 📦 to PyPI and TestPyPI + build: + name: Build distribution 📦 runs-on: ubuntu-latest if: github.repository == 'tee-ar-ex/trx-python' steps: - uses: actions/checkout@v4 + with: + fetch-depth: 0 # Fetch all history and tags for setuptools_scm - name: Set up Python 3.9 uses: actions/setup-python@v5 with: python-version: 3.9 - name: Install pypa/build - run: >- - python -m - pip install - build - --user + run: python -m pip install build --user - name: Build a binary wheel and a source tarball - run: >- - python -m - build - --sdist - --wheel - --outdir dist/ - . + run: python -m build --sdist --wheel --outdir dist/ . + - name: Upload distribution artifacts + uses: actions/upload-artifact@v4 + with: + name: python-package-distributions + path: dist/ + + publish-to-testpypi: + name: Publish to TestPyPI + needs: build + runs-on: ubuntu-latest + environment: + name: testpypi + url: https://test.pypi.org/p/trx-python + permissions: + id-token: write + steps: + - name: Download distribution artifacts + uses: actions/download-artifact@v4 + with: + name: python-package-distributions + path: dist/ - name: Publish distribution 📦 to Test PyPI uses: pypa/gh-action-pypi-publish@release/v1 with: - user: __token__ - password: ${{ secrets.test_pypi_password }} repository-url: https://test.pypi.org/legacy/ + + publish-to-pypi: + name: Publish to PyPI + needs: build + runs-on: ubuntu-latest + if: startsWith(github.ref, 'refs/tags/') + environment: + name: pypi + url: https://pypi.org/p/trx-python + permissions: + id-token: write + steps: + - name: Download distribution artifacts + uses: actions/download-artifact@v4 + with: + name: python-package-distributions + path: dist/ - name: Publish distribution 📦 to PyPI - if: startsWith(github.event.ref, 'refs/tags') uses: pypa/gh-action-pypi-publish@release/v1 + + github-release: + name: Create GitHub Release + needs: publish-to-pypi + runs-on: ubuntu-latest + permissions: + contents: write + steps: + - uses: actions/checkout@v4 + - name: Download distribution artifacts + uses: actions/download-artifact@v4 with: - user: __token__ - password: ${{ secrets.pypi_password }} + name: python-package-distributions + path: dist/ + - name: Get version from tag + id: get_version + run: echo "VERSION=${GITHUB_REF#refs/tags/}" >> $GITHUB_OUTPUT + - name: Create GitHub Release + env: + GH_TOKEN: ${{ github.token }} + run: | + gh release create "${{ steps.get_version.outputs.VERSION }}" \ + --title "Release ${{ steps.get_version.outputs.VERSION }}" \ + --generate-notes \ + dist/* diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml deleted file mode 100644 index 7fdbcce..0000000 --- a/.github/workflows/python-package.yml +++ /dev/null @@ -1,39 +0,0 @@ -name: Python package - -on: - push: - branches: [ master ] - pull_request: - branches: [ master ] - -jobs: - build: - runs-on: ubuntu-latest - strategy: - fail-fast: false - matrix: - python-version: ["3.11", "3.12", "3.13", "3.14"] - os: [ubuntu-latest, windows-latest, macos-latest] - steps: - - uses: actions/checkout@v2 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - - name: Install the package - run: | - python -m pip install --upgrade pip - python -m pip install -e .[test] - - - name: Report version - run: | - python setup.py --version - - name: Lint with flake8 - run: | - # stop the build if there are Python syntax errors or undefined names - flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics - # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide - flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics - - name: Test with pytest - run: | - pytest trx/tests; pytest scripts/tests diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 0000000..8f8df59 --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,42 @@ +name: Tests + +on: + push: + branches: [master] + pull_request: + branches: [master] + +permissions: + contents: read + +jobs: + test: + name: Python ${{ matrix.python-version }} • ${{ matrix.os }} + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] + os: [ubuntu-latest, windows-latest, macos-latest] + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 # needed for setuptools_scm version detection + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + cache: pip + cache-dependency-path: pyproject.toml + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + python -m pip install -e .[dev,test] + + - name: Lint + run: spin lint + + - name: Test + run: spin test diff --git a/.gitignore b/.gitignore index aac37e3..3987e02 100644 --- a/.gitignore +++ b/.gitignore @@ -136,3 +136,5 @@ dmypy.json tmp/ CLAUDE.md claude.md +agents.md +AGENTS.md diff --git a/.spin/cmds.py b/.spin/cmds.py new file mode 100644 index 0000000..aab54a3 --- /dev/null +++ b/.spin/cmds.py @@ -0,0 +1,210 @@ +"""Custom spin commands for trx-python development.""" +import subprocess +import sys + +import click + + +UPSTREAM_URL = "https://github.com/tee-ar-ex/trx-python.git" +UPSTREAM_NAME = "upstream" + + +def run(cmd, check=True, capture=True): + """Run a shell command.""" + result = subprocess.run( + cmd, + capture_output=capture, + text=True, + check=False + ) + if check and result.returncode != 0: + if capture: + click.echo(f"Error: {result.stderr}", err=True) + return None + return result.stdout.strip() if capture else result.returncode + + +def get_remotes(): + """Get dict of remote names to URLs.""" + output = run(["git", "remote", "-v"]) + if not output: + return {} + remotes = {} + for line in output.split("\n"): + if "(fetch)" in line: + parts = line.split() + remotes[parts[0]] = parts[1] + return remotes + + +@click.command() +def setup(): + """Set up development environment (fetch tags from upstream). + + This command configures your fork for development by: + 1. Adding the upstream remote if not present + 2. Fetching tags from upstream (required for correct version detection) + + Run this once after cloning your fork. + """ + click.echo("Setting up trx-python development environment...\n") + + # Check if in git repo + if run(["git", "rev-parse", "--git-dir"], check=False) is None: + click.echo("Error: Not in a git repository", err=True) + sys.exit(1) + + # Check/add upstream remote + remotes = get_remotes() + upstream_remote = None + + for name, url in remotes.items(): + if UPSTREAM_URL.rstrip(".git") in url.rstrip(".git"): + upstream_remote = name + click.echo(f"Found upstream remote: {name}") + break + + if upstream_remote is None: + click.echo(f"Adding upstream remote: {UPSTREAM_URL}") + run(["git", "remote", "add", UPSTREAM_NAME, UPSTREAM_URL]) + upstream_remote = UPSTREAM_NAME + + # Fetch tags + click.echo(f"\nFetching tags from {upstream_remote}...") + run(["git", "fetch", upstream_remote, "--tags"], capture=False) + + # Verify version + click.echo("\nVerifying version detection...") + try: + from setuptools_scm import get_version + version = get_version() + click.echo(f"Detected version: {version}") + + # Check for suspicious version patterns + if version.startswith("0.0"): + click.echo( + "\nWarning: Version starts with 0.0 - tags may not be fetched.", + err=True + ) + sys.exit(1) + except ImportError: + click.echo("Note: Install setuptools_scm to verify version detection") + + click.echo("\nSetup complete! You can now run:") + click.echo(" spin install # Install in development mode") + click.echo(" spin test # Run tests") + + +@click.command() +@click.option( + "-m", "--match", "pattern", default=None, + help="Only run tests matching this pattern (passed to pytest -k)" +) +@click.option( + "-v", "--verbose", is_flag=True, default=False, + help="Verbose output" +) +@click.argument("pytest_args", nargs=-1) +def test(pattern, verbose, pytest_args): + """Run tests using pytest. + + Additional arguments are passed directly to pytest. + + Examples: + spin test # Run all tests + spin test -m memmap # Run tests matching 'memmap' + spin test -v # Verbose output + spin test -- -x --tb=short # Pass args to pytest + """ + cmd = ["pytest", "trx/tests", "scripts/tests"] + + if pattern: + cmd.extend(["-k", pattern]) + + if verbose: + cmd.append("-v") + + if pytest_args: + cmd.extend(pytest_args) + + click.echo(f"Running: {' '.join(cmd)}\n") + sys.exit(run(cmd, capture=False, check=False)) + + +@click.command() +@click.option( + "--fix", is_flag=True, default=False, + help="Currently unused (for future auto-fix support)" +) +def lint(fix): + """Run linting checks using flake8. + + Examples: + spin lint # Run flake8 checks + """ + # Strict check for syntax errors + click.echo("Checking for syntax errors...") + cmd_strict = [ + "flake8", ".", "--count", + "--select=E9,F63,F7,F82", + "--show-source", "--statistics" + ] + result = run(cmd_strict, capture=False, check=False) + if result != 0: + click.echo("Syntax errors found!", err=True) + sys.exit(1) + + # Full lint check + click.echo("\nRunning full lint check...") + cmd_full = [ + "flake8", ".", "--count", + "--max-line-length=88", + "--max-complexity=10", + "--statistics" + ] + sys.exit(run(cmd_full, capture=False, check=False)) + + +@click.command() +@click.option( + "--clean", is_flag=True, default=False, + help="Clean build directory before building" +) +@click.option( + "--open", "open_browser", is_flag=True, default=False, + help="Open documentation in browser after building" +) +def docs(clean, open_browser): + """Build documentation using Sphinx. + + Examples: + spin docs # Build docs + spin docs --clean # Clean and rebuild + spin docs --open # Build and open in browser + """ + import os + docs_dir = "docs" + + if clean: + click.echo("Cleaning build directory...") + build_dir = os.path.join(docs_dir, "_build") + if os.path.exists(build_dir): + import shutil + shutil.rmtree(build_dir) + + click.echo("Building documentation...") + cmd = ["make", "-C", docs_dir, "html"] + result = run(cmd, capture=False, check=False) + + if result == 0: + index_path = os.path.abspath( + os.path.join(docs_dir, "_build", "html", "index.html") + ) + click.echo("\nDocs built successfully!") + click.echo(f"Open: {index_path}") + + if open_browser: + import webbrowser + webbrowser.open(f"file://{index_path}") + + sys.exit(result) diff --git a/README.md b/README.md index deca3d3..1f1c46e 100644 --- a/README.md +++ b/README.md @@ -14,6 +14,30 @@ Or, to install from source: cd trx-python pip install . +### Development + +For contributors, we use [spin](https://github.com/scientific-python/spin) to manage the development workflow. This ensures proper version detection when working with forks. + +**First-time setup (required for forks):** + + git clone https://github.com/YOUR_USERNAME/trx-python.git + cd trx-python + pip install -e ".[dev]" + spin setup + +The `spin setup` command configures your fork by fetching version tags from the upstream repository. This is required for correct version detection with `setuptools_scm`. + +**Common development commands:** + + spin setup # Set up development environment (fetch upstream tags) + spin install # Install package in development/editable mode + spin test # Run all tests + spin test -m memmap # Run tests matching 'memmap' + spin lint # Run linting checks + spin docs # Build documentation + +Run `spin` without arguments to see all available commands. + ### Temporary Directory The TRX file format uses memmaps to limit RAM usage. When dealing with large files this means several gigabytes could be required on disk (instead of RAM). diff --git a/docs/source/conf.py b/docs/source/conf.py index eb70011..8b02e6e 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -92,7 +92,8 @@ "name": "GitHub", # URL where the link will redirect "url": "https://github.com/tee-ar-ex", # required - # Icon class (if "type": "fontawesome"), or path to local image (if "type": "local") + # Icon class (if "type": "fontawesome"), or path to local image + # (if "type": "local") "icon": "fab fa-github-square", # The type of image to be used (see below for details) "type": "fontawesome", diff --git a/pyproject.toml b/pyproject.toml index 6d97948..bf24e02 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,84 @@ [build-system] -requires = ["setuptools >= 42.0", "wheel", "setuptools_scm[toml] >= 7"] -build-backend = "setuptools.build_meta:__legacy__" +requires = ["setuptools >= 64", "wheel", "setuptools_scm[toml] >= 7"] +build-backend = "setuptools.build_meta" + +[project] +name = "trx-python" +dynamic = ["version"] +description = "Experiments with new file format for tractography" +readme = "README.md" +license = {text = "BSD License"} +requires-python = ">=3.9" +authors = [ + {name = "The TRX developers"} +] +classifiers = [ + "Development Status :: 3 - Alpha", + "Environment :: Console", + "Intended Audience :: Science/Research", + "License :: OSI Approved :: BSD License", + "Operating System :: OS Independent", + "Programming Language :: Python", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Topic :: Scientific/Engineering", +] +dependencies = [ + "deepdiff", + "nibabel >= 5", + "numpy >= 1.22", +] + +[project.optional-dependencies] +dev = [ + "spin >= 0.13", + "setuptools_scm", +] +doc = [ + "astroid >= 4.0.0", + "sphinx >= 9.0.0", + "pydata-sphinx-theme >= 0.16.1", + "sphinx-autoapi >= 3.0.0", + "numpydoc", +] +test = [ + "flake8", + "psutil", + "pytest >= 7", + "pytest-console-scripts >= 0", +] +all = [ + "trx-python[dev]", + "trx-python[doc]", + "trx-python[test]", +] + +[project.urls] +Homepage = "https://github.com/tee-ar-ex/trx-python" +Documentation = "https://tee-ar-ex.github.io/trx-python/" +Repository = "https://github.com/tee-ar-ex/trx-python" + +[tool.setuptools] +packages = ["trx"] +include-package-data = true +script-files = [ + "scripts/tff_concatenate_tractograms.py", + "scripts/tff_convert_dsi_studio.py", + "scripts/tff_convert_tractogram.py", + "scripts/tff_generate_trx_from_scratch.py", + "scripts/tff_manipulate_datatype.py", + "scripts/tff_simple_compare.py", + "scripts/tff_validate_trx.py", + "scripts/tff_verify_header_compatibility.py", + "scripts/tff_visualize_overlap.py", +] + +[tool.setuptools.dynamic] +version = {attr = "trx._version.__version__"} [tool.setuptools_scm] write_to = "trx/_version.py" @@ -10,3 +88,12 @@ __version__ = "{version}" """ fallback_version = "0.0" local_scheme = "no-local-version" + +[tool.spin] +package = "trx" + +[tool.spin.commands] +"Setup" = [".spin/cmds.py:setup"] +"Build" = ["spin.cmds.pip.install"] +"Test" = [".spin/cmds.py:test", ".spin/cmds.py:lint"] +"Docs" = [".spin/cmds.py:docs"] diff --git a/scripts/tff_manipulate_datatype.py b/scripts/tff_manipulate_datatype.py index 4c9c074..2c0afe5 100644 --- a/scripts/tff_manipulate_datatype.py +++ b/scripts/tff_manipulate_datatype.py @@ -65,7 +65,7 @@ def _build_arg_parser(): return p -def main(): +def main(): # noqa: C901 parser = _build_arg_parser() args = parser.parse_args() diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index b9e7279..0000000 --- a/setup.cfg +++ /dev/null @@ -1,49 +0,0 @@ -[metadata] -name = trx-python -url = https://github.com/tee-ar-ex/trx-python -classifiers = - Development Status :: 3 - Alpha - Environment :: Console - Intended Audience :: Science/Research - License :: OSI Approved :: BSD License - Operating System :: OS Independent - Programming Language :: Python - Topic :: Scientific/Engineering - -license = BSD License -description = Experiments with new file format for tractography -long_description = file: README.md -long_description_content_type = text/markdown -platforms = OS Independent - -packages = find: -include_package_data = True - - -[options] -python_requires = >=3.8 -setup_requires = - packaging >= 19.0 - cython >= 0.29 -install_requires = - setuptools_scm - deepdiff - nibabel >= 5 - numpy >= 1.22 - -[options.extras_require] -doc = astroid==2.15.8 - sphinx - pydata-sphinx-theme >= 0.16.1 - sphinx-autoapi - numpydoc - -test = - flake8 - psutil - pytest >= 7 - pytest-console-scripts >= 0 - -all = - %(doc)s - %(test)s diff --git a/setup.py b/setup.py deleted file mode 100644 index 72e2115..0000000 --- a/setup.py +++ /dev/null @@ -1,33 +0,0 @@ -import glob -import os.path as op -import string -from setuptools_scm import get_version -from setuptools import setup - - -def local_version(version): - """ - Patch in a version that can be uploaded to test PyPI - """ - scm_version = get_version() - if "dev" in scm_version: - gh_in_int = [] - for char in version.node: - if char.isdigit(): - gh_in_int.append(str(char)) - else: - gh_in_int.append(str(string.ascii_letters.find(char))) - return "".join(gh_in_int) - else: - return "" - - -opts = dict(use_scm_version={ - "root": ".", "relative_to": __file__, - "write_to": op.join("trx", "version.py"), - "local_scheme": local_version}, - scripts=glob.glob("scripts/*.py")) - - -if __name__ == '__main__': - setup(**opts) diff --git a/trx/io.py b/trx/io.py index e292778..26b758d 100644 --- a/trx/io.py +++ b/trx/io.py @@ -6,7 +6,7 @@ import sys try: - import dipy + import dipy # noqa: F401 dipy_available = True except ImportError: dipy_available = False diff --git a/trx/tests/test_io.py b/trx/tests/test_io.py index 33d02bf..3ddb6ad 100644 --- a/trx/tests/test_io.py +++ b/trx/tests/test_io.py @@ -18,7 +18,7 @@ import trx.trx_file_memmap as tmm from trx.trx_file_memmap import TrxFile -from trx.io import load, save, get_trx_tmp_dir +from trx.io import load, save from trx.fetcher import (get_testing_files_dict, fetch_data, get_home) @@ -41,7 +41,7 @@ def test_seq_ops_sft(path): obj.close() save_tractogram(sft_1, os.path.join(tmp_dir, 'tmp.trk')) - sft_2 = load_tractogram(os.path.join(tmp_dir, 'tmp.trk'), 'same') + _ = load_tractogram(os.path.join(tmp_dir, 'tmp.trk'), 'same') def test_seq_ops_trx(): diff --git a/trx/tests/test_memmap.py b/trx/tests/test_memmap.py index a671c08..2ccedbd 100644 --- a/trx/tests/test_memmap.py +++ b/trx/tests/test_memmap.py @@ -8,7 +8,7 @@ import pytest try: - import dipy + import dipy # noqa: F401 dipy_available = True except ImportError: dipy_available = False diff --git a/trx/trx_file_memmap.py b/trx/trx_file_memmap.py index c617ebb..471cdc4 100644 --- a/trx/trx_file_memmap.py +++ b/trx/trx_file_memmap.py @@ -24,7 +24,7 @@ get_reference_info_wrapper) try: - import dipy + import dipy # noqa: F401 dipy_available = True except ImportError: dipy_available = False @@ -350,6 +350,155 @@ def load_from_directory(directory: str) -> Type["TrxFile"]: root=directory) +def _filter_empty_trx_files(trx_list: List["TrxFile"]) -> List["TrxFile"]: + """Remove empty TrxFiles from the list.""" + return [curr_trx for curr_trx in trx_list if curr_trx.header["NB_STREAMLINES"] > 0] + + +def _get_all_data_keys(trx_list: List["TrxFile"]) -> Tuple[set, set]: + """Get all dps and dpv keys from the TrxFile list.""" + all_dps = [] + all_dpv = [] + for curr_trx in trx_list: + all_dps.extend(list(curr_trx.data_per_streamline.keys())) + all_dpv.extend(list(curr_trx.data_per_vertex.keys())) + return set(all_dps), set(all_dpv) + + +def _check_space_attributes(trx_list: List["TrxFile"]) -> None: + """Verify that space attributes are consistent across TrxFiles.""" + ref_trx = trx_list[0] + for curr_trx in trx_list[1:]: + if not np.allclose( + ref_trx.header["VOXEL_TO_RASMM"], curr_trx.header["VOXEL_TO_RASMM"] + ) or not np.array_equal( + ref_trx.header["DIMENSIONS"], curr_trx.header["DIMENSIONS"] + ): + raise ValueError("Wrong space attributes.") + + +def _verify_dpv_coherence(trx_list: List["TrxFile"], all_dpv: set, + ref_trx: "TrxFile", delete_dpv: bool) -> None: + """Verify dpv coherence across TrxFiles.""" + for curr_trx in trx_list: + for key in all_dpv: + if (key not in ref_trx.data_per_vertex.keys() or + key not in curr_trx.data_per_vertex.keys()): + if not delete_dpv: + logging.debug( + "{} dpv key does not exist in all TrxFile.".format(key) + ) + raise ValueError( + "TrxFile must be sharing identical dpv keys.") + elif (ref_trx.data_per_vertex[key]._data.dtype != + curr_trx.data_per_vertex[key]._data.dtype): + logging.debug( + "{} dpv key is not declared with the same dtype " + "in all TrxFile.".format(key) + ) + raise ValueError("Shared dpv key, has different dtype.") + + +def _verify_dps_coherence(trx_list: List["TrxFile"], all_dps: set, + ref_trx: "TrxFile", delete_dps: bool) -> None: + """Verify dps coherence across TrxFiles.""" + for curr_trx in trx_list: + for key in all_dps: + if (key not in ref_trx.data_per_streamline.keys() or + key not in curr_trx.data_per_streamline.keys()): + if not delete_dps: + logging.debug( + "{} dps key does not exist in all TrxFile.".format(key) + ) + raise ValueError( + "TrxFile must be sharing identical dps keys.") + elif (ref_trx.data_per_streamline[key].dtype != + curr_trx.data_per_streamline[key].dtype): + logging.debug( + "{} dps key is not declared with the same dtype " + "in all TrxFile.".format(key) + ) + raise ValueError("Shared dps key, has different dtype.") + + +def _compute_groups_info(trx_list: List["TrxFile"]) -> Tuple[dict, dict]: + """Compute group length and dtype information.""" + all_groups_len = {} + all_groups_dtype = {} + + for trx_1 in trx_list: + for group_key in trx_1.groups.keys(): + if group_key in all_groups_len: + all_groups_len[group_key] += len(trx_1.groups[group_key]) + else: + all_groups_len[group_key] = len(trx_1.groups[group_key]) + + if (group_key in all_groups_dtype and + trx_1.groups[group_key].dtype != all_groups_dtype[group_key]): + raise ValueError("Shared group key, has different dtype.") + else: + all_groups_dtype[group_key] = trx_1.groups[group_key].dtype + + return all_groups_len, all_groups_dtype + + +def _create_new_trx_for_concatenation(trx_list: List["TrxFile"], ref_trx: "TrxFile", + delete_dps: bool, delete_dpv: bool, + delete_groups: bool) -> "TrxFile": + """Create a new TrxFile for concatenation.""" + nb_vertices = 0 + nb_streamlines = 0 + for curr_trx in trx_list: + curr_strs_len, curr_pts_len = curr_trx._get_real_len() + nb_streamlines += curr_strs_len + nb_vertices += curr_pts_len + + new_trx = TrxFile( + nb_vertices=nb_vertices, nb_streamlines=nb_streamlines, + init_as=ref_trx + ) + if delete_dps: + new_trx.data_per_streamline = {} + if delete_dpv: + new_trx.data_per_vertex = {} + if delete_groups: + new_trx.groups = {} + + return new_trx + + +def _setup_groups_for_concatenation(new_trx: "TrxFile", trx_list: List["TrxFile"], + all_groups_len: dict, all_groups_dtype: dict, + delete_groups: bool) -> None: + """Setup groups in the new TrxFile for concatenation.""" + if delete_groups: + return + + tmp_dir = new_trx._uncompressed_folder_handle.name + + for group_key in all_groups_len.keys(): + if not os.path.isdir(os.path.join(tmp_dir, "groups/")): + os.mkdir(os.path.join(tmp_dir, "groups/")) + + dtype = all_groups_dtype[group_key] + group_filename = os.path.join( + tmp_dir, "groups/" "{}.{}".format(group_key, dtype.name) + ) + group_len = all_groups_len[group_key] + new_trx.groups[group_key] = _create_memmap( + group_filename, mode="w+", shape=(group_len,), dtype=dtype + ) + + pos = 0 + count = 0 + for curr_trx in trx_list: + curr_len = len(curr_trx.groups[group_key]) + new_trx.groups[group_key][pos: pos + curr_len] = \ + curr_trx.groups[group_key] + count + pos += curr_len + count += curr_trx.header["NB_STREAMLINES"] + + def concatenate( trx_list: List["TrxFile"], delete_dpv: bool = False, @@ -378,150 +527,41 @@ def concatenate( TrxFile representing the concatenated data """ - trx_list = [ - curr_trx for curr_trx in trx_list if curr_trx.header["NB_STREAMLINES"] > 0 - ] + trx_list = _filter_empty_trx_files(trx_list) if len(trx_list) == 0: logging.warning("Inputs of concatenation were empty.") return TrxFile() ref_trx = trx_list[0] - all_dps = [] - all_dpv = [] - for curr_trx in trx_list: - all_dps.extend(list(curr_trx.data_per_streamline.keys())) - all_dpv.extend(list(curr_trx.data_per_vertex.keys())) - all_dps, all_dpv = set(all_dps), set(all_dpv) + all_dps, all_dpv = _get_all_data_keys(trx_list) if check_space_attributes: - for curr_trx in trx_list[1:]: - if not np.allclose( - ref_trx.header["VOXEL_TO_RASMM"], curr_trx.header["VOXEL_TO_RASMM"] - ) or not np.array_equal( - ref_trx.header["DIMENSIONS"], curr_trx.header["DIMENSIONS"] - ): - raise ValueError("Wrong space attributes.") + _check_space_attributes(trx_list) if preallocation and not delete_groups: raise ValueError( - "Groups are variables, cannot be handled with " "preallocation" + "Groups are variables, cannot be handled with preallocation" ) - # Verifying the validity of fixed-size arrays, coherence between inputs - for curr_trx in trx_list: - for key in all_dpv: - if key not in ref_trx.data_per_vertex.keys() \ - or key not in curr_trx.data_per_vertex.keys(): - if not delete_dpv: - logging.debug( - "{} dpv key does not exist in all TrxFile.".format(key) - ) - raise ValueError( - "TrxFile must be sharing identical dpv " "keys.") - elif ( - ref_trx.data_per_vertex[key]._data.dtype - != curr_trx.data_per_vertex[key]._data.dtype - ): - logging.debug( - "{} dpv key is not declared with the same dtype " - "in all TrxFile.".format(key) - ) - raise ValueError("Shared dpv key, has different dtype.") + _verify_dpv_coherence(trx_list, all_dpv, ref_trx, delete_dpv) + _verify_dps_coherence(trx_list, all_dps, ref_trx, delete_dps) - for curr_trx in trx_list: - for key in all_dps: - if key not in ref_trx.data_per_streamline.keys() \ - or key not in curr_trx.data_per_streamline.keys(): - if not delete_dps: - logging.debug( - "{} dps key does not exist in all " "TrxFile.".format( - key) - ) - raise ValueError( - "TrxFile must be sharing identical dps " "keys.") - elif ( - ref_trx.data_per_streamline[key].dtype - != curr_trx.data_per_streamline[key].dtype - ): - logging.debug( - "{} dps key is not declared with the same dtype " - "in all TrxFile.".format(key) - ) - raise ValueError("Shared dps key, has different dtype.") + all_groups_len, all_groups_dtype = _compute_groups_info(trx_list) - all_groups_len = {} - all_groups_dtype = {} - # Variable-size arrays do not have to exist in all TrxFile - if not delete_groups: - for trx_1 in trx_list: - for group_key in trx_1.groups.keys(): - # Concatenating groups together - if group_key in all_groups_len: - all_groups_len[group_key] += len(trx_1.groups[group_key]) - else: - all_groups_len[group_key] = len(trx_1.groups[group_key]) - if ( - group_key in all_groups_dtype - and trx_1.groups[group_key].dtype != all_groups_dtype[group_key] - ): - raise ValueError("Shared group key, has different dtype.") - else: - all_groups_dtype[group_key] = trx_1.groups[group_key].dtype - - # Once the checks are done, actually concatenate to_concat_list = trx_list[1:] if preallocation else trx_list if not preallocation: - nb_vertices = 0 - nb_streamlines = 0 - for curr_trx in to_concat_list: - curr_strs_len, curr_pts_len = curr_trx._get_real_len() - nb_streamlines += curr_strs_len - nb_vertices += curr_pts_len - - new_trx = TrxFile( - nb_vertices=nb_vertices, nb_streamlines=nb_streamlines, - init_as=ref_trx + new_trx = _create_new_trx_for_concatenation( + to_concat_list, ref_trx, delete_dps, delete_dpv, delete_groups + ) + _setup_groups_for_concatenation( + new_trx, trx_list, all_groups_len, all_groups_dtype, delete_groups ) - if delete_dps: - new_trx.data_per_streamline = {} - if delete_dpv: - new_trx.data_per_vertex = {} - if delete_groups: - new_trx.groups = {} - - tmp_dir = new_trx._uncompressed_folder_handle.name - - # When memory is allocated on the spot, groups and data_per_group can - # be concatenated together - for group_key in all_groups_len.keys(): - if not os.path.isdir(os.path.join(tmp_dir, "groups/")): - os.mkdir(os.path.join(tmp_dir, "groups/")) - dtype = all_groups_dtype[group_key] - group_filename = os.path.join( - tmp_dir, "groups/" "{}.{}".format(group_key, dtype.name) - ) - group_len = all_groups_len[group_key] - new_trx.groups[group_key] = _create_memmap( - group_filename, mode="w+", shape=(group_len,), dtype=dtype - ) - if delete_groups: - continue - pos = 0 - count = 0 - for curr_trx in trx_list: - curr_len = len(curr_trx.groups[group_key]) - new_trx.groups[group_key][pos: pos + curr_len] = \ - curr_trx.groups[group_key] + count - pos += curr_len - count += curr_trx.header["NB_STREAMLINES"] - strs_end, pts_end = 0, 0 else: new_trx = ref_trx strs_end, pts_end = new_trx._get_real_len() for curr_trx in to_concat_list: - # Copy the TrxFile fixed-size info (the right chunk) strs_end, pts_end = new_trx._copy_fixed_arrays_from( curr_trx, strs_start=strs_end, pts_start=pts_end ) @@ -725,7 +765,7 @@ def __getitem__(self, key) -> Any: def __deepcopy__(self) -> Type["TrxFile"]: return self.deepcopy() - def deepcopy(self) -> Type["TrxFile"]: + def deepcopy(self) -> Type["TrxFile"]: # noqa: C901 """Create a deepcopy of the TrxFile Returns @@ -892,9 +932,11 @@ def _copy_fixed_arrays_from( return strs_end, pts_end @staticmethod - def _initialize_empty_trx( - nb_streamlines: int, nb_vertices: int, - init_as: Optional[Type["TrxFile"]] = None) -> Type["TrxFile"]: + def _initialize_empty_trx( # noqa: C901 + nb_streamlines: int, + nb_vertices: int, + init_as: Optional[Type["TrxFile"]] = None, + ) -> Type["TrxFile"]: """Create on-disk memmaps of a certain size (preallocation) Keyword arguments: @@ -1022,7 +1064,7 @@ def _initialize_empty_trx( return trx - def _create_trx_from_pointer( + def _create_trx_from_pointer( # noqa: C901 header: dict, dict_pointer_size: dict, root_zip: Optional[str] = None, @@ -1062,7 +1104,10 @@ def _create_trx_from_pointer( if os.name != 'nt' and folder.startswith(root.rstrip("/")): folder = folder.replace(root, "").lstrip("/") # These three are for Windows - elif os.path.isdir(folder) and os.path.basename(folder) in ['dpv', 'dps', 'groups']: + elif ( + os.path.isdir(folder) + and os.path.basename(folder) in ['dpv', 'dps', 'groups'] + ): folder = os.path.basename(folder) elif os.path.basename(os.path.dirname(folder)) == 'dpg': folder = os.path.join('dpg', os.path.basename(folder)) @@ -1164,7 +1209,7 @@ def _create_trx_from_pointer( trx.data_per_vertex[dpv_key]._lengths = lengths return trx - def resize( + def resize( # noqa: C901 self, nb_streamlines: Optional[int] = None, nb_vertices: Optional[int] = None, @@ -1751,7 +1796,10 @@ def close(self) -> None: try: self._uncompressed_folder_handle.cleanup() except PermissionError: - logging.error("Windows PermissionError, temporary directory {}" - "was not deleted!".format(self._uncompressed_folder_handle.name)) + logging.error( + "Windows PermissionError, temporary directory %s was not " + "deleted!", + self._uncompressed_folder_handle.name, + ) self.__init__() logging.debug("Deleted memmaps and intialized empty TrxFile.") diff --git a/trx/utils.py b/trx/utils.py index b58a2a2..d009f63 100644 --- a/trx/utils.py +++ b/trx/utils.py @@ -66,7 +66,7 @@ def split_name_with_gz(filename): return base, ext -def get_reference_info_wrapper(reference): +def get_reference_info_wrapper(reference): # noqa: C901 """ Will compare the spatial attribute of 2 references. Parameters @@ -407,7 +407,7 @@ def append_generator_to_dict(gen, data): data['strs'].append(gen.tolist()) -def verify_trx_dtype(trx, dict_dtype): +def verify_trx_dtype(trx, dict_dtype): # noqa: C901 """ Verify if the dtype of the data in the trx is the same as the one in the dict. @@ -447,13 +447,19 @@ def verify_trx_dtype(trx, dict_dtype): elif key == 'dpg': for key_group in dict_dtype[key]: for key_dpg in dict_dtype[key][key_group]: - if trx.data_per_point[key_group][key_dpg].dtype != dict_dtype[key][key_group][key_dpg]: + if ( + trx.data_per_point[key_group][key_dpg].dtype + != dict_dtype[key][key_group][key_dpg] + ): logging.warning( 'Data per group ({}) dtype is different'.format(key_dpg)) identical = False elif key == 'groups': for key_group in dict_dtype[key]: - if trx.data_per_point[key_group]._data.dtype != dict_dtype[key][key_group]: + if ( + trx.data_per_point[key_group]._data.dtype + != dict_dtype[key][key_group] + ): logging.warning( 'Data per group ({}) dtype is different'.format(key_group)) identical = False diff --git a/trx/workflows.py b/trx/workflows.py index eaa85c5..a1782a2 100644 --- a/trx/workflows.py +++ b/trx/workflows.py @@ -12,7 +12,7 @@ from nibabel.streamlines.array_sequence import ArraySequence import numpy as np try: - import dipy + import dipy # noqa: F401 dipy_available = True except ImportError: dipy_available = False @@ -76,8 +76,10 @@ def convert_dsi_studio(in_dsi_tractogram, in_dsi_fa, out_tractogram, tmm.save(trx, out_tractogram) -def convert_tractogram(in_tractogram, out_tractogram, reference, - pos_dtype='float32', offsets_dtype='uint32'): +def convert_tractogram( # noqa: C901 + in_tractogram, out_tractogram, reference, + pos_dtype='float32', offsets_dtype='uint32', +): if not dipy_available: logging.error('Dipy library is missing, scripts are not available.') return None @@ -304,140 +306,171 @@ def validate_tractogram(in_tractogram, reference, out_tractogram, save(new_sft, out_tractogram) -def generate_trx_from_scratch(reference, out_tractogram, positions_csv=False, - positions=False, offsets=False, - positions_dtype='float32', offsets_dtype='uint64', - space_str='rasmm', origin_str='nifti', - verify_invalid=True, dpv=[], dps=[], - groups=[], dpg=[]): +def _load_streamlines_from_csv(positions_csv): + """Load streamlines from CSV file.""" + with open(positions_csv, newline='') as f: + reader = csv.reader(f) + data = list(reader) + data = [np.reshape(i, (len(i) // 3, 3)).astype(float) + for i in data] + return ArraySequence(data) + + +def _load_streamlines_from_arrays(positions, offsets): + """Load streamlines from position and offset arrays.""" + positions = load_matrix_in_any_format(positions) + offsets = load_matrix_in_any_format(offsets) + lengths = tmm._compute_lengths(offsets) + streamlines = ArraySequence() + streamlines._data = positions + streamlines._offsets = deepcopy(offsets) + streamlines._lengths = lengths + return streamlines, offsets + + +def _apply_spatial_transforms(streamlines, reference, space_str, origin_str, + verify_invalid, offsets): + """Apply spatial transforms and verify streamlines.""" + if not dipy_available: + logging.error('Dipy library is missing, advanced options ' + 'related to spatial transforms and invalid ' + 'streamlines are not available.') + return None + + from dipy.io.stateful_tractogram import StatefulTractogram + + space, origin = get_reverse_enum(space_str, origin_str) + sft = StatefulTractogram(streamlines, reference, space, origin) + if verify_invalid: + rem, _ = sft.remove_invalid_streamlines() + print('{} streamlines were removed becaused they were ' + 'invalid.'.format(len(rem))) + sft.to_rasmm() + sft.to_center() + streamlines = sft.streamlines + streamlines._offsets = offsets + return streamlines + + +def _write_header(tmp_dir_name, reference, streamlines): + """Write header file.""" + affine, dimensions, _, _ = get_reference_info_wrapper(reference) + header = { + "DIMENSIONS": dimensions.tolist(), + "VOXEL_TO_RASMM": affine.tolist(), + "NB_VERTICES": len(streamlines._data), + "NB_STREAMLINES": len(streamlines)-1, + } + + if header['NB_STREAMLINES'] <= 1: + raise IOError('To use this script, you need at least 2' + 'streamlines.') + + with open(os.path.join(tmp_dir_name, "header.json"), "w") as out_json: + json.dump(header, out_json) + + +def _write_streamline_data(tmp_dir_name, streamlines, positions_dtype, + offsets_dtype): + """Write streamline position and offset data.""" + curr_filename = os.path.join(tmp_dir_name, 'positions.3.{}'.format( + positions_dtype)) + streamlines._data.astype(positions_dtype).tofile(curr_filename) + + curr_filename = os.path.join(tmp_dir_name, 'offsets.{}'.format( + offsets_dtype)) + streamlines._offsets.astype(offsets_dtype).tofile(curr_filename) + + +def _normalize_dtype(dtype_str): + """Normalize dtype string format.""" + return 'bit' if dtype_str == 'bool' else dtype_str + + +def _write_data_array(tmp_dir_name, subdir_name, args, is_dpg=False): + """Write data array to file.""" + if is_dpg: + os.makedirs(os.path.join(tmp_dir_name, 'dpg', args[0]), exist_ok=True) + curr_arr = load_matrix_in_any_format(args[1]).astype(args[2]) + basename = os.path.basename(os.path.splitext(args[1])[0]) + dtype_str = _normalize_dtype(args[1]) if args[1] != 'bool' else 'bit' + dtype = args[2] + else: + os.makedirs(os.path.join(tmp_dir_name, subdir_name), exist_ok=True) + curr_arr = np.squeeze(load_matrix_in_any_format(args[0]).astype(args[1])) + basename = os.path.basename(os.path.splitext(args[0])[0]) + dtype_str = _normalize_dtype(args[1]) + dtype = dtype_str + + if curr_arr.ndim > 2: + raise IOError('Maximum of 2 dimensions for dpv/dps/dpg.') + + if curr_arr.shape == (1, 1): + curr_arr = curr_arr.reshape((1,)) + + dim = '' if curr_arr.ndim == 1 else '{}.'.format(curr_arr.shape[-1]) + + if is_dpg: + curr_filename = os.path.join(tmp_dir_name, 'dpg', args[0], + '{}.{}{}'.format(basename, dim, dtype)) + else: + curr_filename = os.path.join(tmp_dir_name, subdir_name, + '{}.{}{}'.format(basename, dim, dtype)) + + curr_arr.tofile(curr_filename) + + +def generate_trx_from_scratch( # noqa: C901 + reference, out_tractogram, positions_csv=False, + positions=False, offsets=False, + positions_dtype='float32', offsets_dtype='uint64', + space_str='rasmm', origin_str='nifti', + verify_invalid=True, dpv=[], dps=[], + groups=[], dpg=[], +): + """Generate TRX file from scratch using various input formats.""" with get_trx_tmp_dir() as tmp_dir_name: if positions_csv: - with open(positions_csv, newline='') as f: - reader = csv.reader(f) - data = list(reader) - data = [np.reshape(i, (len(i) // 3, 3)).astype(float) - for i in data] - streamlines = ArraySequence(data) + streamlines = _load_streamlines_from_csv(positions_csv) + offsets = None else: - positions = load_matrix_in_any_format(positions) - offsets = load_matrix_in_any_format(offsets) - lengths = tmm._compute_lengths(offsets) - streamlines = ArraySequence() - streamlines._data = positions - streamlines._offsets = deepcopy(offsets) - streamlines._lengths = lengths - - if space_str.lower() != 'rasmm' or origin_str.lower() != 'nifti' or \ - verify_invalid: - if not dipy_available: - logging.error('Dipy library is missing, advanced options ' - 'related to spatial transforms and invalid ' - 'streamlines are not available.') + streamlines, offsets = _load_streamlines_from_arrays(positions, offsets) + + if (space_str.lower() != 'rasmm' or origin_str.lower() != 'nifti' or + verify_invalid): + streamlines = _apply_spatial_transforms( + streamlines, reference, space_str, origin_str, + verify_invalid, offsets + ) + if streamlines is None: return - from dipy.io.stateful_tractogram import StatefulTractogram - - space, origin = get_reverse_enum(space_str, origin_str) - sft = StatefulTractogram(streamlines, reference, space, origin) - if verify_invalid: - rem, _ = sft.remove_invalid_streamlines() - print('{} streamlines were removed becaused they were ' - 'invalid.'.format(len(rem))) - sft.to_rasmm() - sft.to_center() - streamlines = sft.streamlines - streamlines._offsets = offsets - - affine, dimensions, _, _ = get_reference_info_wrapper(reference) - header = { - "DIMENSIONS": dimensions.tolist(), - "VOXEL_TO_RASMM": affine.tolist(), - "NB_VERTICES": len(streamlines._data), - "NB_STREAMLINES": len(streamlines)-1, - } - - if header['NB_STREAMLINES'] <= 1: - raise IOError('To use this script, you need at least 2' - 'streamlines.') - - with open(os.path.join(tmp_dir_name, "header.json"), "w") as out_json: - json.dump(header, out_json) - - curr_filename = os.path.join(tmp_dir_name, 'positions.3.{}'.format( - positions_dtype)) - streamlines._data.astype(positions_dtype).tofile( - curr_filename) - curr_filename = os.path.join(tmp_dir_name, 'offsets.{}'.format( - offsets_dtype)) - streamlines._offsets.astype(offsets_dtype).tofile( - curr_filename) + + _write_header(tmp_dir_name, reference, streamlines) + _write_streamline_data(tmp_dir_name, streamlines, positions_dtype, + offsets_dtype) if dpv: - os.mkdir(os.path.join(tmp_dir_name, 'dpv')) for arg in dpv: - curr_arr = np.squeeze(load_matrix_in_any_format(arg[0]).astype( - arg[1])) - if arg[1] == 'bool': - arg[1] = 'bit' - if curr_arr.ndim > 2: - raise IOError('Maximum of 2 dimensions for dpv/dps/dpg.') - dim = '' if curr_arr.ndim == 1 else '{}.'.format( - curr_arr.shape[-1]) - curr_filename = os.path.join(tmp_dir_name, 'dpv', '{}.{}{}'.format( - os.path.basename(os.path.splitext(arg[0])[0]), dim, arg[1])) - curr_arr.tofile(curr_filename) + _write_data_array(tmp_dir_name, 'dpv', arg) if dps: - os.mkdir(os.path.join(tmp_dir_name, 'dps')) for arg in dps: - curr_arr = np.squeeze(load_matrix_in_any_format(arg[0]).astype( - arg[1])) - if arg[1] == 'bool': - arg[1] = 'bit' - if curr_arr.ndim > 2: - raise IOError('Maximum of 2 dimensions for dpv/dps/dpg.') - dim = '' if curr_arr.ndim == 1 else '{}.'.format( - curr_arr.shape[-1]) - curr_filename = os.path.join(tmp_dir_name, 'dps', '{}.{}{}'.format( - os.path.basename(os.path.splitext(arg[0])[0]), dim, arg[1])) - curr_arr.tofile(curr_filename) + _write_data_array(tmp_dir_name, 'dps', arg) + if groups: - os.mkdir(os.path.join(tmp_dir_name, 'groups')) for arg in groups: - curr_arr = load_matrix_in_any_format(arg[0]).astype(arg[1]) - if arg[1] == 'bool': - arg[1] = 'bit' - if curr_arr.ndim > 2: - raise IOError('Maximum of 2 dimensions for dpv/dps/dpg.') - dim = '' if curr_arr.ndim == 1 else '{}.'.format( - curr_arr.shape[-1]) - curr_filename = os.path.join(tmp_dir_name, 'groups', '{}.{}{}'.format( - os.path.basename(os.path.splitext(arg[0])[0]), dim, arg[1])) - curr_arr.tofile(curr_filename) + _write_data_array(tmp_dir_name, 'groups', arg) if dpg: - os.mkdir(os.path.join(tmp_dir_name, 'dpg')) for arg in dpg: - if not os.path.isdir(os.path.join(tmp_dir_name, 'dpg', arg[0])): - os.mkdir(os.path.join(tmp_dir_name, 'dpg', arg[0])) - curr_arr = load_matrix_in_any_format(arg[1]).astype(arg[2]) - if arg[1] == 'bool': - arg[1] = 'bit' - if curr_arr.ndim > 2: - raise IOError('Maximum of 2 dimensions for dpv/dps/dpg.') - if curr_arr.shape == (1, 1): - curr_arr = curr_arr.reshape((1,)) - dim = '' if curr_arr.ndim == 1 else '{}.'.format( - curr_arr.shape[-1]) - curr_filename = os.path.join(tmp_dir_name, 'dpg', arg[0], '{}.{}{}'.format( - os.path.basename(os.path.splitext(arg[1])[0]), dim, arg[2])) - curr_arr.tofile(curr_filename) + _write_data_array(tmp_dir_name, 'dpg', arg, is_dpg=True) trx = tmm.load(tmp_dir_name) tmm.save(trx, out_tractogram) trx.close() -def manipulate_trx_datatype(in_filename, out_filename, dict_dtype): +def manipulate_trx_datatype(in_filename, out_filename, dict_dtype): # noqa: C901 trx = tmm.load(in_filename) # For each key in dict_dtype, we create a new memmap with the new dtype @@ -476,10 +509,12 @@ def manipulate_trx_datatype(in_filename, out_filename, dict_dtype): elif key == 'dpg': for key_group in dict_dtype[key]: for key_dpg in dict_dtype[key][key_group]: - tmp_mm = np.memmap(tempfile.NamedTemporaryFile(), - dtype=dict_dtype[key][key_group][key_dpg], - mode='w+', - shape=trx.data_per_group[key_group][key_dpg].shape) + tmp_mm = np.memmap( + tempfile.NamedTemporaryFile(), + dtype=dict_dtype[key][key_group][key_dpg], + mode='w+', + shape=trx.data_per_group[key_group][key_dpg].shape, + ) tmp_mm[:] = trx.data_per_group[key_group][key_dpg][:] trx.data_per_group[key_group][key_dpg] = tmp_mm elif key == 'groups': From 78265773ba0895fdb9d48ec5ef60ce07b956cca6 Mon Sep 17 00:00:00 2001 From: Serge Koudoro Date: Thu, 15 Jan 2026 14:15:20 -0500 Subject: [PATCH 6/7] NF: add spin clean to help clean up --- .spin/__init__.py | 1 + .spin/cmds.py | 38 +++++++++++++++++++++++ pyproject.toml | 1 + trx/fetcher.py | 78 ++++++++++++++++++++++++++++++++++++++--------- 4 files changed, 103 insertions(+), 15 deletions(-) create mode 100644 .spin/__init__.py diff --git a/.spin/__init__.py b/.spin/__init__.py new file mode 100644 index 0000000..85155f8 --- /dev/null +++ b/.spin/__init__.py @@ -0,0 +1 @@ +# Spin commands package diff --git a/.spin/cmds.py b/.spin/cmds.py index aab54a3..9892753 100644 --- a/.spin/cmds.py +++ b/.spin/cmds.py @@ -1,6 +1,10 @@ """Custom spin commands for trx-python development.""" +import os import subprocess import sys +import tempfile +import glob +import shutil import click @@ -208,3 +212,37 @@ def docs(clean, open_browser): webbrowser.open(f"file://{index_path}") sys.exit(result) + + +@click.command() +def clean(): # noqa: C901 + """Clean up temporary files and build artifacts.""" + click.echo("Cleaning up temporary files...") + + # Clean TRX temp directory + trx_tmp_dir = os.getenv('TRX_TMPDIR', tempfile.gettempdir()) + if os.path.exists(trx_tmp_dir): + temp_files = glob.glob(os.path.join(trx_tmp_dir, 'trx_*')) + for temp_dir in temp_files: + if os.path.isdir(temp_dir): + click.echo(f"Removing temporary directory: {temp_dir}") + shutil.rmtree(temp_dir) + + # Clean build artifacts + for build_pattern in ['build', 'dist', '*.egg-info']: + for path in glob.glob(build_pattern): + if os.path.isdir(path): + click.echo(f"Removing build directory: {path}") + shutil.rmtree(path) + elif os.path.isfile(path): + click.echo(f"Removing build file: {path}") + os.remove(path) + + # Clean Python cache + for cache_dir in ['**/__pycache__', '**/.pytest_cache']: + for path in glob.glob(cache_dir, recursive=True): + if os.path.isdir(path): + click.echo(f"Removing cache directory: {path}") + shutil.rmtree(path) + + click.echo("Cleanup complete!") diff --git a/pyproject.toml b/pyproject.toml index bf24e02..4bf13ac 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -97,3 +97,4 @@ package = "trx" "Build" = ["spin.cmds.pip.install"] "Test" = [".spin/cmds.py:test", ".spin/cmds.py:lint"] "Docs" = [".spin/cmds.py:docs"] +"Clean" = [".spin/cmds.py:clean"] diff --git a/trx/fetcher.py b/trx/fetcher.py index 469fb8a..7a033a8 100644 --- a/trx/fetcher.py +++ b/trx/fetcher.py @@ -7,6 +7,18 @@ import urllib.request +TEST_DATA_REPO = "tee-ar-ex/trx-test-data" +TEST_DATA_TAG = "v0.1.0" +# GitHub release API entrypoint for metadata (asset list, sizes, etc.). +TEST_DATA_API_URL = ( + f"https://api.github.com/repos/{TEST_DATA_REPO}/releases/tags/{TEST_DATA_TAG}" +) +# Direct download base for release assets. +TEST_DATA_BASE_URL = ( + f"https://github.com/{TEST_DATA_REPO}/releases/download/{TEST_DATA_TAG}" +) + + def get_home(): """ Set a user-writeable file-system location to put files """ if 'TRX_HOME' in os.environ: @@ -17,20 +29,33 @@ def get_home(): def get_testing_files_dict(): - """ Get dictionary linking zip file to their Figshare URL & MD5SUM """ + """ Get dictionary linking zip file to their GitHub release URL & checksums. + + Assets are hosted under the v0.1.0 release of tee-ar-ex/trx-test-data. + If URLs change, check TEST_DATA_API_URL to discover the latest asset + locations. + """ return { - 'DSI.zip': - ('https://figshare.com/ndownloader/files/37624154', - 'b847f053fc694d55d935c0be0e5268f7'), # V1 (27.09.2022) - 'memmap_test_data.zip': - ('https://figshare.com/ndownloader/files/37624148', - '03f7651a0f9e3eeabee9aed0ad5f69e1'), # V2 (27.09.2022) - 'trx_from_scratch.zip': - ('https://figshare.com/ndownloader/files/37624151', - 'd9f220a095ce7f027772fcd9451a2ee5'), # V2 (27.09.2022) - 'gold_standard.zip': - ('https://figshare.com/ndownloader/files/38146098', - '57e3f9951fe77245684ede8688af3ae8') # V1 (8.11.2022) + 'DSI.zip': ( + f'{TEST_DATA_BASE_URL}/DSI.zip', + 'b847f053fc694d55d935c0be0e5268f7', # md5 + '1b09ce8b4b47b2600336c558fdba7051218296e8440e737364f2c4b8ebae666c', + ), + 'memmap_test_data.zip': ( + f'{TEST_DATA_BASE_URL}/memmap_test_data.zip', + '03f7651a0f9e3eeabee9aed0ad5f69e1', # md5 + '98ba89d7a9a7baa2d37956a0a591dce9bb4581bd01296ad5a596706ee90a52ef', + ), + 'trx_from_scratch.zip': ( + f'{TEST_DATA_BASE_URL}/trx_from_scratch.zip', + 'd9f220a095ce7f027772fcd9451a2ee5', # md5 + 'f98ab6da6a6065527fde4b0b6aa40f07583e925d952182e9bbd0febd55c0f6b2', + ), + 'gold_standard.zip': ( + f'{TEST_DATA_BASE_URL}/gold_standard.zip', + '57e3f9951fe77245684ede8688af3ae8', # md5 + '35a0b633560cc2b0d8ecda885aa72d06385499e0cd1ca11a956b0904c3358f01', + ), } @@ -43,7 +68,16 @@ def md5sum(filename): return h.hexdigest() -def fetch_data(files_dict, keys=None): +def sha256sum(filename): + """ Compute one sha256 checksum for a file """ + h = hashlib.sha256() + with open(filename, 'rb') as f: + for chunk in iter(lambda: f.read(128 * h.block_size), b''): + h.update(chunk) + return h.hexdigest() + + +def fetch_data(files_dict, keys=None): # noqa: C901 """ Downloads files to folder and checks their md5 checksums Parameters @@ -71,7 +105,12 @@ def fetch_data(files_dict, keys=None): keys = [keys] for f in keys: - url, expected_md5 = files_dict[f] + file_entry = files_dict[f] + if len(file_entry) == 2: + url, expected_md5 = file_entry + expected_sha = None + else: + url, expected_md5, expected_sha = file_entry full_path = os.path.join(trx_home, f) logging.info('Downloading {} to {}'.format(f, trx_home)) @@ -86,6 +125,15 @@ def fetch_data(files_dict, keys=None): full_path ) + if expected_sha is not None: + actual_sha = sha256sum(full_path) + if expected_sha != actual_sha: + raise ValueError( + f'SHA256 for {f} does not match. ' + 'Please remove the file to download it again: ' + + full_path + ) + if f.endswith('.zip'): dst_dir = os.path.join(trx_home, f[:-4]) shutil.unpack_archive(full_path, From 79966be5ffc450d9c4c5a2484e92ce2c8cd6c52f Mon Sep 17 00:00:00 2001 From: Serge Koudoro Date: Thu, 15 Jan 2026 15:18:13 -0500 Subject: [PATCH 7/7] RF: improve cli management and introduce typer --- .github/workflows/codeformat.yml | 27 + .github/workflows/coverage.yml | 48 ++ .github/workflows/publish-to-test-pypi.yml | 4 +- .github/workflows/test.yml | 5 +- .pre-commit-config.yaml | 19 + .spin/cmds.py | 105 +-- README.md | 173 ++++- docs/source/contributing.rst | 209 ++++++ docs/source/dev.rst | 342 +++++++++ docs/source/index.rst | 19 +- docs/source/scripts.rst | 156 ++++ pyproject.toml | 36 +- ruff.toml | 41 ++ scripts/tests/test_workflows.py | 322 -------- scripts/tff_concatenate_tractograms.py | 74 -- scripts/tff_convert_dsi_studio.py | 63 -- scripts/tff_convert_tractogram.py | 59 -- scripts/tff_generate_trx_from_scratch.py | 140 ---- scripts/tff_manipulate_datatype.py | 103 --- scripts/tff_simple_compare.py | 36 - scripts/tff_validate_trx.py | 64 -- scripts/tff_verify_header_compatibility.py | 33 - scripts/tff_visualize_overlap.py | 39 - tools/update_switcher.py | 61 +- trx/cli.py | 815 +++++++++++++++++++++ trx/fetcher.py | 75 +- trx/io.py | 76 +- trx/streamlines_ops.py | 7 +- trx/tests/test_cli.py | 420 +++++++++++ trx/tests/test_io.py | 162 ++-- trx/tests/test_memmap.py | 146 ++-- trx/tests/test_streamlines_ops.py | 59 +- trx/trx_file_memmap.py | 529 ++++++------- trx/utils.py | 210 +++--- trx/viz.py | 95 ++- trx/workflows.py | 411 ++++++----- 36 files changed, 3307 insertions(+), 1876 deletions(-) create mode 100644 .github/workflows/codeformat.yml create mode 100644 .github/workflows/coverage.yml create mode 100644 .pre-commit-config.yaml create mode 100644 docs/source/contributing.rst create mode 100644 docs/source/dev.rst create mode 100644 docs/source/scripts.rst create mode 100644 ruff.toml delete mode 100644 scripts/tests/test_workflows.py delete mode 100755 scripts/tff_concatenate_tractograms.py delete mode 100644 scripts/tff_convert_dsi_studio.py delete mode 100644 scripts/tff_convert_tractogram.py delete mode 100755 scripts/tff_generate_trx_from_scratch.py delete mode 100644 scripts/tff_manipulate_datatype.py delete mode 100755 scripts/tff_simple_compare.py delete mode 100755 scripts/tff_validate_trx.py delete mode 100644 scripts/tff_verify_header_compatibility.py delete mode 100644 scripts/tff_visualize_overlap.py create mode 100644 trx/cli.py create mode 100644 trx/tests/test_cli.py diff --git a/.github/workflows/codeformat.yml b/.github/workflows/codeformat.yml new file mode 100644 index 0000000..03e3a21 --- /dev/null +++ b/.github/workflows/codeformat.yml @@ -0,0 +1,27 @@ +name: Code Format + +on: + push: + branches: [master] + pull_request: + branches: [master] + +permissions: + contents: read + +jobs: + pre-commit: + name: Pre-commit checks + runs-on: ubuntu-latest + + steps: + - name: Check out repository + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.12" + + - name: Install and run pre-commit hooks + uses: pre-commit/action@v3.0.1 diff --git a/.github/workflows/coverage.yml b/.github/workflows/coverage.yml new file mode 100644 index 0000000..5738183 --- /dev/null +++ b/.github/workflows/coverage.yml @@ -0,0 +1,48 @@ +name: Coverage + +on: + push: + branches: [master] + pull_request: + branches: [master] + +permissions: + contents: read + +jobs: + coverage: + name: Code Coverage + runs-on: ubuntu-latest + + steps: + - name: Check out repository + uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.12" + cache: pip + cache-dependency-path: pyproject.toml + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + python -m pip install -e .[test] + + - name: Run tests with coverage + run: | + pytest trx/tests --cov=trx --cov-report=xml --cov-report=term-missing + + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v4 + with: + files: ./coverage.xml + flags: unittests + name: codecov-trx + fail_ci_if_error: false + verbose: true + env: + CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} diff --git a/.github/workflows/publish-to-test-pypi.yml b/.github/workflows/publish-to-test-pypi.yml index 3b3c454..67918b9 100644 --- a/.github/workflows/publish-to-test-pypi.yml +++ b/.github/workflows/publish-to-test-pypi.yml @@ -16,10 +16,10 @@ jobs: - uses: actions/checkout@v4 with: fetch-depth: 0 # Fetch all history and tags for setuptools_scm - - name: Set up Python 3.9 + - name: Set up Python 3.11 uses: actions/setup-python@v5 with: - python-version: 3.9 + python-version: "3.11" - name: Install pypa/build run: python -m pip install build --user - name: Build a binary wheel and a source tarball diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 8f8df59..e9f231e 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -16,7 +16,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] + python-version: ["3.11", "3.12", "3.13"] os: [ubuntu-latest, windows-latest, macos-latest] steps: - uses: actions/checkout@v4 @@ -35,8 +35,5 @@ jobs: python -m pip install --upgrade pip python -m pip install -e .[dev,test] - - name: Lint - run: spin lint - - name: Test run: spin test diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..898b98a --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,19 @@ +default_language_version: + python: python3 + +repos: + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.8.6 + hooks: + # Run the linter + - id: ruff + args: [ --fix ] + # Run the formatter + - id: ruff-format + - repo: https://github.com/codespell-project/codespell + rev: v2.3.0 + hooks: + - id: codespell + args: [--skip, "pyproject.toml,docs/_build/*,*.egg-info"] + additional_dependencies: + - tomli diff --git a/.spin/cmds.py b/.spin/cmds.py index 9892753..ea7b2c3 100644 --- a/.spin/cmds.py +++ b/.spin/cmds.py @@ -1,26 +1,21 @@ """Custom spin commands for trx-python development.""" + +import glob import os +import shutil import subprocess import sys import tempfile -import glob -import shutil import click - UPSTREAM_URL = "https://github.com/tee-ar-ex/trx-python.git" UPSTREAM_NAME = "upstream" def run(cmd, check=True, capture=True): """Run a shell command.""" - result = subprocess.run( - cmd, - capture_output=capture, - text=True, - check=False - ) + result = subprocess.run(cmd, capture_output=capture, text=True, check=False) if check and result.returncode != 0: if capture: click.echo(f"Error: {result.stderr}", err=True) @@ -81,6 +76,7 @@ def setup(): click.echo("\nVerifying version detection...") try: from setuptools_scm import get_version + version = get_version() click.echo(f"Detected version: {version}") @@ -88,7 +84,7 @@ def setup(): if version.startswith("0.0"): click.echo( "\nWarning: Version starts with 0.0 - tags may not be fetched.", - err=True + err=True, ) sys.exit(1) except ImportError: @@ -101,13 +97,13 @@ def setup(): @click.command() @click.option( - "-m", "--match", "pattern", default=None, - help="Only run tests matching this pattern (passed to pytest -k)" -) -@click.option( - "-v", "--verbose", is_flag=True, default=False, - help="Verbose output" + "-m", + "--match", + "pattern", + default=None, + help="Only run tests matching this pattern (passed to pytest -k)", ) +@click.option("-v", "--verbose", is_flag=True, default=False, help="Verbose output") @click.argument("pytest_args", nargs=-1) def test(pattern, verbose, pytest_args): """Run tests using pytest. @@ -120,7 +116,7 @@ def test(pattern, verbose, pytest_args): spin test -v # Verbose output spin test -- -x --tb=short # Pass args to pytest """ - cmd = ["pytest", "trx/tests", "scripts/tests"] + cmd = ["pytest", "trx/tests"] if pattern: cmd.extend(["-k", pattern]) @@ -137,46 +133,60 @@ def test(pattern, verbose, pytest_args): @click.command() @click.option( - "--fix", is_flag=True, default=False, - help="Currently unused (for future auto-fix support)" + "--fix", is_flag=True, default=False, help="Automatically fix issues where possible" ) def lint(fix): - """Run linting checks using flake8. + """Run linting checks using ruff and codespell. Examples: - spin lint # Run flake8 checks + spin lint # Run ruff and codespell checks + spin lint --fix # Run ruff and auto-fix issues """ - # Strict check for syntax errors - click.echo("Checking for syntax errors...") - cmd_strict = [ - "flake8", ".", "--count", - "--select=E9,F63,F7,F82", - "--show-source", "--statistics" - ] - result = run(cmd_strict, capture=False, check=False) + click.echo("Running ruff linter...") + cmd = ["ruff", "check", "."] + + if fix: + cmd.append("--fix") + + result = run(cmd, capture=False, check=False) + if result != 0: + click.echo("\nLinting issues found!", err=True) + sys.exit(1) + + click.echo("\nRunning ruff formatter check...") + cmd_format = ["ruff", "format", "--check", "."] + result = run(cmd_format, capture=False, check=False) if result != 0: - click.echo("Syntax errors found!", err=True) + click.echo("\nFormatting issues found!", err=True) sys.exit(1) - # Full lint check - click.echo("\nRunning full lint check...") - cmd_full = [ - "flake8", ".", "--count", - "--max-line-length=88", - "--max-complexity=10", - "--statistics" + click.echo("\nRunning codespell...") + cmd_spell = [ + "codespell", + "--skip", + "*.pyc,.git,pyproject.toml,./docs/_build/*,*.egg-info,./build/*,./dist/*,./tmp/*", + "trx", + "docs/source", + ".spin", ] - sys.exit(run(cmd_full, capture=False, check=False)) + result = run(cmd_spell, capture=False, check=False) + if result != 0: + click.echo("\nSpelling issues found!", err=True) + sys.exit(1) + + click.echo("\nAll checks passed!") @click.command() @click.option( - "--clean", is_flag=True, default=False, - help="Clean build directory before building" + "--clean", is_flag=True, default=False, help="Clean build directory before building" ) @click.option( - "--open", "open_browser", is_flag=True, default=False, - help="Open documentation in browser after building" + "--open", + "open_browser", + is_flag=True, + default=False, + help="Open documentation in browser after building", ) def docs(clean, open_browser): """Build documentation using Sphinx. @@ -187,6 +197,7 @@ def docs(clean, open_browser): spin docs --open # Build and open in browser """ import os + docs_dir = "docs" if clean: @@ -194,6 +205,7 @@ def docs(clean, open_browser): build_dir = os.path.join(docs_dir, "_build") if os.path.exists(build_dir): import shutil + shutil.rmtree(build_dir) click.echo("Building documentation...") @@ -209,6 +221,7 @@ def docs(clean, open_browser): if open_browser: import webbrowser + webbrowser.open(f"file://{index_path}") sys.exit(result) @@ -220,16 +233,16 @@ def clean(): # noqa: C901 click.echo("Cleaning up temporary files...") # Clean TRX temp directory - trx_tmp_dir = os.getenv('TRX_TMPDIR', tempfile.gettempdir()) + trx_tmp_dir = os.getenv("TRX_TMPDIR", tempfile.gettempdir()) if os.path.exists(trx_tmp_dir): - temp_files = glob.glob(os.path.join(trx_tmp_dir, 'trx_*')) + temp_files = glob.glob(os.path.join(trx_tmp_dir, "trx_*")) for temp_dir in temp_files: if os.path.isdir(temp_dir): click.echo(f"Removing temporary directory: {temp_dir}") shutil.rmtree(temp_dir) # Clean build artifacts - for build_pattern in ['build', 'dist', '*.egg-info']: + for build_pattern in ["build", "dist", "*.egg-info"]: for path in glob.glob(build_pattern): if os.path.isdir(path): click.echo(f"Removing build directory: {path}") @@ -239,7 +252,7 @@ def clean(): # noqa: C901 os.remove(path) # Clean Python cache - for cache_dir in ['**/__pycache__', '**/.pytest_cache']: + for cache_dir in ["**/__pycache__", "**/.pytest_cache"]: for path in glob.glob(cache_dir, recursive=True): if os.path.isdir(path): click.echo(f"Removing cache directory: {path}") diff --git a/README.md b/README.md index 1f1c46e..f86676a 100644 --- a/README.md +++ b/README.md @@ -1,53 +1,162 @@ # trx-python -This is a Python implementation of the trx file-format for tractography data. +[![Tests](https://github.com/tee-ar-ex/trx-python/actions/workflows/test.yml/badge.svg)](https://github.com/tee-ar-ex/trx-python/actions/workflows/test.yml) +[![Code Format](https://github.com/tee-ar-ex/trx-python/actions/workflows/codeformat.yml/badge.svg)](https://github.com/tee-ar-ex/trx-python/actions/workflows/codeformat.yml) +[![codecov](https://codecov.io/gh/tee-ar-ex/trx-python/branch/master/graph/badge.svg)](https://codecov.io/gh/tee-ar-ex/trx-python) +[![PyPI version](https://badge.fury.io/py/trx-python.svg)](https://badge.fury.io/py/trx-python) -For details, please visit the documentation web-page at https://tee-ar-ex.github.io/trx-python/. +A Python implementation of the TRX file format for tractography data. -To install this, you can run: +For details, please visit the [documentation](https://tee-ar-ex.github.io/trx-python/). - pip install trx-python +## Installation -Or, to install from source: +### From PyPI - git clone https://github.com/tee-ar-ex/trx-python.git - cd trx-python - pip install . +```bash +pip install trx-python +``` -### Development +### From Source -For contributors, we use [spin](https://github.com/scientific-python/spin) to manage the development workflow. This ensures proper version detection when working with forks. +```bash +git clone https://github.com/tee-ar-ex/trx-python.git +cd trx-python +pip install . +``` -**First-time setup (required for forks):** +## Quick Start - git clone https://github.com/YOUR_USERNAME/trx-python.git - cd trx-python - pip install -e ".[dev]" - spin setup +### Loading and Saving Tractograms -The `spin setup` command configures your fork by fetching version tags from the upstream repository. This is required for correct version detection with `setuptools_scm`. +```python +from trx.io import load, save -**Common development commands:** +# Load a tractogram (supports .trx, .trk, .tck, .vtk, .fib, .dpy) +trx = load("tractogram.trx") - spin setup # Set up development environment (fetch upstream tags) - spin install # Install package in development/editable mode - spin test # Run all tests - spin test -m memmap # Run tests matching 'memmap' - spin lint # Run linting checks - spin docs # Build documentation +# Save to a different format +save(trx, "output.trk") +``` + +### Command-Line Interface + +TRX-Python provides a unified CLI (`tff`) for common operations: + +```bash +# Show all available commands +tff --help + +# Convert between formats +tff convert input.trk output.trx + +# Concatenate tractograms +tff concatenate tract1.trx tract2.trx merged.trx + +# Validate a TRX file +tff validate data.trx +``` + +Individual commands are also available for backward compatibility: + +```bash +tff_convert_tractogram input.trk output.trx +tff_concatenate_tractograms tract1.trx tract2.trx merged.trx +tff_validate_trx data.trx +``` + +## Development + +We use [spin](https://github.com/scientific-python/spin) for development workflow. + +### First-Time Setup + +```bash +# Clone the repository (or your fork) +git clone https://github.com/tee-ar-ex/trx-python.git +cd trx-python + +# Install with all dependencies +pip install -e ".[all]" + +# Set up development environment (fetches upstream tags) +spin setup +``` + +### Common Commands + +```bash +spin setup # Set up development environment +spin install # Install in editable mode +spin test # Run all tests +spin test -m memmap # Run tests matching pattern +spin lint # Run linting (ruff) +spin lint --fix # Auto-fix linting issues +spin docs # Build documentation +spin clean # Clean temporary files +``` Run `spin` without arguments to see all available commands. -### Temporary Directory -The TRX file format uses memmaps to limit RAM usage. When dealing with large files this means several gigabytes could be required on disk (instead of RAM). +### Code Quality + +We use [ruff](https://docs.astral.sh/ruff/) for linting and formatting: + +```bash +# Check for issues +spin lint + +# Auto-fix issues +spin lint --fix + +# Format code +ruff format . +``` + +### Pre-commit Hooks + +```bash +# Install hooks +pre-commit install + +# Run on all files +pre-commit run --all-files +``` + +## Temporary Directory + +The TRX file format uses memory-mapped files to limit RAM usage. When dealing with large files, several gigabytes may be required on disk. + +By default, temporary files are stored in: +- Linux/macOS: `/tmp` +- Windows: `C:\WINDOWS\Temp` + +To change the directory: + +```bash +# Use a specific directory (must exist) +export TRX_TMPDIR=/path/to/tmp + +# Use current working directory +export TRX_TMPDIR=use_working_dir +``` + +Temporary folders are automatically cleaned, but if the code crashes unexpectedly, ensure folders are deleted manually. + +## Documentation + +Full documentation is available at https://tee-ar-ex.github.io/trx-python/ + +To build locally: + +```bash +spin docs --open +``` -By default, the temporary directory on Linux and MacOS is `/tmp` and on Windows it should be `C:\WINDOWS\Temp`. +## Contributing -If you wish to change the directory add the following variable to your script or to your .bashrc or .bash_profile: -`export TRX_TMPDIR=/WHERE/I/WANT/MY/TMP/DATA` (a) -OR -`export TRX_TMPDIR=use_working_dir` (b) +We welcome contributions! Please see our [Contributing Guide](https://tee-ar-ex.github.io/trx-python/contributing.html) for details. -The provided folder must already exists (a). `use_working_dir` will be the directory where the code is being executed from (b). +## License -The temporary folders should be automatically cleaned. But, if the code crash unexpectedly, make sure the folders are deleted. +BSD License - see [LICENSE](LICENSE) for details. diff --git a/docs/source/contributing.rst b/docs/source/contributing.rst new file mode 100644 index 0000000..0ab192a --- /dev/null +++ b/docs/source/contributing.rst @@ -0,0 +1,209 @@ +Contributing to TRX-Python +========================== + +We welcome contributions from the community! This guide will help you get started +with contributing to the TRX-Python project. + +Ways to Contribute +------------------ + +There are many ways to contribute to TRX-Python: + +- **Report bugs**: If you find a bug, please open an issue on GitHub +- **Suggest features**: Have an idea? Open an issue to discuss it +- **Fix bugs**: Look for issues labeled "good first issue" or "help wanted" +- **Write documentation**: Help improve our docs or add examples +- **Write tests**: Increase test coverage +- **Code review**: Review pull requests from other contributors + +Getting Started +--------------- + +1. **Fork the repository** on GitHub +2. **Clone your fork**: + + .. code-block:: bash + + git clone https://github.com/YOUR_USERNAME/trx-python.git + cd trx-python + +3. **Set up development environment**: + + .. code-block:: bash + + pip install -e ".[all]" + spin setup + + The ``spin setup`` command fetches version tags from upstream, which is + required for correct version detection. + +4. **Create a branch** for your changes: + + .. code-block:: bash + + git checkout -b my-feature-branch + +Making Changes +-------------- + +Development Workflow +~~~~~~~~~~~~~~~~~~~~ + +We use `spin `_ for development workflow: + +.. code-block:: bash + + spin install # Install in editable mode + spin test # Run all tests + spin lint # Run linting (ruff) + spin docs # Build documentation + +Before Submitting +~~~~~~~~~~~~~~~~~ + +1. **Run tests** to ensure your changes don't break existing functionality: + + .. code-block:: bash + + spin test + +2. **Run linting** to ensure code style compliance: + + .. code-block:: bash + + spin lint + + You can auto-fix many issues with: + + .. code-block:: bash + + spin lint --fix + +3. **Format your code** using ruff: + + .. code-block:: bash + + ruff format . + +4. **Write tests** for any new functionality + +5. **Update documentation** if needed + +Submitting a Pull Request +------------------------- + +1. **Push your changes** to your fork: + + .. code-block:: bash + + git push origin my-feature-branch + +2. **Open a Pull Request** on GitHub against the ``master`` branch + +3. **Describe your changes** in the PR description: + + - What does this PR do? + - Why is this change needed? + - How was it tested? + +4. **Wait for CI checks** to pass + +5. **Address review feedback** if requested + +Code Style +---------- + +We follow these conventions: + +- **PEP 8** style guide +- **Line length**: 88 characters maximum +- **Docstrings**: NumPy style format +- **Type hints**: Encouraged but not required + +Example docstring: + +.. code-block:: python + + def my_function(param1, param2): + """Short description of the function. + + Parameters + ---------- + param1 : int + Description of param1. + param2 : str + Description of param2. + + Returns + ------- + result : bool + Description of return value. + + Examples + -------- + >>> my_function(1, "test") + True + """ + pass + +We use `ruff `_ for linting and formatting. +Configuration is in ``ruff.toml``. + +Testing +------- + +Tests are located in ``trx/tests/``. We use pytest for testing. + +Running Tests +~~~~~~~~~~~~~ + +.. code-block:: bash + + # Run all tests + spin test + + # Run tests matching a pattern + spin test -m memmap + + # Run with verbose output + spin test -v + + # Run a specific test file + pytest trx/tests/test_memmap.py + +Writing Tests +~~~~~~~~~~~~~ + +- Place tests in ``trx/tests/`` +- Name test files ``test_*.py`` +- Name test functions ``test_*`` +- Use pytest fixtures for common setup + +Documentation +------------- + +Documentation is built with Sphinx and hosted on GitHub Pages. + +Building Docs +~~~~~~~~~~~~~ + +.. code-block:: bash + + spin docs # Build documentation + spin docs --clean # Clean build + spin docs --open # Build and open in browser + +Writing Documentation +~~~~~~~~~~~~~~~~~~~~~ + +- Documentation source is in ``docs/source/`` +- Use reStructuredText format +- API documentation is auto-generated from docstrings + +Getting Help +------------ + +- **GitHub Issues**: For bugs and feature requests +- **GitHub Discussions**: For questions and discussions + +Thank you for contributing to TRX-Python! diff --git a/docs/source/dev.rst b/docs/source/dev.rst new file mode 100644 index 0000000..5aacdfd --- /dev/null +++ b/docs/source/dev.rst @@ -0,0 +1,342 @@ +Developer Guide +=============== + +This guide provides detailed information for developers working on TRX-Python. + +Installation for Development +---------------------------- + +Prerequisites +~~~~~~~~~~~~~ + +- Python 3.11 or later (Python 3.12+ recommended) +- Git +- pip + +Setting Up Your Environment +~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +1. **Clone the repository**: + + .. code-block:: bash + + # If you're a contributor, fork first then clone your fork + git clone https://github.com/YOUR_USERNAME/trx-python.git + cd trx-python + +2. **Install with all development dependencies**: + + .. code-block:: bash + + pip install -e ".[all]" + + This installs: + + - Core dependencies (numpy, nibabel, deepdiff, typer) + - Development tools (spin, setuptools_scm) + - Documentation tools (sphinx, numpydoc) + - Style tools (ruff, pre-commit) + - Testing tools (pytest, pytest-cov) + +3. **Set up the development environment**: + + .. code-block:: bash + + spin setup + + This command: + + - Adds upstream remote if missing + - Fetches version tags for correct ``setuptools_scm`` version detection + +Using Spin +---------- + +We use `spin `_ for development workflow. +Spin provides a consistent interface for common development tasks. + +Available Commands +~~~~~~~~~~~~~~~~~~ + +Run ``spin`` without arguments to see all available commands: + +.. code-block:: bash + + spin + +**Setup Commands:** + +.. code-block:: bash + + spin setup # Configure development environment + +**Build Commands:** + +.. code-block:: bash + + spin install # Install package in editable mode + +**Test Commands:** + +.. code-block:: bash + + spin test # Run all tests + spin test -m NAME # Run tests matching pattern + spin test -v # Verbose output + spin lint # Run ruff linting + spin lint --fix # Auto-fix linting issues + +**Documentation Commands:** + +.. code-block:: bash + + spin docs # Build documentation + spin docs --clean # Clean and rebuild + spin docs --open # Build and open in browser + +**Cleanup Commands:** + +.. code-block:: bash + + spin clean # Remove temporary files and build artifacts + +Code Quality +------------ + +Linting with Ruff +~~~~~~~~~~~~~~~~~ + +We use `ruff `_ for linting and formatting. +Configuration is in ``ruff.toml``. + +.. code-block:: bash + + # Check for issues + spin lint + + # Auto-fix issues + spin lint --fix + + # Format code + ruff format . + + # Check formatting without changes + ruff format --check . + +Pre-commit Hooks +~~~~~~~~~~~~~~~~ + +We recommend using pre-commit hooks to catch issues before committing: + +.. code-block:: bash + + # Install pre-commit hooks + pre-commit install + + # Run hooks manually on all files + pre-commit run --all-files + +The hooks run: + +- ``ruff`` - Linting with auto-fix +- ``ruff-format`` - Code formatting +- ``codespell`` - Spell checking + +Testing +------- + +Running Tests +~~~~~~~~~~~~~ + +.. code-block:: bash + + # Run all tests + spin test + + # Run tests matching a pattern + spin test -m memmap + + # Run with pytest directly + pytest trx/tests + + # Run with coverage + pytest trx/tests --cov=trx --cov-report=term-missing + +Test Data +~~~~~~~~~ + +Test data is automatically downloaded from Figshare on first run. +Data is cached in ``~/.tee_ar_ex/``. + +You can manually fetch test data: + +.. code-block:: python + + from trx.fetcher import fetch_data, get_testing_files_dict + fetch_data(get_testing_files_dict()) + +Writing Tests +~~~~~~~~~~~~~ + +- Tests go in ``trx/tests/`` +- Use pytest fixtures for setup/teardown +- Use ``pytest.mark.skipif`` for conditional tests + +Example: + +.. code-block:: python + + import pytest + import numpy as np + from numpy.testing import assert_array_equal + + def test_my_function(): + result = my_function(input_data) + expected = np.array([1, 2, 3]) + assert_array_equal(result, expected) + + @pytest.mark.skipif(not dipy_available, reason="Dipy required") + def test_with_dipy(): + # Test that requires dipy + pass + +Documentation +------------- + +Building Documentation +~~~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: bash + + # Build docs + spin docs + + # Clean build + spin docs --clean + + # Build and open in browser + spin docs --open + +Documentation is built with Sphinx and uses: + +- ``pydata-sphinx-theme`` for styling +- ``sphinx-autoapi`` for API documentation +- ``numpydoc`` for NumPy-style docstrings + +Writing Documentation +~~~~~~~~~~~~~~~~~~~~~ + +- Source files are in ``docs/source/`` +- Use reStructuredText format +- API docs are auto-generated from docstrings + +NumPy Docstring Format +~~~~~~~~~~~~~~~~~~~~~~ + +All functions and classes should be documented using NumPy-style docstrings: + +.. code-block:: python + + def load(filename, reference=None): + """Load a tractogram file. + + Parameters + ---------- + filename : str + Path to the tractogram file. + reference : str, optional + Path to reference anatomy for formats that require it. + + Returns + ------- + tractogram : TrxFile or StatefulTractogram + The loaded tractogram. + + Raises + ------ + ValueError + If the file format is not supported. + + See Also + -------- + save : Save a tractogram to file. + + Examples + -------- + >>> from trx.io import load + >>> trx = load("tractogram.trx") + """ + pass + +Project Structure +----------------- + +.. code-block:: text + + trx-python/ + ├── trx/ # Main package + │ ├── __init__.py + │ ├── cli.py # Command-line interface (Typer) + │ ├── fetcher.py # Test data fetching + │ ├── io.py # Unified I/O interface + │ ├── streamlines_ops.py # Streamline operations + │ ├── trx_file_memmap.py # Core TrxFile class + │ ├── utils.py # Utility functions + │ ├── viz.py # Visualization (optional) + │ ├── workflows.py # High-level workflows + │ └── tests/ # Test suite + ├── docs/ # Documentation + │ └── source/ + ├── .github/ # GitHub Actions workflows + │ └── workflows/ + ├── .spin/ # Spin configuration + │ └── cmds.py + ├── pyproject.toml # Project configuration + ├── ruff.toml # Ruff configuration + └── .pre-commit-config.yaml # Pre-commit hooks + +Continuous Integration +---------------------- + +GitHub Actions runs on every push and pull request: + +- **test.yml**: Runs tests on Python 3.11-3.13 across Linux, macOS, Windows +- **codeformat.yml**: Checks code formatting with pre-commit/ruff +- **coverage.yml**: Generates code coverage reports +- **docbuild.yml**: Builds and deploys documentation + +Environment Variables +--------------------- + +TRX_TMPDIR +~~~~~~~~~~ + +Controls where temporary files are stored during memory-mapped operations. + +.. code-block:: bash + + # Use a specific directory + export TRX_TMPDIR=/path/to/tmp + + # Use current working directory + export TRX_TMPDIR=use_working_dir + +Default: System temp directory (``/tmp`` on Linux/macOS, ``C:\WINDOWS\Temp`` on Windows) + +Release Process +--------------- + +Releases are managed via GitHub: + +1. Update version in ``pyproject.toml`` if needed +2. Create a GitHub release with appropriate tag +3. CI automatically publishes to PyPI + +Version Detection +~~~~~~~~~~~~~~~~~ + +We use ``setuptools_scm`` for automatic version detection from git tags. +This requires: + +- Proper git tags from upstream +- Running ``spin setup`` after cloning a fork diff --git a/docs/source/index.rst b/docs/source/index.rst index 84be5d0..55ecf8e 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -42,8 +42,19 @@ Development of TRX is supported by `NIMH grant 1R01MH126699 --help # Show help for a specific command + +Available subcommands: + +- ``tff concatenate`` - Concatenate multiple tractograms +- ``tff convert`` - Convert between tractography formats +- ``tff convert-dsi`` - Fix DSI-Studio TRK files +- ``tff generate`` - Generate TRX from raw data files +- ``tff manipulate-dtype`` - Change array data types +- ``tff compare`` - Simple tractogram comparison +- ``tff validate`` - Validate and clean TRX files +- ``tff verify-header`` - Check header compatibility +- ``tff visualize`` - Visualize tractogram overlap + +Standalone Commands +------------------- + +For backward compatibility, standalone commands are also available: + +tff_concatenate_tractograms +~~~~~~~~~~~~~~~~~~~~~~~~~~~ +Concatenate multiple tractograms into a single output. + +- Supports ``trk``, ``tck``, ``vtk``, ``fib``, ``dpy``, and ``trx`` inputs. +- Flags: ``--delete-dpv``, ``--delete-dps``, ``--delete-groups`` to drop mismatched metadata; ``--reference`` for formats needing an anatomy reference; ``-f`` to overwrite. + +.. code-block:: bash + + # Using unified CLI + tff concatenate in1.trk in2.trk merged.trx + + # Using standalone command + tff_concatenate_tractograms in1.trk in2.trk merged.trx + +tff_convert_dsi_studio +~~~~~~~~~~~~~~~~~~~~~~ +Convert a DSI Studio ``.trk`` with accompanying ``.nii.gz`` reference into a cleaned ``.trk`` or TRX. + +.. code-block:: bash + + # Using unified CLI + tff convert-dsi input.trk reference.nii.gz cleaned.trk + + # Using standalone command + tff_convert_dsi_studio input.trk reference.nii.gz cleaned.trk + +tff_convert_tractogram +~~~~~~~~~~~~~~~~~~~~~~ +General-purpose converter between ``trk``, ``tck``, ``vtk``, ``fib``, ``dpy``, and ``trx``. + +- Flags: ``--reference`` for formats needing a NIfTI, ``--positions-dtype``, ``--offsets-dtype``, ``-f`` to overwrite. + +.. code-block:: bash + + # Using unified CLI + tff convert input.trk output.trx --positions-dtype float32 --offsets-dtype uint64 + + # Using standalone command + tff_convert_tractogram input.trk output.trx --positions-dtype float32 --offsets-dtype uint64 + +tff_generate_trx_from_scratch +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +Build a TRX file from raw NumPy arrays or CSV streamline coordinates. + +- Flags: ``--positions``, ``--offsets``, ``--positions-dtype``, ``--offsets-dtype``, spatial options (``--space``, ``--origin``), and metadata loaders for dpv/dps/groups/dpg. + +.. code-block:: bash + + # Using unified CLI + tff generate fa.nii.gz output.trx --positions positions.npy --offsets offsets.npy + + # Using standalone command + tff_generate_trx_from_scratch fa.nii.gz output.trx --positions positions.npy --offsets offsets.npy + +tff_manipulate_datatype +~~~~~~~~~~~~~~~~~~~~~~~ +Rewrite TRX datasets with new dtypes for positions/offsets/dpv/dps/dpg/groups. + +- Accepts per-field dtype arguments and overwrites with ``-f``. + +.. code-block:: bash + + # Using unified CLI + tff manipulate-dtype input.trx output.trx --positions-dtype float16 --dpv color,uint8 + + # Using standalone command + tff_manipulate_datatype input.trx output.trx --positions-dtype float16 --dpv color,uint8 + +tff_simple_compare +~~~~~~~~~~~~~~~~~~ +Compare two tractograms for quick difference checks. + +.. code-block:: bash + + # Using unified CLI + tff compare first.trk second.trk + + # Using standalone command + tff_simple_compare first.trk second.trk + +tff_validate_trx +~~~~~~~~~~~~~~~~ +Validate a TRX file for consistency and remove invalid streamlines. + +.. code-block:: bash + + # Using unified CLI + tff validate data.trx --out cleaned.trx + + # Using standalone command + tff_validate_trx data.trx --out cleaned.trx + +tff_verify_header_compatibility +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +Check whether tractogram headers are compatible for operations such as concatenation. + +.. code-block:: bash + + # Using unified CLI + tff verify-header file1.trk file2.trk + + # Using standalone command + tff_verify_header_compatibility file1.trk file2.trk + +tff_visualize_overlap +~~~~~~~~~~~~~~~~~~~~~ +Visualize streamline overlap between tractograms (requires visualization dependencies). + +.. code-block:: bash + + # Using unified CLI + tff visualize tractogram.trk reference.nii.gz + + # Using standalone command + tff_visualize_overlap tractogram.trk reference.nii.gz + +Notes +----- +- Test datasets for examples can be fetched with ``python -m trx.fetcher`` helpers: ``fetch_data(get_testing_files_dict())`` downloads to ``$TRX_HOME`` (default ``~/.tee_ar_ex``). +- All commands print detailed usage with ``--help``. +- The unified ``tff`` CLI uses `Typer `_ for beautiful terminal output with colors and rich formatting. diff --git a/pyproject.toml b/pyproject.toml index 4bf13ac..46c5a1c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,7 @@ dynamic = ["version"] description = "Experiments with new file format for tractography" readme = "README.md" license = {text = "BSD License"} -requires-python = ">=3.9" +requires-python = ">=3.11" authors = [ {name = "The TRX developers"} ] @@ -20,8 +20,6 @@ classifiers = [ "Operating System :: OS Independent", "Programming Language :: Python", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.9", - "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", "Programming Language :: Python :: 3.13", @@ -31,6 +29,7 @@ dependencies = [ "deepdiff", "nibabel >= 5", "numpy >= 1.22", + "typer >= 0.9", ] [project.optional-dependencies] @@ -45,15 +44,21 @@ doc = [ "sphinx-autoapi >= 3.0.0", "numpydoc", ] +style = [ + "codespell", + "pre-commit", + "ruff", +] test = [ - "flake8", "psutil", "pytest >= 7", "pytest-console-scripts >= 0", + "pytest-cov", ] all = [ "trx-python[dev]", "trx-python[doc]", + "trx-python[style]", "trx-python[test]", ] @@ -62,20 +67,21 @@ Homepage = "https://github.com/tee-ar-ex/trx-python" Documentation = "https://tee-ar-ex.github.io/trx-python/" Repository = "https://github.com/tee-ar-ex/trx-python" +[project.scripts] +tff = "trx.cli:main" +tff_concatenate_tractograms = "trx.cli:concatenate_tractograms_cmd" +tff_convert_dsi_studio = "trx.cli:convert_dsi_cmd" +tff_convert_tractogram = "trx.cli:convert_cmd" +tff_generate_trx_from_scratch = "trx.cli:generate_cmd" +tff_manipulate_datatype = "trx.cli:manipulate_dtype_cmd" +tff_simple_compare = "trx.cli:compare_cmd" +tff_validate_trx = "trx.cli:validate_cmd" +tff_verify_header_compatibility = "trx.cli:verify_header_cmd" +tff_visualize_overlap = "trx.cli:visualize_cmd" + [tool.setuptools] packages = ["trx"] include-package-data = true -script-files = [ - "scripts/tff_concatenate_tractograms.py", - "scripts/tff_convert_dsi_studio.py", - "scripts/tff_convert_tractogram.py", - "scripts/tff_generate_trx_from_scratch.py", - "scripts/tff_manipulate_datatype.py", - "scripts/tff_simple_compare.py", - "scripts/tff_validate_trx.py", - "scripts/tff_verify_header_compatibility.py", - "scripts/tff_visualize_overlap.py", -] [tool.setuptools.dynamic] version = {attr = "trx._version.__version__"} diff --git a/ruff.toml b/ruff.toml new file mode 100644 index 0000000..0cfb2a9 --- /dev/null +++ b/ruff.toml @@ -0,0 +1,41 @@ +target-version = "py312" + +line-length = 88 +force-exclude = true +extend-exclude = [ + "__pycache__", + "build", + "_version.py", + "docs/**", +] + +[lint] +select = [ + "F", # Pyflakes + "E", # pycodestyle errors + "W", # pycodestyle warnings + "C", # mccabe complexity + "B", # flake8-bugbear + "I", # isort +] +ignore = [ + "B905", # zip without explicit strict parameter + "C901", # too complex + "E203", # whitespace before ':' +] + +[lint.extend-per-file-ignores] +"trx/tests/**" = ["B011"] + +[lint.isort] +case-sensitive = true +combine-as-imports = true +force-sort-within-sections = true +known-first-party = ["trx"] +no-sections = false +order-by-type = true +relative-imports-order = "closest-to-furthest" +section-order = ["future", "standard-library", "third-party", "first-party", "local-folder"] + +[format] +quote-style = "double" diff --git a/scripts/tests/test_workflows.py b/scripts/tests/test_workflows.py deleted file mode 100644 index 3ae5e95..0000000 --- a/scripts/tests/test_workflows.py +++ /dev/null @@ -1,322 +0,0 @@ -#! /usr/bin/env python3 -# -*- coding: utf-8 -*- - -from deepdiff import DeepDiff -import os -import tempfile - -import pytest -import numpy as np -from numpy.testing import assert_equal, assert_array_equal, assert_allclose -try: - from dipy.io.streamline import load_tractogram - dipy_available = True -except ImportError: - dipy_available = False - -from trx.fetcher import (get_testing_files_dict, - fetch_data, get_home) -import trx.trx_file_memmap as tmm -from trx.workflows import (convert_dsi_studio, - convert_tractogram, - manipulate_trx_datatype, - generate_trx_from_scratch, - validate_tractogram,) - - -# If they already exist, this only takes 5 seconds (check md5sum) -fetch_data(get_testing_files_dict(), keys=['DSI.zip', 'trx_from_scratch.zip']) - - -def test_help_option_convert_dsi(script_runner): - ret = script_runner.run('tff_convert_dsi_studio.py', '--help') - assert ret.success - - -def test_help_option_convert(script_runner): - ret = script_runner.run('tff_convert_tractogram.py', '--help') - assert ret.success - - -def test_help_option_generate_trx_from_scratch(script_runner): - ret = script_runner.run('tff_generate_trx_from_scratch.py', '--help') - assert ret.success - - -@pytest.mark.skipif(not dipy_available, - reason='Dipy is not installed.') -def test_execution_convert_dsi(): - with tempfile.TemporaryDirectory() as tmp_dir: - in_trk = os.path.join(get_home(), 'DSI', - 'CC.trk.gz') - in_nii = os.path.join(get_home(), 'DSI', - 'CC.nii.gz') - exp_data = os.path.join(get_home(), 'DSI', - 'CC_fix_data.npy') - exp_offsets = os.path.join(get_home(), 'DSI', - 'CC_fix_offsets.npy') - out_fix_path = os.path.join(tmp_dir, 'fixed.trk') - convert_dsi_studio(in_trk, in_nii, out_fix_path, - remove_invalid=False, - keep_invalid=True) - - data_fix = np.load(exp_data) - offsets_fix = np.load(exp_offsets) - - sft = load_tractogram(out_fix_path, 'same') - assert_equal(sft.streamlines._data, data_fix) - assert_equal(sft.streamlines._offsets, offsets_fix) - - -@pytest.mark.skipif(not dipy_available, - reason='Dipy is not installed.') -def test_execution_convert_to_trx(): - with tempfile.TemporaryDirectory() as tmp_dir: - in_trk = os.path.join(get_home(), 'DSI', - 'CC_fix.trk') - exp_data = os.path.join(get_home(), 'DSI', - 'CC_fix_data.npy') - exp_offsets = os.path.join(get_home(), 'DSI', - 'CC_fix_offsets.npy') - out_trx_path = os.path.join(tmp_dir, 'CC_fix.trx') - convert_tractogram(in_trk, out_trx_path, None) - - data_fix = np.load(exp_data) - offsets_fix = np.load(exp_offsets) - - trx = tmm.load(out_trx_path) - assert_equal(trx.streamlines._data.dtype, np.float32) - assert_equal(trx.streamlines._offsets.dtype, np.uint32) - assert_array_equal(trx.streamlines._data, data_fix) - assert_array_equal(trx.streamlines._offsets, offsets_fix) - trx.close() - - -@pytest.mark.skipif(not dipy_available, - reason='Dipy is not installed.') -def test_execution_convert_from_trx(): - with tempfile.TemporaryDirectory() as tmp_dir: - in_trk = os.path.join(get_home(), 'DSI', - 'CC_fix.trk') - in_nii = os.path.join(get_home(), 'DSI', - 'CC.nii.gz') - exp_data = os.path.join(get_home(), 'DSI', - 'CC_fix_data.npy') - exp_offsets = os.path.join(get_home(), 'DSI', - 'CC_fix_offsets.npy') - - # Sequential conversions - out_trx_path = os.path.join(tmp_dir, 'CC_fix.trx') - out_trk_path = os.path.join(tmp_dir, 'CC_fix.trk') - out_tck_path = os.path.join(tmp_dir, 'CC_fix.tck') - convert_tractogram(in_trk, out_trx_path, None) - convert_tractogram(out_trx_path, out_tck_path, None) - convert_tractogram(out_trx_path, out_trk_path, None) - - data_fix = np.load(exp_data) - offsets_fix = np.load(exp_offsets) - - sft = load_tractogram(out_trk_path, 'same') - assert_equal(sft.streamlines._data, data_fix) - assert_equal(sft.streamlines._offsets, offsets_fix) - - sft = load_tractogram(out_tck_path, in_nii) - assert_equal(sft.streamlines._data, data_fix) - assert_equal(sft.streamlines._offsets, offsets_fix) - - -@pytest.mark.skipif(not dipy_available, - reason='Dipy is not installed.') -def test_execution_convert_dtype_p16_o64(): - with tempfile.TemporaryDirectory() as tmp_dir: - in_trk = os.path.join(get_home(), 'DSI', - 'CC_fix.trk') - out_convert_path = os.path.join(tmp_dir, 'CC_fix_p16_o64.trx') - convert_tractogram(in_trk, out_convert_path, None, - pos_dtype='float16', offsets_dtype='uint64') - - trx = tmm.load(out_convert_path) - assert_equal(trx.streamlines._data.dtype, np.float16) - assert_equal(trx.streamlines._offsets.dtype, np.uint64) - trx.close() - - -@pytest.mark.skipif(not dipy_available, - reason='Dipy is not installed.') -def test_execution_convert_dtype_p64_o32(): - with tempfile.TemporaryDirectory() as tmp_dir: - in_trk = os.path.join(get_home(), 'DSI', - 'CC_fix.trk') - out_convert_path = os.path.join(tmp_dir, 'CC_fix_p16_o64.trx') - convert_tractogram(in_trk, out_convert_path, None, - pos_dtype='float64', offsets_dtype='uint32') - - trx = tmm.load(out_convert_path) - assert_equal(trx.streamlines._data.dtype, np.float64) - assert_equal(trx.streamlines._offsets.dtype, np.uint32) - trx.close() - - -def test_execution_generate_trx_from_scratch(): - with tempfile.TemporaryDirectory() as tmp_dir: - reference_fa = os.path.join(get_home(), 'trx_from_scratch', - 'fa.nii.gz') - raw_arr_dir = os.path.join(get_home(), 'trx_from_scratch', - 'test_npy') - expected_trx = os.path.join(get_home(), 'trx_from_scratch', - 'expected.trx') - - dpv = [(os.path.join(raw_arr_dir, 'dpv_cx.npy'), 'uint8'), - (os.path.join(raw_arr_dir, 'dpv_cy.npy'), 'uint8'), - (os.path.join(raw_arr_dir, 'dpv_cz.npy'), 'uint8')] - dps = [(os.path.join(raw_arr_dir, 'dps_algo.npy'), 'uint8'), - (os.path.join(raw_arr_dir, 'dps_cw.npy'), 'float64')] - dpg = [('g_AF_L', os.path.join(raw_arr_dir, 'dpg_AF_L_mean_fa.npy'), 'float32'), - ('g_AF_R', os.path.join(raw_arr_dir, 'dpg_AF_R_mean_fa.npy'), 'float32'), - ('g_AF_L', os.path.join(raw_arr_dir, 'dpg_AF_L_volume.npy'), 'float32')] - groups = [(os.path.join(raw_arr_dir, 'g_AF_L.npy'), 'int32'), - (os.path.join(raw_arr_dir, 'g_AF_R.npy'), 'int32'), - (os.path.join(raw_arr_dir, 'g_CST_L.npy'), 'int32')] - - out_gen_path = os.path.join(tmp_dir, 'generated.trx') - generate_trx_from_scratch(reference_fa, out_gen_path, - positions=os.path.join(raw_arr_dir, - 'positions.npy'), - offsets=os.path.join(raw_arr_dir, - 'offsets.npy'), - positions_dtype='float16', - offsets_dtype='uint64', - space_str='rasmm', origin_str='nifti', - verify_invalid=False, dpv=dpv, dps=dps, - groups=groups, dpg=dpg) - exp_trx = tmm.load(expected_trx) - gen_trx = tmm.load(out_gen_path) - - assert DeepDiff(exp_trx.get_dtype_dict(), - gen_trx.get_dtype_dict()) == {} - - assert_allclose(exp_trx.streamlines._data, gen_trx.streamlines._data, - atol=0.1, rtol=0.1) - assert_equal(exp_trx.streamlines._offsets, - gen_trx.streamlines._offsets) - - for key in exp_trx.data_per_vertex.keys(): - assert_equal(exp_trx.data_per_vertex[key]._data, - gen_trx.data_per_vertex[key]._data) - assert_equal(exp_trx.data_per_vertex[key]._offsets, - gen_trx.data_per_vertex[key]._offsets) - for key in exp_trx.data_per_streamline.keys(): - assert_equal(exp_trx.data_per_streamline[key], - gen_trx.data_per_streamline[key]) - for key in exp_trx.groups.keys(): - assert_equal(exp_trx.groups[key], gen_trx.groups[key]) - - for group in exp_trx.groups.keys(): - if group in exp_trx.data_per_group: - for key in exp_trx.data_per_group[group].keys(): - assert_equal(exp_trx.data_per_group[group][key], - gen_trx.data_per_group[group][key]) - exp_trx.close() - gen_trx.close() - - -@pytest.mark.skipif(not dipy_available, - reason='Dipy is not installed.') -def test_execution_concatenate_validate_trx(): - with tempfile.TemporaryDirectory() as tmp_dir: - trx1 = tmm.load(os.path.join(get_home(), 'gold_standard', - 'gs.trx')) - trx2 = tmm.load(os.path.join(get_home(), 'gold_standard', - 'gs.trx')) - # trx2.streamlines._data += 0.001 - trx = tmm.concatenate([trx1, trx2], preallocation=False) - - # Right size - assert_equal(len(trx.streamlines), 2*len(trx1.streamlines)) - - # Right data - end_idx = trx1.header['NB_VERTICES'] - assert_allclose( - trx.streamlines._data[:end_idx], trx1.streamlines._data) - assert_allclose( - trx.streamlines._data[end_idx:], trx2.streamlines._data) - - # Right data_per_* - for key in trx.data_per_vertex.keys(): - assert_equal(trx.data_per_vertex[key]._data[:end_idx], - trx1.data_per_vertex[key]._data) - assert_equal(trx.data_per_vertex[key]._data[end_idx:], - trx2.data_per_vertex[key]._data) - - end_idx = trx1.header['NB_STREAMLINES'] - for key in trx.data_per_streamline.keys(): - assert_equal(trx.data_per_streamline[key][:end_idx], - trx1.data_per_streamline[key]) - assert_equal(trx.data_per_streamline[key][end_idx:], - trx2.data_per_streamline[key]) - - # Validate - out_concat_path = os.path.join(tmp_dir, 'concat.trx') - out_valid_path = os.path.join(tmp_dir, 'valid.trx') - tmm.save(trx, out_concat_path) - validate_tractogram(out_concat_path, None, out_valid_path, - remove_identical_streamlines=True, - precision=0) - trx_val = tmm.load(out_valid_path) - - # # Right dtype and size - assert DeepDiff(trx.get_dtype_dict(), trx_val.get_dtype_dict()) == {} - assert_equal(len(trx1.streamlines), len(trx_val.streamlines)) - - trx.close() - trx1.close() - trx2.close() - trx_val.close() - - -@pytest.mark.skipif(not dipy_available, - reason='Dipy is not installed.') -def test_execution_manipulate_trx_datatype(): - with tempfile.TemporaryDirectory() as tmp_dir: - expected_trx = os.path.join(get_home(), 'trx_from_scratch', - 'expected.trx') - trx = tmm.load(expected_trx) - - expected_dtype = {'positions': np.dtype('float16'), - 'offsets': np.dtype('uint64'), - 'dpv': {'dpv_cx': np.dtype('uint8'), - 'dpv_cy': np.dtype('uint8'), - 'dpv_cz': np.dtype('uint8')}, - 'dps': {'dps_algo': np.dtype('uint8'), - 'dps_cw': np.dtype('float64')}, - 'dpg': {'g_AF_L': - {'dpg_AF_L_mean_fa': np.dtype('float32'), - 'dpg_AF_L_volume': np.dtype('float32')}, - 'g_AF_R': - {'dpg_AF_R_mean_fa': np.dtype('float32')}}, - 'groups': {'g_AF_L': np.dtype('int32'), - 'g_AF_R': np.dtype('int32')}} - - assert DeepDiff(trx.get_dtype_dict(), expected_dtype) == {} - trx.close() - - generated_dtype = {'positions': np.dtype('float32'), - 'offsets': np.dtype('uint32'), - 'dpv': {'dpv_cx': np.dtype('uint16'), - 'dpv_cy': np.dtype('uint16'), - 'dpv_cz': np.dtype('uint16')}, - 'dps': {'dps_algo': np.dtype('uint8'), - 'dps_cw': np.dtype('float32')}, - 'dpg': {'g_AF_L': - {'dpg_AF_L_mean_fa': np.dtype('float64'), - 'dpg_AF_L_volume': np.dtype('float32')}, - 'g_AF_R': - {'dpg_AF_R_mean_fa': np.dtype('float64')}}, - 'groups': {'g_AF_L': np.dtype('uint16'), - 'g_AF_R': np.dtype('uint16')}} - - out_gen_path = os.path.join(tmp_dir, 'generated.trx') - manipulate_trx_datatype(expected_trx, out_gen_path, generated_dtype) - trx = tmm.load(out_gen_path) - assert DeepDiff(trx.get_dtype_dict(), generated_dtype) == {} - trx.close() diff --git a/scripts/tff_concatenate_tractograms.py b/scripts/tff_concatenate_tractograms.py deleted file mode 100755 index 30ac7a9..0000000 --- a/scripts/tff_concatenate_tractograms.py +++ /dev/null @@ -1,74 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- - -""" -Concatenate multiple tractograms into one. - -If the data_per_point or data_per_streamline is not the same for all -tractograms, the data must be deleted first. -""" - -import argparse -import os - -from trx.io import load, save -from trx.trx_file_memmap import TrxFile, concatenate - - -def _build_arg_parser(): - p = argparse.ArgumentParser(description=__doc__, - formatter_class=argparse.RawTextHelpFormatter) - - p.add_argument('in_tractograms', nargs='+', - help='Tractogram filename. Format must be one of \n' - 'trk, tck, vtk, fib, dpy, trx.') - p.add_argument('out_tractogram', - help='Filename of the concatenated tractogram.') - - p.add_argument('--delete_dpv', action='store_true', - help='Delete the dpv if it exists. ' - 'Required if not all input has the same metadata.') - p.add_argument('--delete_dps', action='store_true', - help='Delete the dps if it exists. ' - 'Required if not all input has the same metadata.') - p.add_argument('--delete_groups', action='store_true', - help='Delete the groups if it exists. ' - 'Required if not all input has the same metadata.') - p.add_argument('--reference', - help='Reference anatomy for tck/vtk/fib/dpy file\n' - 'support (.nii or .nii.gz).') - p.add_argument('-f', dest='overwrite', action='store_true', - help='Force overwriting of the output files.') - - return p - - -def main(): - parser = _build_arg_parser() - args = parser.parse_args() - - if os.path.isfile(args.out_tractogram) and not args.overwrite: - raise IOError('{} already exists, use -f to overwrite.'.format( - args.out_tractogram)) - - trx_list = [] - has_group = False - for filename in args.in_tractograms: - tractogram_obj = load(filename, args.reference) - - if not isinstance(tractogram_obj, TrxFile): - tractogram_obj = TrxFile.from_sft(tractogram_obj) - elif len(tractogram_obj.groups): - has_group = True - trx_list.append(tractogram_obj) - - trx = concatenate(trx_list, delete_dpv=args.delete_dpv, - delete_dps=args.delete_dps, - delete_groups=args.delete_groups or not has_group, - check_space_attributes=True, - preallocation=False) - save(trx, args.out_tractogram) - - -if __name__ == "__main__": - main() diff --git a/scripts/tff_convert_dsi_studio.py b/scripts/tff_convert_dsi_studio.py deleted file mode 100644 index b99111f..0000000 --- a/scripts/tff_convert_dsi_studio.py +++ /dev/null @@ -1,63 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- - -""" -This script is made to fix DSI-Studio TRK file (unknown space/convention) to -make it compatible with TrackVis, MI-Brain, Dipy Horizon (Stateful Tractogram). - -The script either make it match with an anatomy from DSI-Studio. - -This script was tested on various datasets and worked on all of them. However, -always verify the results and if a specific case does not work. Open an issue -on the Scilpy GitHub repository. - -WARNING: This script is still experimental, DSI-Studio evolves quickly and -results may vary depending on the data itself as well as DSI-studio version. -""" - -import argparse -import os - -from trx.workflows import convert_dsi_studio - - -def _build_arg_parser(): - p = argparse.ArgumentParser( - description=__doc__, formatter_class=argparse.RawTextHelpFormatter) - - p.add_argument('in_dsi_tractogram', metavar='IN_DSI_TRACTOGRAM', - help='Path of the input tractogram file from DSI studio ' - '(.trk).') - p.add_argument('in_dsi_fa', metavar='IN_DSI_FA', - help='Path of the input FA from DSI Studio (.nii.gz).') - p.add_argument('out_tractogram', metavar='OUT_TRACTOGRAM', - help='Path of the output tractogram file.') - - invalid = p.add_mutually_exclusive_group() - invalid.add_argument('--remove_invalid', action='store_true', - help='Remove the streamlines landing out of the ' - 'bounding box.') - invalid.add_argument('--keep_invalid', action='store_true', - help='Keep the streamlines landing out of the ' - 'bounding box.') - p.add_argument('-f', dest='overwrite', action='store_true', - help='Force overwriting of the output files.') - - return p - - -def main(): - parser = _build_arg_parser() - args = parser.parse_args() - - if os.path.isfile(args.out_tractogram) and not args.overwrite: - raise IOError('{} already exists, use -f to overwrite.'.format( - args.out_tractogram)) - - convert_dsi_studio(args.in_dsi_tractogram, args.in_dsi_fa, - args.out_tractogram, remove_invalid=args.remove_invalid, - keep_invalid=args.keep_invalid) - - -if __name__ == "__main__": - main() diff --git a/scripts/tff_convert_tractogram.py b/scripts/tff_convert_tractogram.py deleted file mode 100644 index cffee68..0000000 --- a/scripts/tff_convert_tractogram.py +++ /dev/null @@ -1,59 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- - -""" -Conversion of '.tck', '.trk', '.fib', '.vtk', '.trx' and 'dpy' files using -updated file format standard. TCK file always needs a reference file, a NIFTI, -for conversion. The FIB file format is in fact a VTK, MITK Diffusion supports -it. -""" - -import argparse -import os - -from trx.workflows import convert_tractogram - - -def _build_arg_parser(): - p = argparse.ArgumentParser(description=__doc__, - formatter_class=argparse.RawTextHelpFormatter) - - p.add_argument('in_tractogram', metavar='IN_TRACTOGRAM', - help='Tractogram filename. Format must be one of \n' - 'trk, tck, vtk, fib, dpy, trx.') - p.add_argument('out_tractogram', metavar='OUT_TRACTOGRAM', - help='Output filename. Format must be one of \n' - 'trk, tck, vtk, fib, dpy, trx.') - - p.add_argument('--reference', - help='Reference anatomy for tck/vtk/fib/dpy file\n' - 'support (.nii or .nii.gz).') - - p2 = p.add_argument_group(title='Data type options') - p2.add_argument('--positions_dtype', default='float32', - choices=['float16', 'float32', 'float64'], - help='Specify the datatype for positions for trx. [%(default)s]') - p2.add_argument('--offsets_dtype', default='uint64', - choices=['uint32', 'uint64'], - help='Specify the datatype for offsets for trx. [%(default)s]') - p.add_argument('-f', dest='overwrite', action='store_true', - help='Force overwriting of the output files.') - - return p - - -def main(): - parser = _build_arg_parser() - args = parser.parse_args() - - if os.path.isfile(args.out_tractogram) and not args.overwrite: - raise IOError('{} already exists, use -f to overwrite.'.format( - args.out_tractogram)) - - convert_tractogram(args.in_tractogram, args.out_tractogram, args.reference, - pos_dtype=args.positions_dtype, - offsets_dtype=args.offsets_dtype) - - -if __name__ == "__main__": - main() diff --git a/scripts/tff_generate_trx_from_scratch.py b/scripts/tff_generate_trx_from_scratch.py deleted file mode 100755 index 38fd1f9..0000000 --- a/scripts/tff_generate_trx_from_scratch.py +++ /dev/null @@ -1,140 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- - -""" -Generate TRX file from a collection of CSV, TXT or NPY files by individually -specifying positions, offsets, data_per_vertex, data_per_streamlines, -groups and data_per_group. Each file must have its data type specified by the -users. - -A reference file must be provided (NIFTI) and the option --verify_invalid will -remove invalid streamlines (outside of the bounding box in VOX space). - -All dimensions (nbr_vertices and nbr_streamlines) and groups/dpg must match -otherwise the script will (likely) crash. - -Each instance of --dps, --dpv, --groups require 2 arguments (FILE, DTYPE). ---dpg requires 3 arguments (GROUP, FILE, DTYPE). -The choice of DTYPE are: - - (u)int8, (u)int16, (u)int32, (u)int64 - - float16, float32, float64 - - bool - -Example command: -tff_generate_trx_from_scratch.py fa.nii.gz generated.trx -f \ - --positions test_npy/positions.npy --positions_dtype float16 \ - --offsets test_npy/offsets.npy --offsets_dtype uint32 \ - --dpv test_npy/dpv_cx.npy uint8 \ - --dpv test_npy/dpv_cy.npy uint8 \ - --dpv test_npy/dpv_cz.npy uint8 \ - --dps test_npy/dps_algo.npy uint8 \ - --dps test_npy/dps_cw.npy float64 \ - --groups test_npy/g_AF_L.npy int32 \ - --groups test_npy/g_AF_R.npy int32 \ - --dpg g_AF_L test_npy/dpg_AF_L_mean_fa.npy float32 \ - --dpg g_AF_R test_npy/dpg_AF_R_mean_fa.npy float32 \ - --dpg g_AF_L test_npy/dpg_AF_L_volume.npy float32 -""" - -import argparse -import os - -from trx.workflows import generate_trx_from_scratch - - -def _build_arg_parser(): - p = argparse.ArgumentParser(description=__doc__, - formatter_class=argparse.RawTextHelpFormatter) - p.add_argument('reference', - help='Reference anatomy for tck/vtk/fib/dpy file\n' - 'support (.nii or .nii.gz).') - p.add_argument('out_tractogram', metavar='OUT_TRACTOGRAM', - help='Output filename. Format must be one of\n' - 'trk, tck, vtk, fib, dpy, trx.') - - p1 = p.add_argument_group(title='Positions options') - p1.add_argument('--positions', metavar='POSITIONS', - help='Binary file containing the streamlines coordinates.' - '\nMust be Nx3 (.npy)') - p1.add_argument('--offsets', metavar='OFFSETS', - help='Binary file containing the streamlines offsets (.npy)') - p1.add_argument('--positions_csv', metavar='POSITIONS', - help='CSV file containing the streamlines coordinates.' - '\nRows for each streamlines organized as x1,y1,z1,\n' - 'x2,y2,z2,...,xN,yN,zN') - p1.add_argument('--space', choices=['RASMM', 'VOXMM', 'VOX'], - default='RASMM', - help='Space in which the coordinates are declared.' - '[%(default)s]\nNon-default option requires Dipy.') - p1.add_argument('--origin', choices=['NIFTI', 'TRACKVIS'], - default='NIFTI', - help='Origin in which the coordinates are declared. ' - '[%(default)s]\nNon-default option requires Dipy.') - p2 = p.add_argument_group(title='Data type options') - p2.add_argument('--positions_dtype', default='float32', - choices=['float16', 'float32', 'float64'], - help='Specify the datatype for positions for trx. ' - '[%(default)s]') - p2.add_argument('--offsets_dtype', default='uint64', - choices=['uint32', 'uint64'], - help='Specify the datatype for offsets for trx. ' - '[%(default)s]') - - p3 = p.add_argument_group(title='Streamlines metadata options') - p3.add_argument('--dpv', metavar=('FILE', 'DTYPE'), nargs=2, - action='append', - help='Binary file containing data_per_vertex.\n Must have' - 'NB_VERTICES as first dimension (.npy)') - p3.add_argument('--dps', metavar=('FILE', 'DTYPE'), nargs=2, - action='append', - help='Binary file containing data_per_vertex.\n Must have' - 'NB_STREAMLINES as first dimension (.npy)') - p3.add_argument('--groups', metavar=('FILE', 'DTYPE'), nargs=2, - action='append', - help='Binary file containing a sparse group (indices).\n ' - 'Indices should be lower than NB_STREAMLINES (.npy)') - p3.add_argument('--dpg', metavar=('GROUP', 'FILE', 'DTYPE'), nargs=3, - action='append', - help='Binary file containing data_per_group.\n Must have' - '(1,) as first dimension (.npy)') - - p.add_argument('--verify_invalid', action='store_true', - help='Verify that the positions are all valid.\n' - 'None outside of the bounding box in VOX space.\n' - 'Requires Dipy (due to use of SFT).') - p.add_argument('-f', dest='overwrite', action='store_true', - help='Force overwriting of the output files.') - - return p - - -def main(): - parser = _build_arg_parser() - args = parser.parse_args() - - if os.path.isfile(args.out_tractogram) and not args.overwrite: - raise IOError('{} already exists, use -f to overwrite.'.format( - args.out_tractogram)) - - if not args.positions and not args.positions_csv: - parser.error('At least one positions options must be used.') - if args.positions_csv and args.positions: - parser.error('Cannot use both positions options.') - if args.positions and args.offsets is None: - parser.error('--offsets must be provided if --positions is used.') - if args.offsets and args.positions is None: - parser.error('--positions must be provided if --offsets is used.') - - generate_trx_from_scratch(args.reference, args.out_tractogram, - positions_csv=args.positions_csv, - positions=args.positions, offsets=args.offsets, - positions_dtype=args.positions_dtype, - offsets_dtype=args.offsets_dtype, - space_str=args.space, origin_str=args.origin, - verify_invalid=args.verify_invalid, - dpv=args.dpv, dps=args.dps, - groups=args.groups, dpg=args.dpg) - - -if __name__ == "__main__": - main() diff --git a/scripts/tff_manipulate_datatype.py b/scripts/tff_manipulate_datatype.py deleted file mode 100644 index 2c0afe5..0000000 --- a/scripts/tff_manipulate_datatype.py +++ /dev/null @@ -1,103 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- - -""" -Manipulate a TRX file internal array to change their data type. - -Each instance of --dps, --dpv, --groups require 2 arguments (FILE, DTYPE). ---dpg requires 3 arguments (GROUP, FILE, DTYPE). -The choice of DTYPE are: - - (u)int8, (u)int16, (u)int32, (u)int64 - - float16, float32, float64 - - bool - -Example command: -tff_manipulate_datatype.py input.trx output.trx \ - --position float16 --offsets uint64 \ - --dpv color_x uint8 --dpv color_y uint8 --dpv color_z uint8 \ - --dpv fa float16 --dps algo uint8 --dps clusters_QB uint16 \ - --dps commit_colors uint8 --dps commit_weights float16 \ - --group CC uint64 --dpg CC mean_fa float64 -""" - -import argparse -import os - -import numpy as np -from trx.workflows import manipulate_trx_datatype - - -def _build_arg_parser(): - p = argparse.ArgumentParser(description=__doc__, - formatter_class=argparse.RawTextHelpFormatter) - p.add_argument('in_tractogram', - help='Input TRX file.') - p.add_argument('out_tractogram', - help='Output filename. Format must be one of\n' - 'trk, tck, vtk, fib, dpy, trx.') - - p2 = p.add_argument_group(title='Data type options') - p2.add_argument('--positions_dtype', - choices=['float16', 'float32', 'float64'], - help='Specify the datatype for positions for trx. ' - '[%(choices)s]') - p2.add_argument('--offsets_dtype', - choices=['uint32', 'uint64'], - help='Specify the datatype for offsets for trx. ' - '[%(choices)s]') - - p3 = p.add_argument_group(title='Streamlines metadata options') - p3.add_argument('--dpv', metavar=('NAME', 'DTYPE'), nargs=2, - action='append', - help='Specify the datatype for a specific data_per_vertex.') - p3.add_argument('--dps', metavar=('NAME', 'DTYPE'), nargs=2, - action='append', - help='Specify the datatype for a specific data_per_streamline.') - p3.add_argument('--groups', metavar=('NAME', 'DTYPE'), nargs=2, - action='append', - help='Specify the datatype for a specific group.') - p3.add_argument('--dpg', metavar=('GROUP', 'NAME', 'DTYPE'), nargs=3, - action='append', - help='Specify the datatype for a specific data_per_group.') - p.add_argument('-f', dest='overwrite', action='store_true', - help='Force overwriting of the output files.') - - return p - - -def main(): # noqa: C901 - parser = _build_arg_parser() - args = parser.parse_args() - - if os.path.isfile(args.out_tractogram) and not args.overwrite: - raise IOError('{} already exists, use -f to overwrite.'.format( - args.out_tractogram)) - - dtype_dict = {} - if args.positions_dtype: - dtype_dict['positions'] = np.dtype(args.positions_dtype) - if args.offsets_dtype: - dtype_dict['offsets'] = np.dtype(args.offsets_dtype) - if args.dpv: - dtype_dict['dpv'] = {} - for name, dtype in args.dpv: - dtype_dict['dpv'][name] = np.dtype(dtype) - if args.dps: - dtype_dict['dps'] = {} - for name, dtype in args.dps: - dtype_dict['dps'][name] = np.dtype(dtype) - if args.groups: - dtype_dict['groups'] = {} - for name, dtype in args.groups: - dtype_dict['groups'][name] = np.dtype(dtype) - if args.dpg: - dtype_dict['dpg'] = {} - for group, name, dtype in args.dpg: - dtype_dict['dpg'][group] = {name: np.dtype(dtype)} - - manipulate_trx_datatype( - args.in_tractogram, args.out_tractogram, dtype_dict) - - -if __name__ == "__main__": - main() diff --git a/scripts/tff_simple_compare.py b/scripts/tff_simple_compare.py deleted file mode 100755 index 4e3e146..0000000 --- a/scripts/tff_simple_compare.py +++ /dev/null @@ -1,36 +0,0 @@ -#!/usr/bin/env python - -""" Simple comparison of tractogram by subtracting the coordinates' data. -Does not account for shuffling of streamlines. Simple A-B operations. - -Differences below 1e^3 are expected for affine with large rotation/scaling. -Difference below 1e^6 are expected for isotropic data with small rotation. -""" - -import argparse - -from trx.workflows import tractogram_simple_compare - - -def _build_arg_parser(): - p = argparse.ArgumentParser(description=__doc__, - formatter_class=argparse.RawTextHelpFormatter) - - p.add_argument('in_tractograms', nargs=2, metavar='IN_TRACTOGRAM', - help='Tractogram filename. Format must be one of \n' - 'trk, tck, vtk, fib, dpy, trx.') - p.add_argument('--reference', metavar='REFERENCE', - help='Reference anatomy for tck/vtk/fib/dpy file\n' - 'support (.nii or .nii.gz).') - return p - - -def main(): - parser = _build_arg_parser() - args = parser.parse_args() - - tractogram_simple_compare(args.in_tractograms, args.reference) - - -if __name__ == "__main__": - main() diff --git a/scripts/tff_validate_trx.py b/scripts/tff_validate_trx.py deleted file mode 100755 index 03c808b..0000000 --- a/scripts/tff_validate_trx.py +++ /dev/null @@ -1,64 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- - -""" -Validate TRX file. - -Removes streamlines that are out of the volume bounding box. In voxel space, -no negative coordinate and no above volume dimension coordinate are possible. -Any streamline that do not respect these two conditions are removed. - -Also removes streamlines with single or no point. -The --remove_identical_streamlines option will remove identical streamlines. -'identical' is defined as having the same number of points and the same -points coordinates (to a specified precision, using a hash table). -""" - -import argparse -import os - -from trx.workflows import validate_tractogram - - -def _build_arg_parser(): - p = argparse.ArgumentParser(description=__doc__, - formatter_class=argparse.RawTextHelpFormatter) - - p.add_argument('in_tractogram', - help='Tractogram filename. Format must be one of \n' - 'trk, tck, vtk, fib, dpy, trx.') - p.add_argument('--out_tractogram', - help='Filename of the tractogram after removing invalid ' - 'streamlines.') - p.add_argument('--remove_identical_streamlines', action='store_true', - help='Remove identical streamlines from the set.') - p.add_argument('--precision', type=int, default=1, - help='Number of decimals to keep when hashing the points ' - 'of streamlines [%(default)s].') - - p.add_argument('--reference', - help='Reference anatomy for tck/vtk/fib/dpy file\n' - 'support (.nii or .nii.gz).') - p.add_argument('-f', dest='overwrite', action='store_true', - help='Force overwriting of the output files.') - - return p - - -def main(): - parser = _build_arg_parser() - args = parser.parse_args() - - if args.out_tractogram and os.path.isfile(args.out_tractogram) \ - and not args.overwrite: - raise IOError('{} already exists, use -f to overwrite.'.format( - args.out_tractogram)) - - validate_tractogram(args.in_tractogram, reference=args.reference, - out_tractogram=args.out_tractogram, - remove_identical_streamlines=args.remove_identical_streamlines, - precision=args.precision) - - -if __name__ == "__main__": - main() diff --git a/scripts/tff_verify_header_compatibility.py b/scripts/tff_verify_header_compatibility.py deleted file mode 100644 index 629edda..0000000 --- a/scripts/tff_verify_header_compatibility.py +++ /dev/null @@ -1,33 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- - -""" -Will compare all input files against the first one for the compatibility -of their spatial attributes. - -Spatial attributes are: affine, dimensions, voxel sizes and voxel order. -""" - -import argparse -from trx.workflows import verify_header_compatibility - - -def _build_arg_parser(): - p = argparse.ArgumentParser(description=__doc__, - formatter_class=argparse.RawTextHelpFormatter) - - p.add_argument('in_files', nargs='+', - help='List of file to compare (trk, trx and nii).') - - return p - - -def main(): - parser = _build_arg_parser() - args = parser.parse_args() - - verify_header_compatibility(args.in_files) - - -if __name__ == "__main__": - main() diff --git a/scripts/tff_visualize_overlap.py b/scripts/tff_visualize_overlap.py deleted file mode 100644 index 5e530f8..0000000 --- a/scripts/tff_visualize_overlap.py +++ /dev/null @@ -1,39 +0,0 @@ -#!/usr/bin/env python - -""" -Display a tractogram and its density map (computed from Dipy) in rasmm, -voxmm and vox space with its bounding box. -""" - -import argparse - -from trx.workflows import tractogram_visualize_overlap - - -def _build_arg_parser(): - p = argparse.ArgumentParser(description=__doc__, - formatter_class=argparse.RawTextHelpFormatter) - - p.add_argument('in_tractogram', metavar='IN_TRACTOGRAM', - help='Tractogram filename. Format must be one of \n' - 'trk, tck, vtk, fib, dpy, trx.') - p.add_argument('reference', - help='Reference anatomy for tck/vtk/fib/dpy file\n' - 'support (nii or nii.gz).') - p.add_argument('--remove_invalid', action='store_true', - help='Removes invalid streamlines to avoid the density_map' - 'function to crash.') - - return p - - -def main(): - parser = _build_arg_parser() - args = parser.parse_args() - - tractogram_visualize_overlap(args.in_tractogram, args.reference, - args.remove_invalid) - - -if __name__ == "__main__": - main() diff --git a/tools/update_switcher.py b/tools/update_switcher.py index 6bc5ad6..4a66988 100644 --- a/tools/update_switcher.py +++ b/tools/update_switcher.py @@ -5,10 +5,11 @@ This script maintains the version switcher JSON file used by pydata-sphinx-theme to enable users to switch between different documentation versions. """ + import argparse import json -import sys from pathlib import Path +import sys BASE_URL = "https://tee-ar-ex.github.io/trx-python" @@ -16,7 +17,7 @@ def load_switcher(path): """Load existing switcher.json or return empty list.""" try: - with open(path, 'r') as f: + with open(path, "r") as f: return json.load(f) except (FileNotFoundError, json.JSONDecodeError): return [] @@ -24,38 +25,35 @@ def load_switcher(path): def save_switcher(path, versions): """Save switcher.json with proper formatting.""" - with open(path, 'w') as f: + with open(path, "w") as f: json.dump(versions, f, indent=4) - f.write('\n') + f.write("\n") def ensure_dev_entry(versions): """Ensure dev entry exists in versions list.""" - dev_exists = any(v.get('version') == 'dev' for v in versions) + dev_exists = any(v.get("version") == "dev" for v in versions) if not dev_exists: - versions.insert(0, { - "name": "dev", - "version": "dev", - "url": f"{BASE_URL}/dev/" - }) + versions.insert(0, {"name": "dev", "version": "dev", "url": f"{BASE_URL}/dev/"}) return versions def ensure_stable_entry(versions): """Ensure stable entry exists with preferred flag.""" stable_idx = next( - (i for i, v in enumerate(versions) if v.get('version') == 'stable'), - None + (i for i, v in enumerate(versions) if v.get("version") == "stable"), None ) if stable_idx is not None: - versions[stable_idx]['preferred'] = True + versions[stable_idx]["preferred"] = True else: - versions.append({ - "name": "stable", - "version": "stable", - "url": f"{BASE_URL}/stable/", - "preferred": True - }) + versions.append( + { + "name": "stable", + "version": "stable", + "url": f"{BASE_URL}/stable/", + "preferred": True, + } + ) return versions @@ -76,21 +74,20 @@ def add_version(versions, version): """ # Remove 'preferred' from all existing entries for v in versions: - v.pop('preferred', None) + v.pop("preferred", None) # Check if this version already exists - version_exists = any(v.get('version') == version for v in versions) + version_exists = any(v.get("version") == version for v in versions) if not version_exists: new_entry = { "name": version, "version": version, - "url": f"{BASE_URL}/{version}/" + "url": f"{BASE_URL}/{version}/", } # Find dev entry index to insert after it dev_idx = next( - (i for i, v in enumerate(versions) if v.get('version') == 'dev'), - -1 + (i for i, v in enumerate(versions) if v.get("version") == "dev"), -1 ) if dev_idx >= 0: versions.insert(dev_idx + 1, new_entry) @@ -103,18 +100,10 @@ def add_version(versions, version): def main(): """Main entry point.""" parser = argparse.ArgumentParser( - description='Update switcher.json for documentation version switching' - ) - parser.add_argument( - 'switcher_path', - type=Path, - help='Path to switcher.json file' - ) - parser.add_argument( - '--version', - type=str, - help='New version to add (e.g., 0.5.0)' + description="Update switcher.json for documentation version switching" ) + parser.add_argument("switcher_path", type=Path, help="Path to switcher.json file") + parser.add_argument("--version", type=str, help="New version to add (e.g., 0.5.0)") args = parser.parse_args() @@ -139,5 +128,5 @@ def main(): return 0 -if __name__ == '__main__': +if __name__ == "__main__": sys.exit(main()) diff --git a/trx/cli.py b/trx/cli.py new file mode 100644 index 0000000..cae5ed0 --- /dev/null +++ b/trx/cli.py @@ -0,0 +1,815 @@ +# -*- coding: utf-8 -*- +""" +TRX Command Line Interface. + +This module provides a unified CLI for all TRX file format operations using Typer. +""" + +from pathlib import Path +from typing import List, Optional + +import numpy as np +import typer +from typing_extensions import Annotated + +from trx.io import load, save +from trx.trx_file_memmap import TrxFile, concatenate +from trx.workflows import ( + convert_dsi_studio, + convert_tractogram, + generate_trx_from_scratch, + manipulate_trx_datatype, + tractogram_simple_compare, + tractogram_visualize_overlap, + validate_tractogram, + verify_header_compatibility, +) + +app = typer.Typer( + name="tff", + help="TRX File Format Tools - CLI for brain tractography data manipulation.", + add_completion=False, + rich_markup_mode="rich", +) + + +def _check_overwrite(filepath: Path, overwrite: bool) -> None: + """Check if file exists and raise error if overwrite is not enabled. + + Parameters + ---------- + filepath : Path + Path to the output file. + overwrite : bool + If True, allow overwriting existing files. + + Raises + ------ + typer.Exit + If file exists and overwrite is False. + """ + if filepath.is_file() and not overwrite: + typer.echo( + typer.style( + f"Error: {filepath} already exists. Use --force to overwrite.", + fg=typer.colors.RED, + ), + err=True, + ) + raise typer.Exit(code=1) + + +@app.command("concatenate") +def concatenate_tractograms( + in_tractograms: Annotated[ + List[Path], + typer.Argument( + help="Input tractogram files. Format: trk, tck, vtk, fib, dpy, trx.", + ), + ], + out_tractogram: Annotated[ + Path, + typer.Argument(help="Output filename for the concatenated tractogram."), + ], + delete_dpv: Annotated[ + bool, + typer.Option( + "--delete-dpv", + help="Delete data_per_vertex if not all inputs have the same metadata.", + ), + ] = False, + delete_dps: Annotated[ + bool, + typer.Option( + "--delete-dps", + help="Delete data_per_streamline if not all inputs have the same metadata.", + ), + ] = False, + delete_groups: Annotated[ + bool, + typer.Option( + "--delete-groups", + help="Delete groups if not all inputs have the same metadata.", + ), + ] = False, + reference: Annotated[ + Optional[Path], + typer.Option( + "--reference", + "-r", + help="Reference anatomy for tck/vtk/fib/dpy files (.nii or .nii.gz).", + ), + ] = None, + force: Annotated[ + bool, + typer.Option("--force", "-f", help="Force overwriting of output files."), + ] = False, +) -> None: + """Concatenate multiple tractograms into one. + + If the data_per_point or data_per_streamline is not the same for all + tractograms, the data must be deleted first using the appropriate flags. + """ + _check_overwrite(out_tractogram, force) + + ref = str(reference) if reference else None + + trx_list = [] + has_group = False + for filename in in_tractograms: + tractogram_obj = load(str(filename), ref) + + if not isinstance(tractogram_obj, TrxFile): + tractogram_obj = TrxFile.from_sft(tractogram_obj) + elif len(tractogram_obj.groups): + has_group = True + trx_list.append(tractogram_obj) + + trx = concatenate( + trx_list, + delete_dpv=delete_dpv, + delete_dps=delete_dps, + delete_groups=delete_groups or not has_group, + check_space_attributes=True, + preallocation=False, + ) + save(trx, str(out_tractogram)) + + typer.echo( + typer.style( + f"Successfully concatenated {len(in_tractograms)} tractograms " + f"to {out_tractogram}", + fg=typer.colors.GREEN, + ) + ) + + +@app.command("convert") +def convert( + in_tractogram: Annotated[ + Path, + typer.Argument(help="Input tractogram. Format: trk, tck, vtk, fib, dpy, trx."), + ], + out_tractogram: Annotated[ + Path, + typer.Argument(help="Output tractogram. Format: trk, tck, vtk, fib, dpy, trx."), + ], + reference: Annotated[ + Optional[Path], + typer.Option( + "--reference", + "-r", + help="Reference anatomy for tck/vtk/fib/dpy files (.nii or .nii.gz).", + ), + ] = None, + positions_dtype: Annotated[ + str, + typer.Option( + "--positions-dtype", + help="Datatype for positions in TRX output.", + ), + ] = "float32", + offsets_dtype: Annotated[ + str, + typer.Option( + "--offsets-dtype", + help="Datatype for offsets in TRX output.", + ), + ] = "uint64", + force: Annotated[ + bool, + typer.Option("--force", "-f", help="Force overwriting of output files."), + ] = False, +) -> None: + """Convert tractograms between formats. + + Supports conversion of .tck, .trk, .fib, .vtk, .trx and .dpy files. + TCK files always need a reference NIFTI file for conversion. + """ + _check_overwrite(out_tractogram, force) + + ref = str(reference) if reference else None + convert_tractogram( + str(in_tractogram), + str(out_tractogram), + ref, + pos_dtype=positions_dtype, + offsets_dtype=offsets_dtype, + ) + + typer.echo( + typer.style( + f"Successfully converted {in_tractogram} to {out_tractogram}", + fg=typer.colors.GREEN, + ) + ) + + +@app.command("convert-dsi") +def convert_dsi( + in_dsi_tractogram: Annotated[ + Path, + typer.Argument(help="Input tractogram from DSI Studio (.trk)."), + ], + in_dsi_fa: Annotated[ + Path, + typer.Argument(help="Input FA from DSI Studio (.nii.gz)."), + ], + out_tractogram: Annotated[ + Path, + typer.Argument(help="Output tractogram file."), + ], + remove_invalid: Annotated[ + bool, + typer.Option( + "--remove-invalid", + help="Remove streamlines landing out of the bounding box.", + ), + ] = False, + keep_invalid: Annotated[ + bool, + typer.Option( + "--keep-invalid", + help="Keep streamlines landing out of the bounding box.", + ), + ] = False, + force: Annotated[ + bool, + typer.Option("--force", "-f", help="Force overwriting of output files."), + ] = False, +) -> None: + """Fix DSI-Studio TRK files for compatibility. + + This script fixes DSI-Studio TRK files (unknown space/convention) to make + them compatible with TrackVis, MI-Brain, and Dipy Horizon. + + [bold yellow]WARNING:[/bold yellow] This script is experimental. DSI-Studio evolves + quickly and results may vary depending on the data and DSI-Studio version. + """ + _check_overwrite(out_tractogram, force) + + if remove_invalid and keep_invalid: + typer.echo( + typer.style( + "Error: Cannot use both --remove-invalid and --keep-invalid.", + fg=typer.colors.RED, + ), + err=True, + ) + raise typer.Exit(code=1) + + convert_dsi_studio( + str(in_dsi_tractogram), + str(in_dsi_fa), + str(out_tractogram), + remove_invalid=remove_invalid, + keep_invalid=keep_invalid, + ) + + typer.echo( + typer.style( + f"Successfully converted DSI-Studio tractogram to {out_tractogram}", + fg=typer.colors.GREEN, + ) + ) + + +@app.command("generate") +def generate( + reference: Annotated[ + Path, + typer.Argument(help="Reference anatomy (.nii or .nii.gz)."), + ], + out_tractogram: Annotated[ + Path, + typer.Argument(help="Output tractogram. Format: trk, tck, vtk, fib, dpy, trx."), + ], + positions: Annotated[ + Optional[Path], + typer.Option( + "--positions", + help="Binary file with streamline coordinates (Nx3 .npy).", + ), + ] = None, + offsets: Annotated[ + Optional[Path], + typer.Option( + "--offsets", + help="Binary file with streamline offsets (.npy).", + ), + ] = None, + positions_csv: Annotated[ + Optional[Path], + typer.Option( + "--positions-csv", + help="CSV file with streamline coordinates (x1,y1,z1,x2,y2,z2,...).", + ), + ] = None, + space: Annotated[ + str, + typer.Option( + "--space", + help="Coordinate space. Non-default requires Dipy.", + ), + ] = "RASMM", + origin: Annotated[ + str, + typer.Option( + "--origin", + help="Coordinate origin. Non-default requires Dipy.", + ), + ] = "NIFTI", + positions_dtype: Annotated[ + str, + typer.Option("--positions-dtype", help="Datatype for positions."), + ] = "float32", + offsets_dtype: Annotated[ + str, + typer.Option("--offsets-dtype", help="Datatype for offsets."), + ] = "uint64", + dpv: Annotated[ + Optional[List[str]], + typer.Option( + "--dpv", + help="Data per vertex: FILE,DTYPE (e.g., color.npy,uint8).", + ), + ] = None, + dps: Annotated[ + Optional[List[str]], + typer.Option( + "--dps", + help="Data per streamline: FILE,DTYPE (e.g., algo.npy,uint8).", + ), + ] = None, + groups: Annotated[ + Optional[List[str]], + typer.Option( + "--groups", + help="Groups: FILE,DTYPE (e.g., AF_L.npy,int32).", + ), + ] = None, + dpg: Annotated[ + Optional[List[str]], + typer.Option( + "--dpg", + help="Data per group: GROUP,FILE,DTYPE (e.g., AF_L,mean_fa.npy,float32).", + ), + ] = None, + verify_invalid: Annotated[ + bool, + typer.Option( + "--verify-invalid", + help="Verify positions are valid (within bounding box). Requires Dipy.", + ), + ] = False, + force: Annotated[ + bool, + typer.Option("--force", "-f", help="Force overwriting of output files."), + ] = False, +) -> None: + """Generate TRX file from raw data files. + + Create a TRX file from CSV, TXT, or NPY files by specifying positions, + offsets, data_per_vertex, data_per_streamlines, groups, and data_per_group. + + Each --dpv, --dps, --groups option requires FILE,DTYPE format. + Each --dpg option requires GROUP,FILE,DTYPE format. + + Valid DTYPEs: (u)int8, (u)int16, (u)int32, (u)int64, float16, float32, float64, bool + """ + _check_overwrite(out_tractogram, force) + + # Validate input combinations + if not positions and not positions_csv: + typer.echo( + typer.style( + "Error: At least one positions option must be provided " + "(--positions or --positions-csv).", + fg=typer.colors.RED, + ), + err=True, + ) + raise typer.Exit(code=1) + if positions_csv and positions: + typer.echo( + typer.style( + "Error: Cannot use both --positions and --positions-csv.", + fg=typer.colors.RED, + ), + err=True, + ) + raise typer.Exit(code=1) + if positions and offsets is None: + typer.echo( + typer.style( + "Error: --offsets must be provided if --positions is used.", + fg=typer.colors.RED, + ), + err=True, + ) + raise typer.Exit(code=1) + if offsets and positions is None: + typer.echo( + typer.style( + "Error: --positions must be provided if --offsets is used.", + fg=typer.colors.RED, + ), + err=True, + ) + raise typer.Exit(code=1) + + # Parse comma-separated arguments to tuples + dpv_list = None + if dpv: + dpv_list = [tuple(item.split(",")) for item in dpv] + + dps_list = None + if dps: + dps_list = [tuple(item.split(",")) for item in dps] + + groups_list = None + if groups: + groups_list = [tuple(item.split(",")) for item in groups] + + dpg_list = None + if dpg: + dpg_list = [tuple(item.split(",")) for item in dpg] + + generate_trx_from_scratch( + str(reference), + str(out_tractogram), + positions_csv=str(positions_csv) if positions_csv else None, + positions=str(positions) if positions else None, + offsets=str(offsets) if offsets else None, + positions_dtype=positions_dtype, + offsets_dtype=offsets_dtype, + space_str=space, + origin_str=origin, + verify_invalid=verify_invalid, + dpv=dpv_list, + dps=dps_list, + groups=groups_list, + dpg=dpg_list, + ) + + typer.echo( + typer.style( + f"Successfully generated {out_tractogram}", + fg=typer.colors.GREEN, + ) + ) + + +@app.command("manipulate-dtype") +def manipulate_dtype( + in_tractogram: Annotated[ + Path, + typer.Argument(help="Input TRX file."), + ], + out_tractogram: Annotated[ + Path, + typer.Argument(help="Output tractogram file."), + ], + positions_dtype: Annotated[ + Optional[str], + typer.Option( + "--positions-dtype", + help="Datatype for positions (float16, float32, float64).", + ), + ] = None, + offsets_dtype: Annotated[ + Optional[str], + typer.Option( + "--offsets-dtype", + help="Datatype for offsets (uint32, uint64).", + ), + ] = None, + dpv: Annotated[ + Optional[List[str]], + typer.Option( + "--dpv", + help="Data per vertex dtype: NAME,DTYPE (e.g., color_x,uint8).", + ), + ] = None, + dps: Annotated[ + Optional[List[str]], + typer.Option( + "--dps", + help="Data per streamline dtype: NAME,DTYPE (e.g., algo,uint8).", + ), + ] = None, + groups: Annotated[ + Optional[List[str]], + typer.Option( + "--groups", + help="Groups dtype: NAME,DTYPE (e.g., CC,uint64).", + ), + ] = None, + dpg: Annotated[ + Optional[List[str]], + typer.Option( + "--dpg", + help="Data per group dtype: GROUP,NAME,DTYPE (e.g., CC,mean_fa,float64).", + ), + ] = None, + force: Annotated[ + bool, + typer.Option("--force", "-f", help="Force overwriting of output files."), + ] = False, +) -> None: # noqa: C901 + """Manipulate TRX file internal array data types. + + Change the data types of positions, offsets, data_per_vertex, + data_per_streamline, groups, and data_per_group arrays. + + Valid DTYPEs: (u)int8, (u)int16, (u)int32, (u)int64, float16, float32, float64, bool + """ + _check_overwrite(out_tractogram, force) + + dtype_dict = {} + if positions_dtype: + dtype_dict["positions"] = np.dtype(positions_dtype) + if offsets_dtype: + dtype_dict["offsets"] = np.dtype(offsets_dtype) + if dpv: + dtype_dict["dpv"] = {} + for item in dpv: + name, dtype = item.split(",") + dtype_dict["dpv"][name] = np.dtype(dtype) + if dps: + dtype_dict["dps"] = {} + for item in dps: + name, dtype = item.split(",") + dtype_dict["dps"][name] = np.dtype(dtype) + if groups: + dtype_dict["groups"] = {} + for item in groups: + name, dtype = item.split(",") + dtype_dict["groups"][name] = np.dtype(dtype) + if dpg: + dtype_dict["dpg"] = {} + for item in dpg: + parts = item.split(",") + group, name, dtype = parts[0], parts[1], parts[2] + if group not in dtype_dict["dpg"]: + dtype_dict["dpg"][group] = {} + dtype_dict["dpg"][group][name] = np.dtype(dtype) + + manipulate_trx_datatype(str(in_tractogram), str(out_tractogram), dtype_dict) + + typer.echo( + typer.style( + f"Successfully manipulated datatypes and saved to {out_tractogram}", + fg=typer.colors.GREEN, + ) + ) + + +@app.command("compare") +def compare( + in_tractogram1: Annotated[ + Path, + typer.Argument(help="First tractogram file."), + ], + in_tractogram2: Annotated[ + Path, + typer.Argument(help="Second tractogram file."), + ], + reference: Annotated[ + Optional[Path], + typer.Option( + "--reference", + "-r", + help="Reference anatomy for tck/vtk/fib/dpy files (.nii or .nii.gz).", + ), + ] = None, +) -> None: + """Simple comparison of tractograms by subtracting coordinates. + + Does not account for shuffling of streamlines. Simple A-B operations. + + Differences below 1e-3 are expected for affines with large rotation/scaling. + Differences below 1e-6 are expected for isotropic data with small rotation. + """ + ref = str(reference) if reference else None + tractogram_simple_compare([str(in_tractogram1), str(in_tractogram2)], ref) + + +@app.command("validate") +def validate( + in_tractogram: Annotated[ + Path, + typer.Argument(help="Input tractogram. Format: trk, tck, vtk, fib, dpy, trx."), + ], + out_tractogram: Annotated[ + Optional[Path], + typer.Option( + "--out", + "-o", + help="Output tractogram after removing invalid streamlines.", + ), + ] = None, + remove_identical: Annotated[ + bool, + typer.Option( + "--remove-identical", + help="Remove identical streamlines from the set.", + ), + ] = False, + precision: Annotated[ + int, + typer.Option( + "--precision", + "-p", + help="Number of decimals when hashing streamline points.", + ), + ] = 1, + reference: Annotated[ + Optional[Path], + typer.Option( + "--reference", + "-r", + help="Reference anatomy for tck/vtk/fib/dpy files (.nii or .nii.gz).", + ), + ] = None, + force: Annotated[ + bool, + typer.Option("--force", "-f", help="Force overwriting of output files."), + ] = False, +) -> None: + """Validate TRX file and remove invalid streamlines. + + Removes streamlines that are out of the volume bounding box (in voxel space, + no negative coordinates or coordinates above volume dimensions). + + Also removes streamlines with single or no points. + Use --remove-identical to remove duplicate streamlines based on precision. + """ + if out_tractogram: + _check_overwrite(out_tractogram, force) + + ref = str(reference) if reference else None + out = str(out_tractogram) if out_tractogram else None + + validate_tractogram( + str(in_tractogram), + reference=ref, + out_tractogram=out, + remove_identical_streamlines=remove_identical, + precision=precision, + ) + + if out_tractogram: + typer.echo( + typer.style( + f"Validation complete. Output saved to {out_tractogram}", + fg=typer.colors.GREEN, + ) + ) + else: + typer.echo( + typer.style( + "Validation complete.", + fg=typer.colors.GREEN, + ) + ) + + +@app.command("verify-header") +def verify_header( + in_files: Annotated[ + List[Path], + typer.Argument(help="Files to compare (trk, trx, and nii)."), + ], +) -> None: + """Compare spatial attributes of input files. + + Compares all input files against the first one for compatibility of + spatial attributes: affine, dimensions, voxel sizes, and voxel order. + """ + verify_header_compatibility([str(f) for f in in_files]) + + +@app.command("visualize") +def visualize( + in_tractogram: Annotated[ + Path, + typer.Argument(help="Input tractogram. Format: trk, tck, vtk, fib, dpy, trx."), + ], + reference: Annotated[ + Path, + typer.Argument(help="Reference anatomy (.nii or .nii.gz)."), + ], + remove_invalid: Annotated[ + bool, + typer.Option( + "--remove-invalid", + help="Remove invalid streamlines to avoid density_map crash.", + ), + ] = False, +) -> None: + """Display tractogram and density map with bounding box. + + Shows the tractogram and its density map (computed from Dipy) in + rasmm, voxmm, and vox space with its bounding box. + """ + tractogram_visualize_overlap( + str(in_tractogram), + str(reference), + remove_invalid, + ) + + +def main(): + """Entry point for the TRX CLI.""" + app() + + +# Standalone entry points for backward compatibility +# These create individual Typer apps for each command + + +def _create_standalone_app(command_func, name: str, help_text: str): + """Create a standalone Typer app for a single command. + + Parameters + ---------- + command_func : callable + The command function to wrap. + name : str + Name of the command. + help_text : str + Help text for the command. + + Returns + ------- + callable + Entry point function. + """ + standalone = typer.Typer( + name=name, + help=help_text, + add_completion=False, + rich_markup_mode="rich", + ) + standalone.command()(command_func) + return lambda: standalone() + + +concatenate_tractograms_cmd = _create_standalone_app( + concatenate_tractograms, + "tff_concatenate_tractograms", + "Concatenate multiple tractograms into one.", +) + +convert_dsi_cmd = _create_standalone_app( + convert_dsi, + "tff_convert_dsi_studio", + "Fix DSI-Studio TRK files for compatibility.", +) + +convert_cmd = _create_standalone_app( + convert, + "tff_convert_tractogram", + "Convert tractograms between formats.", +) + +generate_cmd = _create_standalone_app( + generate, + "tff_generate_trx_from_scratch", + "Generate TRX file from raw data files.", +) + +manipulate_dtype_cmd = _create_standalone_app( + manipulate_dtype, + "tff_manipulate_datatype", + "Manipulate TRX file internal array data types.", +) + +compare_cmd = _create_standalone_app( + compare, + "tff_simple_compare", + "Simple comparison of tractograms by subtracting coordinates.", +) + +validate_cmd = _create_standalone_app( + validate, + "tff_validate_trx", + "Validate TRX file and remove invalid streamlines.", +) + +verify_header_cmd = _create_standalone_app( + verify_header, + "tff_verify_header_compatibility", + "Compare spatial attributes of input files.", +) + +visualize_cmd = _create_standalone_app( + visualize, + "tff_visualize_overlap", + "Display tractogram and density map with bounding box.", +) + + +if __name__ == "__main__": + main() diff --git a/trx/fetcher.py b/trx/fetcher.py index 7a033a8..4aa3d1b 100644 --- a/trx/fetcher.py +++ b/trx/fetcher.py @@ -6,7 +6,6 @@ import shutil import urllib.request - TEST_DATA_REPO = "tee-ar-ex/trx-test-data" TEST_DATA_TAG = "v0.1.0" # GitHub release API entrypoint for metadata (asset list, sizes, etc.). @@ -20,65 +19,65 @@ def get_home(): - """ Set a user-writeable file-system location to put files """ - if 'TRX_HOME' in os.environ: - trx_home = os.environ['TRX_HOME'] + """Set a user-writeable file-system location to put files""" + if "TRX_HOME" in os.environ: + trx_home = os.environ["TRX_HOME"] else: - trx_home = os.path.join(os.path.expanduser('~'), '.tee_ar_ex') + trx_home = os.path.join(os.path.expanduser("~"), ".tee_ar_ex") return trx_home def get_testing_files_dict(): - """ Get dictionary linking zip file to their GitHub release URL & checksums. + """Get dictionary linking zip file to their GitHub release URL & checksums. Assets are hosted under the v0.1.0 release of tee-ar-ex/trx-test-data. If URLs change, check TEST_DATA_API_URL to discover the latest asset locations. """ return { - 'DSI.zip': ( - f'{TEST_DATA_BASE_URL}/DSI.zip', - 'b847f053fc694d55d935c0be0e5268f7', # md5 - '1b09ce8b4b47b2600336c558fdba7051218296e8440e737364f2c4b8ebae666c', + "DSI.zip": ( + f"{TEST_DATA_BASE_URL}/DSI.zip", + "b847f053fc694d55d935c0be0e5268f7", # md5 + "1b09ce8b4b47b2600336c558fdba7051218296e8440e737364f2c4b8ebae666c", ), - 'memmap_test_data.zip': ( - f'{TEST_DATA_BASE_URL}/memmap_test_data.zip', - '03f7651a0f9e3eeabee9aed0ad5f69e1', # md5 - '98ba89d7a9a7baa2d37956a0a591dce9bb4581bd01296ad5a596706ee90a52ef', + "memmap_test_data.zip": ( + f"{TEST_DATA_BASE_URL}/memmap_test_data.zip", + "03f7651a0f9e3eeabee9aed0ad5f69e1", # md5 + "98ba89d7a9a7baa2d37956a0a591dce9bb4581bd01296ad5a596706ee90a52ef", ), - 'trx_from_scratch.zip': ( - f'{TEST_DATA_BASE_URL}/trx_from_scratch.zip', - 'd9f220a095ce7f027772fcd9451a2ee5', # md5 - 'f98ab6da6a6065527fde4b0b6aa40f07583e925d952182e9bbd0febd55c0f6b2', + "trx_from_scratch.zip": ( + f"{TEST_DATA_BASE_URL}/trx_from_scratch.zip", + "d9f220a095ce7f027772fcd9451a2ee5", # md5 + "f98ab6da6a6065527fde4b0b6aa40f07583e925d952182e9bbd0febd55c0f6b2", ), - 'gold_standard.zip': ( - f'{TEST_DATA_BASE_URL}/gold_standard.zip', - '57e3f9951fe77245684ede8688af3ae8', # md5 - '35a0b633560cc2b0d8ecda885aa72d06385499e0cd1ca11a956b0904c3358f01', + "gold_standard.zip": ( + f"{TEST_DATA_BASE_URL}/gold_standard.zip", + "57e3f9951fe77245684ede8688af3ae8", # md5 + "35a0b633560cc2b0d8ecda885aa72d06385499e0cd1ca11a956b0904c3358f01", ), } def md5sum(filename): - """ Compute one md5 checksum for a file """ + """Compute one md5 checksum for a file""" h = hashlib.md5() - with open(filename, 'rb') as f: - for chunk in iter(lambda: f.read(128 * h.block_size), b''): + with open(filename, "rb") as f: + for chunk in iter(lambda: f.read(128 * h.block_size), b""): h.update(chunk) return h.hexdigest() def sha256sum(filename): - """ Compute one sha256 checksum for a file """ + """Compute one sha256 checksum for a file""" h = hashlib.sha256() - with open(filename, 'rb') as f: - for chunk in iter(lambda: f.read(128 * h.block_size), b''): + with open(filename, "rb") as f: + for chunk in iter(lambda: f.read(128 * h.block_size), b""): h.update(chunk) return h.hexdigest() def fetch_data(files_dict, keys=None): # noqa: C901 - """ Downloads files to folder and checks their md5 checksums + """Downloads files to folder and checks their md5 checksums Parameters ---------- @@ -113,29 +112,25 @@ def fetch_data(files_dict, keys=None): # noqa: C901 url, expected_md5, expected_sha = file_entry full_path = os.path.join(trx_home, f) - logging.info('Downloading {} to {}'.format(f, trx_home)) + logging.info("Downloading {} to {}".format(f, trx_home)) if not os.path.exists(full_path): urllib.request.urlretrieve(url, full_path) actual_md5 = md5sum(full_path) if expected_md5 != actual_md5: raise ValueError( - f'Md5sum for {f} does not match. ' - 'Please remove the file to download it again: ' + - full_path + f"Md5sum for {f} does not match. " + "Please remove the file to download it again: " + full_path ) if expected_sha is not None: actual_sha = sha256sum(full_path) if expected_sha != actual_sha: raise ValueError( - f'SHA256 for {f} does not match. ' - 'Please remove the file to download it again: ' + - full_path + f"SHA256 for {f} does not match. " + "Please remove the file to download it again: " + full_path ) - if f.endswith('.zip'): + if f.endswith(".zip"): dst_dir = os.path.join(trx_home, f[:-4]) - shutil.unpack_archive(full_path, - extract_dir=dst_dir, - format='zip') + shutil.unpack_archive(full_path, extract_dir=dst_dir, format="zip") diff --git a/trx/io.py b/trx/io.py index 26b758d..86c8156 100644 --- a/trx/io.py +++ b/trx/io.py @@ -1,12 +1,13 @@ # -*- coding: utf-8 -*- -import os import logging -import tempfile +import os import sys +import tempfile try: import dipy # noqa: F401 + dipy_available = True except ImportError: dipy_available = False @@ -15,58 +16,61 @@ def get_trx_tmp_dir(): - if os.getenv('TRX_TMPDIR') is not None: - if os.getenv('TRX_TMPDIR') == 'use_working_dir': + if os.getenv("TRX_TMPDIR") is not None: + if os.getenv("TRX_TMPDIR") == "use_working_dir": trx_tmp_dir = os.getcwd() else: - trx_tmp_dir = os.getenv('TRX_TMPDIR') + trx_tmp_dir = os.getenv("TRX_TMPDIR") else: trx_tmp_dir = tempfile.gettempdir() if sys.version_info[1] >= 10: - return tempfile.TemporaryDirectory(dir=trx_tmp_dir, prefix='trx_', - ignore_cleanup_errors=True) + return tempfile.TemporaryDirectory( + dir=trx_tmp_dir, prefix="trx_", ignore_cleanup_errors=True + ) else: - return tempfile.TemporaryDirectory(dir=trx_tmp_dir, prefix='trx_') + return tempfile.TemporaryDirectory(dir=trx_tmp_dir, prefix="trx_") -def load_sft_with_reference(filepath, reference=None, - bbox_check=True): +def load_sft_with_reference(filepath, reference=None, bbox_check=True): if not dipy_available: - logging.error('Dipy library is missing, cannot use functions related ' - 'to the StatefulTractogram.') + logging.error( + "Dipy library is missing, cannot use functions related " + "to the StatefulTractogram." + ) return None from dipy.io.streamline import load_tractogram # Force the usage of --reference for all file formats without an header _, ext = os.path.splitext(filepath) - if ext == '.trk': - if reference is not None and reference != 'same': - logging.warning('Reference is discarded for this file format ' - '{}.'.format(filepath)) - sft = load_tractogram(filepath, 'same', - bbox_valid_check=bbox_check) - elif ext in ['.tck', '.fib', '.vtk', '.dpy']: - if reference is None or reference == 'same': - raise IOError('--reference is required for this file format ' - '{}.'.format(filepath)) + if ext == ".trk": + if reference is not None and reference != "same": + logging.warning( + "Reference is discarded for this file format {}.".format(filepath) + ) + sft = load_tractogram(filepath, "same", bbox_valid_check=bbox_check) + elif ext in [".tck", ".fib", ".vtk", ".dpy"]: + if reference is None or reference == "same": + raise IOError( + "--reference is required for this file format {}.".format(filepath) + ) else: - sft = load_tractogram(filepath, reference, - bbox_valid_check=bbox_check) + sft = load_tractogram(filepath, reference, bbox_valid_check=bbox_check) else: - raise IOError('{} is an unsupported file format'.format(filepath)) + raise IOError("{} is an unsupported file format".format(filepath)) return sft def load(tractogram_filename, reference): import trx.trx_file_memmap as tmm + in_ext = split_name_with_gz(tractogram_filename)[1] - if in_ext != '.trx' and not os.path.isdir(tractogram_filename): - tractogram_obj = load_sft_with_reference(tractogram_filename, - reference, - bbox_check=False) + if in_ext != ".trx" and not os.path.isdir(tractogram_filename): + tractogram_obj = load_sft_with_reference( + tractogram_filename, reference, bbox_check=False + ) else: tractogram_obj = tmm.load(tractogram_filename) @@ -75,20 +79,24 @@ def load(tractogram_filename, reference): def save(tractogram_obj, tractogram_filename, bbox_valid_check=False): if not dipy_available: - logging.error('Dipy library is missing, cannot use functions related ' - 'to the StatefulTractogram.') + logging.error( + "Dipy library is missing, cannot use functions related " + "to the StatefulTractogram." + ) return None from dipy.io.stateful_tractogram import StatefulTractogram from dipy.io.streamline import save_tractogram + import trx.trx_file_memmap as tmm out_ext = split_name_with_gz(tractogram_filename)[1] - if out_ext != '.trx': + if out_ext != ".trx": if not isinstance(tractogram_obj, StatefulTractogram): tractogram_obj = tractogram_obj.to_sft() - save_tractogram(tractogram_obj, tractogram_filename, - bbox_valid_check=bbox_valid_check) + save_tractogram( + tractogram_obj, tractogram_filename, bbox_valid_check=bbox_valid_check + ) else: if not isinstance(tractogram_obj, tmm.TrxFile): tractogram_obj = tmm.TrxFile.from_sft(tractogram_obj) diff --git a/trx/streamlines_ops.py b/trx/streamlines_ops.py index d8aa12a..007aaba 100644 --- a/trx/streamlines_ops.py +++ b/trx/streamlines_ops.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- -import itertools from functools import reduce +import itertools import numpy as np @@ -89,7 +89,7 @@ def hash_streamlines(streamlines, start_index=0, precision=None): def perform_streamlines_operation(operation, streamlines, precision=0): - """Peforms an operation on a list of list of streamlines + """Performs an operation on a list of list of streamlines Given a list of list of streamlines, this function applies the operation to the first two lists of streamlines. The result in then used recursively @@ -123,8 +123,7 @@ def perform_streamlines_operation(operation, streamlines, precision=0): # Hash the streamlines using the desired precision. indices = np.cumsum([0] + [len(s) for s in streamlines[:-1]]) - hashes = [hash_streamlines(s, i, precision) for - s, i in zip(streamlines, indices)] + hashes = [hash_streamlines(s, i, precision) for s, i in zip(streamlines, indices)] # Perform the operation on the hashes and get the output streamlines. to_keep = reduce(operation, hashes) diff --git a/trx/tests/test_cli.py b/trx/tests/test_cli.py new file mode 100644 index 0000000..5860a26 --- /dev/null +++ b/trx/tests/test_cli.py @@ -0,0 +1,420 @@ +# -*- coding: utf-8 -*- +"""Tests for CLI commands and workflow functions.""" + +import os +import tempfile + +from deepdiff import DeepDiff +import numpy as np +from numpy.testing import assert_allclose, assert_array_equal, assert_equal +import pytest + +try: + from dipy.io.streamline import load_tractogram + + dipy_available = True +except ImportError: + dipy_available = False + +from trx.fetcher import fetch_data, get_home, get_testing_files_dict +import trx.trx_file_memmap as tmm +from trx.workflows import ( + convert_dsi_studio, + convert_tractogram, + generate_trx_from_scratch, + manipulate_trx_datatype, + validate_tractogram, +) + +# If they already exist, this only takes 5 seconds (check md5sum) +fetch_data(get_testing_files_dict(), keys=["DSI.zip", "trx_from_scratch.zip"]) + + +# Tests for standalone CLI commands (tff_* commands) +class TestStandaloneCommands: + """Tests for standalone CLI commands.""" + + def test_help_option_convert_dsi(self, script_runner): + ret = script_runner.run(["tff_convert_dsi_studio", "--help"]) + assert ret.success + + def test_help_option_convert(self, script_runner): + ret = script_runner.run(["tff_convert_tractogram", "--help"]) + assert ret.success + + def test_help_option_generate_trx_from_scratch(self, script_runner): + ret = script_runner.run(["tff_generate_trx_from_scratch", "--help"]) + assert ret.success + + def test_help_option_concatenate(self, script_runner): + ret = script_runner.run(["tff_concatenate_tractograms", "--help"]) + assert ret.success + + def test_help_option_manipulate(self, script_runner): + ret = script_runner.run(["tff_manipulate_datatype", "--help"]) + assert ret.success + + def test_help_option_compare(self, script_runner): + ret = script_runner.run(["tff_simple_compare", "--help"]) + assert ret.success + + def test_help_option_validate(self, script_runner): + ret = script_runner.run(["tff_validate_trx", "--help"]) + assert ret.success + + def test_help_option_verify_header(self, script_runner): + ret = script_runner.run(["tff_verify_header_compatibility", "--help"]) + assert ret.success + + def test_help_option_visualize(self, script_runner): + ret = script_runner.run(["tff_visualize_overlap", "--help"]) + assert ret.success + + +# Tests for unified tff CLI +class TestUnifiedCLI: + """Tests for the unified tff CLI.""" + + def test_tff_help(self, script_runner): + ret = script_runner.run(["tff", "--help"]) + assert ret.success + + def test_tff_concatenate_help(self, script_runner): + ret = script_runner.run(["tff", "concatenate", "--help"]) + assert ret.success + + def test_tff_convert_help(self, script_runner): + ret = script_runner.run(["tff", "convert", "--help"]) + assert ret.success + + def test_tff_convert_dsi_help(self, script_runner): + ret = script_runner.run(["tff", "convert-dsi", "--help"]) + assert ret.success + + def test_tff_generate_help(self, script_runner): + ret = script_runner.run(["tff", "generate", "--help"]) + assert ret.success + + def test_tff_manipulate_dtype_help(self, script_runner): + ret = script_runner.run(["tff", "manipulate-dtype", "--help"]) + assert ret.success + + def test_tff_compare_help(self, script_runner): + ret = script_runner.run(["tff", "compare", "--help"]) + assert ret.success + + def test_tff_validate_help(self, script_runner): + ret = script_runner.run(["tff", "validate", "--help"]) + assert ret.success + + def test_tff_verify_header_help(self, script_runner): + ret = script_runner.run(["tff", "verify-header", "--help"]) + assert ret.success + + def test_tff_visualize_help(self, script_runner): + ret = script_runner.run(["tff", "visualize", "--help"]) + assert ret.success + + +# Tests for workflow functions +class TestWorkflowFunctions: + """Tests for workflow functions.""" + + @pytest.mark.skipif(not dipy_available, reason="Dipy is not installed.") + def test_execution_convert_dsi(self): + with tempfile.TemporaryDirectory() as tmp_dir: + in_trk = os.path.join(get_home(), "DSI", "CC.trk.gz") + in_nii = os.path.join(get_home(), "DSI", "CC.nii.gz") + exp_data = os.path.join(get_home(), "DSI", "CC_fix_data.npy") + exp_offsets = os.path.join(get_home(), "DSI", "CC_fix_offsets.npy") + out_fix_path = os.path.join(tmp_dir, "fixed.trk") + convert_dsi_studio( + in_trk, in_nii, out_fix_path, remove_invalid=False, keep_invalid=True + ) + + data_fix = np.load(exp_data) + offsets_fix = np.load(exp_offsets) + + sft = load_tractogram(out_fix_path, "same") + assert_equal(sft.streamlines._data, data_fix) + assert_equal(sft.streamlines._offsets, offsets_fix) + + @pytest.mark.skipif(not dipy_available, reason="Dipy is not installed.") + def test_execution_convert_to_trx(self): + with tempfile.TemporaryDirectory() as tmp_dir: + in_trk = os.path.join(get_home(), "DSI", "CC_fix.trk") + exp_data = os.path.join(get_home(), "DSI", "CC_fix_data.npy") + exp_offsets = os.path.join(get_home(), "DSI", "CC_fix_offsets.npy") + out_trx_path = os.path.join(tmp_dir, "CC_fix.trx") + convert_tractogram(in_trk, out_trx_path, None) + + data_fix = np.load(exp_data) + offsets_fix = np.load(exp_offsets) + + trx = tmm.load(out_trx_path) + assert_equal(trx.streamlines._data.dtype, np.float32) + assert_equal(trx.streamlines._offsets.dtype, np.uint32) + assert_array_equal(trx.streamlines._data, data_fix) + assert_array_equal(trx.streamlines._offsets, offsets_fix) + trx.close() + + @pytest.mark.skipif(not dipy_available, reason="Dipy is not installed.") + def test_execution_convert_from_trx(self): + with tempfile.TemporaryDirectory() as tmp_dir: + in_trk = os.path.join(get_home(), "DSI", "CC_fix.trk") + in_nii = os.path.join(get_home(), "DSI", "CC.nii.gz") + exp_data = os.path.join(get_home(), "DSI", "CC_fix_data.npy") + exp_offsets = os.path.join(get_home(), "DSI", "CC_fix_offsets.npy") + + # Sequential conversions + out_trx_path = os.path.join(tmp_dir, "CC_fix.trx") + out_trk_path = os.path.join(tmp_dir, "CC_fix.trk") + out_tck_path = os.path.join(tmp_dir, "CC_fix.tck") + convert_tractogram(in_trk, out_trx_path, None) + convert_tractogram(out_trx_path, out_tck_path, None) + convert_tractogram(out_trx_path, out_trk_path, None) + + data_fix = np.load(exp_data) + offsets_fix = np.load(exp_offsets) + + sft = load_tractogram(out_trk_path, "same") + assert_equal(sft.streamlines._data, data_fix) + assert_equal(sft.streamlines._offsets, offsets_fix) + + sft = load_tractogram(out_tck_path, in_nii) + assert_equal(sft.streamlines._data, data_fix) + assert_equal(sft.streamlines._offsets, offsets_fix) + + @pytest.mark.skipif(not dipy_available, reason="Dipy is not installed.") + def test_execution_convert_dtype_p16_o64(self): + with tempfile.TemporaryDirectory() as tmp_dir: + in_trk = os.path.join(get_home(), "DSI", "CC_fix.trk") + out_convert_path = os.path.join(tmp_dir, "CC_fix_p16_o64.trx") + convert_tractogram( + in_trk, + out_convert_path, + None, + pos_dtype="float16", + offsets_dtype="uint64", + ) + + trx = tmm.load(out_convert_path) + assert_equal(trx.streamlines._data.dtype, np.float16) + assert_equal(trx.streamlines._offsets.dtype, np.uint64) + trx.close() + + @pytest.mark.skipif(not dipy_available, reason="Dipy is not installed.") + def test_execution_convert_dtype_p64_o32(self): + with tempfile.TemporaryDirectory() as tmp_dir: + in_trk = os.path.join(get_home(), "DSI", "CC_fix.trk") + out_convert_path = os.path.join(tmp_dir, "CC_fix_p16_o64.trx") + convert_tractogram( + in_trk, + out_convert_path, + None, + pos_dtype="float64", + offsets_dtype="uint32", + ) + + trx = tmm.load(out_convert_path) + assert_equal(trx.streamlines._data.dtype, np.float64) + assert_equal(trx.streamlines._offsets.dtype, np.uint32) + trx.close() + + def test_execution_generate_trx_from_scratch(self): + with tempfile.TemporaryDirectory() as tmp_dir: + reference_fa = os.path.join(get_home(), "trx_from_scratch", "fa.nii.gz") + raw_arr_dir = os.path.join(get_home(), "trx_from_scratch", "test_npy") + expected_trx = os.path.join(get_home(), "trx_from_scratch", "expected.trx") + + dpv = [ + (os.path.join(raw_arr_dir, "dpv_cx.npy"), "uint8"), + (os.path.join(raw_arr_dir, "dpv_cy.npy"), "uint8"), + (os.path.join(raw_arr_dir, "dpv_cz.npy"), "uint8"), + ] + dps = [ + (os.path.join(raw_arr_dir, "dps_algo.npy"), "uint8"), + (os.path.join(raw_arr_dir, "dps_cw.npy"), "float64"), + ] + dpg = [ + ( + "g_AF_L", + os.path.join(raw_arr_dir, "dpg_AF_L_mean_fa.npy"), + "float32", + ), + ( + "g_AF_R", + os.path.join(raw_arr_dir, "dpg_AF_R_mean_fa.npy"), + "float32", + ), + ("g_AF_L", os.path.join(raw_arr_dir, "dpg_AF_L_volume.npy"), "float32"), + ] + groups = [ + (os.path.join(raw_arr_dir, "g_AF_L.npy"), "int32"), + (os.path.join(raw_arr_dir, "g_AF_R.npy"), "int32"), + (os.path.join(raw_arr_dir, "g_CST_L.npy"), "int32"), + ] + + out_gen_path = os.path.join(tmp_dir, "generated.trx") + generate_trx_from_scratch( + reference_fa, + out_gen_path, + positions=os.path.join(raw_arr_dir, "positions.npy"), + offsets=os.path.join(raw_arr_dir, "offsets.npy"), + positions_dtype="float16", + offsets_dtype="uint64", + space_str="rasmm", + origin_str="nifti", + verify_invalid=False, + dpv=dpv, + dps=dps, + groups=groups, + dpg=dpg, + ) + exp_trx = tmm.load(expected_trx) + gen_trx = tmm.load(out_gen_path) + + assert DeepDiff(exp_trx.get_dtype_dict(), gen_trx.get_dtype_dict()) == {} + + assert_allclose( + exp_trx.streamlines._data, gen_trx.streamlines._data, atol=0.1, rtol=0.1 + ) + assert_equal(exp_trx.streamlines._offsets, gen_trx.streamlines._offsets) + + for key in exp_trx.data_per_vertex.keys(): + assert_equal( + exp_trx.data_per_vertex[key]._data, + gen_trx.data_per_vertex[key]._data, + ) + assert_equal( + exp_trx.data_per_vertex[key]._offsets, + gen_trx.data_per_vertex[key]._offsets, + ) + for key in exp_trx.data_per_streamline.keys(): + assert_equal( + exp_trx.data_per_streamline[key], gen_trx.data_per_streamline[key] + ) + for key in exp_trx.groups.keys(): + assert_equal(exp_trx.groups[key], gen_trx.groups[key]) + + for group in exp_trx.groups.keys(): + if group in exp_trx.data_per_group: + for key in exp_trx.data_per_group[group].keys(): + assert_equal( + exp_trx.data_per_group[group][key], + gen_trx.data_per_group[group][key], + ) + exp_trx.close() + gen_trx.close() + + @pytest.mark.skipif(not dipy_available, reason="Dipy is not installed.") + def test_execution_concatenate_validate_trx(self): + with tempfile.TemporaryDirectory() as tmp_dir: + trx1 = tmm.load(os.path.join(get_home(), "gold_standard", "gs.trx")) + trx2 = tmm.load(os.path.join(get_home(), "gold_standard", "gs.trx")) + trx = tmm.concatenate([trx1, trx2], preallocation=False) + + # Right size + assert_equal(len(trx.streamlines), 2 * len(trx1.streamlines)) + + # Right data + end_idx = trx1.header["NB_VERTICES"] + assert_allclose(trx.streamlines._data[:end_idx], trx1.streamlines._data) + assert_allclose(trx.streamlines._data[end_idx:], trx2.streamlines._data) + + # Right data_per_* + for key in trx.data_per_vertex.keys(): + assert_equal( + trx.data_per_vertex[key]._data[:end_idx], + trx1.data_per_vertex[key]._data, + ) + assert_equal( + trx.data_per_vertex[key]._data[end_idx:], + trx2.data_per_vertex[key]._data, + ) + + end_idx = trx1.header["NB_STREAMLINES"] + for key in trx.data_per_streamline.keys(): + assert_equal( + trx.data_per_streamline[key][:end_idx], + trx1.data_per_streamline[key], + ) + assert_equal( + trx.data_per_streamline[key][end_idx:], + trx2.data_per_streamline[key], + ) + + # Validate + out_concat_path = os.path.join(tmp_dir, "concat.trx") + out_valid_path = os.path.join(tmp_dir, "valid.trx") + tmm.save(trx, out_concat_path) + validate_tractogram( + out_concat_path, + None, + out_valid_path, + remove_identical_streamlines=True, + precision=0, + ) + trx_val = tmm.load(out_valid_path) + + # Right dtype and size + assert DeepDiff(trx.get_dtype_dict(), trx_val.get_dtype_dict()) == {} + assert_equal(len(trx1.streamlines), len(trx_val.streamlines)) + + trx.close() + trx1.close() + trx2.close() + trx_val.close() + + @pytest.mark.skipif(not dipy_available, reason="Dipy is not installed.") + def test_execution_manipulate_trx_datatype(self): + with tempfile.TemporaryDirectory() as tmp_dir: + expected_trx = os.path.join(get_home(), "trx_from_scratch", "expected.trx") + trx = tmm.load(expected_trx) + + expected_dtype = { + "positions": np.dtype("float16"), + "offsets": np.dtype("uint64"), + "dpv": { + "dpv_cx": np.dtype("uint8"), + "dpv_cy": np.dtype("uint8"), + "dpv_cz": np.dtype("uint8"), + }, + "dps": {"dps_algo": np.dtype("uint8"), "dps_cw": np.dtype("float64")}, + "dpg": { + "g_AF_L": { + "dpg_AF_L_mean_fa": np.dtype("float32"), + "dpg_AF_L_volume": np.dtype("float32"), + }, + "g_AF_R": {"dpg_AF_R_mean_fa": np.dtype("float32")}, + }, + "groups": {"g_AF_L": np.dtype("int32"), "g_AF_R": np.dtype("int32")}, + } + + assert DeepDiff(trx.get_dtype_dict(), expected_dtype) == {} + trx.close() + + generated_dtype = { + "positions": np.dtype("float32"), + "offsets": np.dtype("uint32"), + "dpv": { + "dpv_cx": np.dtype("uint16"), + "dpv_cy": np.dtype("uint16"), + "dpv_cz": np.dtype("uint16"), + }, + "dps": {"dps_algo": np.dtype("uint8"), "dps_cw": np.dtype("float32")}, + "dpg": { + "g_AF_L": { + "dpg_AF_L_mean_fa": np.dtype("float64"), + "dpg_AF_L_volume": np.dtype("float32"), + }, + "g_AF_R": {"dpg_AF_R_mean_fa": np.dtype("float64")}, + }, + "groups": {"g_AF_L": np.dtype("uint16"), "g_AF_R": np.dtype("uint16")}, + } + + out_gen_path = os.path.join(tmp_dir, "generated.trx") + manipulate_trx_datatype(expected_trx, out_gen_path, generated_dtype) + trx = tmm.load(out_gen_path) + assert DeepDiff(trx.get_dtype_dict(), generated_dtype) == {} + trx.close() diff --git a/trx/tests/test_io.py b/trx/tests/test_io.py index 3ddb6ad..c5e823d 100644 --- a/trx/tests/test_io.py +++ b/trx/tests/test_io.py @@ -2,69 +2,64 @@ from copy import deepcopy import os -import psutil from tempfile import TemporaryDirectory import zipfile -import pytest import numpy as np from numpy.testing import assert_allclose +import psutil +import pytest try: - from dipy.io.streamline import save_tractogram, load_tractogram + from dipy.io.streamline import load_tractogram, save_tractogram + dipy_available = True except ImportError: dipy_available = False +from trx.fetcher import fetch_data, get_home, get_testing_files_dict +from trx.io import load, save import trx.trx_file_memmap as tmm from trx.trx_file_memmap import TrxFile -from trx.io import load, save -from trx.fetcher import (get_testing_files_dict, - fetch_data, get_home) - -fetch_data(get_testing_files_dict(), keys=['gold_standard.zip']) +fetch_data(get_testing_files_dict(), keys=["gold_standard.zip"]) -@pytest.mark.parametrize("path", [("gs.trk"), ("gs.tck"), - ("gs.vtk")]) -@pytest.mark.skipif(not dipy_available, reason='Dipy is not installed.') +@pytest.mark.parametrize("path", [("gs.trk"), ("gs.tck"), ("gs.vtk")]) +@pytest.mark.skipif(not dipy_available, reason="Dipy is not installed.") def test_seq_ops_sft(path): with TemporaryDirectory() as tmp_dir: - gs_dir = os.path.join(get_home(), 'gold_standard') + gs_dir = os.path.join(get_home(), "gold_standard") path = os.path.join(tmp_dir, path) - obj = load(os.path.join(gs_dir, 'gs.trx'), - os.path.join(gs_dir, 'gs.nii')) + obj = load(os.path.join(gs_dir, "gs.trx"), os.path.join(gs_dir, "gs.nii")) sft_1 = obj.to_sft() save_tractogram(sft_1, path) obj.close() - save_tractogram(sft_1, os.path.join(tmp_dir, 'tmp.trk')) + save_tractogram(sft_1, os.path.join(tmp_dir, "tmp.trk")) - _ = load_tractogram(os.path.join(tmp_dir, 'tmp.trk'), 'same') + _ = load_tractogram(os.path.join(tmp_dir, "tmp.trk"), "same") def test_seq_ops_trx(): with TemporaryDirectory() as tmp_dir: - gs_dir = os.path.join(get_home(), 'gold_standard') - path = os.path.join(gs_dir, 'gs.trx') + gs_dir = os.path.join(get_home(), "gold_standard") + path = os.path.join(gs_dir, "gs.trx") trx_1 = tmm.load(path) - tmm.save(trx_1, os.path.join(tmp_dir, 'tmp.trx')) + tmm.save(trx_1, os.path.join(tmp_dir, "tmp.trx")) trx_1.close() - trx_2 = tmm.load(os.path.join(tmp_dir, 'tmp.trx')) + trx_2 = tmm.load(os.path.join(tmp_dir, "tmp.trx")) trx_2.close() -@pytest.mark.parametrize("path", [("gs.trx"), ("gs.trk"), ("gs.tck"), - ("gs.vtk")]) -@pytest.mark.skipif(not dipy_available, reason='Dipy is not installed.') +@pytest.mark.parametrize("path", [("gs.trx"), ("gs.trk"), ("gs.tck"), ("gs.vtk")]) +@pytest.mark.skipif(not dipy_available, reason="Dipy is not installed.") def test_load_vox(path): - gs_dir = os.path.join(get_home(), 'gold_standard') + gs_dir = os.path.join(get_home(), "gold_standard") path = os.path.join(gs_dir, path) - coord = np.loadtxt(os.path.join(get_home(), 'gold_standard', - 'gs_vox_space.txt')) - obj = load(path, os.path.join(gs_dir, 'gs.nii')) + coord = np.loadtxt(os.path.join(get_home(), "gold_standard", "gs_vox_space.txt")) + obj = load(path, os.path.join(gs_dir, "gs.nii")) sft = obj.to_sft() if isinstance(obj, TrxFile) else obj sft.to_vox() @@ -74,15 +69,13 @@ def test_load_vox(path): obj.close() -@pytest.mark.parametrize("path", [("gs.trx"), ("gs.trk"), ("gs.tck"), - ("gs.vtk")]) -@pytest.mark.skipif(not dipy_available, reason='Dipy is not installed.') +@pytest.mark.parametrize("path", [("gs.trx"), ("gs.trk"), ("gs.tck"), ("gs.vtk")]) +@pytest.mark.skipif(not dipy_available, reason="Dipy is not installed.") def test_load_voxmm(path): - gs_dir = os.path.join(get_home(), 'gold_standard') + gs_dir = os.path.join(get_home(), "gold_standard") path = os.path.join(gs_dir, path) - coord = np.loadtxt(os.path.join(get_home(), 'gold_standard', - 'gs_voxmm_space.txt')) - obj = load(path, os.path.join(gs_dir, 'gs.nii')) + coord = np.loadtxt(os.path.join(get_home(), "gold_standard", "gs_voxmm_space.txt")) + obj = load(path, os.path.join(gs_dir, "gs.nii")) sft = obj.to_sft() if isinstance(obj, TrxFile) else obj sft.to_voxmm() @@ -93,25 +86,25 @@ def test_load_voxmm(path): @pytest.mark.parametrize("path", [("gs.trk"), ("gs.trx"), ("gs_fldr.trx")]) -@pytest.mark.skipif(not dipy_available, reason='Dipy is not installed.') +@pytest.mark.skipif(not dipy_available, reason="Dipy is not installed.") def test_multi_load_save_rasmm(path): with TemporaryDirectory() as tmp_gs_dir: - gs_dir = os.path.join(get_home(), 'gold_standard') + gs_dir = os.path.join(get_home(), "gold_standard") basename, ext = os.path.splitext(path) path = os.path.join(gs_dir, path) - coord = np.loadtxt(os.path.join(get_home(), 'gold_standard', - 'gs_rasmm_space.txt')) + coord = np.loadtxt( + os.path.join(get_home(), "gold_standard", "gs_rasmm_space.txt") + ) - obj = load(path, os.path.join(gs_dir, 'gs.nii')) + obj = load(path, os.path.join(gs_dir, "gs.nii")) for i in range(3): - out_path = os.path.join( - tmp_gs_dir, '{}_tmp{}_{}'.format(basename, i, ext)) + out_path = os.path.join(tmp_gs_dir, "{}_tmp{}_{}".format(basename, i, ext)) save(obj, out_path) if isinstance(obj, TrxFile): obj.close() - obj = load(out_path, os.path.join(gs_dir, 'gs.nii')) + obj = load(out_path, os.path.join(gs_dir, "gs.nii")) assert_allclose(obj.streamlines._data, coord, rtol=1e-04, atol=1e-06) if isinstance(obj, TrxFile): @@ -119,9 +112,9 @@ def test_multi_load_save_rasmm(path): @pytest.mark.parametrize("path", [("gs.trx"), ("gs_fldr.trx")]) -@pytest.mark.skipif(not dipy_available, reason='Dipy is not installed.') +@pytest.mark.skipif(not dipy_available, reason="Dipy is not installed.") def test_delete_tmp_gs_dir(path): - gs_dir = os.path.join(get_home(), 'gold_standard') + gs_dir = os.path.join(get_home(), "gold_standard") path = os.path.join(gs_dir, path) trx1 = tmm.load(path) @@ -131,10 +124,12 @@ def test_delete_tmp_gs_dir(path): sft = trx1.to_sft() trx1.close() - coord_rasmm = np.loadtxt(os.path.join(get_home(), 'gold_standard', - 'gs_rasmm_space.txt')) - coord_vox = np.loadtxt(os.path.join(get_home(), 'gold_standard', - 'gs_vox_space.txt')) + coord_rasmm = np.loadtxt( + os.path.join(get_home(), "gold_standard", "gs_rasmm_space.txt") + ) + coord_vox = np.loadtxt( + os.path.join(get_home(), "gold_standard", "gs_vox_space.txt") + ) # The folder trx representation does not need tmp files if os.path.isfile(path): @@ -144,33 +139,38 @@ def test_delete_tmp_gs_dir(path): # Reloading the TRX and checking its data, then closing trx2 = tmm.load(path) - assert_allclose(trx2.streamlines._data, - sft.streamlines._data, rtol=1e-04, atol=1e-06) + assert_allclose( + trx2.streamlines._data, sft.streamlines._data, rtol=1e-04, atol=1e-06 + ) trx2.close() sft.to_vox() assert_allclose(sft.streamlines._data, coord_vox, rtol=1e-04, atol=1e-06) trx3 = tmm.load(path) - assert_allclose(trx3.streamlines._data, - coord_rasmm, rtol=1e-04, atol=1e-06) + assert_allclose(trx3.streamlines._data, coord_rasmm, rtol=1e-04, atol=1e-06) trx3.close() @pytest.mark.parametrize("path", [("gs.trx")]) -@pytest.mark.skipif(not dipy_available, reason='Dipy is not installed.') +@pytest.mark.skipif(not dipy_available, reason="Dipy is not installed.") def test_close_tmp_files(path): - gs_dir = os.path.join(get_home(), 'gold_standard') + gs_dir = os.path.join(get_home(), "gold_standard") path = os.path.join(gs_dir, path) trx = tmm.load(path) process = psutil.Process(os.getpid()) open_files = process.open_files() - expected_content = ['offsets.uint32', 'positions.3.float32', - 'header.json', 'random_coord.3.float32', - 'color_y.float32', 'color_x.float32', - 'color_z.float32'] + expected_content = [ + "offsets.uint32", + "positions.3.float32", + "header.json", + "random_coord.3.float32", + "color_y.float32", + "color_x.float32", + "color_z.float32", + ] count = 0 for open_file in open_files: @@ -192,18 +192,18 @@ def test_close_tmp_files(path): @pytest.mark.parametrize("tmp_path", [("~"), ("use_working_dir")]) def test_change_tmp_dir(tmp_path): - gs_dir = os.path.join(get_home(), 'gold_standard') - path = os.path.join(gs_dir, 'gs.trx') + gs_dir = os.path.join(get_home(), "gold_standard") + path = os.path.join(gs_dir, "gs.trx") - if tmp_path == 'use_working_dir': - os.environ['TRX_TMPDIR'] = 'use_working_dir' + if tmp_path == "use_working_dir": + os.environ["TRX_TMPDIR"] = "use_working_dir" else: - os.environ['TRX_TMPDIR'] = os.path.expanduser(tmp_path) + os.environ["TRX_TMPDIR"] = os.path.expanduser(tmp_path) trx = tmm.load(path) tmp_gs_dir = deepcopy(trx._uncompressed_folder_handle.name) - if tmp_path == 'use_working_dir': + if tmp_path == "use_working_dir": assert os.path.dirname(tmp_gs_dir) == os.getcwd() else: assert os.path.dirname(tmp_gs_dir) == os.path.expanduser(tmp_path) @@ -214,7 +214,7 @@ def test_change_tmp_dir(tmp_path): @pytest.mark.parametrize("path", [("gs.trx"), ("gs_fldr.trx")]) def test_complete_dir_from_trx(path): - gs_dir = os.path.join(get_home(), 'gold_standard') + gs_dir = os.path.join(get_home(), "gold_standard") path = os.path.join(gs_dir, path) trx = tmm.load(path) @@ -227,25 +227,35 @@ def test_complete_dir_from_trx(path): for dirpath, _, filenames in os.walk(dir_to_check): for filename in filenames: full_path = os.path.join(dirpath, filename) - cut_path = full_path.split(dir_to_check)[1][1:].replace('\\', '/') + cut_path = full_path.split(dir_to_check)[1][1:].replace("\\", "/") file_paths.append(cut_path) - expected_content = ['offsets.uint32', 'positions.3.float32', - 'header.json', 'dps/random_coord.3.float32', - 'dpv/color_y.float32', 'dpv/color_x.float32', - 'dpv/color_z.float32'] + expected_content = [ + "offsets.uint32", + "positions.3.float32", + "header.json", + "dps/random_coord.3.float32", + "dpv/color_y.float32", + "dpv/color_x.float32", + "dpv/color_z.float32", + ] assert set(file_paths) == set(expected_content) def test_complete_zip_from_trx(): - gs_dir = os.path.join(get_home(), 'gold_standard') - path = os.path.join(gs_dir, 'gs.trx') + gs_dir = os.path.join(get_home(), "gold_standard") + path = os.path.join(gs_dir, "gs.trx") with zipfile.ZipFile(path, mode="r") as zf: zip_file_list = zf.namelist() - expected_content = ['offsets.uint32', 'positions.3.float32', - 'header.json', 'dps/random_coord.3.float32', - 'dpv/color_y.float32', 'dpv/color_x.float32', - 'dpv/color_z.float32'] + expected_content = [ + "offsets.uint32", + "positions.3.float32", + "header.json", + "dps/random_coord.3.float32", + "dpv/color_y.float32", + "dpv/color_x.float32", + "dpv/color_z.float32", + ] assert set(zip_file_list) == set(expected_content) diff --git a/trx/tests/test_memmap.py b/trx/tests/test_memmap.py index 2ccedbd..529600c 100644 --- a/trx/tests/test_memmap.py +++ b/trx/tests/test_memmap.py @@ -2,24 +2,23 @@ import os -from nibabel.streamlines.tests.test_tractogram import make_dummy_streamline from nibabel.streamlines import LazyTractogram +from nibabel.streamlines.tests.test_tractogram import make_dummy_streamline import numpy as np import pytest try: import dipy # noqa: F401 + dipy_available = True except ImportError: dipy_available = False +from trx.fetcher import fetch_data, get_home, get_testing_files_dict from trx.io import get_trx_tmp_dir import trx.trx_file_memmap as tmm -from trx.fetcher import (get_testing_files_dict, - fetch_data, get_home) - -fetch_data(get_testing_files_dict(), keys=['memmap_test_data.zip']) +fetch_data(get_testing_files_dict(), keys=["memmap_test_data.zip"]) tmp_dir = get_trx_tmp_dir() @@ -36,11 +35,9 @@ def test__generate_filename_from_data( arr, expected, value_error, filename="mean_fa.bit" ): - if value_error: with pytest.raises(ValueError): - new_fn = tmm._generate_filename_from_data(arr=arr, - filename=filename) + new_fn = tmm._generate_filename_from_data(arr=arr, filename=filename) assert new_fn is None else: new_fn = tmm._generate_filename_from_data(arr=arr, filename=filename) @@ -55,8 +52,7 @@ def test__generate_filename_from_data( ("mean_fa", None, True), ("mean_fa.5.4.int32", None, True), pytest.param( - "mean_fa.fa", None, True, marks=pytest.mark.xfail, - id="invalid extension" + "mean_fa.fa", None, True, marks=pytest.mark.xfail, id="invalid extension" ), ], ) @@ -72,16 +68,18 @@ def test__split_ext_with_dimensionality(filename, expected, value_error): "offsets,nb_vertices,expected", [ (np.array(range(5), dtype=np.int16), 4, np.array([1, 1, 1, 1, 0])), - (np.array([0, 1, 1, 3, 4], dtype=np.int32), - 4, np.array([1, 0, 2, 1, 0])), + (np.array([0, 1, 1, 3, 4], dtype=np.int32), 4, np.array([1, 0, 2, 1, 0])), (np.array(range(4), dtype=np.uint64), 4, np.array([1, 1, 1, 1])), - pytest.param(np.array([0, 1, 0, 3, 4], dtype=np.int16), 4, - np.array([1, 3, 0, 1, 0]), marks=pytest.mark.xfail, - id="offsets not sorted"), + pytest.param( + np.array([0, 1, 0, 3, 4], dtype=np.int16), + 4, + np.array([1, 3, 0, 1, 0]), + marks=pytest.mark.xfail, + id="offsets not sorted", + ), ], ) def test__compute_lengths(offsets, nb_vertices, expected): - offsets = tmm._append_last_offsets(offsets, nb_vertices) lengths = tmm._compute_lengths(offsets=offsets) assert np.array_equal(lengths, expected) @@ -120,8 +118,7 @@ def test__dichotomic_search(arr, l_bound, r_bound, expected): @pytest.mark.parametrize( "basename, create, expected", [ - ("offsets.int16", True, np.array(range(12), dtype=np.int16).reshape(( - 3, 4))), + ("offsets.int16", True, np.array(range(12), dtype=np.int16).reshape((3, 4))), ("offsets.float32", False, None), ], ) @@ -132,18 +129,15 @@ def test__create_memmap(basename, create, expected): filename = os.path.join(dirname, basename) fp = np.memmap(filename, dtype=np.int16, mode="w+", shape=(3, 4)) fp[:] = expected[:] - mmarr = tmm._create_memmap(filename=filename, shape=(3, 4), - dtype=np.int16) + mmarr = tmm._create_memmap(filename=filename, shape=(3, 4), dtype=np.int16) assert np.array_equal(mmarr, expected) else: with get_trx_tmp_dir() as dirname: filename = os.path.join(dirname, basename) - mmarr = tmm._create_memmap(filename=filename, shape=(0,), - dtype=np.int16) + mmarr = tmm._create_memmap(filename=filename, shape=(0,), dtype=np.int16) assert os.path.isfile(filename) - assert np.array_equal(mmarr, np.zeros( - shape=(0,), dtype=np.float32)) + assert np.array_equal(mmarr, np.zeros(shape=(0,), dtype=np.float32)) # need dpg test with missing keys @@ -157,7 +151,7 @@ def test__create_memmap(basename, create, expected): ], ) def test_load(path, check_dpg, value_error): - path = os.path.join(get_home(), 'memmap_test_data', path) + path = os.path.join(get_home(), "memmap_test_data", path) # Need to perhaps improve test if value_error: with pytest.raises(ValueError): @@ -165,25 +159,24 @@ def test_load(path, check_dpg, value_error): tmm.load(input_obj=path, check_dpg=check_dpg), tmm.TrxFile ) else: - assert isinstance(tmm.load(input_obj=path, check_dpg=check_dpg), - tmm.TrxFile) + assert isinstance(tmm.load(input_obj=path, check_dpg=check_dpg), tmm.TrxFile) @pytest.mark.parametrize("path", [("small.trx")]) def test_load_zip(path): - path = os.path.join(get_home(), 'memmap_test_data', path) + path = os.path.join(get_home(), "memmap_test_data", path) assert isinstance(tmm.load_from_zip(path), tmm.TrxFile) @pytest.mark.parametrize("path", [("small_fldr.trx")]) def test_load_directory(path): - path = os.path.join(get_home(), 'memmap_test_data', path) + path = os.path.join(get_home(), "memmap_test_data", path) assert isinstance(tmm.load_from_directory(path), tmm.TrxFile) @pytest.mark.parametrize("path", [("small.trx")]) def test_concatenate(path): - path = os.path.join(get_home(), 'memmap_test_data', path) + path = os.path.join(get_home(), "memmap_test_data", path) trx1 = tmm.load(path) trx2 = tmm.load(path) concat = tmm.concatenate([trx1, trx2]) @@ -196,10 +189,9 @@ def test_concatenate(path): @pytest.mark.parametrize("path", [("small.trx")]) def test_resize(path): - path = os.path.join(get_home(), 'memmap_test_data', path) + path = os.path.join(get_home(), "memmap_test_data", path) trx1 = tmm.load(path) - concat = tmm.TrxFile(nb_vertices=1000000, nb_streamlines=10000, - init_as=trx1) + concat = tmm.TrxFile(nb_vertices=1000000, nb_streamlines=10000, init_as=trx1) tmm.concatenate([concat, trx1], preallocation=True, delete_groups=True) concat.resize() @@ -209,18 +201,11 @@ def test_resize(path): concat.close() -@pytest.mark.parametrize( - "path, buffer", - [ - ("small.trx", 10000), - ("small.trx", 0) - ] -) +@pytest.mark.parametrize("path, buffer", [("small.trx", 10000), ("small.trx", 0)]) def test_append(path, buffer): - path = os.path.join(get_home(), 'memmap_test_data', path) + path = os.path.join(get_home(), "memmap_test_data", path) trx1 = tmm.load(path) - concat = tmm.TrxFile(nb_vertices=1, nb_streamlines=1, - init_as=trx1) + concat = tmm.TrxFile(nb_vertices=1, nb_streamlines=1, init_as=trx1) concat.append(trx1, extra_buffer=buffer) if buffer > 0: @@ -234,7 +219,7 @@ def test_append(path, buffer): @pytest.mark.parametrize("path, buffer", [("small.trx", 10000)]) @pytest.mark.skipif(not dipy_available, reason="Dipy is not installed") def test_append_StatefulTractogram(path, buffer): - path = os.path.join(get_home(), 'memmap_test_data', path) + path = os.path.join(get_home(), "memmap_test_data", path) trx = tmm.load(path) obj = trx.to_sft() concat = tmm.TrxFile(nb_vertices=1, nb_streamlines=1, init_as=trx) @@ -250,7 +235,7 @@ def test_append_StatefulTractogram(path, buffer): @pytest.mark.parametrize("path, buffer", [("small.trx", 10000)]) def test_append_Tractogram(path, buffer): - path = os.path.join(get_home(), 'memmap_test_data', path) + path = os.path.join(get_home(), "memmap_test_data", path) trx = tmm.load(path) obj = trx.to_tractogram() concat = tmm.TrxFile(nb_vertices=1, nb_streamlines=1, init_as=trx) @@ -264,12 +249,17 @@ def test_append_Tractogram(path, buffer): concat.close() -@pytest.mark.parametrize("path, size, buffer", [("small.trx", 50, 10000), - ("small.trx", 0, 10000), - ("small.trx", 25000, 10000), - ("small.trx", 50, 0), - ("small.trx", 0, 0), - ("small.trx", 25000, 10000)]) +@pytest.mark.parametrize( + "path, size, buffer", + [ + ("small.trx", 50, 10000), + ("small.trx", 0, 10000), + ("small.trx", 25000, 10000), + ("small.trx", 50, 0), + ("small.trx", 0, 0), + ("small.trx", 25000, 10000), + ], +) def test_from_lazy_tractogram(path, size, buffer): _ = np.random.RandomState(1776) streamlines = [] @@ -281,32 +271,36 @@ def test_from_lazy_tractogram(path, size, buffer): data = make_dummy_streamline(i) streamline, data_per_point, data_for_streamline = data streamlines.append(streamline) - fa.append(data_per_point['fa'].astype(np.float16)) - commit_weights.append( - data_for_streamline['mean_curvature'].astype(np.float32)) - clusters_QB.append( - data_for_streamline['mean_torsion'].astype(np.uint16)) - - def streamlines_func(): return (e for e in streamlines) - data_per_point_func = {'fa': lambda: (e for e in fa)} + fa.append(data_per_point["fa"].astype(np.float16)) + commit_weights.append(data_for_streamline["mean_curvature"].astype(np.float32)) + clusters_QB.append(data_for_streamline["mean_torsion"].astype(np.uint16)) + + def streamlines_func(): + return (e for e in streamlines) + + data_per_point_func = {"fa": lambda: (e for e in fa)} data_per_streamline_func = { - 'commit_weights': lambda: (e for e in commit_weights), - 'clusters_QB': lambda: (e for e in clusters_QB)} - - obj = LazyTractogram(streamlines_func, - data_per_streamline_func, - data_per_point_func, - affine_to_rasmm=np.eye(4)) - - dtype_dict = {'positions': np.float32, 'offsets': np.uint32, - 'dpv': {'fa': np.float16}, - 'dps': {'commit_weights': np.float32, - 'clusters_QB': np.uint16}} - path = os.path.join(get_home(), 'memmap_test_data', path) - trx = tmm.TrxFile.from_lazy_tractogram(obj, reference=path, - extra_buffer=buffer, - chunk_size=1000, - dtype_dict=dtype_dict) + "commit_weights": lambda: (e for e in commit_weights), + "clusters_QB": lambda: (e for e in clusters_QB), + } + + obj = LazyTractogram( + streamlines_func, + data_per_streamline_func, + data_per_point_func, + affine_to_rasmm=np.eye(4), + ) + + dtype_dict = { + "positions": np.float32, + "offsets": np.uint32, + "dpv": {"fa": np.float16}, + "dps": {"commit_weights": np.float32, "clusters_QB": np.uint16}, + } + path = os.path.join(get_home(), "memmap_test_data", path) + trx = tmm.TrxFile.from_lazy_tractogram( + obj, reference=path, extra_buffer=buffer, chunk_size=1000, dtype_dict=dtype_dict + ) assert len(trx) == len(gen_range) diff --git a/trx/tests/test_streamlines_ops.py b/trx/tests/test_streamlines_ops.py index 3cd2678..268606b 100644 --- a/trx/tests/test_streamlines_ops.py +++ b/trx/tests/test_streamlines_ops.py @@ -3,11 +3,17 @@ import numpy as np import pytest -from trx.streamlines_ops import (perform_streamlines_operation, - intersection, union, difference) +from trx.streamlines_ops import ( + difference, + intersection, + perform_streamlines_operation, + union, +) -streamlines_ori = [np.ones(90).reshape((30, 3)), - np.arange(90).reshape((30, 3)) + 0.3333] +streamlines_ori = [ + np.ones(90).reshape((30, 3)), + np.arange(90).reshape((30, 3)) + 0.3333, +] @pytest.mark.parametrize( @@ -27,16 +33,15 @@ def test_intersection(precision, noise, expected): streamlines_new = [] for i in range(5): if i < 4: - streamlines_new.append(streamlines_ori[1] + - np.random.random((30, 3))) + streamlines_new.append(streamlines_ori[1] + np.random.random((30, 3))) else: - streamlines_new.append(streamlines_ori[1] + - noise * np.random.random((30, 3))) + streamlines_new.append( + streamlines_ori[1] + noise * np.random.random((30, 3)) + ) # print(streamlines_new) - _, indices_uniq = perform_streamlines_operation(intersection, - [streamlines_new, - streamlines_ori], - precision=precision) + _, indices_uniq = perform_streamlines_operation( + intersection, [streamlines_new, streamlines_ori], precision=precision + ) indices_uniq = indices_uniq.tolist() assert indices_uniq == expected @@ -58,16 +63,15 @@ def test_union(precision, noise, expected): streamlines_new = [] for i in range(5): if i < 4: - streamlines_new.append(streamlines_ori[1] + - np.random.random((30, 3))) + streamlines_new.append(streamlines_ori[1] + np.random.random((30, 3))) else: - streamlines_new.append(streamlines_ori[1] + - noise * np.random.random((30, 3))) + streamlines_new.append( + streamlines_ori[1] + noise * np.random.random((30, 3)) + ) - unique_streamlines, _ = perform_streamlines_operation(union, - [streamlines_new, - streamlines_ori], - precision=precision) + unique_streamlines, _ = perform_streamlines_operation( + union, [streamlines_new, streamlines_ori], precision=precision + ) assert len(unique_streamlines) == expected @@ -88,14 +92,13 @@ def test_difference(precision, noise, expected): streamlines_new = [] for i in range(5): if i < 4: - streamlines_new.append(streamlines_ori[1] + - np.random.random((30, 3))) + streamlines_new.append(streamlines_ori[1] + np.random.random((30, 3))) else: - streamlines_new.append(streamlines_ori[1] + - noise * np.random.random((30, 3))) + streamlines_new.append( + streamlines_ori[1] + noise * np.random.random((30, 3)) + ) - unique_streamlines, _ = perform_streamlines_operation(difference, - [streamlines_new, - streamlines_ori], - precision=precision) + unique_streamlines, _ = perform_streamlines_operation( + difference, [streamlines_new, streamlines_ori], precision=precision + ) assert len(unique_streamlines) == expected diff --git a/trx/trx_file_memmap.py b/trx/trx_file_memmap.py index 471cdc4..5a70e82 100644 --- a/trx/trx_file_memmap.py +++ b/trx/trx_file_memmap.py @@ -5,7 +5,7 @@ import logging import os import shutil -from typing import Any, List, Tuple, Type, Union, Optional +from typing import Any, List, Optional, Tuple, Type, Union import zipfile import nibabel as nib @@ -13,18 +13,21 @@ from nibabel.nifti1 import Nifti1Header, Nifti1Image from nibabel.orientations import aff2axcodes from nibabel.streamlines.array_sequence import ArraySequence +from nibabel.streamlines.tractogram import LazyTractogram, Tractogram from nibabel.streamlines.trk import TrkFile -from nibabel.streamlines.tractogram import Tractogram, LazyTractogram import numpy as np from trx.io import get_trx_tmp_dir -from trx.utils import (append_generator_to_dict, - close_or_delete_mmap, - convert_data_dict_to_tractogram, - get_reference_info_wrapper) +from trx.utils import ( + append_generator_to_dict, + close_or_delete_mmap, + convert_data_dict_to_tractogram, + get_reference_info_wrapper, +) try: import dipy # noqa: F401 + dipy_available = True except ImportError: dipy_available = False @@ -42,9 +45,12 @@ def _append_last_offsets(nib_offsets: np.ndarray, nb_vertices: int) -> np.ndarra Returns: Offsets -- np.ndarray (VTK convention) """ - def is_sorted(a): return np.all(a[:-1] <= a[1:]) + + def is_sorted(a): + return np.all(a[:-1] <= a[1:]) + if not is_sorted(nib_offsets): - raise ValueError('Offsets must be sorted values.') + raise ValueError("Offsets must be sorted values.") return np.append(nib_offsets, nb_vertices).astype(nib_offsets.dtype) @@ -64,7 +70,7 @@ def _generate_filename_from_data(arr: np.ndarray, filename: str) -> str: logging.warning("Will overwrite provided extension if needed.") dtype = arr.dtype - dtype = "bit" if dtype == bool else dtype.name + dtype = "bit" if dtype is np.dtype(bool) else dtype.name if arr.ndim == 1: new_filename = "{}.{}".format(base, dtype) @@ -196,8 +202,7 @@ def _create_memmap( if shape[0]: return np.memmap( - filename, mode=mode, offset=offset, shape=shape, dtype=dtype, - order=order + filename, mode=mode, offset=offset, shape=shape, dtype=dtype, order=order ) else: if not os.path.isfile(filename): @@ -233,8 +238,7 @@ def load(input_obj: str, check_dpg: bool = True) -> Type["TrxFile"]: trx = load_from_directory(tmp_dir.name) trx._uncompressed_folder_handle = tmp_dir logging.info( - "File was compressed, call the close() function before " - "exiting." + "File was compressed, call the close() function before exiting." ) else: trx = load_from_zip(input_obj) @@ -248,8 +252,7 @@ def load(input_obj: str, check_dpg: bool = True) -> Type["TrxFile"]: for dpg in trx.data_per_group.keys(): if dpg not in trx.groups.keys(): raise ValueError( - "An undeclared group ({}) has " "data_per_group.".format( - dpg) + "An undeclared group ({}) has data_per_group.".format(dpg) ) return trx @@ -270,8 +273,7 @@ def load_from_zip(filename: str) -> Type["TrxFile"]: header["VOXEL_TO_RASMM"] = np.reshape( header["VOXEL_TO_RASMM"], (4, 4) ).astype(np.float32) - header["DIMENSIONS"] = np.array( - header["DIMENSIONS"], dtype=np.uint16) + header["DIMENSIONS"] = np.array(header["DIMENSIONS"], dtype=np.uint16) files_pointer_size = {} for zip_info in zf.filelist: @@ -282,8 +284,7 @@ def load_from_zip(filename: str) -> Type["TrxFile"]: if not _is_dtype_valid(ext): continue - raise ValueError( - "The dtype {} is not supported".format(elem_filename)) + raise ValueError("The dtype {} is not supported".format(elem_filename)) if ext == ".bit": ext = ".bool" @@ -318,11 +319,12 @@ def load_from_directory(directory: str) -> Type["TrxFile"]: directory = os.path.abspath(directory) with open(os.path.join(directory, "header.json")) as header: header = json.load(header) - header["VOXEL_TO_RASMM"] = np.reshape(header["VOXEL_TO_RASMM"], - (4, 4)).astype(np.float32) + header["VOXEL_TO_RASMM"] = np.reshape(header["VOXEL_TO_RASMM"], (4, 4)).astype( + np.float32 + ) header["DIMENSIONS"] = np.array(header["DIMENSIONS"], dtype=np.uint16) files_pointer_size = {} - for root, dirs, files in os.walk(directory): + for root, _dirs, files in os.walk(directory): for name in files: elem_filename = os.path.join(root, name) _, ext = os.path.splitext(elem_filename) @@ -346,8 +348,7 @@ def load_from_directory(directory: str) -> Type["TrxFile"]: else: raise ValueError("Wrong size or datatype") - return TrxFile._create_trx_from_pointer(header, files_pointer_size, - root=directory) + return TrxFile._create_trx_from_pointer(header, files_pointer_size, root=directory) def _filter_empty_trx_files(trx_list: List["TrxFile"]) -> List["TrxFile"]: @@ -377,21 +378,25 @@ def _check_space_attributes(trx_list: List["TrxFile"]) -> None: raise ValueError("Wrong space attributes.") -def _verify_dpv_coherence(trx_list: List["TrxFile"], all_dpv: set, - ref_trx: "TrxFile", delete_dpv: bool) -> None: +def _verify_dpv_coherence( + trx_list: List["TrxFile"], all_dpv: set, ref_trx: "TrxFile", delete_dpv: bool +) -> None: """Verify dpv coherence across TrxFiles.""" for curr_trx in trx_list: for key in all_dpv: - if (key not in ref_trx.data_per_vertex.keys() or - key not in curr_trx.data_per_vertex.keys()): + if ( + key not in ref_trx.data_per_vertex.keys() + or key not in curr_trx.data_per_vertex.keys() + ): if not delete_dpv: logging.debug( "{} dpv key does not exist in all TrxFile.".format(key) ) - raise ValueError( - "TrxFile must be sharing identical dpv keys.") - elif (ref_trx.data_per_vertex[key]._data.dtype != - curr_trx.data_per_vertex[key]._data.dtype): + raise ValueError("TrxFile must be sharing identical dpv keys.") + elif ( + ref_trx.data_per_vertex[key]._data.dtype + != curr_trx.data_per_vertex[key]._data.dtype + ): logging.debug( "{} dpv key is not declared with the same dtype " "in all TrxFile.".format(key) @@ -399,21 +404,25 @@ def _verify_dpv_coherence(trx_list: List["TrxFile"], all_dpv: set, raise ValueError("Shared dpv key, has different dtype.") -def _verify_dps_coherence(trx_list: List["TrxFile"], all_dps: set, - ref_trx: "TrxFile", delete_dps: bool) -> None: +def _verify_dps_coherence( + trx_list: List["TrxFile"], all_dps: set, ref_trx: "TrxFile", delete_dps: bool +) -> None: """Verify dps coherence across TrxFiles.""" for curr_trx in trx_list: for key in all_dps: - if (key not in ref_trx.data_per_streamline.keys() or - key not in curr_trx.data_per_streamline.keys()): + if ( + key not in ref_trx.data_per_streamline.keys() + or key not in curr_trx.data_per_streamline.keys() + ): if not delete_dps: logging.debug( "{} dps key does not exist in all TrxFile.".format(key) ) - raise ValueError( - "TrxFile must be sharing identical dps keys.") - elif (ref_trx.data_per_streamline[key].dtype != - curr_trx.data_per_streamline[key].dtype): + raise ValueError("TrxFile must be sharing identical dps keys.") + elif ( + ref_trx.data_per_streamline[key].dtype + != curr_trx.data_per_streamline[key].dtype + ): logging.debug( "{} dps key is not declared with the same dtype " "in all TrxFile.".format(key) @@ -433,8 +442,10 @@ def _compute_groups_info(trx_list: List["TrxFile"]) -> Tuple[dict, dict]: else: all_groups_len[group_key] = len(trx_1.groups[group_key]) - if (group_key in all_groups_dtype and - trx_1.groups[group_key].dtype != all_groups_dtype[group_key]): + if ( + group_key in all_groups_dtype + and trx_1.groups[group_key].dtype != all_groups_dtype[group_key] + ): raise ValueError("Shared group key, has different dtype.") else: all_groups_dtype[group_key] = trx_1.groups[group_key].dtype @@ -442,9 +453,13 @@ def _compute_groups_info(trx_list: List["TrxFile"]) -> Tuple[dict, dict]: return all_groups_len, all_groups_dtype -def _create_new_trx_for_concatenation(trx_list: List["TrxFile"], ref_trx: "TrxFile", - delete_dps: bool, delete_dpv: bool, - delete_groups: bool) -> "TrxFile": +def _create_new_trx_for_concatenation( + trx_list: List["TrxFile"], + ref_trx: "TrxFile", + delete_dps: bool, + delete_dpv: bool, + delete_groups: bool, +) -> "TrxFile": """Create a new TrxFile for concatenation.""" nb_vertices = 0 nb_streamlines = 0 @@ -454,8 +469,7 @@ def _create_new_trx_for_concatenation(trx_list: List["TrxFile"], ref_trx: "TrxFi nb_vertices += curr_pts_len new_trx = TrxFile( - nb_vertices=nb_vertices, nb_streamlines=nb_streamlines, - init_as=ref_trx + nb_vertices=nb_vertices, nb_streamlines=nb_streamlines, init_as=ref_trx ) if delete_dps: new_trx.data_per_streamline = {} @@ -467,9 +481,13 @@ def _create_new_trx_for_concatenation(trx_list: List["TrxFile"], ref_trx: "TrxFi return new_trx -def _setup_groups_for_concatenation(new_trx: "TrxFile", trx_list: List["TrxFile"], - all_groups_len: dict, all_groups_dtype: dict, - delete_groups: bool) -> None: +def _setup_groups_for_concatenation( + new_trx: "TrxFile", + trx_list: List["TrxFile"], + all_groups_len: dict, + all_groups_dtype: dict, + delete_groups: bool, +) -> None: """Setup groups in the new TrxFile for concatenation.""" if delete_groups: return @@ -482,7 +500,7 @@ def _setup_groups_for_concatenation(new_trx: "TrxFile", trx_list: List["TrxFile" dtype = all_groups_dtype[group_key] group_filename = os.path.join( - tmp_dir, "groups/" "{}.{}".format(group_key, dtype.name) + tmp_dir, "groups/{}.{}".format(group_key, dtype.name) ) group_len = all_groups_len[group_key] new_trx.groups[group_key] = _create_memmap( @@ -493,8 +511,9 @@ def _setup_groups_for_concatenation(new_trx: "TrxFile", trx_list: List["TrxFile" count = 0 for curr_trx in trx_list: curr_len = len(curr_trx.groups[group_key]) - new_trx.groups[group_key][pos: pos + curr_len] = \ + new_trx.groups[group_key][pos : pos + curr_len] = ( curr_trx.groups[group_key] + count + ) pos += curr_len count += curr_trx.header["NB_STREAMLINES"] @@ -539,9 +558,7 @@ def concatenate( _check_space_attributes(trx_list) if preallocation and not delete_groups: - raise ValueError( - "Groups are variables, cannot be handled with preallocation" - ) + raise ValueError("Groups are variables, cannot be handled with preallocation") _verify_dpv_coherence(trx_list, all_dpv, ref_trx, delete_dpv) _verify_dps_coherence(trx_list, all_dps, ref_trx, delete_dps) @@ -663,10 +680,9 @@ def __init__( if nb_vertices is None and nb_streamlines is None: if init_as is not None: raise ValueError( - "Cant use init_as without declaring " - "nb_vertices AND nb_streamlines" + "Can't use init_as without declaring nb_vertices AND nb_streamlines" ) - logging.debug("Intializing empty TrxFile.") + logging.debug("Initializing empty TrxFile.") self.header = {} # Using the new format default type tmp_strs = ArraySequence() @@ -685,16 +701,16 @@ def __init__( elif nb_vertices is not None and nb_streamlines is not None: logging.debug( - "Preallocating TrxFile with size {} streamlines" - "and {} vertices.".format(nb_streamlines, nb_vertices) + "Preallocating TrxFile with size {} streamlinesand {} vertices.".format( + nb_streamlines, nb_vertices + ) ) trx = self._initialize_empty_trx( nb_streamlines, nb_vertices, init_as=init_as ) self.__dict__ = trx.__dict__ else: - raise ValueError( - "You must declare both nb_vertices AND " "NB_STREAMLINES") + raise ValueError("You must declare both nb_vertices AND NB_STREAMLINES") self.header["VOXEL_TO_RASMM"] = affine self.header["DIMENSIONS"] = dimensions @@ -710,13 +726,11 @@ def __str__(self) -> str: vox_order = "".join(aff2axcodes(affine)) text = "VOXEL_TO_RASMM: \n{}".format( - np.array2string(affine, formatter={ - "float_kind": lambda x: "%.6f" % x}) + np.array2string(affine, formatter={"float_kind": lambda x: "%.6f" % x}) ) text += "\nDIMENSIONS: {}".format(np.array2string(dimensions)) text += "\nVOX_SIZES: {}".format( - np.array2string(vox_sizes, formatter={ - "float_kind": lambda x: "%.2f" % x}) + np.array2string(vox_sizes, formatter={"float_kind": lambda x: "%.2f" % x}) ) text += "\nVOX_ORDER: {}".format(vox_order) @@ -730,8 +744,7 @@ def __str__(self) -> str: text += "\nstreamline_count: {}".format(strs_len) text += "\nvertex_count: {}".format(pts_len) - text += "\ndata_per_vertex keys: {}".format( - list(self.data_per_vertex.keys())) + text += "\ndata_per_vertex keys: {}".format(list(self.data_per_vertex.keys())) text += "\ndata_per_streamline keys: {}".format( list(self.data_per_streamline.keys()) ) @@ -758,7 +771,7 @@ def __getitem__(self, key) -> Any: key += len(self) key = [key] elif isinstance(key, slice): - key = [ii for ii in range(*key.indices(len(self)))] + key = list(range(*key.indices(len(self)))) return self.select(key, keep_group=False) @@ -780,7 +793,7 @@ def deepcopy(self) -> Type["TrxFile"]: # noqa: C901 if not isinstance(tmp_header["DIMENSIONS"], list): tmp_header["DIMENSIONS"] = tmp_header["DIMENSIONS"].tolist() - # tofile() alway write in C-order + # tofile() always write in C-order if not self._copy_safe: to_dump = self.streamlines.copy()._data tmp_header["NB_STREAMLINES"] = len(self.streamlines) @@ -796,11 +809,13 @@ def deepcopy(self) -> Type["TrxFile"]: # noqa: C901 to_dump.tofile(positions_filename) if not self._copy_safe: - to_dump = _append_last_offsets(self.streamlines.copy()._offsets, - self.header["NB_VERTICES"]) + to_dump = _append_last_offsets( + self.streamlines.copy()._offsets, self.header["NB_VERTICES"] + ) else: - to_dump = _append_last_offsets(self.streamlines._offsets, - self.header["NB_VERTICES"]) + to_dump = _append_last_offsets( + self.streamlines._offsets, self.header["NB_VERTICES"] + ) offsets_filename = _generate_filename_from_data( self.streamlines._offsets, os.path.join(tmp_dir.name, "offsets") ) @@ -847,8 +862,7 @@ def deepcopy(self) -> Type["TrxFile"]: # noqa: C901 os.mkdir(os.path.join(tmp_dir.name, "dpg/", group_key)) to_dump = self.data_per_group[group_key][dpg_key] dpg_filename = _generate_filename_from_data( - to_dump, os.path.join( - tmp_dir.name, "dpg/", group_key, dpg_key) + to_dump, os.path.join(tmp_dir.name, "dpg/", group_key, dpg_key) ) to_dump.tofile(dpg_filename) @@ -909,25 +923,28 @@ def _copy_fixed_arrays_from( return strs_start, pts_start # Mandatory arrays - self.streamlines._data[pts_start:pts_end] = \ - trx.streamlines._data[0:curr_pts_len] - self.streamlines._offsets[strs_start:strs_end] = \ - (trx.streamlines._offsets[0:curr_strs_len] + pts_start) - self.streamlines._lengths[strs_start:strs_end] = \ - trx.streamlines._lengths[0:curr_strs_len] + self.streamlines._data[pts_start:pts_end] = trx.streamlines._data[ + 0:curr_pts_len + ] + self.streamlines._offsets[strs_start:strs_end] = ( + trx.streamlines._offsets[0:curr_strs_len] + pts_start + ) + self.streamlines._lengths[strs_start:strs_end] = trx.streamlines._lengths[ + 0:curr_strs_len + ] # Optional fixed-sized arrays for dpv_key in self.data_per_vertex.keys(): - self.data_per_vertex[dpv_key]._data[ - pts_start:pts_end - ] = trx.data_per_vertex[dpv_key]._data[0:curr_pts_len] + self.data_per_vertex[dpv_key]._data[pts_start:pts_end] = ( + trx.data_per_vertex[dpv_key]._data[0:curr_pts_len] + ) self.data_per_vertex[dpv_key]._offsets = self.streamlines._offsets self.data_per_vertex[dpv_key]._lengths = self.streamlines._lengths for dps_key in self.data_per_streamline.keys(): - self.data_per_streamline[dps_key][ - strs_start:strs_end - ] = trx.data_per_streamline[dps_key][0:curr_strs_len] + self.data_per_streamline[dps_key][strs_start:strs_end] = ( + trx.data_per_streamline[dps_key][0:curr_strs_len] + ) return strs_end, pts_end @@ -968,29 +985,24 @@ def _initialize_empty_trx( # noqa: C901 lengths_dtype = np.dtype(np.uint32) logging.debug( - "Initializing positions with dtype: {}".format( - positions_dtype.name) + "Initializing positions with dtype: {}".format(positions_dtype.name) ) - logging.debug( - "Initializing offsets with dtype: {}".format(offsets_dtype.name)) - logging.debug( - "Initializing lengths with dtype: {}".format(lengths_dtype.name)) + logging.debug("Initializing offsets with dtype: {}".format(offsets_dtype.name)) + logging.debug("Initializing lengths with dtype: {}".format(lengths_dtype.name)) # A TrxFile without init_as only contain the essential arrays positions_filename = os.path.join( tmp_dir.name, "positions.3.{}".format(positions_dtype.name) ) trx.streamlines._data = _create_memmap( - positions_filename, mode="w+", shape=(nb_vertices, 3), - dtype=positions_dtype + positions_filename, mode="w+", shape=(nb_vertices, 3), dtype=positions_dtype ) offsets_filename = os.path.join( tmp_dir.name, "offsets.{}".format(offsets_dtype.name) ) trx.streamlines._offsets = _create_memmap( - offsets_filename, mode="w+", shape=(nb_streamlines,), - dtype=offsets_dtype + offsets_filename, mode="w+", shape=(nb_streamlines,), dtype=offsets_dtype ) trx.streamlines._lengths = np.zeros( shape=(nb_streamlines,), dtype=lengths_dtype @@ -1008,23 +1020,20 @@ def _initialize_empty_trx( # noqa: C901 tmp_as = init_as.data_per_vertex[dpv_key]._data if tmp_as.ndim == 1: dpv_filename = os.path.join( - tmp_dir.name, "dpv/" "{}.{}".format( - dpv_key, dtype.name) + tmp_dir.name, "dpv/{}.{}".format(dpv_key, dtype.name) ) shape = (nb_vertices, 1) elif tmp_as.ndim == 2: dim = tmp_as.shape[-1] shape = (nb_vertices, dim) dpv_filename = os.path.join( - tmp_dir.name, "dpv/" "{}.{}.{}".format( - dpv_key, dim, dtype.name) + tmp_dir.name, "dpv/{}.{}.{}".format(dpv_key, dim, dtype.name) ) else: raise ValueError("Invalid dimensionality.") logging.debug( - "Initializing {} (dpv) with dtype: " - "{}".format(dpv_key, dtype.name) + "Initializing {} (dpv) with dtype: {}".format(dpv_key, dtype.name) ) trx.data_per_vertex[dpv_key] = ArraySequence() trx.data_per_vertex[dpv_key]._data = _create_memmap( @@ -1038,23 +1047,22 @@ def _initialize_empty_trx( # noqa: C901 tmp_as = init_as.data_per_streamline[dps_key] if tmp_as.ndim == 1: dps_filename = os.path.join( - tmp_dir.name, "dps/" "{}.{}".format( - dps_key, dtype.name) + tmp_dir.name, "dps/{}.{}".format(dps_key, dtype.name) ) shape = (nb_streamlines,) elif tmp_as.ndim == 2: dim = tmp_as.shape[-1] shape = (nb_streamlines, dim) dps_filename = os.path.join( - tmp_dir.name, "dps/" "{}.{}.{}".format( - dps_key, dim, dtype.name) + tmp_dir.name, "dps/{}.{}.{}".format(dps_key, dim, dtype.name) ) else: raise ValueError("Invalid dimensionality.") logging.debug( - "Initializing {} (dps) with and dtype: " - "{}".format(dps_key, dtype.name) + "Initializing {} (dps) with and dtype: {}".format( + dps_key, dtype.name + ) ) trx.data_per_streamline[dps_key] = _create_memmap( dps_filename, mode="w+", shape=shape, dtype=dtype @@ -1081,7 +1089,7 @@ def _create_trx_from_pointer( # noqa: C901 root -- The dirname of the ZipFile pointer Returns: - A TrxFile constructer from the pointer provided + A TrxFile constructor from the pointer provided """ # TODO support empty positions, using optional tag? trx = TrxFile() @@ -1101,18 +1109,19 @@ def _create_trx_from_pointer( # noqa: C901 if root is not None: # This is for Unix - if os.name != 'nt' and folder.startswith(root.rstrip("/")): + if os.name != "nt" and folder.startswith(root.rstrip("/")): folder = folder.replace(root, "").lstrip("/") # These three are for Windows - elif ( - os.path.isdir(folder) - and os.path.basename(folder) in ['dpv', 'dps', 'groups'] - ): + elif os.path.isdir(folder) and os.path.basename(folder) in [ + "dpv", + "dps", + "groups", + ]: folder = os.path.basename(folder) - elif os.path.basename(os.path.dirname(folder)) == 'dpg': - folder = os.path.join('dpg', os.path.basename(folder)) + elif os.path.basename(os.path.dirname(folder)) == "dpg": + folder = os.path.join("dpg", os.path.basename(folder)) else: - folder = '' + folder = "" # Parse/walk the directory tree if base == "positions" and folder == "": @@ -1126,13 +1135,13 @@ def _create_trx_from_pointer( # noqa: C901 dtype=ext[1:], ) elif base == "offsets" and folder == "": - if size != trx.header["NB_STREAMLINES"]+1 or dim != 1: + if size != trx.header["NB_STREAMLINES"] + 1 or dim != 1: raise ValueError("Wrong offsets size/dimensionality.") offsets = _create_memmap( filename, mode="r+", offset=mem_adress, - shape=(trx.header["NB_STREAMLINES"]+1,), + shape=(trx.header["NB_STREAMLINES"] + 1,), dtype=ext[1:], ) if offsets[-1] != 0: @@ -1147,8 +1156,7 @@ def _create_trx_from_pointer( # noqa: C901 shape = (trx.header["NB_STREAMLINES"], int(nb_scalar)) trx.data_per_streamline[base] = _create_memmap( - filename, mode="r+", offset=mem_adress, shape=shape, - dtype=ext[1:] + filename, mode="r+", offset=mem_adress, shape=shape, dtype=ext[1:] ) elif folder == "dpv": nb_scalar = size / trx.header["NB_VERTICES"] @@ -1158,8 +1166,7 @@ def _create_trx_from_pointer( # noqa: C901 shape = (trx.header["NB_VERTICES"], int(nb_scalar)) trx.data_per_vertex[base] = _create_memmap( - filename, mode="r+", offset=mem_adress, shape=shape, - dtype=ext[1:] + filename, mode="r+", offset=mem_adress, shape=shape, dtype=ext[1:] ) elif folder.startswith("dpg"): if int(size) != dim: @@ -1173,8 +1180,7 @@ def _create_trx_from_pointer( # noqa: C901 if sub_folder not in trx.data_per_group: trx.data_per_group[sub_folder] = {} trx.data_per_group[sub_folder][data_name] = _create_memmap( - filename, mode="r+", offset=mem_adress, shape=shape, - dtype=ext[1:] + filename, mode="r+", offset=mem_adress, shape=shape, dtype=ext[1:] ) elif folder == "groups": # Groups are simply indices, nothing else @@ -1184,13 +1190,11 @@ def _create_trx_from_pointer( # noqa: C901 else: shape = (int(size),) trx.groups[base] = _create_memmap( - filename, mode="r+", offset=mem_adress, shape=shape, - dtype=ext[1:] + filename, mode="r+", offset=mem_adress, shape=shape, dtype=ext[1:] ) else: logging.error( - "{} is not part of a valid structure.".format( - elem_filename) + "{} is not part of a valid structure.".format(elem_filename) ) # All essential array must be declared @@ -1215,7 +1219,7 @@ def resize( # noqa: C901 nb_vertices: Optional[int] = None, delete_dpg: bool = False, ) -> None: - """Remove the ununsed portion of preallocated memmaps + """Remove the unused portion of preallocated memmaps Keyword arguments: nb_streamlines -- The number of streamlines to keep @@ -1230,8 +1234,7 @@ def resize( # noqa: C901 if nb_streamlines is not None and nb_streamlines < strs_end: strs_end = nb_streamlines logging.info( - "Resizing (down) memmaps, less streamlines than it " - "actually contains." + "Resizing (down) memmaps, less streamlines than it actually contains." ) if nb_vertices is None: @@ -1252,8 +1255,7 @@ def resize( # noqa: C901 logging.debug("TrxFile of the right size, no resizing.") return - trx = self._initialize_empty_trx( - nb_streamlines, nb_vertices, init_as=self) + trx = self._initialize_empty_trx(nb_streamlines, nb_vertices, init_as=self) logging.info( "Resizing streamlines from size {} to {}".format( @@ -1290,8 +1292,7 @@ def resize( # noqa: C901 group_name, mode="w+", shape=(len(tmp),), dtype=group_dtype ) logging.debug( - "{} group went from {} items to {}".format( - group_key, ori_len, len(tmp)) + "{} group went from {} items to {}".format(group_key, ori_len, len(tmp)) ) trx.groups[group_key][:] = tmp @@ -1322,8 +1323,9 @@ def resize( # noqa: C901 dpg_filename, mode="w+", shape=shape, dtype=dpg_dtype ) - trx.data_per_group[group_key][dpg_key][:] = \ - self.data_per_group[group_key][dpg_key] + trx.data_per_group[group_key][dpg_key][:] = self.data_per_group[ + group_key + ][dpg_key] self.close() self.__dict__ = trx.__dict__ @@ -1334,23 +1336,29 @@ def get_dtype_dict(self): Returns A dictionary containing the dtype for each data element """ - dtype_dict = {"positions": self.streamlines._data.dtype, - "offsets": self.streamlines._offsets.dtype, - "dpv": {}, "dps": {}, "dpg": {}, "groups": {}} + dtype_dict = { + "positions": self.streamlines._data.dtype, + "offsets": self.streamlines._offsets.dtype, + "dpv": {}, + "dps": {}, + "dpg": {}, + "groups": {}, + } for key in self.data_per_vertex.keys(): - dtype_dict['dpv'][key] = self.data_per_vertex[key]._data.dtype + dtype_dict["dpv"][key] = self.data_per_vertex[key]._data.dtype for key in self.data_per_streamline.keys(): - dtype_dict['dps'][key] = self.data_per_streamline[key].dtype + dtype_dict["dps"][key] = self.data_per_streamline[key].dtype for group_key in self.data_per_group.keys(): - dtype_dict['groups'][group_key] = self.groups[group_key].dtype + dtype_dict["groups"][group_key] = self.groups[group_key].dtype for group_key in self.data_per_group.keys(): - dtype_dict['dpg'][group_key] = {} + dtype_dict["dpg"][group_key] = {} for dpg_key in self.data_per_group[group_key].keys(): - dtype_dict['dpg'][group_key][dpg_key] = \ - self.data_per_group[group_key][dpg_key].dtype + dtype_dict["dpg"][group_key][dpg_key] = self.data_per_group[group_key][ + dpg_key + ].dtype return dtype_dict @@ -1359,20 +1367,22 @@ def append(self, obj, extra_buffer: int = 0) -> None: if dipy_available: from dipy.io.stateful_tractogram import StatefulTractogram - if not isinstance(obj, (TrxFile, Tractogram)) \ - and (dipy_available and not isinstance(obj, StatefulTractogram)): + if not isinstance(obj, (TrxFile, Tractogram)) and ( + dipy_available and not isinstance(obj, StatefulTractogram) + ): raise TypeError( - "{} is not a supported object type for appending.".format(type(obj))) + "{} is not a supported object type for appending.".format(type(obj)) + ) elif isinstance(obj, Tractogram): - obj = self.from_tractogram(obj, reference=self.header, - dtype_dict=curr_dtype_dict) + obj = self.from_tractogram( + obj, reference=self.header, dtype_dict=curr_dtype_dict + ) elif dipy_available and isinstance(obj, StatefulTractogram): obj = self.from_sft(obj, dtype_dict=curr_dtype_dict) self._append_trx(obj, extra_buffer=extra_buffer) - def _append_trx(self, trx: Type["TrxFile"], - extra_buffer: int = 0) -> None: + def _append_trx(self, trx: Type["TrxFile"], extra_buffer: int = 0) -> None: """Append a TrxFile to another (support buffer) Keyword arguments: @@ -1443,14 +1453,12 @@ def select( lengths_dtype ) new_trx.header["NB_VERTICES"] = len(new_trx.streamlines._data) - new_trx.header["NB_STREAMLINES"] = len( - new_trx.streamlines._lengths) + new_trx.header["NB_STREAMLINES"] = len(new_trx.streamlines._lengths) return new_trx.deepcopy() if copy_safe else new_trx new_trx.streamlines = ( - self.streamlines[indices].copy( - ) if copy_safe else self.streamlines[indices] + self.streamlines[indices].copy() if copy_safe else self.streamlines[indices] ) for dpv_key in self.data_per_vertex.keys(): new_trx.data_per_vertex[dpv_key] = ( @@ -1468,14 +1476,12 @@ def select( # Not keeping group is equivalent to the [] operator if keep_group: - logging.warning( - "Keeping dpg despite affecting the group " "items.") + logging.warning("Keeping dpg despite affecting the group items.") for group_key in self.groups.keys(): # Keep the group indices even when fancy slicing index = np.argsort(indices) sorted_x = indices[index] - sorted_index = np.searchsorted( - sorted_x, self.groups[group_key]) + sorted_index = np.searchsorted(sorted_x, self.groups[group_key]) yindex = np.take(index, sorted_index, mode="clip") mask = indices[yindex] != self.groups[group_key] intersect = yindex[~mask] @@ -1488,22 +1494,22 @@ def select( for dpg_key in self.data_per_group[group_key].keys(): if group_key not in new_trx.data_per_group: new_trx.data_per_group[group_key] = {} - new_trx.data_per_group[group_key][ - dpg_key - ] = self.data_per_group[group_key][dpg_key] + new_trx.data_per_group[group_key][dpg_key] = ( + self.data_per_group[group_key][dpg_key] + ) new_trx.header["NB_VERTICES"] = len(new_trx.streamlines._data) new_trx.header["NB_STREAMLINES"] = len(new_trx.streamlines._lengths) return new_trx.deepcopy() if copy_safe else new_trx @staticmethod - def from_lazy_tractogram(obj: ["LazyTractogram"], reference, - extra_buffer: int = 0, - chunk_size: int = 10000, - dtype_dict: dict = {'positions': np.float32, - 'offsets': np.uint32, - 'dpv': {}, 'dps': {}}) \ - -> Type["TrxFile"]: + def from_lazy_tractogram( + obj: ["LazyTractogram"], + reference, + extra_buffer: int = 0, + chunk_size: int = 10000, + dtype_dict: dict = None, + ) -> Type["TrxFile"]: """Append a TrxFile to another (support buffer) Keyword arguments: @@ -1513,8 +1519,15 @@ def from_lazy_tractogram(obj: ["LazyTractogram"], reference, Use 0 for no buffer. chunk_size -- The number of streamlines to save at a time. """ - - data = {'strs': [], 'dpv': {}, 'dps': {}} + if dtype_dict is None: + dtype_dict = { + "positions": np.float32, + "offsets": np.uint32, + "dpv": {}, + "dps": {}, + } + + data = {"strs": [], "dpv": {}, "dps": {}} concat = None count = 0 iterator = iter(obj) @@ -1529,60 +1542,70 @@ def from_lazy_tractogram(obj: ["LazyTractogram"], reference, if len(obj.streamlines) == 0: concat = TrxFile() else: - concat = TrxFile.from_tractogram(obj, - reference=reference, - dtype_dict=dtype_dict) + concat = TrxFile.from_tractogram( + obj, reference=reference, dtype_dict=dtype_dict + ) elif len(obj.streamlines) > 0: - curr_obj = TrxFile.from_tractogram(obj, - reference=reference, - dtype_dict=dtype_dict) + curr_obj = TrxFile.from_tractogram( + obj, reference=reference, dtype_dict=dtype_dict + ) concat.append(curr_obj) break append_generator_to_dict(i, data) else: obj = convert_data_dict_to_tractogram(data) if concat is None: - concat = TrxFile.from_tractogram(obj, - reference=reference, - dtype_dict=dtype_dict) + concat = TrxFile.from_tractogram( + obj, reference=reference, dtype_dict=dtype_dict + ) else: - curr_obj = TrxFile.from_tractogram(obj, - reference=reference, - dtype_dict=dtype_dict) + curr_obj = TrxFile.from_tractogram( + obj, reference=reference, dtype_dict=dtype_dict + ) concat.append(curr_obj, extra_buffer=extra_buffer) - data = {'strs': [], 'dpv': {}, 'dps': {}} + data = {"strs": [], "dpv": {}, "dps": {}} count = 0 concat.resize() return concat @staticmethod - def from_sft(sft, dtype_dict={}): + def from_sft(sft, dtype_dict=None): """Generate a valid TrxFile from a StatefulTractogram""" + if dtype_dict is None: + dtype_dict = {} if len(sft.dtype_dict) > 0: dtype_dict = sft.dtype_dict - if 'dpp' in dtype_dict: - dtype_dict['dpv'] = dtype_dict.pop('dpp') + if "dpp" in dtype_dict: + dtype_dict["dpv"] = dtype_dict.pop("dpp") elif len(dtype_dict) == 0: - dtype_dict = {'positions': np.float32, 'offsets': np.uint32, - 'dpv': {}, 'dps': {}} + dtype_dict = { + "positions": np.float32, + "offsets": np.uint32, + "dpv": {}, + "dps": {}, + } - positions_dtype = dtype_dict['positions'] - offsets_dtype = dtype_dict['offsets'] + positions_dtype = dtype_dict["positions"] + offsets_dtype = dtype_dict["offsets"] if not np.issubdtype(positions_dtype, np.floating): logging.warning( "Casting positions as {}, considering using a floating point " - "dtype.".format(positions_dtype)) + "dtype.".format(positions_dtype) + ) if not np.issubdtype(offsets_dtype, np.integer): logging.warning( - "Casting offsets as {}, considering using a integer " - "dtype.".format(offsets_dtype)) + "Casting offsets as {}, considering using a integer dtype.".format( + offsets_dtype + ) + ) - trx = TrxFile(nb_vertices=len(sft.streamlines._data), - nb_streamlines=len(sft.streamlines)) + trx = TrxFile( + nb_vertices=len(sft.streamlines._data), nb_streamlines=len(sft.streamlines) + ) trx.header = { "DIMENSIONS": sft.dimensions.tolist(), "VOXEL_TO_RASMM": sft.affine.tolist(), @@ -1600,24 +1623,26 @@ def from_sft(sft, dtype_dict={}): tmp_streamlines = deepcopy(sft.streamlines) # Cast the int64 of Nibabel to uint32 - tmp_streamlines._offsets = tmp_streamlines._offsets.astype( - offsets_dtype) + tmp_streamlines._offsets = tmp_streamlines._offsets.astype(offsets_dtype) tmp_streamlines._data = tmp_streamlines._data.astype(positions_dtype) trx.streamlines = tmp_streamlines for key in sft.data_per_point: - dtype_to_use = dtype_dict['dpv'][key] if key in dtype_dict['dpv'] \ - else np.float32 - trx.data_per_vertex[key] = \ - sft.data_per_point[key] - trx.data_per_vertex[key]._data = \ - sft.data_per_point[key]._data.astype(dtype_to_use) + dtype_to_use = ( + dtype_dict["dpv"][key] if key in dtype_dict["dpv"] else np.float32 + ) + trx.data_per_vertex[key] = sft.data_per_point[key] + trx.data_per_vertex[key]._data = sft.data_per_point[key]._data.astype( + dtype_to_use + ) for key in sft.data_per_streamline: - dtype_to_use = dtype_dict['dps'][key] if key in dtype_dict['dps'] \ - else np.float32 + dtype_to_use = ( + dtype_dict["dps"][key] if key in dtype_dict["dps"] else np.float32 + ) trx.data_per_streamline[key] = sft.data_per_streamline[key].astype( - dtype_to_use) + dtype_to_use + ) # For safety and for RAM, convert the whole object to memmaps tmp_dir = get_trx_tmp_dir() @@ -1633,26 +1658,37 @@ def from_sft(sft, dtype_dict={}): return trx @staticmethod - def from_tractogram(tractogram, reference, - dtype_dict={'positions': np.float32, - 'offsets': np.uint32, - 'dpv': {}, 'dps': {}}): + def from_tractogram( + tractogram, + reference, + dtype_dict=None, + ): """Generate a valid TrxFile from a Nibabel Tractogram""" - - positions_dtype = dtype_dict['positions'] if 'positions' in dtype_dict \ - else np.float32 - offsets_dtype = dtype_dict['offsets'] if 'offsets' in dtype_dict \ - else np.uint32 + if dtype_dict is None: + dtype_dict = { + "positions": np.float32, + "offsets": np.uint32, + "dpv": {}, + "dps": {}, + } + + positions_dtype = ( + dtype_dict["positions"] if "positions" in dtype_dict else np.float32 + ) + offsets_dtype = dtype_dict["offsets"] if "offsets" in dtype_dict else np.uint32 if not np.issubdtype(positions_dtype, np.floating): logging.warning( "Casting positions as {}, considering using a floating point " - "dtype.".format(positions_dtype)) + "dtype.".format(positions_dtype) + ) if not np.issubdtype(offsets_dtype, np.integer): logging.warning( - "Casting offsets as {}, considering using a integer " - "dtype.".format(offsets_dtype)) + "Casting offsets as {}, considering using a integer dtype.".format( + offsets_dtype + ) + ) trx = TrxFile( nb_vertices=len(tractogram.streamlines._data), @@ -1670,24 +1706,26 @@ def from_tractogram(tractogram, reference, tmp_streamlines = deepcopy(tractogram.streamlines) # Cast the int64 of Nibabel to uint32 - tmp_streamlines._offsets = tmp_streamlines._offsets.astype( - offsets_dtype) + tmp_streamlines._offsets = tmp_streamlines._offsets.astype(offsets_dtype) tmp_streamlines._data = tmp_streamlines._data.astype(positions_dtype) trx.streamlines = tmp_streamlines for key in tractogram.data_per_point: - dtype_to_use = dtype_dict['dpv'][key] if key in dtype_dict['dpv'] \ - else np.float32 - trx.data_per_vertex[key] = \ - tractogram.data_per_point[key] - trx.data_per_vertex[key]._data = \ - tractogram.data_per_point[key]._data.astype(dtype_to_use) + dtype_to_use = ( + dtype_dict["dpv"][key] if key in dtype_dict["dpv"] else np.float32 + ) + trx.data_per_vertex[key] = tractogram.data_per_point[key] + trx.data_per_vertex[key]._data = tractogram.data_per_point[ + key + ]._data.astype(dtype_to_use) for key in tractogram.data_per_streamline: - dtype_to_use = dtype_dict['dps'][key] if key in dtype_dict['dps'] \ - else np.float32 - trx.data_per_streamline[key] = \ - tractogram.data_per_streamline[key].astype(dtype_to_use) + dtype_to_use = ( + dtype_dict["dps"][key] if key in dtype_dict["dps"] else np.float32 + ) + trx.data_per_streamline[key] = tractogram.data_per_streamline[key].astype( + dtype_to_use + ) # For safety and for RAM, convert the whole object to memmaps tmp_dir = get_trx_tmp_dir() @@ -1732,8 +1770,7 @@ def to_memory(self, resize: bool = False) -> Type["TrxFile"]: trx_obj.data_per_vertex[key] = deepcopy(self.data_per_vertex[key]) for key in self.data_per_streamline: - trx_obj.data_per_streamline[key] = deepcopy( - self.data_per_streamline[key]) + trx_obj.data_per_streamline[key] = deepcopy(self.data_per_streamline[key]) for key in self.groups: trx_obj.groups[key] = deepcopy(self.groups[key]) @@ -1746,10 +1783,11 @@ def to_memory(self, resize: bool = False) -> Type["TrxFile"]: def to_sft(self, resize=False): """Convert a TrxFile to a valid StatefulTractogram (in RAM)""" try: - from dipy.io.stateful_tractogram import StatefulTractogram, Space + from dipy.io.stateful_tractogram import Space, StatefulTractogram except ImportError: - logging.error('Dipy library is missing, cannot convert to ' - 'StatefulTractogram.') + logging.error( + "Dipy library is missing, cannot convert to StatefulTractogram." + ) return None affine = np.array(self.header["VOXEL_TO_RASMM"], dtype=np.float32) @@ -1768,8 +1806,8 @@ def to_sft(self, resize=False): data_per_streamline=deepcopy(self.data_per_streamline), ) tmp_dict = self.get_dtype_dict() - if 'dpv' in tmp_dict: - tmp_dict['dpp'] = tmp_dict.pop('dpv') + if "dpv" in tmp_dict: + tmp_dict["dpp"] = tmp_dict.pop("dpv") sft.dtype_dict = self.get_dtype_dict() return sft @@ -1797,9 +1835,8 @@ def close(self) -> None: self._uncompressed_folder_handle.cleanup() except PermissionError: logging.error( - "Windows PermissionError, temporary directory %s was not " - "deleted!", + "Windows PermissionError, temporary directory %s was not deleted!", self._uncompressed_folder_handle.name, ) self.__init__() - logging.debug("Deleted memmaps and intialized empty TrxFile.") + logging.debug("Deleted memmaps and initialized empty TrxFile.") diff --git a/trx/utils.py b/trx/utils.py index d009f63..314b4c2 100644 --- a/trx/utils.py +++ b/trx/utils.py @@ -1,16 +1,16 @@ # -*- coding: utf-8 -*- -from nibabel.streamlines.tractogram import TractogramItem -from nibabel.streamlines.tractogram import Tractogram -from nibabel.streamlines.array_sequence import ArraySequence -import os import logging +import os import nibabel as nib +from nibabel.streamlines.array_sequence import ArraySequence +from nibabel.streamlines.tractogram import Tractogram, TractogramItem import numpy as np try: import dipy + dipy_available = True except ImportError: dipy_available = False @@ -26,7 +26,7 @@ def close_or_delete_mmap(obj): The object that potentially has a memory-mapped file to be closed. """ - if hasattr(obj, '_mmap') and obj._mmap is not None: + if hasattr(obj, "_mmap") and obj._mmap is not None: obj._mmap.close() elif isinstance(obj, ArraySequence): close_or_delete_mmap(obj._data) @@ -35,7 +35,7 @@ def close_or_delete_mmap(obj): elif isinstance(obj, np.memmap): del obj else: - logging.debug('Object to be close or deleted must be np.memmap') + logging.debug("Object to be close or deleted must be np.memmap") def split_name_with_gz(filename): @@ -67,7 +67,7 @@ def split_name_with_gz(filename): def get_reference_info_wrapper(reference): # noqa: C901 - """ Will compare the spatial attribute of 2 references. + """Will compare the spatial attribute of 2 references. Parameters ---------- @@ -77,25 +77,26 @@ def get_reference_info_wrapper(reference): # noqa: C901 Returns ------- output : tuple - - affine ndarray (4,4), np.float32, tranformation of VOX to RASMM + - affine ndarray (4,4), np.float32, transformation of VOX to RASMM - dimensions ndarray (3,), int16, volume shape for each axis - voxel_sizes ndarray (3,), float32, size of voxel for each axis - voxel_order, string, Typically 'RAS' or 'LPS' """ from trx import trx_file_memmap + is_nifti = False is_trk = False is_sft = False is_trx = False if isinstance(reference, str): _, ext = split_name_with_gz(reference) - if ext in ['.nii', '.nii.gz']: + if ext in [".nii", ".nii.gz"]: header = nib.load(reference).header is_nifti = True - elif ext == '.trk': + elif ext == ".trk": header = nib.streamlines.load(reference, lazy_load=True).header is_trk = True - elif ext == '.trx': + elif ext == ".trx": header = trx_file_memmap.load(reference).header is_trx = True elif isinstance(reference, trx_file_memmap.TrxFile): @@ -110,53 +111,56 @@ def get_reference_info_wrapper(reference): # noqa: C901 elif isinstance(reference, nib.nifti1.Nifti1Header): header = reference is_nifti = True - elif isinstance(reference, dict) and 'magic_number' in reference: + elif isinstance(reference, dict) and "magic_number" in reference: header = reference is_trk = True - elif isinstance(reference, dict) and 'NB_VERTICES' in reference: + elif isinstance(reference, dict) and "NB_VERTICES" in reference: header = reference is_trx = True - elif dipy_available and \ - isinstance(reference, dipy.io.stateful_tractogram.StatefulTractogram): + elif dipy_available and isinstance( + reference, dipy.io.stateful_tractogram.StatefulTractogram + ): is_sft = True if is_nifti: affine = header.get_best_affine() - dimensions = header['dim'][1:4] - voxel_sizes = header['pixdim'][1:4] + dimensions = header["dim"][1:4] + voxel_sizes = header["pixdim"][1:4] if not affine[0:3, 0:3].any(): raise ValueError( - 'Invalid affine, contains only zeros.' - 'Cannot determine voxel order from transformation') - voxel_order = ''.join(nib.aff2axcodes(affine)) + "Invalid affine, contains only zeros." + "Cannot determine voxel order from transformation" + ) + voxel_order = "".join(nib.aff2axcodes(affine)) elif is_trk: - affine = header['voxel_to_rasmm'] - dimensions = header['dimensions'] - voxel_sizes = header['voxel_sizes'] - voxel_order = header['voxel_order'] + affine = header["voxel_to_rasmm"] + dimensions = header["dimensions"] + voxel_sizes = header["voxel_sizes"] + voxel_order = header["voxel_order"] elif is_sft: affine, dimensions, voxel_sizes, voxel_order = reference.space_attributes elif is_trx: - affine = header['VOXEL_TO_RASMM'] - dimensions = header['DIMENSIONS'] + affine = header["VOXEL_TO_RASMM"] + dimensions = header["DIMENSIONS"] voxel_sizes = nib.affines.voxel_sizes(affine) - voxel_order = ''.join(nib.aff2axcodes(affine)) + voxel_order = "".join(nib.aff2axcodes(affine)) else: - raise TypeError('Input reference is not one of the supported format') + raise TypeError("Input reference is not one of the supported format") if isinstance(voxel_order, np.bytes_): - voxel_order = voxel_order.decode('utf-8') + voxel_order = voxel_order.decode("utf-8") if dipy_available: from dipy.io.utils import is_reference_info_valid + is_reference_info_valid(affine, dimensions, voxel_sizes, voxel_order) return affine, dimensions, voxel_sizes, voxel_order def is_header_compatible(reference_1, reference_2): - """ Will compare the spatial attribute of 2 references. + """Will compare the spatial attribute of 2 references. Parameters ---------- @@ -173,25 +177,27 @@ def is_header_compatible(reference_1, reference_2): """ affine_1, dimensions_1, voxel_sizes_1, voxel_order_1 = get_reference_info_wrapper( - reference_1) + reference_1 + ) affine_2, dimensions_2, voxel_sizes_2, voxel_order_2 = get_reference_info_wrapper( - reference_2) + reference_2 + ) identical_header = True if not np.allclose(affine_1, affine_2, rtol=1e-03, atol=1e-03): - logging.error('Affine not equal') + logging.error("Affine not equal") identical_header = False if not np.array_equal(dimensions_1, dimensions_2): - logging.error('Dimensions not equal') + logging.error("Dimensions not equal") identical_header = False if not np.allclose(voxel_sizes_1, voxel_sizes_2, rtol=1e-03, atol=1e-03): - logging.error('Voxel_size not equal') + logging.error("Voxel_size not equal") identical_header = False if voxel_order_1 != voxel_order_2: - logging.error('Voxel_order not equal') + logging.error("Voxel_order not equal") identical_header = False return identical_header @@ -211,11 +217,11 @@ def get_axis_shift_vector(flip_axes): Possible values are -1, 1 """ shift_vector = np.zeros(3) - if 'x' in flip_axes: + if "x" in flip_axes: shift_vector[0] = -1.0 - if 'y' in flip_axes: + if "y" in flip_axes: shift_vector[1] = -1.0 - if 'z' in flip_axes: + if "z" in flip_axes: shift_vector[2] = -1.0 return shift_vector @@ -235,11 +241,11 @@ def get_axis_flip_vector(flip_axes): Possible values are -1, 1 """ flip_vector = np.ones(3) - if 'x' in flip_axes: + if "x" in flip_axes: flip_vector[0] = -1.0 - if 'y' in flip_axes: + if "y" in flip_axes: flip_vector[1] = -1.0 - if 'z' in flip_axes: + if "z" in flip_axes: flip_vector[2] = -1.0 return flip_vector @@ -266,7 +272,7 @@ def get_shift_vector(sft): def flip_sft(sft, flip_axes): - """ Flip the streamlines in the StatefulTractogram according to the + """Flip the streamlines in the StatefulTractogram according to the flip_axes. Uses the spatial information to flip according to the center of the grid. @@ -283,8 +289,10 @@ def flip_sft(sft, flip_axes): StatefulTractogram with flipped axes """ if not dipy_available: - logging.error('Dipy library is missing, cannot use functions related ' - 'to the StatefulTractogram.') + logging.error( + "Dipy library is missing, cannot use functions related " + "to the StatefulTractogram." + ) return None flip_vector = get_axis_flip_vector(flip_axes) @@ -298,14 +306,18 @@ def flip_sft(sft, flip_axes): flipped_streamlines.append(mod_streamline) from dipy.io.stateful_tractogram import StatefulTractogram - new_sft = StatefulTractogram.from_sft(flipped_streamlines, sft, - data_per_point=sft.data_per_point, - data_per_streamline=sft.data_per_streamline) + + new_sft = StatefulTractogram.from_sft( + flipped_streamlines, + sft, + data_per_point=sft.data_per_point, + data_per_streamline=sft.data_per_streamline, + ) return new_sft def load_matrix_in_any_format(filepath): - """ Load a matrix from a txt file OR a npy file. + """Load a matrix from a txt file OR a npy file. Parameters ---------- @@ -317,18 +329,18 @@ def load_matrix_in_any_format(filepath): The matrix. """ _, ext = os.path.splitext(filepath) - if ext == '.txt': + if ext == ".txt": data = np.loadtxt(filepath) - elif ext == '.npy': + elif ext == ".npy": data = np.load(filepath) else: - raise ValueError('Extension {} is not supported'.format(ext)) + raise ValueError("Extension {} is not supported".format(ext)) return data def get_reverse_enum(space_str, origin_str): - """ Convert string representation to enums for the StatefulTractogram. + """Convert string representation to enums for the StatefulTractogram. Parameters ---------- @@ -342,14 +354,17 @@ def get_reverse_enum(space_str, origin_str): Space and Origin as Enums. """ if not dipy_available: - logging.error('Dipy library is missing, cannot use functions related ' - 'to the StatefulTractogram.') + logging.error( + "Dipy library is missing, cannot use functions related " + "to the StatefulTractogram." + ) return None - from dipy.io.stateful_tractogram import Space, Origin - origin = Origin.NIFTI if origin_str.lower() == 'nifti' else Origin.TRACKVIS - if space_str.lower() == 'rasmm': + from dipy.io.stateful_tractogram import Origin, Space + + origin = Origin.NIFTI if origin_str.lower() == "nifti" else Origin.TRACKVIS + if space_str.lower() == "rasmm": space = Space.RASMM - elif space_str.lower() == 'voxmm': + elif space_str.lower() == "voxmm": space = Space.VOXMM else: space = Space.VOX @@ -358,7 +373,7 @@ def get_reverse_enum(space_str, origin_str): def convert_data_dict_to_tractogram(data): - """ Convert a data from a lazy tractogram to a tractogram + """Convert a data from a lazy tractogram to a tractogram Keyword arguments: data -- The data dictionary to convert into a nibabel tractogram @@ -366,49 +381,50 @@ def convert_data_dict_to_tractogram(data): Returns: A Tractogram object """ - streamlines = ArraySequence(data['strs']) + streamlines = ArraySequence(data["strs"]) streamlines._data = streamlines._data - for key in data['dps']: - shape = (len(streamlines), len(data['dps'][key]) // len(streamlines)) - data['dps'][key] = np.array(data['dps'][key]).reshape(shape) + for key in data["dps"]: + shape = (len(streamlines), len(data["dps"][key]) // len(streamlines)) + data["dps"][key] = np.array(data["dps"][key]).reshape(shape) - for key in data['dpv']: - shape = (len(streamlines._data), len( - data['dpv'][key]) // len(streamlines._data)) - data['dpv'][key] = np.array(data['dpv'][key]).reshape(shape) + for key in data["dpv"]: + shape = ( + len(streamlines._data), + len(data["dpv"][key]) // len(streamlines._data), + ) + data["dpv"][key] = np.array(data["dpv"][key]).reshape(shape) tmp_arr = ArraySequence() - tmp_arr._data = data['dpv'][key] + tmp_arr._data = data["dpv"][key] tmp_arr._offsets = streamlines._offsets tmp_arr._lengths = streamlines._lengths - data['dpv'][key] = tmp_arr + data["dpv"][key] = tmp_arr - obj = Tractogram(streamlines, data_per_point=data['dpv'], - data_per_streamline=data['dps']) + obj = Tractogram( + streamlines, data_per_point=data["dpv"], data_per_streamline=data["dps"] + ) return obj def append_generator_to_dict(gen, data): if isinstance(gen, TractogramItem): - data['strs'].append(gen.streamline.tolist()) + data["strs"].append(gen.streamline.tolist()) for key in gen.data_for_points: - if key not in data['dpv']: - data['dpv'][key] = np.array([]) - data['dpv'][key] = np.append( - data['dpv'][key], gen.data_for_points[key]) + if key not in data["dpv"]: + data["dpv"][key] = np.array([]) + data["dpv"][key] = np.append(data["dpv"][key], gen.data_for_points[key]) for key in gen.data_for_streamline: - if key not in data['dps']: - data['dps'][key] = np.array([]) - data['dps'][key] = np.append( - data['dps'][key], gen.data_for_streamline[key]) + if key not in data["dps"]: + data["dps"][key] = np.array([]) + data["dps"][key] = np.append(data["dps"][key], gen.data_for_streamline[key]) else: - data['strs'].append(gen.tolist()) + data["strs"].append(gen.tolist()) def verify_trx_dtype(trx, dict_dtype): # noqa: C901 - """ Verify if the dtype of the data in the trx is the same as the one in + """Verify if the dtype of the data in the trx is the same as the one in the dict. Parameters @@ -424,27 +440,29 @@ def verify_trx_dtype(trx, dict_dtype): # noqa: C901 """ identical = True for key in dict_dtype: - if key == 'positions': + if key == "positions": if trx.streamlines._data.dtype != dict_dtype[key]: - logging.warning('Positions dtype is different') + logging.warning("Positions dtype is different") identical = False - elif key == 'offsets': + elif key == "offsets": if trx.streamlines._offsets.dtype != dict_dtype[key]: - logging.warning('Offsets dtype is different') + logging.warning("Offsets dtype is different") identical = False - elif key == 'dpv': + elif key == "dpv": for key_dpv in dict_dtype[key]: if trx.data_per_vertex[key_dpv]._data.dtype != dict_dtype[key][key_dpv]: logging.warning( - 'Data per vertex ({}) dtype is different'.format(key_dpv)) + "Data per vertex ({}) dtype is different".format(key_dpv) + ) identical = False - elif key == 'dps': + elif key == "dps": for key_dps in dict_dtype[key]: if trx.data_per_streamline[key_dps].dtype != dict_dtype[key][key_dps]: logging.warning( - 'Data per streamline ({}) dtype is different'.format(key_dps)) + "Data per streamline ({}) dtype is different".format(key_dps) + ) identical = False - elif key == 'dpg': + elif key == "dpg": for key_group in dict_dtype[key]: for key_dpg in dict_dtype[key][key_group]: if ( @@ -452,16 +470,18 @@ def verify_trx_dtype(trx, dict_dtype): # noqa: C901 != dict_dtype[key][key_group][key_dpg] ): logging.warning( - 'Data per group ({}) dtype is different'.format(key_dpg)) + "Data per group ({}) dtype is different".format(key_dpg) + ) identical = False - elif key == 'groups': + elif key == "groups": for key_group in dict_dtype[key]: if ( trx.data_per_point[key_group]._data.dtype != dict_dtype[key][key_group] ): logging.warning( - 'Data per group ({}) dtype is different'.format(key_group)) + "Data per group ({}) dtype is different".format(key_group) + ) identical = False return identical diff --git a/trx/viz.py b/trx/viz.py index d5afb9b..c995a84 100644 --- a/trx/viz.py +++ b/trx/viz.py @@ -4,36 +4,52 @@ import logging import numpy as np + try: - from dipy.viz import window, actor, colormap - from fury.utils import get_bounds + from dipy.viz import actor, colormap, window import fury.utils as ut_vtk + from fury.utils import get_bounds import vtk + fury_available = True except ImportError: fury_available = False -def display(volume, volume_affine=None, streamlines=None, title='FURY', - display_bounds=True): +def display( + volume, volume_affine=None, streamlines=None, title="FURY", display_bounds=True +): if not fury_available: - logging.error('Fury library is missing, visualization functions ' - 'are not available.') + logging.error( + "Fury library is missing, visualization functions are not available." + ) return None volume = volume.astype(float) scene = window.Scene() - scene.background((1., 0.5, 0.)) + scene.background((1.0, 0.5, 0.0)) # Show the X/Y/Z plane intersecting, mid-slices - slicer_actor_1 = actor.slicer(volume, affine=volume_affine, - value_range=(volume.min(), volume.max()), - interpolation='nearest', opacity=0.8) - slicer_actor_2 = actor.slicer(volume, affine=volume_affine, - value_range=(volume.min(), volume.max()), - interpolation='nearest', opacity=0.8) - slicer_actor_3 = actor.slicer(volume, affine=volume_affine, - value_range=(volume.min(), volume.max()), - interpolation='nearest', opacity=0.8) + slicer_actor_1 = actor.slicer( + volume, + affine=volume_affine, + value_range=(volume.min(), volume.max()), + interpolation="nearest", + opacity=0.8, + ) + slicer_actor_2 = actor.slicer( + volume, + affine=volume_affine, + value_range=(volume.min(), volume.max()), + interpolation="nearest", + opacity=0.8, + ) + slicer_actor_3 = actor.slicer( + volume, + affine=volume_affine, + value_range=(volume.min(), volume.max()), + interpolation="nearest", + opacity=0.8, + ) slicer_actor_1.display(y=volume.shape[1] // 2) slicer_actor_2.display(x=volume.shape[0] // 2) slicer_actor_3.display(z=volume.shape[2] // 2) @@ -55,33 +71,40 @@ def display(volume, volume_affine=None, streamlines=None, title='FURY', # Show each corner's coordinates corners = itertools.product(bounds[0:2], bounds[2:4], bounds[4:6]) for corner in corners: - text_actor = actor.text_3d('{}, {}, {}'.format( - *corner), corner, font_size=6, justification='center') + text_actor = actor.text_3d( + "{}, {}, {}".format(*corner), + corner, + font_size=6, + justification="center", + ) scene.add(text_actor) # Show the X/Y/Z dimensions - text_actor_x = actor.text_3d('{}'.format(np.abs(bounds[0]-bounds[1])), - ((bounds[0]+bounds[1])/2, - bounds[2], - bounds[4]), - font_size=10, justification='center') - text_actor_y = actor.text_3d('{}'.format(np.abs(bounds[2]-bounds[3])), - (bounds[0], - (bounds[2]+bounds[3])/2, - bounds[4]), - font_size=10, justification='center') - text_actor_z = actor.text_3d('{}'.format(np.abs(bounds[4]-bounds[5])), - (bounds[0], - bounds[2], - (bounds[4]+bounds[5])/2), - font_size=10, justification='center') + text_actor_x = actor.text_3d( + "{}".format(np.abs(bounds[0] - bounds[1])), + ((bounds[0] + bounds[1]) / 2, bounds[2], bounds[4]), + font_size=10, + justification="center", + ) + text_actor_y = actor.text_3d( + "{}".format(np.abs(bounds[2] - bounds[3])), + (bounds[0], (bounds[2] + bounds[3]) / 2, bounds[4]), + font_size=10, + justification="center", + ) + text_actor_z = actor.text_3d( + "{}".format(np.abs(bounds[4] - bounds[5])), + (bounds[0], bounds[2], (bounds[4] + bounds[5]) / 2), + font_size=10, + justification="center", + ) scene.add(text_actor_x) scene.add(text_actor_y) scene.add(text_actor_z) if streamlines is not None: - streamlines_actor = actor.line(streamlines, - colormap.line_colors(streamlines), - opacity=0.25) + streamlines_actor = actor.line( + streamlines, colormap.line_colors(streamlines), opacity=0.25 + ) scene.add(streamlines_actor) window.show(scene, title=title, size=(800, 800)) diff --git a/trx/workflows.py b/trx/workflows.py index a1782a2..6eb7921 100644 --- a/trx/workflows.py +++ b/trx/workflows.py @@ -11,54 +11,66 @@ import nibabel as nib from nibabel.streamlines.array_sequence import ArraySequence import numpy as np + try: import dipy # noqa: F401 + dipy_available = True except ImportError: dipy_available = False from trx.io import get_trx_tmp_dir, load, load_sft_with_reference, save -from trx.streamlines_ops import perform_streamlines_operation, intersection +from trx.streamlines_ops import intersection, perform_streamlines_operation import trx.trx_file_memmap as tmm +from trx.utils import ( + flip_sft, + get_axis_shift_vector, + get_reference_info_wrapper, + get_reverse_enum, + is_header_compatible, + load_matrix_in_any_format, + split_name_with_gz, +) from trx.viz import display -from trx.utils import (flip_sft, is_header_compatible, - get_axis_shift_vector, - get_reference_info_wrapper, - get_reverse_enum, - load_matrix_in_any_format, - split_name_with_gz) -def convert_dsi_studio(in_dsi_tractogram, in_dsi_fa, out_tractogram, - remove_invalid=True, keep_invalid=False): +def convert_dsi_studio( + in_dsi_tractogram, + in_dsi_fa, + out_tractogram, + remove_invalid=True, + keep_invalid=False, +): if not dipy_available: - logging.error('Dipy library is missing, scripts are not available.') + logging.error("Dipy library is missing, scripts are not available.") return None - from dipy.io.stateful_tractogram import StatefulTractogram, Space - from dipy.io.streamline import save_tractogram, load_tractogram + from dipy.io.stateful_tractogram import Space, StatefulTractogram + from dipy.io.streamline import load_tractogram, save_tractogram in_ext = split_name_with_gz(in_dsi_tractogram)[1] out_ext = split_name_with_gz(out_tractogram)[1] - if in_ext == '.trk.gz': - with gzip.open(in_dsi_tractogram, 'rb') as f_in: - with open('tmp.trk', 'wb') as f_out: + if in_ext == ".trk.gz": + with gzip.open(in_dsi_tractogram, "rb") as f_in: + with open("tmp.trk", "wb") as f_out: f_out.writelines(f_in) - sft = load_tractogram('tmp.trk', 'same', - bbox_valid_check=False) - os.remove('tmp.trk') - elif in_ext == '.trk': - sft = load_tractogram(in_dsi_tractogram, 'same', - bbox_valid_check=False) + sft = load_tractogram("tmp.trk", "same", bbox_valid_check=False) + os.remove("tmp.trk") + elif in_ext == ".trk": + sft = load_tractogram(in_dsi_tractogram, "same", bbox_valid_check=False) else: - raise IOError('{} is not currently supported.'.format(in_ext)) + raise IOError("{} is not currently supported.".format(in_ext)) sft.to_vox() - sft_fix = StatefulTractogram(sft.streamlines, in_dsi_fa, Space.VOXMM, - data_per_point=sft.data_per_point, - data_per_streamline=sft.data_per_streamline) + sft_fix = StatefulTractogram( + sft.streamlines, + in_dsi_fa, + Space.VOXMM, + data_per_point=sft.data_per_point, + data_per_streamline=sft.data_per_streamline, + ) sft_fix.to_vox() - flip_axis = ['x', 'y'] + flip_axis = ["x", "y"] sft_fix.streamlines._data -= get_axis_shift_vector(flip_axis) sft_flip = flip_sft(sft_fix, flip_axis) @@ -68,20 +80,22 @@ def convert_dsi_studio(in_dsi_tractogram, in_dsi_fa, out_tractogram, if remove_invalid: sft_flip.remove_invalid_streamlines() - if out_ext != '.trx': - save_tractogram(sft_flip, out_tractogram, - bbox_valid_check=not keep_invalid) + if out_ext != ".trx": + save_tractogram(sft_flip, out_tractogram, bbox_valid_check=not keep_invalid) else: trx = tmm.TrxFile.from_sft(sft_flip) tmm.save(trx, out_tractogram) def convert_tractogram( # noqa: C901 - in_tractogram, out_tractogram, reference, - pos_dtype='float32', offsets_dtype='uint32', + in_tractogram, + out_tractogram, + reference, + pos_dtype="float32", + offsets_dtype="uint32", ): if not dipy_available: - logging.error('Dipy library is missing, scripts are not available.') + logging.error("Dipy library is missing, scripts are not available.") return None from dipy.io.streamline import save_tractogram @@ -89,40 +103,39 @@ def convert_tractogram( # noqa: C901 out_ext = split_name_with_gz(out_tractogram)[1] if in_ext == out_ext: - raise IOError('Input and output cannot be of the same file format.') + raise IOError("Input and output cannot be of the same file format.") - if in_ext != '.trx': - sft = load_sft_with_reference(in_tractogram, reference, - bbox_check=False) + if in_ext != ".trx": + sft = load_sft_with_reference(in_tractogram, reference, bbox_check=False) else: trx = tmm.load(in_tractogram) sft = trx.to_sft() trx.close() - if out_ext != '.trx': - if out_ext == '.vtk': + if out_ext != ".trx": + if out_ext == ".vtk": if sft.streamlines._data.dtype.name != pos_dtype: sft.streamlines._data = sft.streamlines._data.astype(pos_dtype) - if offsets_dtype == 'uint64' or offsets_dtype == 'uint32': + if offsets_dtype == "uint64" or offsets_dtype == "uint32": offsets_dtype = offsets_dtype[1:] if sft.streamlines._offsets.dtype.name != offsets_dtype: sft.streamlines._offsets = sft.streamlines._offsets.astype( - offsets_dtype) + offsets_dtype + ) save_tractogram(sft, out_tractogram, bbox_valid_check=False) else: trx = tmm.TrxFile.from_sft(sft) if trx.streamlines._data.dtype.name != pos_dtype: trx.streamlines._data = trx.streamlines._data.astype(pos_dtype) if trx.streamlines._offsets.dtype.name != offsets_dtype: - trx.streamlines._offsets = trx.streamlines._offsets.astype( - offsets_dtype) + trx.streamlines._offsets = trx.streamlines._offsets.astype(offsets_dtype) tmm.save(trx, out_tractogram) trx.close() def tractogram_simple_compare(in_tractograms, reference): if not dipy_available: - logging.error('Dipy library is missing, scripts are not available.') + logging.error("Dipy library is missing, scripts are not available.") return from dipy.io.stateful_tractogram import StatefulTractogram @@ -140,56 +153,62 @@ def tractogram_simple_compare(in_tractograms, reference): else: sft_2 = tractogram_obj - if np.allclose(sft_1.streamlines._data, sft_2.streamlines._data, - atol=0.001): - print('Matching tractograms in rasmm!') + if np.allclose(sft_1.streamlines._data, sft_2.streamlines._data, atol=0.001): + print("Matching tractograms in rasmm!") else: - print('Average difference in rasmm of {}'.format(np.average( - sft_1.streamlines._data - sft_2.streamlines._data, axis=0))) + print( + "Average difference in rasmm of {}".format( + np.average(sft_1.streamlines._data - sft_2.streamlines._data, axis=0) + ) + ) sft_1.to_voxmm() sft_2.to_voxmm() - if np.allclose(sft_1.streamlines._data, sft_2.streamlines._data, - atol=0.001): - print('Matching tractograms in voxmm!') + if np.allclose(sft_1.streamlines._data, sft_2.streamlines._data, atol=0.001): + print("Matching tractograms in voxmm!") else: - print('Average difference in voxmm of {}'.format(np.average( - sft_1.streamlines._data - sft_2.streamlines._data, axis=0))) + print( + "Average difference in voxmm of {}".format( + np.average(sft_1.streamlines._data - sft_2.streamlines._data, axis=0) + ) + ) sft_1.to_vox() sft_2.to_vox() - if np.allclose(sft_1.streamlines._data, sft_2.streamlines._data, - atol=0.001): - print('Matching tractograms in vox!') + if np.allclose(sft_1.streamlines._data, sft_2.streamlines._data, atol=0.001): + print("Matching tractograms in vox!") else: - print('Average difference in vox of {}'.format(np.average( - sft_1.streamlines._data - sft_2.streamlines._data, axis=0))) + print( + "Average difference in vox of {}".format( + np.average(sft_1.streamlines._data - sft_2.streamlines._data, axis=0) + ) + ) def verify_header_compatibility(in_files): if not dipy_available: - logging.error('Dipy library is missing, scripts are not available.') + logging.error("Dipy library is missing, scripts are not available.") return all_valid = True for filepath in in_files: if not os.path.isfile(filepath): - print('{} does not exist'.format(filepath)) + print("{} does not exist".format(filepath)) _, in_extension = split_name_with_gz(filepath) - if in_extension not in ['.trk', '.nii', '.nii.gz', '.trx']: - raise IOError('{} does not have a supported extension'.format( - filepath)) + if in_extension not in [".trk", ".nii", ".nii.gz", ".trx"]: + raise IOError("{} does not have a supported extension".format(filepath)) if not is_header_compatible(in_files[0], filepath): - print('{} and {} do not have compatible header.'.format( - in_files[0], filepath)) + print( + "{} and {} do not have compatible header.".format(in_files[0], filepath) + ) all_valid = False if all_valid: - print('All input files have compatible headers.') + print("All input files have compatible headers.") def tractogram_visualize_overlap(in_tractogram, reference, remove_invalid=True): if not dipy_available: - logging.error('Dipy library is missing, scripts are not available.') + logging.error("Dipy library is missing, scripts are not available.") return None from dipy.io.stateful_tractogram import StatefulTractogram from dipy.tracking.streamline import set_number_of_points @@ -212,8 +231,12 @@ def tractogram_visualize_overlap(in_tractogram, reference, remove_invalid=True): # Approach (1) density_1 = density_map(sft.streamlines, sft.affine, sft.dimensions) img = nib.load(reference) - display(img.get_fdata(), volume_affine=img.affine, - streamlines=sft.streamlines, title='RASMM') + display( + img.get_fdata(), + volume_affine=img.affine, + streamlines=sft.streamlines, + title="RASMM", + ) # Approach (2) sft.to_vox() @@ -221,25 +244,36 @@ def tractogram_visualize_overlap(in_tractogram, reference, remove_invalid=True): # Small difference due to casting of the affine as float32 or float64 diff = density_1 - density_2 - print('Total difference of {} voxels with total value of {}'.format( - np.count_nonzero(diff), np.sum(np.abs(diff)))) + print( + "Total difference of {} voxels with total value of {}".format( + np.count_nonzero(diff), np.sum(np.abs(diff)) + ) + ) - display(img.get_fdata(), streamlines=sft.streamlines, title='VOX') + display(img.get_fdata(), streamlines=sft.streamlines, title="VOX") # Try VOXMM sft.to_voxmm() affine = np.eye(4) affine[0:3, 0:3] *= sft.voxel_sizes - display(img.get_fdata(), volume_affine=affine, - streamlines=sft.streamlines, title='VOXMM') + display( + img.get_fdata(), + volume_affine=affine, + streamlines=sft.streamlines, + title="VOXMM", + ) -def validate_tractogram(in_tractogram, reference, out_tractogram, - remove_identical_streamlines=True, precision=1): - +def validate_tractogram( + in_tractogram, + reference, + out_tractogram, + remove_identical_streamlines=True, + precision=1, +): if not dipy_available: - logging.error('Dipy library is missing, scripts are not available.') + logging.error("Dipy library is missing, scripts are not available.") return None from dipy.io.stateful_tractogram import StatefulTractogram @@ -257,34 +291,42 @@ def validate_tractogram(in_tractogram, reference, out_tractogram, invalid_coord_ind, _ = sft.remove_invalid_streamlines() tot_remove += len(invalid_coord_ind) - logging.warning('Removed {} streamlines with invalid coordinates.'.format( - len(invalid_coord_ind))) + logging.warning( + "Removed {} streamlines with invalid coordinates.".format( + len(invalid_coord_ind) + ) + ) indices = [i for i in range(len(sft)) if len(sft.streamlines[i]) <= 1] - tot_remove = + len(indices) - logging.warning('Removed {} invalid streamlines (1 or 0 points).'.format( - len(indices))) + tot_remove = +len(indices) + logging.warning( + "Removed {} invalid streamlines (1 or 0 points).".format(len(indices)) + ) for i in np.setdiff1d(range(len(sft)), indices): - norm = np.linalg.norm(np.diff(sft.streamlines[i], - axis=0), axis=1) + norm = np.linalg.norm(np.diff(sft.streamlines[i], axis=0), axis=1) if (norm < 0.001).any(): indices.append(i) indices_val = np.setdiff1d(range(len(sft)), indices).astype(np.uint32) - logging.warning('Removed {} invalid streamlines (overlapping points).'.format( - ori_len - len(indices_val))) + logging.warning( + "Removed {} invalid streamlines (overlapping points).".format( + ori_len - len(indices_val) + ) + ) tot_remove += ori_len - len(indices_val) if remove_identical_streamlines: - _, indices_uniq = perform_streamlines_operation(intersection, - [sft.streamlines], - precision=precision) - indices_final = np.intersect1d( - indices_val, indices_uniq).astype(np.uint32) - logging.warning('Removed {} overlapping streamlines.'.format( - ori_len - len(indices_final) - tot_remove)) + _, indices_uniq = perform_streamlines_operation( + intersection, [sft.streamlines], precision=precision + ) + indices_final = np.intersect1d(indices_val, indices_uniq).astype(np.uint32) + logging.warning( + "Removed {} overlapping streamlines.".format( + ori_len - len(indices_final) - tot_remove + ) + ) indices_final = np.intersect1d(indices_val, indices_uniq) else: @@ -299,20 +341,19 @@ def validate_tractogram(in_tractogram, reference, out_tractogram, dps = {} for key in sft.data_per_streamline.keys(): dps[key] = sft.data_per_streamline[key][indices_final] - new_sft = StatefulTractogram.from_sft(streamlines, sft, - data_per_point=dpp, - data_per_streamline=dps) + new_sft = StatefulTractogram.from_sft( + streamlines, sft, data_per_point=dpp, data_per_streamline=dps + ) new_sft.dtype_dict = ori_dtype save(new_sft, out_tractogram) def _load_streamlines_from_csv(positions_csv): """Load streamlines from CSV file.""" - with open(positions_csv, newline='') as f: + with open(positions_csv, newline="") as f: reader = csv.reader(f) data = list(reader) - data = [np.reshape(i, (len(i) // 3, 3)).astype(float) - for i in data] + data = [np.reshape(i, (len(i) // 3, 3)).astype(float) for i in data] return ArraySequence(data) @@ -328,13 +369,16 @@ def _load_streamlines_from_arrays(positions, offsets): return streamlines, offsets -def _apply_spatial_transforms(streamlines, reference, space_str, origin_str, - verify_invalid, offsets): +def _apply_spatial_transforms( + streamlines, reference, space_str, origin_str, verify_invalid, offsets +): """Apply spatial transforms and verify streamlines.""" if not dipy_available: - logging.error('Dipy library is missing, advanced options ' - 'related to spatial transforms and invalid ' - 'streamlines are not available.') + logging.error( + "Dipy library is missing, advanced options " + "related to spatial transforms and invalid " + "streamlines are not available." + ) return None from dipy.io.stateful_tractogram import StatefulTractogram @@ -343,8 +387,9 @@ def _apply_spatial_transforms(streamlines, reference, space_str, origin_str, sft = StatefulTractogram(streamlines, reference, space, origin) if verify_invalid: rem, _ = sft.remove_invalid_streamlines() - print('{} streamlines were removed becaused they were ' - 'invalid.'.format(len(rem))) + print( + "{} streamlines were removed becaused they were invalid.".format(len(rem)) + ) sft.to_rasmm() sft.to_center() streamlines = sft.streamlines @@ -359,41 +404,37 @@ def _write_header(tmp_dir_name, reference, streamlines): "DIMENSIONS": dimensions.tolist(), "VOXEL_TO_RASMM": affine.tolist(), "NB_VERTICES": len(streamlines._data), - "NB_STREAMLINES": len(streamlines)-1, + "NB_STREAMLINES": len(streamlines) - 1, } - if header['NB_STREAMLINES'] <= 1: - raise IOError('To use this script, you need at least 2' - 'streamlines.') + if header["NB_STREAMLINES"] <= 1: + raise IOError("To use this script, you need at least 2streamlines.") with open(os.path.join(tmp_dir_name, "header.json"), "w") as out_json: json.dump(header, out_json) -def _write_streamline_data(tmp_dir_name, streamlines, positions_dtype, - offsets_dtype): +def _write_streamline_data(tmp_dir_name, streamlines, positions_dtype, offsets_dtype): """Write streamline position and offset data.""" - curr_filename = os.path.join(tmp_dir_name, 'positions.3.{}'.format( - positions_dtype)) + curr_filename = os.path.join(tmp_dir_name, "positions.3.{}".format(positions_dtype)) streamlines._data.astype(positions_dtype).tofile(curr_filename) - curr_filename = os.path.join(tmp_dir_name, 'offsets.{}'.format( - offsets_dtype)) + curr_filename = os.path.join(tmp_dir_name, "offsets.{}".format(offsets_dtype)) streamlines._offsets.astype(offsets_dtype).tofile(curr_filename) def _normalize_dtype(dtype_str): """Normalize dtype string format.""" - return 'bit' if dtype_str == 'bool' else dtype_str + return "bit" if dtype_str == "bool" else dtype_str def _write_data_array(tmp_dir_name, subdir_name, args, is_dpg=False): """Write data array to file.""" if is_dpg: - os.makedirs(os.path.join(tmp_dir_name, 'dpg', args[0]), exist_ok=True) + os.makedirs(os.path.join(tmp_dir_name, "dpg", args[0]), exist_ok=True) curr_arr = load_matrix_in_any_format(args[1]).astype(args[2]) basename = os.path.basename(os.path.splitext(args[1])[0]) - dtype_str = _normalize_dtype(args[1]) if args[1] != 'bool' else 'bit' + dtype_str = _normalize_dtype(args[1]) if args[1] != "bool" else "bit" dtype = args[2] else: os.makedirs(os.path.join(tmp_dir_name, subdir_name), exist_ok=True) @@ -403,32 +444,51 @@ def _write_data_array(tmp_dir_name, subdir_name, args, is_dpg=False): dtype = dtype_str if curr_arr.ndim > 2: - raise IOError('Maximum of 2 dimensions for dpv/dps/dpg.') + raise IOError("Maximum of 2 dimensions for dpv/dps/dpg.") if curr_arr.shape == (1, 1): curr_arr = curr_arr.reshape((1,)) - dim = '' if curr_arr.ndim == 1 else '{}.'.format(curr_arr.shape[-1]) + dim = "" if curr_arr.ndim == 1 else "{}.".format(curr_arr.shape[-1]) if is_dpg: - curr_filename = os.path.join(tmp_dir_name, 'dpg', args[0], - '{}.{}{}'.format(basename, dim, dtype)) + curr_filename = os.path.join( + tmp_dir_name, "dpg", args[0], "{}.{}{}".format(basename, dim, dtype) + ) else: - curr_filename = os.path.join(tmp_dir_name, subdir_name, - '{}.{}{}'.format(basename, dim, dtype)) + curr_filename = os.path.join( + tmp_dir_name, subdir_name, "{}.{}{}".format(basename, dim, dtype) + ) curr_arr.tofile(curr_filename) def generate_trx_from_scratch( # noqa: C901 - reference, out_tractogram, positions_csv=False, - positions=False, offsets=False, - positions_dtype='float32', offsets_dtype='uint64', - space_str='rasmm', origin_str='nifti', - verify_invalid=True, dpv=[], dps=[], - groups=[], dpg=[], + reference, + out_tractogram, + positions_csv=False, + positions=False, + offsets=False, + positions_dtype="float32", + offsets_dtype="uint64", + space_str="rasmm", + origin_str="nifti", + verify_invalid=True, + dpv=None, + dps=None, + groups=None, + dpg=None, ): """Generate TRX file from scratch using various input formats.""" + if dpv is None: + dpv = [] + if dps is None: + dps = [] + if groups is None: + groups = [] + if dpg is None: + dpg = [] + with get_trx_tmp_dir() as tmp_dir_name: if positions_csv: streamlines = _load_streamlines_from_csv(positions_csv) @@ -436,34 +496,37 @@ def generate_trx_from_scratch( # noqa: C901 else: streamlines, offsets = _load_streamlines_from_arrays(positions, offsets) - if (space_str.lower() != 'rasmm' or origin_str.lower() != 'nifti' or - verify_invalid): + if ( + space_str.lower() != "rasmm" + or origin_str.lower() != "nifti" + or verify_invalid + ): streamlines = _apply_spatial_transforms( - streamlines, reference, space_str, origin_str, - verify_invalid, offsets + streamlines, reference, space_str, origin_str, verify_invalid, offsets ) if streamlines is None: return _write_header(tmp_dir_name, reference, streamlines) - _write_streamline_data(tmp_dir_name, streamlines, positions_dtype, - offsets_dtype) + _write_streamline_data( + tmp_dir_name, streamlines, positions_dtype, offsets_dtype + ) if dpv: for arg in dpv: - _write_data_array(tmp_dir_name, 'dpv', arg) + _write_data_array(tmp_dir_name, "dpv", arg) if dps: for arg in dps: - _write_data_array(tmp_dir_name, 'dps', arg) + _write_data_array(tmp_dir_name, "dps", arg) if groups: for arg in groups: - _write_data_array(tmp_dir_name, 'groups', arg) + _write_data_array(tmp_dir_name, "groups", arg) if dpg: for arg in dpg: - _write_data_array(tmp_dir_name, 'dpg', arg, is_dpg=True) + _write_data_array(tmp_dir_name, "dpg", arg, is_dpg=True) trx = tmm.load(tmp_dir_name) tmm.save(trx, out_tractogram) @@ -476,53 +539,63 @@ def manipulate_trx_datatype(in_filename, out_filename, dict_dtype): # noqa: C90 # For each key in dict_dtype, we create a new memmap with the new dtype # and we copy the data from the old memmap to the new one. for key in dict_dtype: - if key == 'positions': - tmp_mm = np.memmap(tempfile.NamedTemporaryFile(), - dtype=dict_dtype[key], - mode='w+', - shape=trx.streamlines._data.shape) + if key == "positions": + tmp_mm = np.memmap( + tempfile.NamedTemporaryFile(), + dtype=dict_dtype[key], + mode="w+", + shape=trx.streamlines._data.shape, + ) tmp_mm[:] = trx.streamlines._data[:] trx.streamlines._data = tmp_mm - elif key == 'offsets': - tmp_mm = np.memmap(tempfile.NamedTemporaryFile(), - dtype=dict_dtype[key], - mode='w+', - shape=trx.streamlines._offsets.shape) + elif key == "offsets": + tmp_mm = np.memmap( + tempfile.NamedTemporaryFile(), + dtype=dict_dtype[key], + mode="w+", + shape=trx.streamlines._offsets.shape, + ) tmp_mm[:] = trx.streamlines._offsets[:] trx.streamlines._offsets = tmp_mm - elif key == 'dpv': + elif key == "dpv": for key_dpv in dict_dtype[key]: - tmp_mm = np.memmap(tempfile.NamedTemporaryFile(), - dtype=dict_dtype[key][key_dpv], - mode='w+', - shape=trx.data_per_vertex[key_dpv]._data.shape) + tmp_mm = np.memmap( + tempfile.NamedTemporaryFile(), + dtype=dict_dtype[key][key_dpv], + mode="w+", + shape=trx.data_per_vertex[key_dpv]._data.shape, + ) tmp_mm[:] = trx.data_per_vertex[key_dpv]._data[:] trx.data_per_vertex[key_dpv]._data = tmp_mm - elif key == 'dps': + elif key == "dps": for key_dps in dict_dtype[key]: - tmp_mm = np.memmap(tempfile.NamedTemporaryFile(), - dtype=dict_dtype[key][key_dps], - mode='w+', - shape=trx.data_per_streamline[key_dps].shape) + tmp_mm = np.memmap( + tempfile.NamedTemporaryFile(), + dtype=dict_dtype[key][key_dps], + mode="w+", + shape=trx.data_per_streamline[key_dps].shape, + ) tmp_mm[:] = trx.data_per_streamline[key_dps][:] trx.data_per_streamline[key_dps] = tmp_mm - elif key == 'dpg': + elif key == "dpg": for key_group in dict_dtype[key]: for key_dpg in dict_dtype[key][key_group]: tmp_mm = np.memmap( tempfile.NamedTemporaryFile(), dtype=dict_dtype[key][key_group][key_dpg], - mode='w+', + mode="w+", shape=trx.data_per_group[key_group][key_dpg].shape, ) tmp_mm[:] = trx.data_per_group[key_group][key_dpg][:] trx.data_per_group[key_group][key_dpg] = tmp_mm - elif key == 'groups': + elif key == "groups": for key_group in dict_dtype[key]: - tmp_mm = np.memmap(tempfile.NamedTemporaryFile(), - dtype=dict_dtype[key][key_group], - mode='w+', - shape=trx.groups[key_group].shape) + tmp_mm = np.memmap( + tempfile.NamedTemporaryFile(), + dtype=dict_dtype[key][key_group], + mode="w+", + shape=trx.groups[key_group].shape, + ) tmp_mm[:] = trx.groups[key_group][:] trx.groups[key_group] = tmp_mm