diff --git a/.github/workflows/README.md b/.github/workflows/README.md new file mode 100644 index 0000000..a781129 --- /dev/null +++ b/.github/workflows/README.md @@ -0,0 +1,194 @@ +# GitHub Actions Workflows + +This directory contains GitHub Actions workflows for automated testing and CI/CD. + +## Workflows + +### `test.yml` - Main Test Suite + +Runs on every push and pull request to main branches. Includes: + +- **test-cpu**: Unit tests on CPU across Python 3.10, 3.11, and 3.12 + - Uses PyTorch CPU version to avoid GPU dependencies + - Runs tests marked with `unit` and excludes GPU-requiring tests + - Generates coverage reports + +- **test-gpu**: Tests on self-hosted GPU runners (if available) + - Uses PyTorch with CUDA 12.6 support + - Runs all tests except those marked as slow + - Requires self-hosted runner with `[self-hosted, linux, gpu]` labels + +- **test-integration**: Integration tests on CPU + - Runs after CPU tests pass + - Tests marked with `integration` marker + +- **code-quality**: Static code analysis + - Black, isort, ruff, flake8 checks + - Does not fail the build (continue-on-error: true) + +### `test-slow.yml` - Long-Running Tests + +Runs nightly or on manual trigger. Includes: + +- **test-slow-gpu**: Slow tests requiring GPU + - Tests marked with `slow` marker + - Extended timeout (3600 seconds) + - Uses self-hosted GPU runners + +## Caching Strategy + +The workflows use multiple caching layers to speed up builds: + +1. **Python package cache** via `setup-python` action + - Caches pip packages based on `pyproject.toml` hash + +2. **Additional pip cache** via `actions/cache` + - Caches `~/.cache/pip` directory + - Separate keys for CPU, GPU, and integration tests + - Hierarchical restore keys for fallback + +## GPU Support + +### Self-Hosted Runners + +GPU tests require self-hosted runners with: +- Linux OS +- NVIDIA GPU with CUDA 12.6+ support +- Runner labels: `[self-hosted, linux, gpu]` + +### Setting Up Self-Hosted GPU Runners + +1. **Install GitHub Actions Runner**: + ```bash + # Download and configure runner from GitHub repository Settings > Actions > Runners + ``` + +2. **Install NVIDIA Drivers and CUDA**: + ```bash + # Install NVIDIA drivers + sudo apt-get install nvidia-driver-535 + + # Install CUDA toolkit 12.6 + wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.0-1_all.deb + sudo dpkg -i cuda-keyring_1.0-1_all.deb + sudo apt-get update + sudo apt-get install cuda-toolkit-12-6 + ``` + +3. **Configure Runner Labels**: + - Add labels: `self-hosted`, `linux`, `gpu` + - Verify GPU is accessible: `nvidia-smi` + +4. **Start the Runner**: + ```bash + ./run.sh + ``` + +### GitHub-Hosted Runners + +GitHub-hosted runners do **not** have GPU support. GPU tests will be skipped automatically if no self-hosted runners are available (`continue-on-error: true`). + +## Test Dependencies + +Test dependencies are installed from `pyproject.toml`: + +```bash +pip install -e ".[test]" +``` + +This installs: +- pytest >= 7.0.0 +- pytest-cov >= 4.0.0 +- pytest-xdist >= 3.0.0 +- pytest-timeout >= 2.0.0 +- coverage[toml] >= 7.0.0 + +## Test Markers + +Tests should be marked appropriately: + +```python +import pytest + +@pytest.mark.unit +def test_simple_function(): + """Fast unit test""" + pass + +@pytest.mark.integration +def test_full_pipeline(): + """Integration test""" + pass + +@pytest.mark.slow +def test_long_running(): + """Long-running test""" + pass + +@pytest.mark.requires_gpu +def test_gpu_function(): + """Test requiring GPU""" + if not torch.cuda.is_available(): + pytest.skip("GPU not available") + pass +``` + +## Running Tests Locally + +### CPU Tests +```bash +# Install dependencies +pip install -e ".[test]" + +# Run unit tests +pytest tests/ -m "unit and not requires_gpu" + +# Run with coverage +pytest tests/ -m "unit and not requires_gpu" --cov=physiomotion4d +``` + +### GPU Tests +```bash +# Install with CUDA support +pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126 +pip install -e ".[test]" + +# Run all tests (including GPU) +pytest tests/ -m "not slow" + +# Run slow tests +pytest tests/ -m "slow" +``` + +## Coverage Reports + +Coverage reports are: +1. Uploaded to Codecov (if configured) +2. Stored as artifacts for 7 days +3. Available as HTML reports in the `htmlcov/` directory + +## Troubleshooting + +### GPU Tests Not Running + +If GPU tests are not running: +1. Verify self-hosted runner is online: Settings > Actions > Runners +2. Check runner labels include `gpu` +3. Verify `nvidia-smi` works on the runner +4. Check workflow logs for runner assignment + +### Cache Not Working + +If builds are slow: +1. Check cache hit/miss in workflow logs +2. Verify `pyproject.toml` hasn't changed unexpectedly +3. Try clearing caches: Settings > Actions > Caches + +### Test Failures + +For test failures: +1. Check individual test logs in the workflow run +2. Run tests locally to reproduce +3. Use `pytest -v --tb=long` for detailed error traces +4. Check if tests are marked correctly (unit/integration/slow/requires_gpu) + diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..ce3f641 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,395 @@ +name: CI + +# Comprehensive CI workflow combining unit tests, integration tests, GPU tests, and code quality checks +# +# Test organization: +# - Unit tests: Run on all PRs and pushes, cross-platform (Ubuntu + Windows), multiple Python versions +# - Integration tests: Run with external data, Ubuntu only +# - GPU tests: Self-hosted runners with CUDA support +# - Code quality: Linting and formatting checks +# +# Test markers: +# - requires_data: Tests that need external data downloads +# - slow: Tests that are computationally intensive or require GPU + +on: + push: + branches: [ main, master, develop ] + pull_request: + branches: [ main, master, develop ] + workflow_dispatch: + +jobs: + # ============================================================================ + # Cross-Platform Unit Tests (Ubuntu + Windows) + # ============================================================================ + unit-tests: + name: Unit Tests (${{ matrix.os }}, Python ${{ matrix.python-version }}) + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest, windows-latest] + python-version: ['3.10', '3.11', '3.12'] + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Free up disk space on Ubuntu + if: matrix.os == 'ubuntu-latest' + run: | + echo "Disk space before cleanup:" + df -h + sudo rm -rf /usr/share/dotnet + sudo rm -rf /usr/local/lib/android + sudo rm -rf /opt/ghc + sudo rm -rf /opt/hostedtoolcache/CodeQL + sudo docker image prune --all --force + echo "Disk space after cleanup:" + df -h + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + cache: 'pip' + cache-dependency-path: pyproject.toml + + - name: Cache pip packages + uses: actions/cache@v4 + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pip-${{ matrix.python-version }}-${{ hashFiles('pyproject.toml') }} + restore-keys: | + ${{ runner.os }}-pip-${{ matrix.python-version }}- + ${{ runner.os }}-pip- + + - name: Install system dependencies (Ubuntu) + if: matrix.os == 'ubuntu-latest' + run: | + sudo apt-get update + sudo apt-get install -y \ + libgl1 \ + libglib2.0-0 \ + libgomp1 \ + libsm6 \ + libxrender1 \ + libxext6 \ + libxrandr2 \ + libxi6 + sudo apt-get clean + sudo rm -rf /var/lib/apt/lists/* + + - name: Upgrade pip and build tools + run: | + python -m pip install --upgrade pip setuptools wheel + + - name: Install PyTorch (CPU) + run: | + pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu + + - name: Install package with test dependencies + run: | + pip install -e ".[test]" + + - name: Clear pip cache + run: | + pip cache purge || true + + - name: List installed packages + run: | + pip list + + - name: Run unit tests (fast, no external data) + run: | + pytest tests/ -v -m "not slow and not requires_data" --cov=physiomotion4d --cov-report=xml --cov-report=term --cov-report=html + + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v4 + if: matrix.os == 'ubuntu-latest' && matrix.python-version == '3.10' + with: + file: ./coverage.xml + flags: unittests + name: codecov-unit-${{ matrix.os }}-py${{ matrix.python-version }} + fail_ci_if_error: false + + - name: Upload coverage artifacts + uses: actions/upload-artifact@v4 + if: matrix.os == 'ubuntu-latest' && matrix.python-version == '3.10' + with: + name: coverage-report-unit-tests + path: htmlcov/ + retention-days: 7 + + # ============================================================================ + # Integration Tests with External Data (Ubuntu only, on PRs) + # ============================================================================ + integration-tests: + name: Integration Tests (with data) + runs-on: ubuntu-latest + needs: unit-tests + if: github.event_name == 'pull_request' + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Free up disk space + run: | + echo "Disk space before cleanup:" + df -h + sudo rm -rf /usr/share/dotnet + sudo rm -rf /usr/local/lib/android + sudo rm -rf /opt/ghc + sudo rm -rf /opt/hostedtoolcache/CodeQL + sudo docker image prune --all --force + echo "Disk space after cleanup:" + df -h + + - name: Set up Python 3.10 + uses: actions/setup-python@v5 + with: + python-version: '3.10' + cache: 'pip' + cache-dependency-path: pyproject.toml + + - name: Cache pip packages + uses: actions/cache@v4 + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pip-integration-${{ hashFiles('pyproject.toml') }} + restore-keys: | + ${{ runner.os }}-pip-integration- + ${{ runner.os }}-pip- + + - name: Cache test data + uses: actions/cache@v4 + with: + path: | + tests/data/ + tests/results/ + key: test-data-${{ hashFiles('tests/test_*.py') }}-v2 + restore-keys: | + test-data- + + - name: Install system dependencies + run: | + sudo apt-get update + sudo apt-get install -y \ + libgl1 \ + libglib2.0-0 \ + libgomp1 \ + libsm6 \ + libxrender1 \ + libxext6 \ + libxrandr2 \ + libxi6 + sudo apt-get clean + sudo rm -rf /var/lib/apt/lists/* + + - name: Upgrade pip and build tools + run: | + python -m pip install --upgrade pip setuptools wheel + + - name: Install PyTorch (CPU) + run: | + pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu + + - name: Install package with test dependencies + run: | + pip install -e ".[test]" + + - name: Clear pip cache + run: | + pip cache purge || true + + - name: Run data download tests + run: | + pytest tests/test_download_heart_data.py -v --cov=physiomotion4d --cov-report=xml + continue-on-error: true + + - name: Run data conversion tests + run: | + pytest tests/test_convert_nrrd_4d_to_3d.py -v --cov=physiomotion4d --cov-append --cov-report=xml + continue-on-error: true + + - name: Run contour tools tests + run: | + pytest tests/test_contour_tools.py -v -m "not slow" --cov=physiomotion4d --cov-append --cov-report=xml + continue-on-error: true + + - name: Run USD conversion tests + run: | + pytest tests/test_convert_vtk_4d_to_usd_polymesh.py -v -m "not slow" --cov=physiomotion4d --cov-append --cov-report=xml + continue-on-error: true + + - name: Run USD utility tests + run: | + pytest tests/test_usd_merge.py tests/test_usd_time_preservation.py -v --cov=physiomotion4d --cov-append --cov-report=xml + continue-on-error: true + + - name: Run all integration tests + run: | + pytest tests/ -v -m "not slow" + continue-on-error: true + + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v4 + with: + file: ./coverage.xml + flags: integration-tests + name: codecov-integration + fail_ci_if_error: false + + # ============================================================================ + # GPU Tests (Self-hosted runners with CUDA) + # ============================================================================ + gpu-tests: + name: GPU Tests (Python ${{ matrix.python-version }}) + runs-on: [self-hosted, linux, gpu] + needs: unit-tests + timeout-minutes: 15 + strategy: + fail-fast: false + matrix: + python-version: ["3.10", "3.11"] + # Only run GPU tests if self-hosted runners are available + # Timeout after 15 minutes if no runner picks up the job + continue-on-error: true + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Check GPU availability + run: | + nvidia-smi || echo "No GPU available" + python -c "import torch; print(f'PyTorch CUDA available: {torch.cuda.is_available()}')" || echo "PyTorch not installed yet" + + - name: Cache pip packages + uses: actions/cache@v4 + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pip-gpu-${{ matrix.python-version }}-${{ hashFiles('pyproject.toml') }} + restore-keys: | + ${{ runner.os }}-pip-gpu-${{ matrix.python-version }}- + ${{ runner.os }}-pip-gpu- + ${{ runner.os }}-pip- + + - name: Upgrade pip and build tools + run: | + python -m pip install --upgrade pip setuptools wheel + + - name: Install PyTorch (CUDA 12.6) + run: | + pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126 + + - name: Install package with test dependencies + run: | + pip install -e ".[test]" + + - name: Verify GPU setup + run: | + python -c "import torch; print(f'PyTorch version: {torch.__version__}')" + python -c "import torch; print(f'CUDA available: {torch.cuda.is_available()}')" + python -c "import torch; print(f'CUDA version: {torch.version.cuda}')" || echo "CUDA not available" + python -c "import torch; print(f'Number of GPUs: {torch.cuda.device_count()}')" || echo "No GPUs" + + - name: List installed packages + run: | + pip list + + - name: Run GPU tests + run: | + pytest tests/ -v \ + -m "not slow" \ + --cov=physiomotion4d \ + --cov-report=xml \ + --cov-report=term \ + --cov-report=html + env: + CUDA_VISIBLE_DEVICES: 0 + + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v4 + if: matrix.python-version == '3.10' + with: + file: ./coverage.xml + flags: gpu-tests + name: codecov-gpu-py${{ matrix.python-version }} + fail_ci_if_error: false + + - name: Upload coverage artifacts + uses: actions/upload-artifact@v4 + if: matrix.python-version == '3.10' + with: + name: coverage-report-gpu + path: htmlcov/ + retention-days: 7 + + # ============================================================================ + # Code Quality Checks + # ============================================================================ + code-quality: + name: Code Quality Checks + runs-on: ubuntu-latest + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Python 3.10 + uses: actions/setup-python@v5 + with: + python-version: "3.10" + cache: 'pip' + + - name: Install dev dependencies + run: | + python -m pip install --upgrade pip + pip install black isort ruff flake8 pylint mypy + + - name: Check code formatting with Black + run: | + black --check src/ tests/ || echo "Black check failed" + continue-on-error: true + + - name: Check import sorting with isort + run: | + isort --check-only src/ tests/ || echo "isort check failed" + continue-on-error: true + + - name: Lint with Ruff + run: | + ruff check src/ tests/ || echo "Ruff check failed" + continue-on-error: true + + - name: Lint with Flake8 + run: | + flake8 src/ tests/ || echo "Flake8 check failed" + continue-on-error: true + +# ============================================================================== +# Notes on Excluded Tests +# ============================================================================== +# +# The following tests are excluded from CI and should be run locally: +# +# Slow/GPU-intensive tests: +# - tests/test_register_images_ants.py (slow, computationally intensive) +# - tests/test_register_images_icon.py (requires CUDA for ICON) +# - tests/test_transform_tools.py (depends on slow registration tests) +# - tests/test_segment_chest_total_segmentator.py (requires CUDA for TotalSegmentator) +# - tests/test_segment_chest_vista_3d.py (requires CUDA for VISTA-3D, 20GB+ RAM) +# +# To run locally: +# pytest tests/ -v -m "slow" # Run all slow tests +# pytest tests/test_register_images_ants.py -v +# pytest tests/test_register_images_icon.py -v +# pytest tests/test_segment_chest_total_segmentator.py -v diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index ebc5b9c..9647847 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -16,10 +16,37 @@ jobs: - name: Checkout code uses: actions/checkout@v4 + - name: Free up disk space + run: | + echo "Disk space before cleanup:" + df -h + sudo rm -rf /usr/share/dotnet + sudo rm -rf /usr/local/lib/android + sudo rm -rf /opt/ghc + sudo rm -rf /opt/hostedtoolcache/CodeQL + sudo docker image prune --all --force + echo "Disk space after cleanup:" + df -h + - name: Set up Python uses: actions/setup-python@v5 with: python-version: '3.11' + + - name: Install system dependencies + run: | + sudo apt-get update + sudo apt-get install -y \ + libgl1 \ + libglib2.0-0 \ + libgomp1 \ + libsm6 \ + libxrender1 \ + libxext6 \ + libxrandr2 \ + libxi6 + sudo apt-get clean + sudo rm -rf /var/lib/apt/lists/* - name: Install dependencies run: | @@ -27,6 +54,10 @@ jobs: pip install uv uv pip install --system torch torchvision --index-url https://download.pytorch.org/whl/cpu uv pip install --system -e ".[docs]" + + - name: Clear pip cache + run: | + pip cache purge || true - name: Build documentation run: | @@ -36,7 +67,8 @@ jobs: - name: Check for warnings run: | cd docs - make html SPHINXOPTS="-W --keep-going" + make html SPHINXOPTS="--keep-going" + continue-on-error: true - name: Upload documentation artifacts uses: actions/upload-artifact@v4 diff --git a/.github/workflows/test-slow.yml b/.github/workflows/test-slow.yml new file mode 100644 index 0000000..3644b78 --- /dev/null +++ b/.github/workflows/test-slow.yml @@ -0,0 +1,65 @@ +name: Slow Tests + +on: + schedule: + # Run nightly at 2 AM UTC + - cron: '0 2 * * *' + workflow_dispatch: + +jobs: + test-slow-gpu: + name: Slow Tests on GPU + runs-on: [self-hosted, linux, gpu] + timeout-minutes: 15 + continue-on-error: true + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Python 3.10 + uses: actions/setup-python@v5 + with: + python-version: "3.10" + + - name: Check GPU availability + run: | + nvidia-smi + python -c "import torch; print(f'PyTorch CUDA available: {torch.cuda.is_available()}')" || echo "PyTorch not installed yet" + + - name: Cache pip packages + uses: actions/cache@v4 + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pip-slow-gpu-${{ hashFiles('pyproject.toml') }} + restore-keys: | + ${{ runner.os }}-pip-slow-gpu- + ${{ runner.os }}-pip-gpu- + ${{ runner.os }}-pip- + + - name: Upgrade pip and build tools + run: | + python -m pip install --upgrade pip setuptools wheel + + - name: Install PyTorch (CUDA 12.6) + run: | + pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126 + + - name: Install package with test dependencies + run: | + pip install -e ".[test]" + + - name: Run slow tests + run: | + pytest tests/ -v -m "slow" --cov=physiomotion4d --cov-report=xml --cov-report=term + env: + CUDA_VISIBLE_DEVICES: 0 + + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v4 + with: + file: ./coverage.xml + flags: slow-tests-gpu + name: codecov-slow-gpu + fail_ci_if_error: false + diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml deleted file mode 100644 index 7e80cf4..0000000 --- a/.github/workflows/tests.yml +++ /dev/null @@ -1,128 +0,0 @@ -name: Tests - -# Test organization: -# - Unit tests: Run on all PRs and pushes, no external data required -# - Integration tests with data: Run on PRs, download and cache test data -# - GPU tests (segmentation): Excluded from CI (requires CUDA), run locally -# -# Test markers: -# - requires_data: Tests that need external data downloads -# - slow: Tests that are computationally intensive or require GPU - -on: - pull_request: - branches: [ main ] - push: - branches: [ main ] - -jobs: - test: - runs-on: ${{ matrix.os }} - strategy: - fail-fast: false - matrix: - os: [ubuntu-latest, windows-latest, macos-latest] - python-version: ['3.10', '3.11', '3.12'] - - steps: - - uses: actions/checkout@v4 - - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v5 - with: - python-version: ${{ matrix.python-version }} - - - name: Install dependencies - run: | - python -m pip install --upgrade pip - pip install -e . - pip install pytest pytest-cov - - - name: Run unit tests (no data required) - run: | - pytest tests/ -v -m "not requires_data and not slow" --cov=src/physiomotion4d --cov-report=xml - - - name: Upload coverage to Codecov - uses: codecov/codecov-action@v4 - if: matrix.os == 'ubuntu-latest' && matrix.python-version == '3.10' - with: - file: ./coverage.xml - flags: unittests - name: codecov-umbrella - fail_ci_if_error: false - - test-with-data: - runs-on: ubuntu-latest - if: github.event_name == 'pull_request' - - steps: - - uses: actions/checkout@v4 - - - name: Set up Python 3.10 - uses: actions/setup-python@v5 - with: - python-version: '3.10' - - - name: Install dependencies - run: | - python -m pip install --upgrade pip - pip install -e . - pip install pytest pytest-cov - - - name: Cache test data - uses: actions/cache@v4 - with: - path: | - tests/data/ - tests/results/ - key: test-data-${{ hashFiles('tests/test_*.py') }}-v2 - restore-keys: | - test-data- - - - name: Run data download tests - run: | - pytest tests/test_download_heart_data.py -v --cov=src/physiomotion4d --cov-report=xml - continue-on-error: true - - - name: Run data conversion tests - run: | - pytest tests/test_convert_nrrd_4d_to_3d.py -v --cov=src/physiomotion4d --cov-append --cov-report=xml - continue-on-error: true - - - name: Run contour tools tests - run: | - pytest tests/test_contour_tools.py -v -m "not slow" --cov=src/physiomotion4d --cov-append --cov-report=xml - continue-on-error: true - - - name: Run USD conversion tests - run: | - pytest tests/test_convert_vtk_4d_to_usd_polymesh.py -v -m "not slow" --cov=src/physiomotion4d --cov-append --cov-report=xml - continue-on-error: true - - - name: Run existing USD tests - run: | - pytest tests/test_usd_merge.py tests/test_usd_time_preservation.py -v --cov=src/physiomotion4d --cov-append --cov-report=xml - continue-on-error: true - - - name: Upload coverage - uses: codecov/codecov-action@v4 - with: - file: ./coverage.xml - flags: integration-tests - fail_ci_if_error: false - - # Note: Slow and GPU-dependent tests are excluded from CI - # These tests should be run locally: - # - tests/test_register_images_ants.py (slow, computationally intensive) - # - tests/test_register_images_icon.py (requires CUDA for ICON) - # - tests/test_transform_tools.py (depends on slow registration tests) - # - tests/test_segment_chest_total_segmentator.py (requires CUDA for TotalSegmentator) - # - tests/test_segment_chest_vista_3d.py (requires CUDA for VISTA-3D) - # - # To run slow/GPU tests locally: - # pytest tests/ -v -m "slow" # Run all slow tests - # pytest tests/test_register_images_ants.py -v -s - # pytest tests/test_register_images_icon.py -v -s - # pytest tests/test_transform_tools.py -v -s - # pytest tests/test_segment_chest_total_segmentator.py -v -s - # pytest tests/test_segment_chest_vista_3d.py -v -s diff --git a/.gitignore b/.gitignore index 157bc21..6e23e4b 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,7 @@ .coverage coverage.xml *~ +*.swp venv* .venv build @@ -17,6 +18,7 @@ network_weights # Data files *.gz +*.mat *.mha *.mhd *.zip @@ -41,4 +43,4 @@ results* tests/data/ tests/results/ tests/baseline_staging/ -test_output.txt \ No newline at end of file +test_output.txt diff --git a/README.md b/README.md index cec243a..ba016eb 100644 --- a/README.md +++ b/README.md @@ -98,9 +98,9 @@ print(f"PhysioMotion4D version: {physiomotion4d.__version__}") - `RegisterImagesANTs`: Classical deformable registration using ANTs - `RegisterTimeSeriesImages`: Specialized time series registration for 4D CT - Model-to-Image/Model Registration: - - `RegisterModelToImagePCA`: PCA-based statistical shape model registration - - `RegisterModelToModelICP`: ICP-based surface registration - - `RegisterModelToModelMasks`: Mask-based deformable model registration + - `RegisterModelsPCA`: PCA-based statistical shape model registration + - `RegisterModelsICP`: ICP-based surface registration + - `RegisterModelsDistanceMaps`: Mask-based deformable model registration - `RegisterImagesBase`: Base class for custom registration methods - **Base Classes**: Foundation classes providing common functionality - `PhysioMotion4DBase`: Base class providing standardized logging and debug settings @@ -227,8 +227,8 @@ transforms = time_series_reg.register_time_series( ) # Get forward and inverse displacement fields -phi_FM = results["phi_FM"] # Fixed to moving -phi_MF = results["phi_MF"] # Moving to fixed +inverse_transform = results["inverse_transform"] # Fixed to moving +forward_transform = results["forward_transform"] # Moving to fixed ``` ### Logging and Debug Control @@ -243,7 +243,7 @@ from physiomotion4d import HeartModelToPatientWorkflow, PhysioMotion4DBase PhysioMotion4DBase.set_log_level(logging.DEBUG) # Or filter to show logs from specific classes only -PhysioMotion4DBase.set_log_classes(["HeartModelToPatientWorkflow", "RegisterModelToImagePCA"]) +PhysioMotion4DBase.set_log_classes(["HeartModelToPatientWorkflow", "RegisterModelsPCA"]) # Show all classes again PhysioMotion4DBase.set_log_all_classes() diff --git a/data/KCL-Heart-Model/1-input_meshes_to_input_surfaces.ipynb b/data/KCL-Heart-Model/1-input_meshes_to_input_surfaces.ipynb index f76d5b6..4450258 100644 --- a/data/KCL-Heart-Model/1-input_meshes_to_input_surfaces.ipynb +++ b/data/KCL-Heart-Model/1-input_meshes_to_input_surfaces.ipynb @@ -78,7 +78,7 @@ "metadata": {}, "outputs": [], "source": [ - "# Process average_mesh.vtk\n", + "# Process average_mesh.vtk provided from the KCL data collection.\n", "\n", "mesh = pv.read('average_mesh.vtk')\n", "\n", diff --git a/data/KCL-Heart-Model/2-input_surfaces_to_surfaces_aligned.ipynb b/data/KCL-Heart-Model/2-input_surfaces_to_surfaces_aligned.ipynb index 9addfa5..070ddd0 100644 --- a/data/KCL-Heart-Model/2-input_surfaces_to_surfaces_aligned.ipynb +++ b/data/KCL-Heart-Model/2-input_surfaces_to_surfaces_aligned.ipynb @@ -1,341 +1,341 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# ICP Rigid Registration: Align Heart Models to Average\n", - "\n", - "This notebook performs ICP (Iterative Closest Point) rigid registration to align each individual heart model to the average model.\n", - "\n", - "**Workflow:**\n", - "1. Load the average mesh (`input_meshes/average.vtk`)\n", - "2. Load each individual mesh (`input_meshes/01.vtk` through `20.vtk`)\n", - "3. Use ICP rigid registration to align each mesh to the average\n", - "4. Save the aligned meshes to `icp_aligned_meshes/`\n", - "5. Visualize the results\n" - ] + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# ICP Rigid Registration: Align Heart Models to Average\n", + "\n", + "This notebook performs ICP (Iterative Closest Point) rigid registration to align each individual heart model to the average model.\n", + "\n", + "**Workflow:**\n", + "1. Load the average mesh (`input_meshes/average.vtk`)\n", + "2. Load each individual mesh (`input_meshes/01.vtk` through `20.vtk`)\n", + "3. Use ICP rigid registration to align each mesh to the average\n", + "4. Save the aligned meshes to `icp_aligned_meshes/`\n", + "5. Visualize the results\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import sys\n", + "from pathlib import Path\n", + "\n", + "import numpy as np\n", + "import pyvista as pv\n", + "\n", + "# Add the src directory to the path to import the registration class\n", + "sys.path.insert(0, str(Path.cwd().parent.parent / 'src'))\n", + "\n", + "from physiomotion4d.register_model_to_model_icp import RegisterModelsICP\n", + "\n", + "# Enable interactive plotting\n", + "pv.set_jupyter_backend('trame')\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1. Load the Average Mesh (Target)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Load the average mesh - this will be our fixed target\n", + "template_mesh_path = Path('./average.vtk')\n", + "template_mesh = pv.read(template_mesh_path)\n", + "\n", + "print(f\"Average mesh loaded:\")\n", + "print(f\" Points: {template_mesh.n_points}\")\n", + "print(f\" Cells: {template_mesh.n_cells}\")\n", + "print(f\" Center: {template_mesh.center}\")\n", + "print(f\" Bounds: {template_mesh.bounds}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2. Find All Individual Mesh Files\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Get all individual mesh files (excluding average.vtk)\n", + "input_meshes_dir = Path('surfaces')\n", + "mesh_files = sorted([f for f in input_meshes_dir.glob('??.vtp')])\n", + "\n", + "print(f\"Found {len(mesh_files)} individual mesh files:\")\n", + "for mesh_file in mesh_files:\n", + " print(f\" {mesh_file.name}\")\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3. Perform ICP Rigid Registration for Each Mesh\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Create output directory for aligned meshes\n", + "output_dir = Path('surfaces_aligned')\n", + "output_dir.mkdir(exist_ok=True)\n", + "\n", + "# Store results\n", + "aligned_meshes = {}\n", + "transforms_point_forward = {} # Moving to Fixed point transforms (forward_point_transform)\n", + "transforms_point_inverse = {} # Fixed to Moving point transforms (inverse_point_transform)\n", + "\n", + "# Process each mesh\n", + "for mesh_file in mesh_files:\n", + " print(f\"\\n{'='*60}\")\n", + " print(f\"Processing: {mesh_file.name}\")\n", + " print(f\"{'='*60}\")\n", + " \n", + " # Load the moving mesh\n", + " moving_mesh = pv.read(mesh_file)\n", + " print(f\" Loaded mesh: {moving_mesh.n_points} points\")\n", + " \n", + " # Extract surface if needed (in case it's a volume mesh)\n", + " if isinstance(moving_mesh, pv.UnstructuredGrid):\n", + " print(f\" Extracting surface from volume mesh...\")\n", + " moving_mesh = moving_mesh.extract_surface()\n", + " print(f\" Surface mesh: {moving_mesh.n_points} points\")\n", + " \n", + " # Initialize registrar\n", + " registrar = RegisterModelsICP(\n", + " moving_mesh=moving_mesh,\n", + " fixed_mesh=template_mesh\n", + " )\n", + " \n", + " # Perform rigid ICP registration\n", + " result = registrar.register(mode='rigid', max_iterations=2000)\n", + " \n", + " # Store results\n", + " mesh_id = mesh_file.stem\n", + " aligned_meshes[mesh_id] = result['moving_mesh']\n", + " transforms_point_forward[mesh_id] = result['forward_point_transform']\n", + " transforms_point_inverse[mesh_id] = result['inverse_point_transform']\n", + " \n", + " # Save aligned mesh\n", + " output_path = output_dir / f\"{mesh_id}.vtp\"\n", + " result['moving_mesh'].save(output_path)\n", + " print(f\"\\n Saved aligned mesh: {output_path}\")\n", + "\n", + "print(f\"\\n{'='*60}\")\n", + "print(f\"ICP registration complete for all {len(mesh_files)} meshes!\")\n", + "print(f\"{'='*60}\")\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 4. Visualize Results: Before and After Registration\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Select a few examples to visualize (e.g., 01, 05, 10, 15, 20)\n", + "example_ids = ['01', '05', '10', '15', '20']\n", + "\n", + "for mesh_id in example_ids:\n", + " if mesh_id not in aligned_meshes:\n", + " continue\n", + " \n", + " # Load original mesh\n", + " original_mesh = pv.read(f'input_meshes/{mesh_id}.vtk')\n", + " if isinstance(original_mesh, pv.UnstructuredGrid):\n", + " original_mesh = original_mesh.extract_surface()\n", + " \n", + " # Create side-by-side comparison\n", + " plotter = pv.Plotter(shape=(1, 2))\n", + " \n", + " # Left: Before registration\n", + " plotter.subplot(0, 0)\n", + " plotter.add_mesh(template_mesh, color='lightblue', opacity=1.0, label='Average')\n", + " plotter.add_mesh(original_mesh, color='red', opacity=1.0, label=f'Original {mesh_id}')\n", + " plotter.add_text(f'Before ICP Registration - {mesh_id}', position='upper_left', font_size=10)\n", + " plotter.add_legend()\n", + " plotter.show_axes()\n", + " \n", + " # Right: After registration\n", + " plotter.subplot(0, 1)\n", + " plotter.add_mesh(template_mesh, color='lightblue', opacity=1.0, label='Average')\n", + " plotter.add_mesh(aligned_meshes[mesh_id], color='green', opacity=1.0, label=f'Aligned {mesh_id}')\n", + " plotter.add_text(f'After ICP Registration - {mesh_id}', position='upper_left', font_size=10)\n", + " plotter.add_legend()\n", + " plotter.show_axes()\n", + " \n", + " plotter.link_views()\n", + " plotter.show()\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 6. Calculate Registration Statistics\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "\n", + "# Calculate statistics for each registration\n", + "stats_data = []\n", + "\n", + "for mesh_id, aligned_mesh in aligned_meshes.items():\n", + " # Calculate distance from aligned mesh to average mesh\n", + " # Using point-to-point distance as a metric\n", + " \n", + " # Get closest points on average mesh for each point in aligned mesh\n", + " closest_points = template_mesh.find_closest_cell(aligned_mesh.points, return_closest_point=True)[1]\n", + " \n", + " # Calculate distances\n", + " distances = np.linalg.norm(aligned_mesh.points - closest_points, axis=1)\n", + " \n", + " stats_data.append({\n", + " 'Mesh ID': mesh_id,\n", + " 'Mean Distance (mm)': np.mean(distances),\n", + " 'Median Distance (mm)': np.median(distances),\n", + " 'Std Distance (mm)': np.std(distances),\n", + " 'Max Distance (mm)': np.max(distances),\n", + " 'Min Distance (mm)': np.min(distances)\n", + " })\n", + "\n", + "# Create DataFrame and display\n", + "stats_df = pd.DataFrame(stats_data)\n", + "stats_df = stats_df.sort_values('Mesh ID')\n", + "\n", + "print(\"\\nRegistration Statistics (Distance from aligned mesh to average mesh):\")\n", + "print(\"=\"*80)\n", + "print(stats_df.to_string(index=False))\n", + "print(\"=\"*80)\n", + "\n", + "# Summary statistics\n", + "print(f\"\\nOverall Summary:\")\n", + "print(f\" Average mean distance: {stats_df['Mean Distance (mm)'].mean():.3f} mm\")\n", + "print(f\" Average median distance: {stats_df['Median Distance (mm)'].mean():.3f} mm\")\n", + "print(f\" Range of mean distances: {stats_df['Mean Distance (mm)'].min():.3f} - {stats_df['Mean Distance (mm)'].max():.3f} mm\")\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 7. Save Registration Statistics\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Save statistics to CSV\n", + "stats_csv_path = output_dir / 'registration_statistics.csv'\n", + "stats_df.to_csv(stats_csv_path, index=False)\n", + "print(f\"\\nStatistics saved to: {stats_csv_path}\")\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 8. Visualize Distance Distributions\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "\n", + "# Create bar plot of mean distances\n", + "fig, axes = plt.subplots(2, 1, figsize=(12, 8))\n", + "\n", + "# Plot 1: Mean distances\n", + "axes[0].bar(stats_df['Mesh ID'], stats_df['Mean Distance (mm)'], color='steelblue')\n", + "axes[0].set_xlabel('Mesh ID')\n", + "axes[0].set_ylabel('Mean Distance (mm)')\n", + "axes[0].set_title('Mean Distance from Aligned Mesh to Average Mesh (After ICP Registration)')\n", + "axes[0].grid(axis='y', alpha=0.3)\n", + "\n", + "# Plot 2: Box plot style visualization\n", + "axes[1].errorbar(\n", + " stats_df['Mesh ID'], \n", + " stats_df['Median Distance (mm)'],\n", + " yerr=stats_df['Std Distance (mm)'],\n", + " fmt='o',\n", + " capsize=5,\n", + " capthick=2,\n", + " color='coral',\n", + " ecolor='gray',\n", + " label='Median ± Std'\n", + ")\n", + "axes[1].set_xlabel('Mesh ID')\n", + "axes[1].set_ylabel('Distance (mm)')\n", + "axes[1].set_title('Median Distance ± Standard Deviation')\n", + "axes[1].legend()\n", + "axes[1].grid(axis='y', alpha=0.3)\n", + "\n", + "plt.tight_layout()\n", + "plt.savefig(output_dir / 'registration_statistics.png', dpi=150, bbox_inches='tight')\n", + "plt.show()\n", + "\n", + "print(f\"\\nPlot saved to: {output_dir / 'registration_statistics.png'}\")\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.11" + } }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import sys\n", - "import os\n", - "from pathlib import Path\n", - "import numpy as np\n", - "import pyvista as pv\n", - "\n", - "# Add the src directory to the path to import the registration class\n", - "sys.path.insert(0, str(Path.cwd().parent.parent / 'src'))\n", - "\n", - "from physiomotion4d.register_model_to_model_icp import RegisterModelToModelICP\n", - "\n", - "# Enable interactive plotting\n", - "pv.set_jupyter_backend('trame')\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 1. Load the Average Mesh (Target)\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Load the average mesh - this will be our fixed target\n", - "average_mesh_path = Path('./average.vtk')\n", - "average_mesh = pv.read(average_mesh_path)\n", - "\n", - "print(f\"Average mesh loaded:\")\n", - "print(f\" Points: {average_mesh.n_points}\")\n", - "print(f\" Cells: {average_mesh.n_cells}\")\n", - "print(f\" Center: {average_mesh.center}\")\n", - "print(f\" Bounds: {average_mesh.bounds}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 2. Find All Individual Mesh Files\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Get all individual mesh files (excluding average.vtk)\n", - "input_meshes_dir = Path('surfaces')\n", - "mesh_files = sorted([f for f in input_meshes_dir.glob('??.vtp')])\n", - "\n", - "print(f\"Found {len(mesh_files)} individual mesh files:\")\n", - "for mesh_file in mesh_files:\n", - " print(f\" {mesh_file.name}\")\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 3. Perform ICP Rigid Registration for Each Mesh\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Create output directory for aligned meshes\n", - "output_dir = Path('surfaces_aligned')\n", - "output_dir.mkdir(exist_ok=True)\n", - "\n", - "# Store results\n", - "aligned_meshes = {}\n", - "transforms_MF = {} # Moving to Fixed transforms\n", - "transforms_FM = {} # Fixed to Moving transforms\n", - "\n", - "# Process each mesh\n", - "for mesh_file in mesh_files:\n", - " print(f\"\\n{'='*60}\")\n", - " print(f\"Processing: {mesh_file.name}\")\n", - " print(f\"{'='*60}\")\n", - " \n", - " # Load the moving mesh\n", - " moving_mesh = pv.read(mesh_file)\n", - " print(f\" Loaded mesh: {moving_mesh.n_points} points\")\n", - " \n", - " # Extract surface if needed (in case it's a volume mesh)\n", - " if isinstance(moving_mesh, pv.UnstructuredGrid):\n", - " print(f\" Extracting surface from volume mesh...\")\n", - " moving_mesh = moving_mesh.extract_surface()\n", - " print(f\" Surface mesh: {moving_mesh.n_points} points\")\n", - " \n", - " # Initialize registrar\n", - " registrar = RegisterModelToModelICP(\n", - " moving_mesh=moving_mesh,\n", - " fixed_mesh=average_mesh\n", - " )\n", - " \n", - " # Perform rigid ICP registration\n", - " result = registrar.register(mode='rigid', max_iterations=2000)\n", - " \n", - " # Store results\n", - " mesh_id = mesh_file.stem\n", - " aligned_meshes[mesh_id] = result['moving_mesh']\n", - " transforms_MF[mesh_id] = result['phi_MF']\n", - " transforms_FM[mesh_id] = result['phi_FM']\n", - " \n", - " # Save aligned mesh\n", - " output_path = output_dir / f\"{mesh_id}.vtp\"\n", - " result['moving_mesh'].save(output_path)\n", - " print(f\"\\n Saved aligned mesh: {output_path}\")\n", - "\n", - "print(f\"\\n{'='*60}\")\n", - "print(f\"ICP registration complete for all {len(mesh_files)} meshes!\")\n", - "print(f\"{'='*60}\")\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 4. Visualize Results: Before and After Registration\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Select a few examples to visualize (e.g., 01, 05, 10, 15, 20)\n", - "example_ids = ['01', '05', '10', '15', '20']\n", - "\n", - "for mesh_id in example_ids:\n", - " if mesh_id not in aligned_meshes:\n", - " continue\n", - " \n", - " # Load original mesh\n", - " original_mesh = pv.read(f'input_meshes/{mesh_id}.vtk')\n", - " if isinstance(original_mesh, pv.UnstructuredGrid):\n", - " original_mesh = original_mesh.extract_surface()\n", - " \n", - " # Create side-by-side comparison\n", - " plotter = pv.Plotter(shape=(1, 2))\n", - " \n", - " # Left: Before registration\n", - " plotter.subplot(0, 0)\n", - " plotter.add_mesh(average_mesh, color='lightblue', opacity=1.0, label='Average')\n", - " plotter.add_mesh(original_mesh, color='red', opacity=1.0, label=f'Original {mesh_id}')\n", - " plotter.add_text(f'Before ICP Registration - {mesh_id}', position='upper_left', font_size=10)\n", - " plotter.add_legend()\n", - " plotter.show_axes()\n", - " \n", - " # Right: After registration\n", - " plotter.subplot(0, 1)\n", - " plotter.add_mesh(average_mesh, color='lightblue', opacity=1.0, label='Average')\n", - " plotter.add_mesh(aligned_meshes[mesh_id], color='green', opacity=1.0, label=f'Aligned {mesh_id}')\n", - " plotter.add_text(f'After ICP Registration - {mesh_id}', position='upper_left', font_size=10)\n", - " plotter.add_legend()\n", - " plotter.show_axes()\n", - " \n", - " plotter.link_views()\n", - " plotter.show()\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 6. Calculate Registration Statistics\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import pandas as pd\n", - "\n", - "# Calculate statistics for each registration\n", - "stats_data = []\n", - "\n", - "for mesh_id, aligned_mesh in aligned_meshes.items():\n", - " # Calculate distance from aligned mesh to average mesh\n", - " # Using point-to-point distance as a metric\n", - " \n", - " # Get closest points on average mesh for each point in aligned mesh\n", - " closest_points = average_mesh.find_closest_cell(aligned_mesh.points, return_closest_point=True)[1]\n", - " \n", - " # Calculate distances\n", - " distances = np.linalg.norm(aligned_mesh.points - closest_points, axis=1)\n", - " \n", - " stats_data.append({\n", - " 'Mesh ID': mesh_id,\n", - " 'Mean Distance (mm)': np.mean(distances),\n", - " 'Median Distance (mm)': np.median(distances),\n", - " 'Std Distance (mm)': np.std(distances),\n", - " 'Max Distance (mm)': np.max(distances),\n", - " 'Min Distance (mm)': np.min(distances)\n", - " })\n", - "\n", - "# Create DataFrame and display\n", - "stats_df = pd.DataFrame(stats_data)\n", - "stats_df = stats_df.sort_values('Mesh ID')\n", - "\n", - "print(\"\\nRegistration Statistics (Distance from aligned mesh to average mesh):\")\n", - "print(\"=\"*80)\n", - "print(stats_df.to_string(index=False))\n", - "print(\"=\"*80)\n", - "\n", - "# Summary statistics\n", - "print(f\"\\nOverall Summary:\")\n", - "print(f\" Average mean distance: {stats_df['Mean Distance (mm)'].mean():.3f} mm\")\n", - "print(f\" Average median distance: {stats_df['Median Distance (mm)'].mean():.3f} mm\")\n", - "print(f\" Range of mean distances: {stats_df['Mean Distance (mm)'].min():.3f} - {stats_df['Mean Distance (mm)'].max():.3f} mm\")\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 7. Save Registration Statistics\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Save statistics to CSV\n", - "stats_csv_path = output_dir / 'registration_statistics.csv'\n", - "stats_df.to_csv(stats_csv_path, index=False)\n", - "print(f\"\\nStatistics saved to: {stats_csv_path}\")\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 8. Visualize Distance Distributions\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import matplotlib.pyplot as plt\n", - "\n", - "# Create bar plot of mean distances\n", - "fig, axes = plt.subplots(2, 1, figsize=(12, 8))\n", - "\n", - "# Plot 1: Mean distances\n", - "axes[0].bar(stats_df['Mesh ID'], stats_df['Mean Distance (mm)'], color='steelblue')\n", - "axes[0].set_xlabel('Mesh ID')\n", - "axes[0].set_ylabel('Mean Distance (mm)')\n", - "axes[0].set_title('Mean Distance from Aligned Mesh to Average Mesh (After ICP Registration)')\n", - "axes[0].grid(axis='y', alpha=0.3)\n", - "\n", - "# Plot 2: Box plot style visualization\n", - "axes[1].errorbar(\n", - " stats_df['Mesh ID'], \n", - " stats_df['Median Distance (mm)'],\n", - " yerr=stats_df['Std Distance (mm)'],\n", - " fmt='o',\n", - " capsize=5,\n", - " capthick=2,\n", - " color='coral',\n", - " ecolor='gray',\n", - " label='Median ± Std'\n", - ")\n", - "axes[1].set_xlabel('Mesh ID')\n", - "axes[1].set_ylabel('Distance (mm)')\n", - "axes[1].set_title('Median Distance ± Standard Deviation')\n", - "axes[1].legend()\n", - "axes[1].grid(axis='y', alpha=0.3)\n", - "\n", - "plt.tight_layout()\n", - "plt.savefig(output_dir / 'registration_statistics.png', dpi=150, bbox_inches='tight')\n", - "plt.show()\n", - "\n", - "print(f\"\\nPlot saved to: {output_dir / 'registration_statistics.png'}\")\n" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "venv", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.11" - } - }, - "nbformat": 4, - "nbformat_minor": 2 + "nbformat": 4, + "nbformat_minor": 2 } diff --git a/data/KCL-Heart-Model/4-surfaces_aligned_correspond_to_pca_inputs.ipynb b/data/KCL-Heart-Model/4-surfaces_aligned_correspond_to_pca_inputs.ipynb index c7a2b8d..9deb841 100644 --- a/data/KCL-Heart-Model/4-surfaces_aligned_correspond_to_pca_inputs.ipynb +++ b/data/KCL-Heart-Model/4-surfaces_aligned_correspond_to_pca_inputs.ipynb @@ -30,10 +30,8 @@ "from pathlib import Path\n", "\n", "import itk\n", - "import numpy as np\n", "import pyvista as pv\n", - "\n", - "from physiomotion4d.register_model_to_model_masks import RegisterModelToModelMasks\n", + "from physiomotion4d.register_model_to_model_masks import RegisterModelsDistanceMaps\n", "\n", "# Enable interactive plotting\n", "pv.set_jupyter_backend('trame')\n" @@ -61,7 +59,7 @@ "output_dir.mkdir(exist_ok=True)\n", "\n", "# Find all VTK/VTP files in correspondence directory\n", - "correspond_files = sorted(list(correspond_dir.glob('*.vtp')))\n", + "correspond_files = sorted(correspond_dir.glob('*.vtp'))\n", "\n", "print(f\"Found {len(correspond_files)} files in {correspond_dir}/\")\n", "for f in correspond_files:\n", @@ -99,10 +97,10 @@ " \"\"\"\n", " Process a correspondence mesh by averaging surface positions based on vtkOriginalPointIds.\n", " \"\"\"\n", - " \n", + "\n", " # Load the correspondence mesh\n", " correspond_mesh = pv.read(correspond_file)\n", - " \n", + "\n", " base_name = correspond_file.stem.replace('_correspond', '')\n", " surface_file = surfaces_dir / f\"{base_name}.vtp\"\n", " if not surface_file.exists():\n", @@ -110,7 +108,7 @@ " return None\n", " surface_mesh = pv.read(surface_file)\n", "\n", - " registrar = RegisterModelToModelMasks(\n", + " registrar = RegisterModelsDistanceMaps(\n", " moving_mesh=correspond_mesh,\n", " fixed_mesh=surface_mesh,\n", " reference_image=ref_image,\n", @@ -124,7 +122,7 @@ " output_file = output_dir / f\"{base_name}.vtk\"\n", " registered_mesh.save(output_file)\n", " print(f\" Saved to: {output_file}\")\n", - " \n", + "\n", " return registered_mesh\n" ] }, @@ -145,8 +143,8 @@ "processed_meshes = {}\n", "failed_files = []\n", "\n", - "average_mesh = pv.read(\"average_surface.vtp\")\n", - "bounds = average_mesh.bounds\n", + "template_mesh = pv.read(\"average_surface.vtp\")\n", + "bounds = template_mesh.bounds\n", "xmin = bounds[0]\n", "xmax = bounds[1]\n", "ymin = bounds[2]\n", @@ -188,9 +186,9 @@ " except Exception as e:\n", " print(f\"\\n ERROR processing {correspond_file.name}: {str(e)}\")\n", " failed_files.append(correspond_file.name)\n", - " \n", + "\n", " print(f\"\\n{'='*70}\")\n", - " print(f\"Processing Complete!\")\n", + " print(\"Processing Complete!\")\n", " print(f\"{'='*70}\")\n", " print(f\" Successfully processed: {len(processed_meshes)} files\")\n", " print(f\" Failed: {len(failed_files)} files\")\n", diff --git a/data/KCL-Heart-Model/6-create_template_mask.ipynb b/data/KCL-Heart-Model/6-create_template_mask.ipynb new file mode 100644 index 0000000..eea4963 --- /dev/null +++ b/data/KCL-Heart-Model/6-create_template_mask.ipynb @@ -0,0 +1,66 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 2, + "id": "2616aac8", + "metadata": {}, + "outputs": [], + "source": [ + "import itk\n", + "import pyvista as pv\n", + "\n", + "from physiomotion4d import ContourTools\n", + "\n", + "template_model_surface = pv.read('average_surface.vtp')\n", + "\n", + "contour_tools = ContourTools()\n", + "\n", + "reference_image = contour_tools.create_reference_image(\n", + " template_model_surface,\n", + ")\n", + "\n", + "template_mask = contour_tools.create_mask_from_mesh(\n", + " template_model_surface,\n", + " reference_image,\n", + ")\n", + "\n", + "itk.imwrite(template_mask, \"average_binary_mask.nii.gz\")\n", + "\n", + "\n", + "# Then use 3D Slicer to assign the following labels:\n", + "# 1: heart muscle\n", + "# 2: right ventricle\n", + "# 3: left ventricle\n", + "# 4: right atrium\n", + "# 5: left atrium\n", + "# 6: background rim around heart muscle\n", + "\n", + "# save the 3D Slicer labelmap file as \"average_labelmap.nii.gz\"\n", + "\n", + "\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.11" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/LOGGING_API_REFERENCE.md b/docs/LOGGING_API_REFERENCE.md index c25f294..6d0fdbb 100644 --- a/docs/LOGGING_API_REFERENCE.md +++ b/docs/LOGGING_API_REFERENCE.md @@ -15,7 +15,7 @@ PhysioMotion4DBase.set_log_level(logging.INFO) PhysioMotion4DBase.set_log_level('DEBUG') # Can use string too # Filter to show only specific classes -PhysioMotion4DBase.set_log_classes(["RegisterModelToImagePCA", "HeartModelToPatientWorkflow"]) +PhysioMotion4DBase.set_log_classes(["RegisterModelsPCA", "HeartModelToPatientWorkflow"]) # Show all classes (disable filtering) PhysioMotion4DBase.set_log_all_classes() @@ -44,24 +44,24 @@ obj.log_progress(current, total, prefix="Processing") ### PhysioMotion4DBase Class Methods -| Method | Parameters | Description | -|--------|------------|-------------| -| `set_log_level(log_level)` | `log_level: int \| str` | Set logging level for all classes | -| `set_log_classes(class_names)` | `class_names: list[str]` | Show logs only from specified classes | -| `set_log_all_classes()` | None | Show logs from all classes | -| `get_log_classes()` | None | Get list of currently filtered classes | +| Method | Parameters | Description | +| ------------------------------ | ------------------------ | -------------------------------------- | +| `set_log_level(log_level)` | `log_level: int \| str` | Set logging level for all classes | +| `set_log_classes(class_names)` | `class_names: list[str]` | Show logs only from specified classes | +| `set_log_all_classes()` | None | Show logs from all classes | +| `get_log_classes()` | None | Get list of currently filtered classes | ### Instance Methods -| Method | Parameters | Description | -|--------|------------|-------------| -| `log_debug(message)` | `message: str` | Log DEBUG level message | -| `log_info(message)` | `message: str` | Log INFO level message | -| `log_warning(message)` | `message: str` | Log WARNING level message | -| `log_error(message)` | `message: str` | Log ERROR level message | -| `log_critical(message)` | `message: str` | Log CRITICAL level message | -| `log_section(title, width, char)` | `title: str, width: int=70, char: str='='` | Log formatted section header | -| `log_progress(current, total, prefix)` | `current: int, total: int, prefix: str='Progress'` | Log progress information | +| Method | Parameters | Description | +| -------------------------------------- | -------------------------------------------------- | ---------------------------- | +| `log_debug(message)` | `message: str` | Log DEBUG level message | +| `log_info(message)` | `message: str` | Log INFO level message | +| `log_warning(message)` | `message: str` | Log WARNING level message | +| `log_error(message)` | `message: str` | Log ERROR level message | +| `log_critical(message)` | `message: str` | Log CRITICAL level message | +| `log_section(title, width, char)` | `title: str, width: int=70, char: str='='` | Log formatted section header | +| `log_progress(current, total, prefix)` | `current: int, total: int, prefix: str='Progress'` | Log progress information | ## Usage Patterns @@ -108,7 +108,7 @@ def process_items(self, items): ### Pattern 4: Global Log Control ```python # Create multiple objects -pca_reg = RegisterModelToImagePCA(..., log_level=logging.INFO) +pca_reg = RegisterModelsPCA(..., log_level=logging.INFO) workflow = HeartModelToPatientWorkflow(..., log_level=logging.INFO) # Change log level for both at once @@ -125,14 +125,14 @@ PhysioMotion4DBase.set_log_level(logging.WARNING) ### Pattern 5: Selective Class Filtering ```python # Show only PCA registration logs -PhysioMotion4DBase.set_log_classes(["RegisterModelToImagePCA"]) +PhysioMotion4DBase.set_log_classes(["RegisterModelsPCA"]) pca_reg.log_info("This is shown") workflow.log_info("This is hidden") # Show both classes PhysioMotion4DBase.set_log_classes([ - "RegisterModelToImagePCA", + "RegisterModelsPCA", "HeartModelToPatientWorkflow" ]) @@ -151,7 +151,7 @@ workflow.run_workflow() # Focus on specific class for debugging PhysioMotion4DBase.set_log_level(logging.DEBUG) -PhysioMotion4DBase.set_log_classes(["RegisterModelToImagePCA"]) +PhysioMotion4DBase.set_log_classes(["RegisterModelsPCA"]) # Now only PCA registration shows detailed debug output workflow.run_workflow() @@ -163,13 +163,13 @@ PhysioMotion4DBase.set_log_all_classes() ## Log Levels -| Level | Numeric Value | When to Use | -|-------|---------------|-------------| -| `logging.DEBUG` | 10 | Detailed diagnostic information for debugging | -| `logging.INFO` | 20 | General informational messages (default) | -| `logging.WARNING` | 30 | Warning messages about potential issues | -| `logging.ERROR` | 40 | Error messages for serious problems | -| `logging.CRITICAL` | 50 | Critical errors that may cause termination | +| Level | Numeric Value | When to Use | +| ------------------ | ------------- | --------------------------------------------- | +| `logging.DEBUG` | 10 | Detailed diagnostic information for debugging | +| `logging.INFO` | 20 | General informational messages (default) | +| `logging.WARNING` | 30 | Warning messages about potential issues | +| `logging.ERROR` | 40 | Error messages for serious problems | +| `logging.CRITICAL` | 50 | Critical errors that may cause termination | ## Output Format @@ -180,15 +180,15 @@ TIMESTAMP - PhysioMotion4D - LEVEL - [ClassName] Message Example: ``` -2025-12-13 11:35:27 - PhysioMotion4D - INFO - [RegisterModelToImagePCA] Converting mean shape points... +2025-12-13 11:35:27 - PhysioMotion4D - INFO - [RegisterModelsPCA] Converting mean shape points... 2025-12-13 11:35:27 - PhysioMotion4D - DEBUG - [HeartModelToPatientWorkflow] Auto-generating masks... -2025-12-13 11:35:27 - PhysioMotion4D - WARNING - [RegisterModelToImagePCA] No points found within threshold +2025-12-13 11:35:27 - PhysioMotion4D - WARNING - [RegisterModelsPCA] No points found within threshold ``` ## Available Classes Current PhysioMotion4D classes with logging support: -- `RegisterModelToImagePCA` - PCA-based model-to-image registration +- `RegisterModelsPCA` - PCA-based model-to-image registration - `HeartModelToPatientWorkflow` - Multi-stage heart model registration - (More classes will be added as they are converted) @@ -197,7 +197,7 @@ Current PhysioMotion4D classes with logging support: ### Use Case 1: Normal Operation ```python # Default: INFO level, all classes shown -registrar = RegisterModelToImagePCA(..., log_level=logging.INFO) +registrar = RegisterModelsPCA(..., log_level=logging.INFO) result = registrar.register(...) ``` @@ -205,15 +205,15 @@ result = registrar.register(...) ```python # WARNING level: only warnings and errors PhysioMotion4DBase.set_log_level(logging.WARNING) -registrar = RegisterModelToImagePCA(...) +registrar = RegisterModelsPCA(...) result = registrar.register(...) ``` ### Use Case 3: Debug Specific Component ```python -# Debug only RegisterModelToImagePCA +# Debug only RegisterModelsPCA PhysioMotion4DBase.set_log_level(logging.DEBUG) -PhysioMotion4DBase.set_log_classes(["RegisterModelToImagePCA"]) +PhysioMotion4DBase.set_log_classes(["RegisterModelsPCA"]) workflow = HeartModelToPatientWorkflow(...) # Only PCA component will show debug messages @@ -223,7 +223,7 @@ workflow.run_workflow() ### Use Case 4: Production Logging to File ```python # Log everything to file for analysis -registrar = RegisterModelToImagePCA( +registrar = RegisterModelsPCA( ..., log_level=logging.DEBUG, log_to_file="registration.log" @@ -239,7 +239,7 @@ PhysioMotion4DBase.set_log_level(logging.INFO) PhysioMotion4DBase.set_log_level(logging.DEBUG) # Focus on specific class -PhysioMotion4DBase.set_log_classes(["RegisterModelToImagePCA"]) +PhysioMotion4DBase.set_log_classes(["RegisterModelsPCA"]) # Re-run the problematic part registrar.optimize_rigid_alignment(...) diff --git a/docs/ants_initial_transform_guide.md b/docs/ants_initial_transform_guide.md index a97a8dc..f30f427 100644 --- a/docs/ants_initial_transform_guide.md +++ b/docs/ants_initial_transform_guide.md @@ -30,12 +30,12 @@ def itk_transform_to_ants_transform( 3. Creates an ANTsPy transform object using `ants.transform_from_displacement_field()` 4. Returns the ANTsPy transform object (no disk I/O) -### Updated Method: `registration_method()` +### Updated Method: `register()` -The registration method now accepts an `initial_phi_MF` parameter: +The registration method now accepts an `initial_forward_transform` parameter: **New Parameter:** -- `initial_phi_MF` (itk.Transform, optional): Initial transform from moving to fixed space. Can be any ITK transform type. +- `initial_forward_transform` (itk.Transform, optional): Initial transform from moving to fixed space. Can be any ITK transform type. **Supported Transform Types:** - `itk.AffineTransform` @@ -69,7 +69,7 @@ registrar.set_fixed_image(fixed_image) result = registrar.register( moving_image=moving_image, - initial_phi_MF=initial_tfm # Pass ITK transform directly! + initial_forward_transform=initial_tfm # Pass ITK transform directly! ) ``` @@ -86,7 +86,7 @@ registrar_deform = RegisterImagesANTs() registrar_deform.set_fixed_image(fixed_image) result_deform = registrar_deform.register( moving_image=moving_image, - initial_phi_MF=result_rigid["phi_MF"] # Use previous result + initial_forward_transform=result_rigid["forward_transform"] # Use previous result ) ``` @@ -105,7 +105,7 @@ registrar = RegisterImagesANTs() registrar.set_fixed_image(fixed_image) result = registrar.register( moving_image=moving_image, - initial_phi_MF=disp_tfm + initial_forward_transform=disp_tfm ) ``` @@ -118,8 +118,7 @@ registrar.set_fixed_image(fixed_labels) result = registrar.register( moving_image=moving_labels, - images_are_labelmaps=True, - initial_phi_MF=initial_transform # Works with label registration too! + initial_forward_transform=initial_transform ) ``` @@ -144,7 +143,7 @@ registrar = RegisterImagesANTs() registrar.set_fixed_image(fixed_image) result = registrar.register( moving_image=moving_image, - initial_phi_MF=composite # Composite is automatically converted + initial_forward_transform=composite # Composite is automatically converted ) ``` @@ -163,21 +162,21 @@ When you pass an `initial_transform` to `ants.registration()`: ```python result = registrar.register( moving_image=moving_image, - initial_phi_MF=initial_transform + initial_forward_transform=initial_transform ) # The returned transforms INCLUDE the initial transform! -phi_MF = result["phi_MF"] # = initial_phi_MF ∘ registration_refinement -phi_FM = result["phi_FM"] # = registration_refinement^(-1) ∘ initial_phi_MF^(-1) +forward_transform = result["forward_transform"] +inverse_transform = result["inverse_transform"] ``` The composition is done as follows: -- **phi_MF**: `initial_phi_MF` → `registration_result` (applied in sequence) -- **phi_FM**: `registration_result_inverse` → `initial_phi_MF_inverse` (applied in sequence) +- **forward_transform**: `initial_forward_transform` → `registration_result` (applied in sequence) +- **inverse_transform**: `registration_result_inverse` → `initial_forward_transform_inverse` (applied in sequence) ### Coordinate Systems -- The initial transform should map points from **moving space to fixed space** (phi_MF) +- The initial forward transform maps points from **moving space to fixed space** - The displacement field represents the displacement at each voxel in the **fixed image space** - ANTs internally handles the transform composition with its optimization @@ -189,8 +188,8 @@ The composition is done as follows: ### Transform Direction -The parameter `initial_phi_MF` represents: -- **phi_MF**: Transform from Moving → Fixed space +The parameter `initial_forward_transform` represents: +- **forward_transform**: Transform from Moving → Fixed space - This aligns with ANTs' expected initial transform direction ### Interpolation @@ -231,7 +230,7 @@ result = ants.registration( The implementation handles several edge cases: -1. **Null transforms**: If `initial_phi_MF=None`, uses identity transform (default behavior) +1. **Null transforms**: If `initial_forward_transform=None`, uses identity transform (default behavior) 2. **List transforms**: If ITK returns a list with one transform, extracts the single transform 3. **Transform composition**: Composite transforms are flattened to a single displacement field 4. **Type checking**: Validates that inputs are proper ITK transforms @@ -256,11 +255,11 @@ registrar = RegisterImagesANTs() registrar.set_fixed_image(fixed_image) result = registrar.register( moving_image=moving_image, - initial_phi_MF=initial_affine # Pass initial alignment + initial_forward_transform=initial_affine # Pass initial alignment ) # Step 4: The returned transform includes BOTH the initial affine AND the deformation -phi_MF = result["phi_MF"] +forward_transform = result["forward_transform"] # phi_MF is a CompositeTransform containing: # 1. initial_affine (applied first) # 2. deformation from ANTs registration (applied second) @@ -307,11 +306,11 @@ warped = ants.apply_transforms( # ✅ Composition is automatic! result = registrar.register( moving_image=moving_image, - initial_phi_MF=initial_tfm + initial_forward_transform=initial_tfm ) # The returned transform already includes everything -phi_MF = result["phi_MF"] # Complete transform ready to use +forward_transform = result["forward_transform"] # Complete transform ready to use ``` ## Comparison with File-Based Approach @@ -342,11 +341,11 @@ os.remove("temp_transform.mat") # ✅ No file I/O, automatic composition! result = registrar.register( moving_image=moving_image, - initial_phi_MF=itk_tfm # ITK transform object + initial_forward_transform=itk_tfm # ITK transform object ) # Transform is complete and ready to use -phi_MF = result["phi_MF"] +forward_transform = result["forward_transform"] ``` ## References diff --git a/docs/api/registration.rst b/docs/api/registration.rst index 81a6163..2e411b7 100644 --- a/docs/api/registration.rst +++ b/docs/api/registration.rst @@ -73,15 +73,15 @@ ICON (Inverse Consistent Optimization Network) results = registerer.register(moving_image) # Get results - phi_FM = results["phi_FM"] # Forward deformation field - phi_MF = results["phi_MF"] # Inverse deformation field + forward_transform = results["forward_transform"] # Forward deformation field + inverse_transform = results["inverse_transform"] # Inverse deformation field registered_image = results["registered_image"] similarity = results["similarity_score"] **Output Dictionary:** - * ``phi_FM``: Forward deformation field (fixed → moving) - * ``phi_MF``: Inverse deformation field (moving → fixed) + * ``forward_transform``: Used to warp an image from moving to fixed space + * ``inverse_transform``: Used to warp an image from fixed to moving space * ``registered_image``: Moving image warped to fixed space * ``similarity_score``: Registration quality metric * ``inverse_consistency_error``: Inverse consistency metric @@ -189,9 +189,9 @@ Time Series Registration ) # Access results - phi_MF_list = results["phi_MF_list"] # List of transforms - phi_FM_list = results["phi_FM_list"] # List of inverse transforms - losses = results["losses"] # List of registration losses + forward_transforms_list = results["forward_transforms"] # List of transforms + inverse_transforms_list = results["inverse_transforms"] # List of inverse transforms + registration_losses = results["losses"] # List of registration losses **Use Cases:** @@ -218,7 +218,7 @@ These methods register statistical shape models or segmentation meshes to medica PCA-based Registration ---------------------- -.. autoclass:: physiomotion4d.RegisterModelToImagePCA +.. autoclass:: physiomotion4d.RegisterModelsPCA :members: :undoc-members: :show-inheritance: @@ -236,10 +236,10 @@ PCA-based Registration .. code-block:: python - from physiomotion4d import RegisterModelToImagePCA + from physiomotion4d import RegisterModelsPCA import itk - registerer = RegisterModelToImagePCA() + registerer = RegisterModelsPCA() # Load shape model and image shape_model = load_statistical_shape_model("heart_model.pkl") @@ -292,7 +292,7 @@ Model-to-Model Registration ICP Registration ---------------- -.. autoclass:: physiomotion4d.RegisterModelToModelICP +.. autoclass:: physiomotion4d.RegisterModelsICP :members: :undoc-members: :show-inheritance: @@ -310,25 +310,29 @@ ICP Registration .. code-block:: python - from physiomotion4d import RegisterModelToModelICP + from physiomotion4d import RegisterModelsICP import pyvista as pv - registerer = RegisterModelToModelICP() - # Load meshes fixed_mesh = pv.read("reference_mesh.vtp") moving_mesh = pv.read("target_mesh.vtp") - registerer.set_fixed_mesh(fixed_mesh) - registerer.set_max_iterations(100) + # Initialize registrar + registrar = RegisterModelsICP( + moving_mesh=moving_mesh, + fixed_mesh=fixed_mesh + ) # Register - aligned_mesh, transform = registerer.register(moving_mesh) + result = registrar.register(mode='rigid', max_iterations=100) + aligned_mesh = result['moving_mesh'] + forward_point_transform = result['forward_point_transform'] + inverse_point_transform = result['inverse_point_transform'] Mask-based Registration ----------------------- -.. autoclass:: physiomotion4d.RegisterModelToModelMasks +.. autoclass:: physiomotion4d.RegisterModelsDistanceMaps :members: :undoc-members: :show-inheritance: @@ -396,10 +400,10 @@ PhysioMotion4D provides utilities for evaluating registration quality: tre = compute_target_registration_error(fixed_landmarks, warped_landmarks) # Inverse consistency (bidirectional registration) - ice = compute_inverse_consistency_error(phi_FM, phi_MF) + ice = compute_inverse_consistency_error(inverse_transform, forward_transform) # Jacobian determinant (topology preservation) - jac_det = compute_jacobian_determinant(phi_FM) + jac_det = compute_jacobian_determinant(inverse_transform) folding_points = jac_det < 0 # Locations with folding Best Practices @@ -419,12 +423,12 @@ Choosing a Registration Method 3. **For initialization/coarse alignment**: - * Start with :class:`RegisterModelToModelICP` + * Start with :class:`RegisterModelsICP` * Then refine with image-based registration 4. **For shape model fitting**: - * Use :class:`RegisterModelToImagePCA` + * Use :class:`RegisterModelsPCA` * Especially when prior knowledge available Parameter Selection diff --git a/docs/api/utilities.rst b/docs/api/utilities.rst index 927a5c9..b36ecc0 100644 --- a/docs/api/utilities.rst +++ b/docs/api/utilities.rst @@ -481,7 +481,7 @@ Class-Based Filtering # Show logs only from specific classes PhysioMotion4DBase.set_log_classes([ - "RegisterModelToImagePCA", + "RegisterModelsPCA", "HeartModelToPatientWorkflow" ]) diff --git a/docs/conf.py b/docs/conf.py index 243dfa6..7b62473 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -5,10 +5,22 @@ import os import sys from datetime import datetime +from unittest.mock import MagicMock # Add the source directory to the path sys.path.insert(0, os.path.abspath('../src')) +# Create a more robust mock for complex packages +class Mock(MagicMock): + @classmethod + def __getattr__(cls, name): + return MagicMock() + +# Mock modules that need special handling +sys.modules['itk.TubeTK'] = Mock() +sys.modules['icon_registration.losses'] = Mock() +sys.modules['icon_registration.network_wrappers'] = Mock() + # -- Project information ----------------------------------------------------- project = 'PhysioMotion4D' copyright = f'{datetime.now().year}, Stephen R. Aylward, NVIDIA Corporation' @@ -55,7 +67,7 @@ html_theme_options = { 'logo_only': False, - 'display_version': True, + # 'display_version': True, # Deprecated in sphinx_rtd_theme >= 1.0 'prev_next_buttons_location': 'bottom', 'style_external_links': False, 'vcs_pageview_mode': '', @@ -109,6 +121,8 @@ } autodoc_typehints = 'description' autodoc_typehints_description_target = 'documented' +autodoc_inherit_docstrings = True +autodoc_warningiserror = False # Don't treat import warnings as errors # Intersphinx mapping intersphinx_mapping = { @@ -152,6 +166,23 @@ 'nibabel', 'pynrrd', 'transformers', + 'SimpleITK', + 'cupy', + 'cupyx', + 'pxr', + 'scipy', + 'matplotlib', + 'ants', + 'antspyx', + 'cv2', + 'skimage', + 'PIL', + 'Usd', + 'UsdGeom', + 'UsdShade', + 'Gf', + 'Vt', + 'Sdf', ] # Copybutton configuration @@ -160,8 +191,20 @@ # -- Custom setup ------------------------------------------------------------ +def autodoc_skip_member(app, what, name, obj, skip, options): + """Custom function to skip certain members during autodoc processing.""" + # Skip private methods unless explicitly documented + if name.startswith('_') and not name.startswith('__'): + return True + return skip + def setup(app): """Custom setup function for Sphinx.""" - # You can add custom setup here if needed - pass + # Connect the autodoc-skip-member event + app.connect('autodoc-skip-member', autodoc_skip_member) + + # Suppress specific warnings + import warnings + warnings.filterwarnings('ignore', category=DeprecationWarning) + warnings.filterwarnings('ignore', category=FutureWarning) diff --git a/docs/examples.rst b/docs/examples.rst index 395b48b..d2e0cec 100644 --- a/docs/examples.rst +++ b/docs/examples.rst @@ -149,13 +149,14 @@ Fast GPU-accelerated registration: results = registerer.register(moving) # Get results - phi_FM = results["phi_FM"] - phi_MF = results["phi_MF"] + inverse_transform = results["inverse_transform"] + forward_transform = results["forward_transform"] registered = results["registered_image"] # Save itk.imwrite(registered, "registered.mha") - itk.imwrite(phi_FM, "transform_forward.mha") + itk.transformwrite(forward_transform, "transform_forward.hdf") + itk.transformwrite(inverse_transform, "transform_inverse.hdf") Multi-Phase Cardiac Registration --------------------------------- @@ -182,7 +183,7 @@ Register all cardiac phases to reference: for frame_file in frame_files[1:]: # Skip reference moving = itk.imread(frame_file) results = registerer.register(moving) - transforms.append(results["phi_FM"]) + transforms.append(results["inverse_transform"]) print(f"Registered {frame_file}: similarity = {results['similarity_score']:.3f}") @@ -577,7 +578,7 @@ Mix and match different components: results = registerer.register(frame) warped_mesh = transform_tools.apply_transform_to_contour( reference_mesh, - results["phi_FM"] + results["inverse_transform"] ) meshes.append(warped_mesh) diff --git a/docs/index.rst b/docs/index.rst index 7fed975..4db898c 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -69,6 +69,7 @@ PhysioMotion4D is a comprehensive medical imaging package that converts 4D CT sc tutorials/vtk_to_usd tutorials/colormap_rendering tutorials/model_to_image_registration + ants_initial_transform_guide .. toctree:: :maxdepth: 3 @@ -87,6 +88,10 @@ PhysioMotion4D is a comprehensive medical imaging package that converts 4D CT sc architecture testing changelog + README + DOCUMENTATION_SETUP + LOGGING_API_REFERENCE + PYPI_RELEASE_GUIDE .. toctree:: :maxdepth: 1 diff --git a/docs/quickstart.rst b/docs/quickstart.rst index d327b27..ec6d8df 100644 --- a/docs/quickstart.rst +++ b/docs/quickstart.rst @@ -158,8 +158,8 @@ For standalone registration: results = registerer.register(moving_image) # Get transformation fields - phi_FM = results["phi_FM"] # Forward transform - phi_MF = results["phi_MF"] # Inverse transform + inverse_transform = results["inverse_transform"] # Fixed to moving space + forward_transform = results["forward_transform"] # Moving to fixed space VTK to USD Conversion --------------------- diff --git a/docs/tutorials/basic_workflow.rst b/docs/tutorials/basic_workflow.rst index d5bceb6..3b95e05 100644 --- a/docs/tutorials/basic_workflow.rst +++ b/docs/tutorials/basic_workflow.rst @@ -402,8 +402,8 @@ The workflow generates several intermediate files: │ ├── frame_001.mha │ └── ... ├── transforms/ # Registration transforms - │ ├── phi_FM_001.mha - │ ├── phi_MF_001.mha + │ ├── inverse_transform_001.mha + │ ├── forward_transform_001.mha │ └── ... ├── masks/ # Segmentation masks │ ├── heart_mask.nrrd diff --git a/docs/user_guide/logging.rst b/docs/user_guide/logging.rst index a1ea0d0..48d2f19 100644 --- a/docs/user_guide/logging.rst +++ b/docs/user_guide/logging.rst @@ -51,7 +51,7 @@ Example output: .. code-block:: text - 2025-12-13 11:24:49 - PhysioMotion4D - INFO - [RegisterModelToImagePCA] Processing started + 2025-12-13 11:24:49 - PhysioMotion4D - INFO - [RegisterModelsPCA] Processing started 2025-12-13 11:24:49 - PhysioMotion4D - DEBUG - [HeartModelToPatientWorkflow] Detailed debug info Multiple Log Levels @@ -79,12 +79,12 @@ Filter to show logs from only specific classes: from physiomotion4d import PhysioMotion4DBase - # Show only RegisterModelToImagePCA logs - PhysioMotion4DBase.set_log_classes(["RegisterModelToImagePCA"]) + # Show only RegisterModelsPCA logs + PhysioMotion4DBase.set_log_classes(["RegisterModelsPCA"]) # Show logs from multiple classes PhysioMotion4DBase.set_log_classes([ - "RegisterModelToImagePCA", + "RegisterModelsPCA", "HeartModelToPatientWorkflow" ]) @@ -225,13 +225,13 @@ Class Filtering for Selective Debugging from physiomotion4d import ( PhysioMotion4DBase, - RegisterModelToImagePCA, + RegisterModelsPCA, HeartModelToPatientWorkflow ) import logging # Create multiple objects - registrar1 = RegisterModelToImagePCA(..., log_level=logging.INFO) + registrar1 = RegisterModelsPCA(..., log_level=logging.INFO) registrar2 = HeartModelToPatientWorkflow(..., log_level=logging.INFO) # Show logs from all classes (default) @@ -239,14 +239,14 @@ Class Filtering for Selective Debugging registrar2.log_info("Message from Workflow") # Both messages are shown - # Filter to show only RegisterModelToImagePCA - PhysioMotion4DBase.set_log_classes(["RegisterModelToImagePCA"]) + # Filter to show only RegisterModelsPCA + PhysioMotion4DBase.set_log_classes(["RegisterModelsPCA"]) registrar1.log_info("This is shown") registrar2.log_info("This is hidden") # Show multiple specific classes PhysioMotion4DBase.set_log_classes([ - "RegisterModelToImagePCA", + "RegisterModelsPCA", "HeartModelToPatientWorkflow" ]) diff --git a/docs/user_guide/registration.rst b/docs/user_guide/registration.rst index b4df917..e45ca70 100644 --- a/docs/user_guide/registration.rst +++ b/docs/user_guide/registration.rst @@ -37,19 +37,19 @@ Model-to-Image/Model Registration For registering anatomical models to patient data: -* **ICP (RegisterModelToModelICP)**: Iterative Closest Point for surface alignment +* **ICP (RegisterModelsICP)**: Iterative Closest Point for surface alignment - Best for: Initial rough alignment, mesh registration - Speed: Very fast (<10 seconds) - Type: Rigid/affine registration -* **Mask-based (RegisterModelToModelMasks)**: Deformable registration using binary masks +* **Mask-based (RegisterModelsDistanceMaps)**: Deformable registration using binary masks - Best for: Model-to-patient fitting - Features: Dice coefficient optimization - Type: Deformable registration -* **PCA-based (RegisterModelToImagePCA)**: Statistical shape model registration +* **PCA-based (RegisterModelsPCA)**: Statistical shape model registration - Best for: Shape prior constraints - Features: Low-dimensional optimization @@ -87,8 +87,8 @@ Image Registration with ICON results = registerer.register(moving) # Get displacement fields - phi_FM = results["phi_FM"] # Fixed to moving - phi_MF = results["phi_MF"] # Moving to fixed + inverse_transform = results["inverse_transform"] # Fixed to moving + forward_transform = results["forward_transform"] # Moving to fixed Time Series Registration ------------------------- @@ -108,8 +108,8 @@ Time Series Registration results = reg.register_time_series(image_filenames=image_files) # Access transforms - transforms_FM = results["phi_FM_list"] - transforms_MF = results["phi_MF_list"] + transforms_inverse = results["inverse_transforms"] + transforms_forward = results["forward_transforms"] Model to Patient Registration ------------------------------ diff --git a/experiments/Colormap-VTK_To_USD/colormap_vtk_to_usd.ipynb b/experiments/Colormap-VTK_To_USD/colormap_vtk_to_usd.ipynb index c4539ac..f9ef180 100644 --- a/experiments/Colormap-VTK_To_USD/colormap_vtk_to_usd.ipynb +++ b/experiments/Colormap-VTK_To_USD/colormap_vtk_to_usd.ipynb @@ -1,445 +1,446 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Colormap Features for VTK to USD Conversion\n", - "\n", - "This notebook demonstrates the colormap features of `ConvertVTK4DToUSD` for visualizing point data arrays in NVIDIA Omniverse.\n", - "\n", - "## Features Demonstrated\n", - "\n", - "1. **Pre-defined colormaps**: plasma, viridis, rainbow, heat, coolwarm, grayscale, random\n", - "2. **Custom intensity ranges**: Control value-to-color mapping\n", - "3. **Point data visualization**: Map scalar data to colors on 3D meshes\n", - "4. **Time-varying data**: Create animated USD files with colored meshes\n", - "\n", - "## Requirements\n", - "\n", - "```bash\n", - "pip install physiomotion4d pyvista numpy\n", - "```" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Setup and Imports" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import numpy as np\n", - "import pyvista as pv\n", - "from pathlib import Path\n", - "\n", - "from physiomotion4d import ConvertVTK4DToUSD\n", - "\n", - "# Create output directory\n", - "output_dir = Path(\"./output\")\n", - "output_dir.mkdir(exist_ok=True)\n", - "\n", - "print(\"PhysioMotion4D Colormap Examples\")\n", - "print(\"=\" * 50)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Helper Function: Create Example Meshes\n", - "\n", - "This function creates sphere meshes with synthetic time-varying data to demonstrate colormap functionality." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def create_example_mesh_with_data(time_step):\n", - " \"\"\"\n", - " Create a sphere mesh with synthetic data for demonstration.\n", - " \n", - " Parameters\n", - " ----------\n", - " time_step : int\n", - " Current time step for animation\n", - " \n", - " Returns\n", - " -------\n", - " pyvista.PolyData\n", - " Sphere mesh with point data arrays\n", - " \"\"\"\n", - " # Create a sphere\n", - " sphere = pv.Sphere(radius=1.0, theta_resolution=30, phi_resolution=30)\n", - " \n", - " # Add synthetic point data (e.g., simulating transmembrane potential)\n", - " points = sphere.points\n", - " z_coords = points[:, 2]\n", - " theta = 2 * np.pi * time_step / 10.0 # Full cycle every 10 frames\n", - " \n", - " # Simulate transmembrane potential: -80 mV (rest) to +20 mV (depolarized)\n", - " potential = -80.0 + 100.0 * (0.5 + 0.5 * np.sin(3 * z_coords + theta))\n", - " sphere.point_data['transmembrane_potential'] = potential\n", - " \n", - " # Add temperature data example\n", - " temperature = 20.0 + 15.0 * (0.5 + 0.5 * np.cos(2 * z_coords - theta))\n", - " sphere.point_data['temperature'] = temperature\n", - " \n", - " return sphere\n", - "\n", - "print(\"Helper function defined successfully\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Example 1: Default Plasma Colormap with Automatic Range\n", - "\n", - "The plasma colormap is the default and provides a perceptually uniform gradient from purple to pink to orange." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "print(\"\\nExample 1: Plasma colormap (default) with auto range\")\n", - "print(\"-\" * 50)\n", - "\n", - "# Create time series of meshes\n", - "meshes = [create_example_mesh_with_data(t) for t in range(20)]\n", - "\n", - "# Initialize converter\n", - "converter = ConvertVTK4DToUSD(\n", - " data_basename=\"CardiacModel\",\n", - " input_polydata=meshes,\n", - " mask_ids=None\n", - ")\n", - "\n", - "# List available arrays for coloring\n", - "print(\"Available point data arrays:\")\n", - "available = converter.list_available_arrays()\n", - "for name, info in available.items():\n", - " print(f\" - {name}: range={info['range']}, dtype={info['dtype']}\")\n", - "\n", - "# Set colormap (automatic range)\n", - "converter.set_colormap(\n", - " color_by_array=\"transmembrane_potential\",\n", - " colormap=\"plasma\",\n", - " intensity_range=None # Auto-detect from data\n", - ")\n", - "\n", - "# Convert to USD\n", - "output_file = output_dir / \"example1_plasma_auto.usd\"\n", - "stage = converter.convert(str(output_file))\n", - "print(f\"\\n✓ Created: {output_file}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Example 2: Rainbow Colormap with Custom Range\n", - "\n", - "The rainbow colormap provides a classic ROYGBIV spectrum. Here we specify a custom physiological range for cardiac action potentials." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "print(\"\\nExample 2: Rainbow colormap with custom range [-80, 20] mV\")\n", - "print(\"-\" * 50)\n", - "\n", - "meshes = [create_example_mesh_with_data(t) for t in range(20)]\n", - "\n", - "converter = ConvertVTK4DToUSD(\n", - " data_basename=\"CardiacModel\",\n", - " input_polydata=meshes,\n", - " mask_ids=None\n", - ")\n", - "\n", - "converter.set_colormap(\n", - " color_by_array=\"transmembrane_potential\",\n", - " colormap=\"rainbow\",\n", - " intensity_range=(-80.0, 20.0) # Physiological range for action potential\n", - ")\n", - "\n", - "output_file = output_dir / \"example2_rainbow_custom.usd\"\n", - "stage = converter.convert(str(output_file))\n", - "print(f\"✓ Created: {output_file}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Example 3: Heat Colormap for Temperature Data\n", - "\n", - "The heat colormap (black-red-yellow-white) is ideal for temperature or intensity visualizations." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "print(\"\\nExample 3: Heat colormap for temperature visualization\")\n", - "print(\"-\" * 50)\n", - "\n", - "meshes = [create_example_mesh_with_data(t) for t in range(20)]\n", - "\n", - "converter = ConvertVTK4DToUSD(\n", - " data_basename=\"TemperatureModel\",\n", - " input_polydata=meshes,\n", - " mask_ids=None\n", - ")\n", - "\n", - "converter.set_colormap(\n", - " color_by_array=\"temperature\",\n", - " colormap=\"heat\",\n", - " intensity_range=(15.0, 40.0) # Temperature range in Celsius\n", - ")\n", - "\n", - "output_file = output_dir / \"example3_heat_temperature.usd\"\n", - "stage = converter.convert(str(output_file))\n", - "print(f\"✓ Created: {output_file}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Example 4: Coolwarm (Diverging) Colormap\n", - "\n", - "The coolwarm colormap is a diverging colormap (blue-white-red) useful for data centered around a midpoint." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "print(\"\\nExample 4: Coolwarm colormap for diverging data\")\n", - "print(\"-\" * 50)\n", - "\n", - "meshes = [create_example_mesh_with_data(t) for t in range(20)]\n", - "\n", - "converter = ConvertVTK4DToUSD(\n", - " data_basename=\"CardiacModel\",\n", - " input_polydata=meshes,\n", - " mask_ids=None\n", - ")\n", - "\n", - "converter.set_colormap(\n", - " color_by_array=\"transmembrane_potential\",\n", - " colormap=\"coolwarm\",\n", - " intensity_range=(-80.0, 20.0)\n", - ")\n", - "\n", - "output_file = output_dir / \"example4_coolwarm_diverging.usd\"\n", - "stage = converter.convert(str(output_file))\n", - "print(f\"✓ Created: {output_file}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Example 5: Grayscale Colormap\n", - "\n", - "The grayscale colormap provides a simple black-to-white gradient for monochrome visualizations." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "print(\"\\nExample 5: Grayscale colormap\")\n", - "print(\"-\" * 50)\n", - "\n", - "meshes = [create_example_mesh_with_data(t) for t in range(20)]\n", - "\n", - "converter = ConvertVTK4DToUSD(\n", - " data_basename=\"CardiacModel\",\n", - " input_polydata=meshes,\n", - " mask_ids=None\n", - ")\n", - "\n", - "converter.set_colormap(\n", - " color_by_array=\"transmembrane_potential\",\n", - " colormap=\"grayscale\",\n", - " intensity_range=None\n", - ")\n", - "\n", - "output_file = output_dir / \"example5_grayscale.usd\"\n", - "stage = converter.convert(str(output_file))\n", - "print(f\"✓ Created: {output_file}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Example 6: Random Colormap for Categorical Data\n", - "\n", - "The random colormap assigns random colors to different values, making it useful for visualizing categorical or region-based data." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "print(\"\\nExample 6: Random colormap for categorical visualization\")\n", - "print(\"-\" * 50)\n", - "\n", - "meshes = [create_example_mesh_with_data(t) for t in range(20)]\n", - "\n", - "# Add categorical-like data (discrete regions)\n", - "for mesh in meshes:\n", - " z_values = mesh.points[:, 2]\n", - " regions = np.floor(3 * (z_values + 1) / 2) # 3 regions\n", - " mesh.point_data['region_id'] = regions\n", - "\n", - "converter = ConvertVTK4DToUSD(\n", - " data_basename=\"RegionModel\",\n", - " input_polydata=meshes,\n", - " mask_ids=None\n", - ")\n", - "\n", - "converter.set_colormap(\n", - " color_by_array=\"region_id\",\n", - " colormap=\"random\",\n", - " intensity_range=None\n", - ")\n", - "\n", - "output_file = output_dir / \"example6_random_categorical.usd\"\n", - "stage = converter.convert(str(output_file))\n", - "print(f\"✓ Created: {output_file}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Example 7: Method Chaining for Concise API Usage\n", - "\n", - "The `set_colormap()` method supports chaining, allowing for more concise code." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "print(\"\\nExample 7: Method chaining with viridis colormap\")\n", - "print(\"-\" * 50)\n", - "\n", - "meshes = [create_example_mesh_with_data(t) for t in range(20)]\n", - "\n", - "output_file = output_dir / \"example7_viridis_chained.usd\"\n", - "\n", - "# Method chaining for concise code\n", - "stage = (\n", - " ConvertVTK4DToUSD(\n", - " data_basename=\"CardiacModel\",\n", - " input_polydata=meshes,\n", - " mask_ids=None\n", - " )\n", - " .set_colormap(\n", - " color_by_array=\"transmembrane_potential\",\n", - " colormap=\"viridis\",\n", - " intensity_range=(-80.0, 20.0)\n", - " )\n", - " .convert(str(output_file))\n", - ")\n", - "\n", - "print(f\"✓ Created: {output_file}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Summary: Available Colormaps and Features\n", - "\n", - "### Colormap Options\n", - "\n", - "| Colormap | Description | Best For |\n", - "|----------|-------------|----------|\n", - "| `plasma` | Purple-pink-orange gradient (default) | General purpose, perceptually uniform |\n", - "| `viridis` | Blue-green-yellow gradient | General purpose, colorblind-friendly |\n", - "| `rainbow` | Classic rainbow spectrum (ROYGBIV) | Full range visualization |\n", - "| `heat` | Black-red-yellow-white | Temperature, intensity data |\n", - "| `coolwarm` | Blue-white-red (diverging) | Data centered around zero/midpoint |\n", - "| `grayscale` | Black to white linear | Monochrome, publication figures |\n", - "| `random` | Random colors per value | Categorical/discrete data |\n", - "\n", - "### Intensity Range Options\n", - "\n", - "- **`None`**: Automatic range from data min/max\n", - "- **`(vmin, vmax)`**: Custom range tuple, e.g., `(-80.0, 20.0)`\n", - "\n", - "### Key API Methods\n", - "\n", - "- **`list_available_arrays()`**: List all point data arrays available for coloring\n", - "- **`set_colormap()`**: Configure colormap settings (supports method chaining)\n", - "- **`convert(output_file)`**: Perform USD conversion with specified output path" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "print(\"\\n\" + \"=\" * 50)\n", - "print(\"All examples completed!\")\n", - "print(\"=\" * 50)\n", - "print(f\"\\nOutput files created in: {output_dir.absolute()}\")\n", - "print(\"\\nView these USD files in NVIDIA Omniverse to see the colormap visualizations.\")" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "venv", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.11" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Colormap Features for VTK to USD Conversion\n", + "\n", + "This notebook demonstrates the colormap features of `ConvertVTK4DToUSD` for visualizing point data arrays in NVIDIA Omniverse.\n", + "\n", + "## Features Demonstrated\n", + "\n", + "1. **Pre-defined colormaps**: plasma, viridis, rainbow, heat, coolwarm, grayscale, random\n", + "2. **Custom intensity ranges**: Control value-to-color mapping\n", + "3. **Point data visualization**: Map scalar data to colors on 3D meshes\n", + "4. **Time-varying data**: Create animated USD files with colored meshes\n", + "\n", + "## Requirements\n", + "\n", + "```bash\n", + "pip install physiomotion4d pyvista numpy\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setup and Imports" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from pathlib import Path\n", + "\n", + "import numpy as np\n", + "import pyvista as pv\n", + "\n", + "from physiomotion4d import ConvertVTK4DToUSD\n", + "\n", + "# Create output directory\n", + "output_dir = Path(\"./output\")\n", + "output_dir.mkdir(exist_ok=True)\n", + "\n", + "print(\"PhysioMotion4D Colormap Examples\")\n", + "print(\"=\" * 50)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Helper Function: Create Example Meshes\n", + "\n", + "This function creates sphere meshes with synthetic time-varying data to demonstrate colormap functionality." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def create_example_mesh_with_data(time_step):\n", + " \"\"\"\n", + " Create a sphere mesh with synthetic data for demonstration.\n", + "\n", + " Parameters\n", + " ----------\n", + " time_step : int\n", + " Current time step for animation\n", + "\n", + " Returns\n", + " -------\n", + " pyvista.PolyData\n", + " Sphere mesh with point data arrays\n", + " \"\"\"\n", + " # Create a sphere\n", + " sphere = pv.Sphere(radius=1.0, theta_resolution=30, phi_resolution=30)\n", + "\n", + " # Add synthetic point data (e.g., simulating transmembrane potential)\n", + " points = sphere.points\n", + " z_coords = points[:, 2]\n", + " theta = 2 * np.pi * time_step / 10.0 # Full cycle every 10 frames\n", + "\n", + " # Simulate transmembrane potential: -80 mV (rest) to +20 mV (depolarized)\n", + " potential = -80.0 + 100.0 * (0.5 + 0.5 * np.sin(3 * z_coords + theta))\n", + " sphere.point_data['transmembrane_potential'] = potential\n", + "\n", + " # Add temperature data example\n", + " temperature = 20.0 + 15.0 * (0.5 + 0.5 * np.cos(2 * z_coords - theta))\n", + " sphere.point_data['temperature'] = temperature\n", + "\n", + " return sphere\n", + "\n", + "print(\"Helper function defined successfully\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Example 1: Default Plasma Colormap with Automatic Range\n", + "\n", + "The plasma colormap is the default and provides a perceptually uniform gradient from purple to pink to orange." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(\"\\nExample 1: Plasma colormap (default) with auto range\")\n", + "print(\"-\" * 50)\n", + "\n", + "# Create time series of meshes\n", + "meshes = [create_example_mesh_with_data(t) for t in range(20)]\n", + "\n", + "# Initialize converter\n", + "converter = ConvertVTK4DToUSD(\n", + " data_basename=\"CardiacModel\",\n", + " input_polydata=meshes,\n", + " mask_ids=None\n", + ")\n", + "\n", + "# List available arrays for coloring\n", + "print(\"Available point data arrays:\")\n", + "available = converter.list_available_arrays()\n", + "for name, info in available.items():\n", + " print(f\" - {name}: range={info['range']}, dtype={info['dtype']}\")\n", + "\n", + "# Set colormap (automatic range)\n", + "converter.set_colormap(\n", + " color_by_array=\"transmembrane_potential\",\n", + " colormap=\"plasma\",\n", + " intensity_range=None # Auto-detect from data\n", + ")\n", + "\n", + "# Convert to USD\n", + "output_file = output_dir / \"example1_plasma_auto.usd\"\n", + "stage = converter.convert(str(output_file))\n", + "print(f\"\\n✓ Created: {output_file}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Example 2: Rainbow Colormap with Custom Range\n", + "\n", + "The rainbow colormap provides a classic ROYGBIV spectrum. Here we specify a custom physiological range for cardiac action potentials." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(\"\\nExample 2: Rainbow colormap with custom range [-80, 20] mV\")\n", + "print(\"-\" * 50)\n", + "\n", + "meshes = [create_example_mesh_with_data(t) for t in range(20)]\n", + "\n", + "converter = ConvertVTK4DToUSD(\n", + " data_basename=\"CardiacModel\",\n", + " input_polydata=meshes,\n", + " mask_ids=None\n", + ")\n", + "\n", + "converter.set_colormap(\n", + " color_by_array=\"transmembrane_potential\",\n", + " colormap=\"rainbow\",\n", + " intensity_range=(-80.0, 20.0) # Physiological range for action potential\n", + ")\n", + "\n", + "output_file = output_dir / \"example2_rainbow_custom.usd\"\n", + "stage = converter.convert(str(output_file))\n", + "print(f\"✓ Created: {output_file}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Example 3: Heat Colormap for Temperature Data\n", + "\n", + "The heat colormap (black-red-yellow-white) is ideal for temperature or intensity visualizations." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(\"\\nExample 3: Heat colormap for temperature visualization\")\n", + "print(\"-\" * 50)\n", + "\n", + "meshes = [create_example_mesh_with_data(t) for t in range(20)]\n", + "\n", + "converter = ConvertVTK4DToUSD(\n", + " data_basename=\"TemperatureModel\",\n", + " input_polydata=meshes,\n", + " mask_ids=None\n", + ")\n", + "\n", + "converter.set_colormap(\n", + " color_by_array=\"temperature\",\n", + " colormap=\"heat\",\n", + " intensity_range=(15.0, 40.0) # Temperature range in Celsius\n", + ")\n", + "\n", + "output_file = output_dir / \"example3_heat_temperature.usd\"\n", + "stage = converter.convert(str(output_file))\n", + "print(f\"✓ Created: {output_file}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Example 4: Coolwarm (Diverging) Colormap\n", + "\n", + "The coolwarm colormap is a diverging colormap (blue-white-red) useful for data centered around a midpoint." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(\"\\nExample 4: Coolwarm colormap for diverging data\")\n", + "print(\"-\" * 50)\n", + "\n", + "meshes = [create_example_mesh_with_data(t) for t in range(20)]\n", + "\n", + "converter = ConvertVTK4DToUSD(\n", + " data_basename=\"CardiacModel\",\n", + " input_polydata=meshes,\n", + " mask_ids=None\n", + ")\n", + "\n", + "converter.set_colormap(\n", + " color_by_array=\"transmembrane_potential\",\n", + " colormap=\"coolwarm\",\n", + " intensity_range=(-80.0, 20.0)\n", + ")\n", + "\n", + "output_file = output_dir / \"example4_coolwarm_diverging.usd\"\n", + "stage = converter.convert(str(output_file))\n", + "print(f\"✓ Created: {output_file}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Example 5: Grayscale Colormap\n", + "\n", + "The grayscale colormap provides a simple black-to-white gradient for monochrome visualizations." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(\"\\nExample 5: Grayscale colormap\")\n", + "print(\"-\" * 50)\n", + "\n", + "meshes = [create_example_mesh_with_data(t) for t in range(20)]\n", + "\n", + "converter = ConvertVTK4DToUSD(\n", + " data_basename=\"CardiacModel\",\n", + " input_polydata=meshes,\n", + " mask_ids=None\n", + ")\n", + "\n", + "converter.set_colormap(\n", + " color_by_array=\"transmembrane_potential\",\n", + " colormap=\"grayscale\",\n", + " intensity_range=None\n", + ")\n", + "\n", + "output_file = output_dir / \"example5_grayscale.usd\"\n", + "stage = converter.convert(str(output_file))\n", + "print(f\"✓ Created: {output_file}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Example 6: Random Colormap for Categorical Data\n", + "\n", + "The random colormap assigns random colors to different values, making it useful for visualizing categorical or region-based data." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(\"\\nExample 6: Random colormap for categorical visualization\")\n", + "print(\"-\" * 50)\n", + "\n", + "meshes = [create_example_mesh_with_data(t) for t in range(20)]\n", + "\n", + "# Add categorical-like data (discrete regions)\n", + "for mesh in meshes:\n", + " z_values = mesh.points[:, 2]\n", + " regions = np.floor(3 * (z_values + 1) / 2) # 3 regions\n", + " mesh.point_data['region_id'] = regions\n", + "\n", + "converter = ConvertVTK4DToUSD(\n", + " data_basename=\"RegionModel\",\n", + " input_polydata=meshes,\n", + " mask_ids=None\n", + ")\n", + "\n", + "converter.set_colormap(\n", + " color_by_array=\"region_id\",\n", + " colormap=\"random\",\n", + " intensity_range=None\n", + ")\n", + "\n", + "output_file = output_dir / \"example6_random_categorical.usd\"\n", + "stage = converter.convert(str(output_file))\n", + "print(f\"✓ Created: {output_file}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Example 7: Method Chaining for Concise API Usage\n", + "\n", + "The `set_colormap()` method supports chaining, allowing for more concise code." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(\"\\nExample 7: Method chaining with viridis colormap\")\n", + "print(\"-\" * 50)\n", + "\n", + "meshes = [create_example_mesh_with_data(t) for t in range(20)]\n", + "\n", + "output_file = output_dir / \"example7_viridis_chained.usd\"\n", + "\n", + "# Method chaining for concise code\n", + "stage = (\n", + " ConvertVTK4DToUSD(\n", + " data_basename=\"CardiacModel\",\n", + " input_polydata=meshes,\n", + " mask_ids=None\n", + " )\n", + " .set_colormap(\n", + " color_by_array=\"transmembrane_potential\",\n", + " colormap=\"viridis\",\n", + " intensity_range=(-80.0, 20.0)\n", + " )\n", + " .convert(str(output_file))\n", + ")\n", + "\n", + "print(f\"✓ Created: {output_file}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Summary: Available Colormaps and Features\n", + "\n", + "### Colormap Options\n", + "\n", + "| Colormap | Description | Best For |\n", + "|----------|-------------|----------|\n", + "| `plasma` | Purple-pink-orange gradient (default) | General purpose, perceptually uniform |\n", + "| `viridis` | Blue-green-yellow gradient | General purpose, colorblind-friendly |\n", + "| `rainbow` | Classic rainbow spectrum (ROYGBIV) | Full range visualization |\n", + "| `heat` | Black-red-yellow-white | Temperature, intensity data |\n", + "| `coolwarm` | Blue-white-red (diverging) | Data centered around zero/midpoint |\n", + "| `grayscale` | Black to white linear | Monochrome, publication figures |\n", + "| `random` | Random colors per value | Categorical/discrete data |\n", + "\n", + "### Intensity Range Options\n", + "\n", + "- **`None`**: Automatic range from data min/max\n", + "- **`(vmin, vmax)`**: Custom range tuple, e.g., `(-80.0, 20.0)`\n", + "\n", + "### Key API Methods\n", + "\n", + "- **`list_available_arrays()`**: List all point data arrays available for coloring\n", + "- **`set_colormap()`**: Configure colormap settings (supports method chaining)\n", + "- **`convert(output_file)`**: Perform USD conversion with specified output path" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(\"\\n\" + \"=\" * 50)\n", + "print(\"All examples completed!\")\n", + "print(\"=\" * 50)\n", + "print(f\"\\nOutput files created in: {output_dir.absolute()}\")\n", + "print(\"\\nView these USD files in NVIDIA Omniverse to see the colormap visualizations.\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.11" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} \ No newline at end of file diff --git a/experiments/DisplacementField_To_USD/displacement_field_to_usd.ipynb b/experiments/DisplacementField_To_USD/displacement_field_to_usd.ipynb index a531c22..c21f4fa 100644 --- a/experiments/DisplacementField_To_USD/displacement_field_to_usd.ipynb +++ b/experiments/DisplacementField_To_USD/displacement_field_to_usd.ipynb @@ -1,325 +1,334 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Displacement Field to USD for Omniverse Visualization\n", - "\n", - "This notebook demonstrates how to convert time-varying 3D displacement fields (stored as NIfTI images) into USD format for visualization in NVIDIA Omniverse using the PhysicsNeMo extension.\n", - "\n", - "## Pipeline Overview\n", - "\n", - "1. Load 3D vector fields from NIfTI files using ITK\n", - "2. Convert ITK images to VTK data structures\n", - "3. Create PhysicsNeMo-compatible USD stages for time-varying visualization\n", - "4. Export animated USD stage for Omniverse Create/Kit\n", - "\n", - "## Architecture\n", - "\n", - "The `DisplacementFieldToUSD` class encapsulates all pipeline logic for converting medical imaging displacement fields to Omniverse-compatible USD format.\n", - "\n", - "## Required Libraries\n", - "\n", - "- ITK (InsightToolkit) for medical image I/O\n", - "- VTK (Visualization Toolkit) for data structure conversion \n", - "- numpy for array processing\n", - "- PhysicsNeMo and PhysicsNeMo-Sym for physics-based visualization\n", - "- Omniverse USD Python API for stage creation\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 1. Install Dependencies\n", - "\n", - "Install the latest versions of PhysicsNeMo and PhysicsNeMo-Sym from GitHub.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Install PhysicsNeMo and PhysicsNeMo-Sym from GitHub\n", - "# Uncomment these lines if running for the first time\n", - "\n", - "# !pip install git+https://github.com/NVIDIA/physicsnemo.git\n", - "# !pip install git+https://github.com/NVIDIA/physicsnemo-sym.git\n", - "\n", - "# Install other required packages\n", - "# !pip install itk vtk numpy pxr\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 2. Import Libraries\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import os\n", - "from pathlib import Path\n", - "from typing import List, Optional, Tuple\n", - "\n", - "import itk\n", - "import numpy as np\n", - "import vtk\n", - "from pxr import Gf, Usd, UsdGeom, Vt\n", - "\n", - "print(\"Libraries imported successfully!\")\n", - "print(f\"ITK version: {itk.Version.GetITKVersion()}\")\n", - "print(f\"VTK version: {vtk.vtkVersion.GetVTKVersion()}\")\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 3. Import DisplacementFieldToUSD Class\n", - "\n", - "Import the converter class from the local module.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Import the DisplacementFieldToUSD class from the local module\n", - "from displacement_field_converter import DisplacementFieldToUSD\n", - "\n", - "# Display class documentation\n", - "help(DisplacementFieldToUSD)\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 4. Helper: Generate Sample Displacement Fields\n", - "\n", - "For demonstration purposes, this function creates synthetic time-varying displacement fields.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def generate_sample_displacement_fields(\n", - " output_dir: str,\n", - " n_timesteps: int = 10,\n", - " size: Tuple[int, int, int] = (32, 32, 32)\n", - ") -> List[str]:\n", - " \"\"\"\n", - " Generate synthetic time-varying displacement fields for demonstration.\n", - " \n", - " Creates a rotating/pulsating vector field pattern.\n", - " \n", - " Args:\n", - " output_dir: Directory to save NIfTI files\n", - " n_timesteps: Number of time steps to generate\n", - " size: 3D size of the displacement field\n", - " \n", - " Returns:\n", - " List of file paths to generated NIfTI files\n", - " \"\"\"\n", - " os.makedirs(output_dir, exist_ok=True)\n", - " file_paths = []\n", - " \n", - " # Create coordinate grid\n", - " z, y, x = np.meshgrid(\n", - " np.linspace(-1, 1, size[2]),\n", - " np.linspace(-1, 1, size[1]),\n", - " np.linspace(-1, 1, size[0]),\n", - " indexing='ij'\n", - " )\n", - " \n", - " for t in range(n_timesteps):\n", - " # Time-varying rotation angle\n", - " theta = 2 * np.pi * t / n_timesteps\n", - " \n", - " # Create rotating vector field with radial component\n", - " r = np.sqrt(x**2 + y**2 + z**2)\n", - " \n", - " # Displacement components (rotating + pulsating)\n", - " displacement_x = -y * np.cos(theta) + z * np.sin(theta)\n", - " displacement_y = x * np.cos(theta) - r * 0.2 * np.sin(theta)\n", - " displacement_z = -x * np.sin(theta) + y * np.cos(theta)\n", - " \n", - " # Scale by distance from center (creates flow pattern)\n", - " scale_factor = 5.0 * (1 - r / np.max(r))\n", - " displacement_x *= scale_factor\n", - " displacement_y *= scale_factor\n", - " displacement_z *= scale_factor\n", - " \n", - " # Stack into vector field (z, y, x, 3)\n", - " displacement_field = np.stack(\n", - " [displacement_x, displacement_y, displacement_z],\n", - " axis=-1\n", - " ).astype(np.float32)\n", - " \n", - " # Convert to ITK image\n", - " itk_image = itk.image_from_array(displacement_field, is_vector=True)\n", - " itk_image.SetSpacing([1.0, 1.0, 1.0])\n", - " itk_image.SetOrigin([0.0, 0.0, 0.0])\n", - " \n", - " # Save as NIfTI\n", - " file_path = os.path.join(output_dir, f\"displacement_t{t:03d}.nii.gz\")\n", - " itk.imwrite(itk_image, file_path, compression=True)\n", - " file_paths.append(file_path)\n", - " \n", - " print(f\"Generated: {file_path}\")\n", - " \n", - " return file_paths\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 5. Example Usage\n", - "\n", - "Demonstrate the complete pipeline with synthetic data.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Configuration\n", - "output_dir = \"./sample_displacement_fields\"\n", - "usd_output_path = \"./displacement_field_animation.usd\"\n", - "n_timesteps = 10\n", - "\n", - "# Generate sample data\n", - "print(\"Generating sample displacement fields...\")\n", - "nifti_files = generate_sample_displacement_fields(\n", - " output_dir,\n", - " n_timesteps=n_timesteps,\n", - " size=(32, 32, 32)\n", - ")\n", - "\n", - "print(f\"\\\\nGenerated {len(nifti_files)} sample files\")\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Create converter instance\n", - "converter = DisplacementFieldToUSD(\n", - " subsample_factor=4,\n", - " vector_scale=2.0\n", - ")\n", - "\n", - "# Run complete pipeline\n", - "stage = converter.process_pipeline(\n", - " nifti_files=nifti_files,\n", - " output_path=usd_output_path,\n", - " fps=24.0\n", - ")\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Example with real displacement field data\n", - "# Uncomment and modify paths to use your own data:\n", - "\n", - "# nifti_files = [\n", - "# \"path/to/displacement_t00.nii.gz\",\n", - "# \"path/to/displacement_t01.nii.gz\",\n", - "# \"path/to/displacement_t02.nii.gz\",\n", - "# # ... more files\n", - "# ]\n", - "#\n", - "# converter = DisplacementFieldToUSD(subsample_factor=8, vector_scale=5.0)\n", - "# stage = converter.process_pipeline(\n", - "# nifti_files=nifti_files,\n", - "# output_path=\"cardiac_motion.usd\",\n", - "# fps=10.0\n", - "# )\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 7. Visualization in Omniverse\n", - "\n", - "### Steps to Visualize:\n", - "\n", - "1. **Open Omniverse Create or Kit**\n", - " - Launch NVIDIA Omniverse Create\n", - " - File → Open and select the generated USD file\n", - "\n", - "2. **Enable PhysicsNeMo Extension**\n", - " - Window → Extensions\n", - " - Search for \"PhysicsNeMo\"\n", - " - Enable the extension\n", - "\n", - "3. **Visualize the Vector Field**\n", - " - Select `/DisplacementField/VectorField` prim in the stage\n", - " - In PhysicsNeMo panel, choose visualization mode:\n", - " - **Streamlines**: Flow trajectories\n", - " - **Vector Glyphs**: Direction and magnitude at points\n", - " - **Volume Rendering**: Field magnitude as volume\n", - " - **Flow Particles**: Animated particles\n", - "\n", - "4. **Play Animation**\n", - " - Use timeline controls to play through time steps\n", - " - Adjust playback speed as needed\n", - "\n", - "### Class Methods Summary:\n", - "\n", - "- `load_nifti_files()`: Load displacement fields from NIfTI\n", - "- `convert_to_vtk()`: Convert ITK images to VTK format\n", - "- `extract_all_vector_fields()`: Extract subsampled data\n", - "- `create_usd_stage()`: Create time-varying USD stage\n", - "- `process_pipeline()`: Run complete pipeline\n", - "\n", - "### Key Features:\n", - "\n", - "✅ Class-based architecture for clean, reusable code\n", - "✅ Complete ITK → VTK → USD pipeline\n", - "✅ Time-varying animation support \n", - "✅ PhysicsNeMo-compatible velocities attribute\n", - "✅ Configurable subsampling and vector scaling\n", - "✅ Production-ready for medical imaging workflows\n", - "\n", - "This implementation encapsulates all logic in the `DisplacementFieldToUSD` class, making it easy to integrate into larger pipelines or customize for specific use cases.\n" - ] - } - ], - "metadata": { - "language_info": { - "name": "python" - } + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Displacement Field to USD for Omniverse Visualization\n", + "\n", + "This notebook demonstrates how to convert time-varying 3D displacement fields (stored as NIfTI images) into USD format for visualization in NVIDIA Omniverse using the PhysicsNeMo extension.\n", + "\n", + "## Pipeline Overview\n", + "\n", + "1. Load 3D vector fields from NIfTI files using ITK\n", + "2. Convert ITK images to VTK data structures\n", + "3. Create PhysicsNeMo-compatible USD stages for time-varying visualization\n", + "4. Export animated USD stage for Omniverse Create/Kit\n", + "\n", + "## Architecture\n", + "\n", + "The `DisplacementFieldToUSD` class encapsulates all pipeline logic for converting medical imaging displacement fields to Omniverse-compatible USD format.\n", + "\n", + "## Required Libraries\n", + "\n", + "- ITK (InsightToolkit) for medical image I/O\n", + "- VTK (Visualization Toolkit) for data structure conversion \n", + "- numpy for array processing\n", + "- PhysicsNeMo and PhysicsNeMo-Sym for physics-based visualization\n", + "- Omniverse USD Python API for stage creation\n" + ] }, - "nbformat": 4, - "nbformat_minor": 2 + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1. Install Dependencies\n", + "\n", + "Install the latest versions of PhysicsNeMo and PhysicsNeMo-Sym from GitHub.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Install PhysicsNeMo and PhysicsNeMo-Sym from GitHub\n", + "# Uncomment these lines if running for the first time\n", + "\n", + "# !pip install git+https://github.com/NVIDIA/physicsnemo.git\n", + "# !pip install git+https://github.com/NVIDIA/physicsnemo-sym.git\n", + "\n", + "# Install other required packages\n", + "# !pip install itk vtk numpy pxr\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2. Import Libraries\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "from pathlib import Path\n", + "from typing import List, Optional, Tuple\n", + "\n", + "import itk\n", + "import numpy as np\n", + "import vtk\n", + "from pxr import Gf, Usd, UsdGeom, Vt\n", + "\n", + "print(\"Libraries imported successfully!\")\n", + "print(f\"ITK version: {itk.Version.GetITKVersion()}\")\n", + "print(f\"VTK version: {vtk.vtkVersion.GetVTKVersion()}\")\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3. Import DisplacementFieldToUSD Class\n", + "\n", + "Import the converter class from the local module.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Import the DisplacementFieldToUSD class from the local module\n", + "from displacement_field_converter import DisplacementFieldToUSD\n", + "\n", + "# Display class documentation\n", + "help(DisplacementFieldToUSD)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 4. Helper: Generate Sample Displacement Fields\n", + "\n", + "For demonstration purposes, this function creates synthetic time-varying displacement fields.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def generate_sample_displacement_fields(\n", + " output_dir: str,\n", + " n_timesteps: int = 10,\n", + " size: Tuple[int, int, int] = (32, 32, 32)\n", + ") -> List[str]:\n", + " \"\"\"\n", + " Generate synthetic time-varying displacement fields for demonstration.\n", + "\n", + " Creates a rotating/pulsating vector field pattern.\n", + "\n", + " Args:\n", + " output_dir: Directory to save NIfTI files\n", + " n_timesteps: Number of time steps to generate\n", + " size: 3D size of the displacement field\n", + "\n", + " Returns:\n", + " List of file paths to generated NIfTI files\n", + " \"\"\"\n", + " os.makedirs(output_dir, exist_ok=True)\n", + " file_paths = []\n", + "\n", + " # Create coordinate grid\n", + " z, y, x = np.meshgrid(\n", + " np.linspace(-1, 1, size[2]),\n", + " np.linspace(-1, 1, size[1]),\n", + " np.linspace(-1, 1, size[0]),\n", + " indexing='ij'\n", + " )\n", + "\n", + " for t in range(n_timesteps):\n", + " # Time-varying rotation angle\n", + " theta = 2 * np.pi * t / n_timesteps\n", + "\n", + " # Create rotating vector field with radial component\n", + " r = np.sqrt(x**2 + y**2 + z**2)\n", + "\n", + " # Displacement components (rotating + pulsating)\n", + " displacement_x = -y * np.cos(theta) + z * np.sin(theta)\n", + " displacement_y = x * np.cos(theta) - r * 0.2 * np.sin(theta)\n", + " displacement_z = -x * np.sin(theta) + y * np.cos(theta)\n", + "\n", + " # Scale by distance from center (creates flow pattern)\n", + " scale_factor = 5.0 * (1 - r / np.max(r))\n", + " displacement_x *= scale_factor\n", + " displacement_y *= scale_factor\n", + " displacement_z *= scale_factor\n", + "\n", + " # Stack into vector field (z, y, x, 3)\n", + " displacement_field = np.stack(\n", + " [displacement_x, displacement_y, displacement_z],\n", + " axis=-1\n", + " ).astype(np.float32)\n", + "\n", + " # Convert to ITK image\n", + " itk_image = itk.image_from_array(displacement_field, is_vector=True)\n", + " itk_image.SetSpacing([1.0, 1.0, 1.0])\n", + " itk_image.SetOrigin([0.0, 0.0, 0.0])\n", + "\n", + " # Save as NIfTI\n", + " file_path = os.path.join(output_dir, f\"displacement_t{t:03d}.nii.gz\")\n", + " itk.imwrite(itk_image, file_path, compression=True)\n", + " file_paths.append(file_path)\n", + "\n", + " print(f\"Generated: {file_path}\")\n", + "\n", + " return file_paths\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 5. Example Usage\n", + "\n", + "Demonstrate the complete pipeline with synthetic data.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Configuration\n", + "output_dir = \"./sample_displacement_fields\"\n", + "usd_output_path = \"./displacement_field_animation.usd\"\n", + "n_timesteps = 10\n", + "\n", + "# Generate sample data\n", + "print(\"Generating sample displacement fields...\")\n", + "nifti_files = generate_sample_displacement_fields(\n", + " output_dir,\n", + " n_timesteps=n_timesteps,\n", + " size=(32, 32, 32)\n", + ")\n", + "\n", + "print(f\"\\\\nGenerated {len(nifti_files)} sample files\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Create converter instance\n", + "converter = DisplacementFieldToUSD(\n", + " subsample_factor=4,\n", + " vector_scale=2.0\n", + ")\n", + "\n", + "# Run complete pipeline\n", + "stage = converter.process_pipeline(\n", + " nifti_files=nifti_files,\n", + " output_path=usd_output_path,\n", + " fps=24.0\n", + ")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Example with real displacement field data\n", + "# Uncomment and modify paths to use your own data:\n", + "\n", + "# nifti_files = [\n", + "# \"path/to/displacement_t00.nii.gz\",\n", + "# \"path/to/displacement_t01.nii.gz\",\n", + "# \"path/to/displacement_t02.nii.gz\",\n", + "# # ... more files\n", + "# ]\n", + "#\n", + "# converter = DisplacementFieldToUSD(subsample_factor=8, vector_scale=5.0)\n", + "# stage = converter.process_pipeline(\n", + "# nifti_files=nifti_files,\n", + "# output_path=\"cardiac_motion.usd\",\n", + "# fps=10.0\n", + "# )\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 7. Visualization in Omniverse\n", + "\n", + "### Steps to Visualize:\n", + "\n", + "1. **Open Omniverse Create or Kit**\n", + " - Launch NVIDIA Omniverse Create\n", + " - File → Open and select the generated USD file\n", + "\n", + "2. **Enable PhysicsNeMo Extension**\n", + " - Window → Extensions\n", + " - Search for \"PhysicsNeMo\"\n", + " - Enable the extension\n", + "\n", + "3. **Visualize the Vector Field**\n", + " - Select `/DisplacementField/VectorField` prim in the stage\n", + " - In PhysicsNeMo panel, choose visualization mode:\n", + " - **Streamlines**: Flow trajectories\n", + " - **Vector Glyphs**: Direction and magnitude at points\n", + " - **Volume Rendering**: Field magnitude as volume\n", + " - **Flow Particles**: Animated particles\n", + "\n", + "4. **Play Animation**\n", + " - Use timeline controls to play through time steps\n", + " - Adjust playback speed as needed\n", + "\n", + "### Class Methods Summary:\n", + "\n", + "- `load_nifti_files()`: Load displacement fields from NIfTI\n", + "- `convert_to_vtk()`: Convert ITK images to VTK format\n", + "- `extract_all_vector_fields()`: Extract subsampled data\n", + "- `create_usd_stage()`: Create time-varying USD stage\n", + "- `process_pipeline()`: Run complete pipeline\n", + "\n", + "### Key Features:\n", + "\n", + "✅ Class-based architecture for clean, reusable code\n", + "✅ Complete ITK → VTK → USD pipeline\n", + "✅ Time-varying animation support \n", + "✅ PhysicsNeMo-compatible velocities attribute\n", + "✅ Configurable subsampling and vector scaling\n", + "✅ Production-ready for medical imaging workflows\n", + "\n", + "This implementation encapsulates all logic in the `DisplacementFieldToUSD` class, making it easy to integrate into larger pipelines or customize for specific use cases.\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.11" + } + }, + "nbformat": 4, + "nbformat_minor": 2 } diff --git a/experiments/Heart-GatedCT_To_USD/1-register_images.ipynb b/experiments/Heart-GatedCT_To_USD/1-register_images.ipynb index 69aa147..760a32a 100644 --- a/experiments/Heart-GatedCT_To_USD/1-register_images.ipynb +++ b/experiments/Heart-GatedCT_To_USD/1-register_images.ipynb @@ -1,231 +1,231 @@ { - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "id": "3ce61753-11ad-4ade-9afe-6ad1bc748e25", - "metadata": {}, - "outputs": [], - "source": [ - "import itk\n", - "import os\n", - "\n", - "from physiomotion4d.register_images_icon import RegisterImagesICON\n", - "from physiomotion4d.segment_chest_total_segmentator import SegmentChestTotalSegmentator\n", - "\n", - "from physiomotion4d.transform_tools import TransformTools" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "9d2e5d21", - "metadata": {}, - "outputs": [], - "source": [ - "data_dir = \"../../data/Slicer-Heart-CT\"\n", - "\n", - "output_dir = os.path.join(\".\", \"results\")\n", - "if not os.path.exists(output_dir):\n", - " os.makedirs(output_dir)\n", - "\n", - "fixed_image_filename = os.path.join(output_dir, \"slice_fixed.mha\")\n", - "fixed_image = itk.imread(fixed_image_filename)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "b35f90c6", - "metadata": {}, - "outputs": [], - "source": [ - "seg = SegmentChestTotalSegmentator()\n", - "seg.contrast_threshold = 500\n", - "result = seg.segment(fixed_image, contrast_enhanced_study=True)\n", - "labelmap_mask = result[\"labelmap\"]\n", - "lung_mask = result[\"lung\"]\n", - "heart_mask = result[\"heart\"]\n", - "major_vessels_mask = result[\"major_vessels\"]\n", - "bone_mask = result[\"bone\"]\n", - "soft_tissue_mask = result[\"soft_tissue\"]\n", - "other_mask = result[\"other\"]\n", - "contrast_mask = result[\"contrast\"]\n", - "\n", - "fixed_image_labelmap = labelmap_mask\n", - "itk.imwrite(\n", - " fixed_image_labelmap,\n", - " os.path.join(output_dir, \"slice_fixed_mask.mha\"),\n", - " compression=True,\n", - ")\n", - "\n", - "heart_arr = itk.GetArrayFromImage(heart_mask)\n", - "contrast_arr = itk.GetArrayFromImage(contrast_mask)\n", - "major_vessels_arr = itk.GetArrayFromImage(major_vessels_mask)\n", - "fixed_image_dynamic_anatomy_mask = itk.GetImageFromArray(\n", - " heart_arr + contrast_arr + major_vessels_arr\n", - ")\n", - "fixed_image_dynamic_anatomy_mask.CopyInformation(fixed_image)\n", - "itk.imwrite(\n", - " fixed_image_dynamic_anatomy_mask,\n", - " os.path.join(output_dir, \"slice_fixed.dynamic_anatomy_mask.mha\"),\n", - " compression=True,\n", - ")\n", - "\n", - "lung_arr = itk.GetArrayFromImage(lung_mask)\n", - "bone_arr = itk.GetArrayFromImage(bone_mask)\n", - "other_arr = itk.GetArrayFromImage(other_mask)\n", - "fixed_image_static_mask = itk.GetImageFromArray(lung_arr + bone_arr + other_arr)\n", - "fixed_image_static_mask.CopyInformation(fixed_image)\n", - "itk.imwrite(\n", - " fixed_image_static_mask,\n", - " os.path.join(output_dir, \"slice_fixed.static_anatomy_mask.mha\"),\n", - " compression=True,\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "10ffbbf6", - "metadata": {}, - "outputs": [], - "source": [ - "reg = RegisterImagesICON()\n", - "reg.set_mask_dilation(5)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "cc9418da", - "metadata": {}, - "outputs": [], - "source": [ - "for i in range(21):\n", - " print(f\"Processing slice {i:03d}\")\n", - " moving_image = itk.imread(os.path.join(data_dir, f\"slice_{i:03d}.mha\"))\n", - " result = seg.segment(moving_image, contrast_enhanced_study=True)\n", - " labelmap_mask = result[\"labelmap\"]\n", - " lung_mask = result[\"lung\"]\n", - " heart_mask = result[\"heart\"]\n", - " major_vessels_mask = result[\"major_vessels\"]\n", - " bone_mask = result[\"bone\"]\n", - " soft_tissue_mask = result[\"soft_tissue\"]\n", - " other_mask = result[\"other\"]\n", - " contrast_mask = result[\"contrast\"]\n", - " itk.imwrite(\n", - " labelmap_mask,\n", - " os.path.join(output_dir, f\"slice_{i:03d}_mask.mha\"),\n", - " compression=True,\n", - " )\n", - "\n", - " # Register the whole image\n", - " reg.set_fixed_image(fixed_image)\n", - " results = reg.register(moving_image)\n", - " phi_FM = results[\"phi_FM\"]\n", - " phi_MF = results[\"phi_MF\"]\n", - " moving_image_reg = TransformTools().transform_image(\n", - " moving_image, phi_MF, fixed_image, \"sinc\"\n", - " ) # Final resampling with sinc\n", - " itk.imwrite(moving_image_reg, os.path.join(output_dir, f\"slice_{i:03d}.reg_all.mha\"), compression=True)\n", - " itk.transformwrite([phi_FM], os.path.join(output_dir, f\"slice_{i:03d}.reg_all.phi_FM.hdf\"), compression=True)\n", - " itk.transformwrite([phi_MF], os.path.join(output_dir, f\"slice_{i:03d}.reg_all.phi_MF.hdf\"), compression=True)\n", - "\n", - " # Register the dynamic anatomy mask\n", - " heart_arr = itk.GetArrayFromImage(heart_mask)\n", - " contrast_arr = itk.GetArrayFromImage(contrast_mask)\n", - " major_vessels_arr = itk.GetArrayFromImage(major_vessels_mask)\n", - " dynamic_anatomy_arr = heart_arr + contrast_arr + major_vessels_arr\n", - " moving_image_dynamic_anatomy_mask = itk.GetImageFromArray(dynamic_anatomy_arr)\n", - " moving_image_dynamic_anatomy_mask.CopyInformation(moving_image)\n", - " reg.set_fixed_image(fixed_image)\n", - " reg.set_fixed_image_mask(fixed_image_dynamic_anatomy_mask)\n", - " results = reg.register(\n", - " moving_image, moving_image_dynamic_anatomy_mask\n", - " )\n", - " phi_FM = results[\"phi_FM\"]\n", - " phi_MF = results[\"phi_MF\"]\n", - " moving_image_reg_dynamic_anatomy = TransformTools().transform_image(\n", - " moving_image, phi_MF, fixed_image, \"sinc\"\n", - " ) # Final resampling with sinc\n", - " itk.imwrite(\n", - " moving_image_dynamic_anatomy_mask,\n", - " os.path.join(output_dir, f\"slice_{i:03d}.dynamic_anatomy_mask.mha\"),\n", - " compression=True,\n", - " )\n", - " itk.imwrite(\n", - " moving_image_reg_dynamic_anatomy,\n", - " os.path.join(output_dir, f\"slice_{i:03d}.reg_dynamic_anatomy.mha\"),\n", - " compression=True,\n", - " )\n", - " itk.transformwrite(\n", - " [phi_FM],\n", - " os.path.join(output_dir, f\"slice_{i:03d}.reg_dynamic_anatomy.phi_FM.hdf\"),\n", - " compression=True,\n", - " )\n", - " itk.transformwrite(\n", - " [phi_MF],\n", - " os.path.join(output_dir, f\"slice_{i:03d}.reg_dynamic_anatomy.phi_MF.hdf\"),\n", - " compression=True,\n", - " )\n", - "\n", - " # Register the static anatomy mask\n", - " lung_arr = itk.GetArrayFromImage(lung_mask)\n", - " bone_arr = itk.GetArrayFromImage(bone_mask)\n", - " other_arr = itk.GetArrayFromImage(other_mask)\n", - " moving_image_static_mask = itk.GetImageFromArray(lung_arr + bone_arr + other_arr)\n", - " moving_image_static_mask.CopyInformation(moving_image)\n", - " reg.set_fixed_image(fixed_image)\n", - " reg.set_fixed_image_mask(fixed_image_static_mask)\n", - " results = reg.register(moving_image, moving_image_static_mask)\n", - " phi_FM = results[\"phi_FM\"]\n", - " phi_MF = results[\"phi_MF\"]\n", - " moving_image_reg_static = TransformTools().transform_image(\n", - " moving_image, phi_MF, fixed_image, \"sinc\"\n", - " ) # Final resampling with sinc\n", - " itk.imwrite(\n", - " moving_image_static_mask,\n", - " os.path.join(output_dir, f\"slice_{i:03d}.static_anatomy_mask.mha\"),\n", - " compression=True,\n", - " )\n", - " itk.imwrite(\n", - " moving_image_reg_static,\n", - " os.path.join(output_dir, f\"slice_{i:03d}.reg_static_anatomy.mha\"),\n", - " compression=True,\n", - " )\n", - " itk.transformwrite(\n", - " [phi_FM],\n", - " os.path.join(output_dir, f\"slice_{i:03d}.reg_static_anatomy.phi_FM.hdf\"),\n", - " compression=True,\n", - " )\n", - " itk.transformwrite(\n", - " [phi_MF],\n", - " os.path.join(output_dir, f\"slice_{i:03d}.reg_static_anatomy.phi_MF.hdf\"),\n", - " compression=True,\n", - " )" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "venv", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.11" - } - }, - "nbformat": 4, - "nbformat_minor": 5 + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "3ce61753-11ad-4ade-9afe-6ad1bc748e25", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "import itk\n", + "\n", + "from physiomotion4d.register_images_icon import RegisterImagesICON\n", + "from physiomotion4d.segment_chest_total_segmentator import SegmentChestTotalSegmentator\n", + "from physiomotion4d.transform_tools import TransformTools\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9d2e5d21", + "metadata": {}, + "outputs": [], + "source": [ + "data_dir = \"../../data/Slicer-Heart-CT\"\n", + "\n", + "output_dir = os.path.join(\".\", \"results\")\n", + "if not os.path.exists(output_dir):\n", + " os.makedirs(output_dir)\n", + "\n", + "fixed_image_filename = os.path.join(output_dir, \"slice_fixed.mha\")\n", + "fixed_image = itk.imread(fixed_image_filename)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b35f90c6", + "metadata": {}, + "outputs": [], + "source": [ + "seg = SegmentChestTotalSegmentator()\n", + "seg.contrast_threshold = 500\n", + "result = seg.segment(fixed_image, contrast_enhanced_study=True)\n", + "labelmap_mask = result[\"labelmap\"]\n", + "lung_mask = result[\"lung\"]\n", + "heart_mask = result[\"heart\"]\n", + "major_vessels_mask = result[\"major_vessels\"]\n", + "bone_mask = result[\"bone\"]\n", + "soft_tissue_mask = result[\"soft_tissue\"]\n", + "other_mask = result[\"other\"]\n", + "contrast_mask = result[\"contrast\"]\n", + "\n", + "fixed_image_labelmap = labelmap_mask\n", + "itk.imwrite(\n", + " fixed_image_labelmap,\n", + " os.path.join(output_dir, \"slice_fixed_mask.mha\"),\n", + " compression=True,\n", + ")\n", + "\n", + "heart_arr = itk.GetArrayFromImage(heart_mask)\n", + "contrast_arr = itk.GetArrayFromImage(contrast_mask)\n", + "major_vessels_arr = itk.GetArrayFromImage(major_vessels_mask)\n", + "fixed_image_dynamic_anatomy_mask = itk.GetImageFromArray(\n", + " heart_arr + contrast_arr + major_vessels_arr\n", + ")\n", + "fixed_image_dynamic_anatomy_mask.CopyInformation(fixed_image)\n", + "itk.imwrite(\n", + " fixed_image_dynamic_anatomy_mask,\n", + " os.path.join(output_dir, \"slice_fixed.dynamic_anatomy_mask.mha\"),\n", + " compression=True,\n", + ")\n", + "\n", + "lung_arr = itk.GetArrayFromImage(lung_mask)\n", + "bone_arr = itk.GetArrayFromImage(bone_mask)\n", + "other_arr = itk.GetArrayFromImage(other_mask)\n", + "fixed_image_static_mask = itk.GetImageFromArray(lung_arr + bone_arr + other_arr)\n", + "fixed_image_static_mask.CopyInformation(fixed_image)\n", + "itk.imwrite(\n", + " fixed_image_static_mask,\n", + " os.path.join(output_dir, \"slice_fixed.static_anatomy_mask.mha\"),\n", + " compression=True,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "10ffbbf6", + "metadata": {}, + "outputs": [], + "source": [ + "reg = RegisterImagesICON()\n", + "reg.set_mask_dilation(5)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cc9418da", + "metadata": {}, + "outputs": [], + "source": [ + "for i in range(21):\n", + " print(f\"Processing slice {i:03d}\")\n", + " moving_image = itk.imread(os.path.join(data_dir, f\"slice_{i:03d}.mha\"))\n", + " result = seg.segment(moving_image, contrast_enhanced_study=True)\n", + " labelmap_mask = result[\"labelmap\"]\n", + " lung_mask = result[\"lung\"]\n", + " heart_mask = result[\"heart\"]\n", + " major_vessels_mask = result[\"major_vessels\"]\n", + " bone_mask = result[\"bone\"]\n", + " soft_tissue_mask = result[\"soft_tissue\"]\n", + " other_mask = result[\"other\"]\n", + " contrast_mask = result[\"contrast\"]\n", + " itk.imwrite(\n", + " labelmap_mask,\n", + " os.path.join(output_dir, f\"slice_{i:03d}_mask.mha\"),\n", + " compression=True,\n", + " )\n", + "\n", + " # Register the whole image\n", + " reg.set_fixed_image(fixed_image)\n", + " results = reg.register(moving_image)\n", + " inverse_transform = results[\"inverse_transform\"]\n", + " forward_transform = results[\"forward_transform\"]\n", + " moving_image_reg = TransformTools().transform_image(\n", + " moving_image, forward_transform, fixed_image, \"sinc\"\n", + " ) # Final resampling with sinc\n", + " itk.imwrite(moving_image_reg, os.path.join(output_dir, f\"slice_{i:03d}.reg_all.mha\"), compression=True)\n", + " itk.transformwrite([forward_transform], os.path.join(output_dir, f\"slice_{i:03d}.reg_all.forward.hdf\"), compression=True)\n", + " itk.transformwrite([inverse_transform], os.path.join(output_dir, f\"slice_{i:03d}.reg_all.inverse.hdf\"), compression=True)\n", + "\n", + " # Register the dynamic anatomy mask\n", + " heart_arr = itk.GetArrayFromImage(heart_mask)\n", + " contrast_arr = itk.GetArrayFromImage(contrast_mask)\n", + " major_vessels_arr = itk.GetArrayFromImage(major_vessels_mask)\n", + " dynamic_anatomy_arr = heart_arr + contrast_arr + major_vessels_arr\n", + " moving_image_dynamic_anatomy_mask = itk.GetImageFromArray(dynamic_anatomy_arr)\n", + " moving_image_dynamic_anatomy_mask.CopyInformation(moving_image)\n", + " reg.set_fixed_image(fixed_image)\n", + " reg.set_fixed_mask(fixed_image_dynamic_anatomy_mask)\n", + " results = reg.register(\n", + " moving_image, moving_image_dynamic_anatomy_mask\n", + " )\n", + " inverse_transform = results[\"inverse_transform\"]\n", + " forward_transform = results[\"forward_transform\"]\n", + " moving_image_reg_dynamic_anatomy = TransformTools().transform_image(\n", + " moving_image, forward_transform, fixed_image, \"sinc\"\n", + " ) # Final resampling with sinc\n", + " itk.imwrite(\n", + " moving_image_dynamic_anatomy_mask,\n", + " os.path.join(output_dir, f\"slice_{i:03d}.dynamic_anatomy_mask.mha\"),\n", + " compression=True,\n", + " )\n", + " itk.imwrite(\n", + " moving_image_reg_dynamic_anatomy,\n", + " os.path.join(output_dir, f\"slice_{i:03d}.reg_dynamic_anatomy.mha\"),\n", + " compression=True,\n", + " )\n", + " itk.transformwrite(\n", + " [forward_transform],\n", + " os.path.join(output_dir, f\"slice_{i:03d}.reg_dynamic_anatomy.forward.hdf\"),\n", + " compression=True,\n", + " )\n", + " itk.transformwrite(\n", + " [inverse_transform],\n", + " os.path.join(output_dir, f\"slice_{i:03d}.reg_dynamic_anatomy.inverse.hdf\"),\n", + " compression=True,\n", + " )\n", + "\n", + " # Register the static anatomy mask\n", + " lung_arr = itk.GetArrayFromImage(lung_mask)\n", + " bone_arr = itk.GetArrayFromImage(bone_mask)\n", + " other_arr = itk.GetArrayFromImage(other_mask)\n", + " moving_image_static_mask = itk.GetImageFromArray(lung_arr + bone_arr + other_arr)\n", + " moving_image_static_mask.CopyInformation(moving_image)\n", + " reg.set_fixed_image(fixed_image)\n", + " reg.set_fixed_mask(fixed_image_static_mask)\n", + " results = reg.register(moving_image, moving_image_static_mask)\n", + " inverse_transform = results[\"inverse_transform\"]\n", + " forward_transform = results[\"forward_transform\"]\n", + " moving_image_reg_static = TransformTools().transform_image(\n", + " moving_image, forward_transform, fixed_image, \"sinc\"\n", + " ) # Final resampling with sinc\n", + " itk.imwrite(\n", + " moving_image_static_mask,\n", + " os.path.join(output_dir, f\"slice_{i:03d}.static_anatomy_mask.mha\"),\n", + " compression=True,\n", + " )\n", + " itk.imwrite(\n", + " moving_image_reg_static,\n", + " os.path.join(output_dir, f\"slice_{i:03d}.reg_static_anatomy.mha\"),\n", + " compression=True,\n", + " )\n", + " itk.transformwrite(\n", + " [forward_transform],\n", + " os.path.join(output_dir, f\"slice_{i:03d}.reg_static_anatomy.forward.hdf\"),\n", + " compression=True,\n", + " )\n", + " itk.transformwrite(\n", + " [inverse_transform],\n", + " os.path.join(output_dir, f\"slice_{i:03d}.reg_static_anatomy.inverse.hdf\"),\n", + " compression=True,\n", + " )" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.11" + } + }, + "nbformat": 4, + "nbformat_minor": 5 } diff --git a/experiments/Heart-GatedCT_To_USD/2-generate_segmentation.ipynb b/experiments/Heart-GatedCT_To_USD/2-generate_segmentation.ipynb index 285879b..ebab1e8 100644 --- a/experiments/Heart-GatedCT_To_USD/2-generate_segmentation.ipynb +++ b/experiments/Heart-GatedCT_To_USD/2-generate_segmentation.ipynb @@ -1,234 +1,234 @@ { - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "id": "3ce61753-11ad-4ade-9afe-6ad1bc748e25", - "metadata": {}, - "outputs": [], - "source": [ - "import os\n", - "\n", - "import itk\n", - "import numpy as np\n", - "import pyvista as pv\n", - "\n", - "from physiomotion4d.contour_tools import ContourTools\n", - "from physiomotion4d.segment_chest_total_segmentator import SegmentChestTotalSegmentator\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "b60954cf", - "metadata": {}, - "outputs": [], - "source": [ - "# When re-running, you can bypass certain long-running steps\n", - "re_run_image_max = False\n", - "use_fixed_image = True\n", - "re_run_image_segmentation = True" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "5c2f5e00", - "metadata": {}, - "outputs": [], - "source": [ - "output_dir = \"./results\"\n", - "max_image = None\n", - "print(\"Computing max image...\")\n", - "if re_run_image_max and not use_fixed_image:\n", - " # Compute max of all images\n", - " image = None\n", - " try:\n", - " image = itk.imread(os.path.join(output_dir, \"slice_000.reg_dynamic_anatomy.mha\"))\n", - " except (FileNotFoundError, OSError):\n", - " print(\"No image found. Aborting. Please run 1-generate_images.ipynb first.\")\n", - " exit(1)\n", - " arr = itk.array_from_image(image)\n", - " print(arr.shape)\n", - " arr = np.where(arr == 0, -1000, arr)\n", - " for i in range(1, 21):\n", - " print(f\"Processing slice {i:03d}...\")\n", - " tmp_arr = itk.array_from_image(\n", - " itk.imread(os.path.join(output_dir, f\"slice_{i:03d}.reg_dynamic_anatomy.mha\"))\n", - " )\n", - " tmp_arr = np.where(tmp_arr == 0, -1000, tmp_arr)\n", - " arr = np.maximum(arr, tmp_arr)\n", - " print(\"Max image computed.\")\n", - " max_image = itk.image_from_array(arr)\n", - " max_image.CopyInformation(image)\n", - " itk.imwrite(\n", - " max_image,\n", - " os.path.join(output_dir, \"slice_max.reg_dynamic_anatomy.mha\"),\n", - " compression=True,\n", - " )" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "9438634d", - "metadata": {}, - "outputs": [], - "source": [ - "if use_fixed_image:\n", - " max_image = itk.imread(os.path.join(output_dir, \"slice_fixed.mha\"))\n", - " outname = \"slice_fixed\"\n", - "else:\n", - " max_image = itk.imread(os.path.join(output_dir, \"slice_max.reg_dynamic_anatomy.mha\"))\n", - " outname = \"slice_max\"\n", - "\n", - "seg = SegmentChestTotalSegmentator()\n", - "seg.contrast_threshold = 500\n", - "if re_run_image_segmentation:\n", - " result = seg.segment(max_image, contrast_enhanced_study=True)\n", - " labelmap_image = result[\"labelmap\"]\n", - " lung_mask = result[\"lung\"]\n", - " heart_mask = result[\"heart\"]\n", - " major_vessels_mask = result[\"major_vessels\"]\n", - " bone_mask = result[\"bone\"]\n", - " soft_tissue_mask = result[\"soft_tissue\"]\n", - " other_mask = result[\"other\"]\n", - " contrast_mask = result[\"contrast\"]\n", - " itk.imwrite(\n", - " labelmap_image,\n", - " os.path.join(output_dir, f\"{outname}.all_mask.mha\"),\n", - " compression=True,\n", - " )\n", - "else:\n", - " labelmap_image = itk.imread(\n", - " os.path.join(output_dir, f\"{outname}.all_mask.mha\")\n", - " )" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "a2325199", - "metadata": {}, - "outputs": [], - "source": [ - "con = ContourTools()\n", - "all_contours = con.extract_contours(labelmap_image)\n", - "all_contours.save(os.path.join(output_dir, f\"{outname}.all_mask.vtp\"))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "10db4bfa", - "metadata": {}, - "outputs": [], - "source": [ - "label_arr = itk.array_from_image(labelmap_image)\n", - "lung_arr = itk.array_from_image(lung_mask)\n", - "heart_arr = itk.array_from_image(heart_mask)\n", - "major_vessels_arr = itk.array_from_image(major_vessels_mask)\n", - "bone_arr = itk.array_from_image(bone_mask)\n", - "soft_tissue_arr = itk.array_from_image(soft_tissue_mask)\n", - "other_arr = itk.array_from_image(other_mask)\n", - "contrast_arr = itk.array_from_image(contrast_mask)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "4d19335d", - "metadata": {}, - "outputs": [], - "source": [ - "dynamic_anatomy_arr = np.maximum(heart_arr, contrast_arr)\n", - "dynamic_anatomy_arr = np.maximum(dynamic_anatomy_arr, major_vessels_arr)\n", - "dynamic_anatomy_arr = np.where(dynamic_anatomy_arr, label_arr, 0)\n", - "dynamic_anatomy_image = itk.image_from_array(dynamic_anatomy_arr.astype(np.int16))\n", - "dynamic_anatomy_image.CopyInformation(labelmap_image)\n", - "itk.imwrite(\n", - " dynamic_anatomy_image,\n", - " os.path.join(output_dir, f\"{outname}.dynamic_anatomy_mask.mha\"),\n", - " compression=True,\n", - ")\n", - "\n", - "contours = con.extract_contours(dynamic_anatomy_image)\n", - "contours.save(os.path.join(output_dir, f\"{outname}.dynamic_anatomy_mask.vtp\"))\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "e4601d28", - "metadata": {}, - "outputs": [], - "source": [ - "static_anatomy_arr = lung_arr + bone_arr + soft_tissue_arr + other_arr\n", - "static_anatomy_arr = np.where(static_anatomy_arr, label_arr, 0)\n", - "static_anatomy_image = itk.image_from_array(static_anatomy_arr.astype(np.int16))\n", - "static_anatomy_image.CopyInformation(labelmap_image)\n", - "itk.imwrite(\n", - " static_anatomy_image,\n", - " os.path.join(output_dir, f\"{outname}.static_anatomy_mask.mha\"),\n", - " compression=True,\n", - ")\n", - "\n", - "contours = con.extract_contours(static_anatomy_image)\n", - "contours.save(os.path.join(output_dir, f\"{outname}.static_anatomy_mask.vtp\"))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "b9abb28c", - "metadata": {}, - "outputs": [], - "source": [ - "input_image = None\n", - "if use_fixed_image:\n", - " input_image = itk.imread(os.path.join(output_dir, \"slice_fixed.mha\"), itk.SS)\n", - "else:\n", - " input_image = itk.imread(os.path.join(output_dir, \"slice_max.reg_dynamic_anatomy.mha\"), itk.SS)\n", - "arr = itk.array_from_image(input_image)\n", - "flipped_input_image = itk.image_from_array(arr)\n", - "flipped_input_image.CopyInformation(input_image)\n", - "\n", - "image = pv.wrap(itk.vtk_image_from_image(flipped_input_image))\n", - "\n", - "pl = pv.Plotter()\n", - "pl.add_mesh(image.slice(normal=\"z\"), cmap=\"bone\", show_scalar_bar=False, opacity=0.5)\n", - "pl.add_mesh(\n", - " contours.slice(normal=\"z\"),\n", - " cmap=\"pink\",\n", - " clim=[50, 800],\n", - " show_scalar_bar=False,\n", - " opacity=1.0,\n", - ")\n", - "pl.set_background(\"black\")\n", - "pl.camera_position = \"xy\"\n", - "pl.show()" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "venv", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.11" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "3ce61753-11ad-4ade-9afe-6ad1bc748e25", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "import numpy as np\n", + "import itk\n", + "import pyvista as pv\n", + "\n", + "from physiomotion4d.contour_tools import ContourTools\n", + "from physiomotion4d.segment_chest_total_segmentator import SegmentChestTotalSegmentator\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b60954cf", + "metadata": {}, + "outputs": [], + "source": [ + "# When re-running, you can bypass certain long-running steps\n", + "re_run_image_max = False\n", + "use_fixed_image = True\n", + "re_run_image_segmentation = True" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5c2f5e00", + "metadata": {}, + "outputs": [], + "source": [ + "output_dir = \"./results\"\n", + "max_image = None\n", + "print(\"Computing max image...\")\n", + "if re_run_image_max and not use_fixed_image:\n", + " # Compute max of all images\n", + " image = None\n", + " try:\n", + " image = itk.imread(os.path.join(output_dir, \"slice_000.reg_dynamic_anatomy.mha\"))\n", + " except (FileNotFoundError, OSError):\n", + " print(\"No image found. Aborting. Please run 1-generate_images.ipynb first.\")\n", + " exit(1)\n", + " arr = itk.array_from_image(image)\n", + " print(arr.shape)\n", + " arr = np.where(arr == 0, -1000, arr)\n", + " for i in range(1, 21):\n", + " print(f\"Processing slice {i:03d}...\")\n", + " tmp_arr = itk.array_from_image(\n", + " itk.imread(os.path.join(output_dir, f\"slice_{i:03d}.reg_dynamic_anatomy.mha\"))\n", + " )\n", + " tmp_arr = np.where(tmp_arr == 0, -1000, tmp_arr)\n", + " arr = np.maximum(arr, tmp_arr)\n", + " print(\"Max image computed.\")\n", + " max_image = itk.image_from_array(arr)\n", + " max_image.CopyInformation(image)\n", + " itk.imwrite(\n", + " max_image,\n", + " os.path.join(output_dir, \"slice_max.reg_dynamic_anatomy.mha\"),\n", + " compression=True,\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9438634d", + "metadata": {}, + "outputs": [], + "source": [ + "if use_fixed_image:\n", + " max_image = itk.imread(os.path.join(output_dir, \"slice_fixed.mha\"))\n", + " outname = \"slice_fixed\"\n", + "else:\n", + " max_image = itk.imread(os.path.join(output_dir, \"slice_max.reg_dynamic_anatomy.mha\"))\n", + " outname = \"slice_max\"\n", + "\n", + "seg = SegmentChestTotalSegmentator()\n", + "seg.contrast_threshold = 500\n", + "if re_run_image_segmentation:\n", + " result = seg.segment(max_image, contrast_enhanced_study=True)\n", + " labelmap_image = result[\"labelmap\"]\n", + " lung_mask = result[\"lung\"]\n", + " heart_mask = result[\"heart\"]\n", + " major_vessels_mask = result[\"major_vessels\"]\n", + " bone_mask = result[\"bone\"]\n", + " soft_tissue_mask = result[\"soft_tissue\"]\n", + " other_mask = result[\"other\"]\n", + " contrast_mask = result[\"contrast\"]\n", + " itk.imwrite(\n", + " labelmap_image,\n", + " os.path.join(output_dir, f\"{outname}.all_mask.mha\"),\n", + " compression=True,\n", + " )\n", + "else:\n", + " labelmap_image = itk.imread(\n", + " os.path.join(output_dir, f\"{outname}.all_mask.mha\")\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a2325199", + "metadata": {}, + "outputs": [], + "source": [ + "con = ContourTools()\n", + "all_contours = con.extract_contours(labelmap_image)\n", + "all_contours.save(os.path.join(output_dir, f\"{outname}.all_mask.vtp\"))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "10db4bfa", + "metadata": {}, + "outputs": [], + "source": [ + "label_arr = itk.array_from_image(labelmap_image)\n", + "lung_arr = itk.array_from_image(lung_mask)\n", + "heart_arr = itk.array_from_image(heart_mask)\n", + "major_vessels_arr = itk.array_from_image(major_vessels_mask)\n", + "bone_arr = itk.array_from_image(bone_mask)\n", + "soft_tissue_arr = itk.array_from_image(soft_tissue_mask)\n", + "other_arr = itk.array_from_image(other_mask)\n", + "contrast_arr = itk.array_from_image(contrast_mask)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4d19335d", + "metadata": {}, + "outputs": [], + "source": [ + "dynamic_anatomy_arr = np.maximum(heart_arr, contrast_arr)\n", + "dynamic_anatomy_arr = np.maximum(dynamic_anatomy_arr, major_vessels_arr)\n", + "dynamic_anatomy_arr = np.where(dynamic_anatomy_arr, label_arr, 0)\n", + "dynamic_anatomy_image = itk.image_from_array(dynamic_anatomy_arr.astype(np.int16))\n", + "dynamic_anatomy_image.CopyInformation(labelmap_image)\n", + "itk.imwrite(\n", + " dynamic_anatomy_image,\n", + " os.path.join(output_dir, f\"{outname}.dynamic_anatomy_mask.mha\"),\n", + " compression=True,\n", + ")\n", + "\n", + "contours = con.extract_contours(dynamic_anatomy_image)\n", + "contours.save(os.path.join(output_dir, f\"{outname}.dynamic_anatomy_mask.vtp\"))\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e4601d28", + "metadata": {}, + "outputs": [], + "source": [ + "static_anatomy_arr = lung_arr + bone_arr + soft_tissue_arr + other_arr\n", + "static_anatomy_arr = np.where(static_anatomy_arr, label_arr, 0)\n", + "static_anatomy_image = itk.image_from_array(static_anatomy_arr.astype(np.int16))\n", + "static_anatomy_image.CopyInformation(labelmap_image)\n", + "itk.imwrite(\n", + " static_anatomy_image,\n", + " os.path.join(output_dir, f\"{outname}.static_anatomy_mask.mha\"),\n", + " compression=True,\n", + ")\n", + "\n", + "contours = con.extract_contours(static_anatomy_image)\n", + "contours.save(os.path.join(output_dir, f\"{outname}.static_anatomy_mask.vtp\"))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b9abb28c", + "metadata": {}, + "outputs": [], + "source": [ + "input_image = None\n", + "if use_fixed_image:\n", + " input_image = itk.imread(os.path.join(output_dir, \"slice_fixed.mha\"), itk.SS)\n", + "else:\n", + " input_image = itk.imread(os.path.join(output_dir, \"slice_max.reg_dynamic_anatomy.mha\"), itk.SS)\n", + "arr = itk.array_from_image(input_image)\n", + "flipped_input_image = itk.image_from_array(arr)\n", + "flipped_input_image.CopyInformation(input_image)\n", + "\n", + "image = pv.wrap(itk.vtk_image_from_image(flipped_input_image))\n", + "\n", + "pl = pv.Plotter()\n", + "pl.add_mesh(image.slice(normal=\"z\"), cmap=\"bone\", show_scalar_bar=False, opacity=0.5)\n", + "pl.add_mesh(\n", + " contours.slice(normal=\"z\"),\n", + " cmap=\"pink\",\n", + " clim=[50, 800],\n", + " show_scalar_bar=False,\n", + " opacity=1.0,\n", + ")\n", + "pl.set_background(\"black\")\n", + "pl.camera_position = \"xy\"\n", + "pl.show()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.11" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} \ No newline at end of file diff --git a/experiments/Heart-GatedCT_To_USD/3-transform_dynamic_and_static_contours.ipynb b/experiments/Heart-GatedCT_To_USD/3-transform_dynamic_and_static_contours.ipynb index 83f1b4d..541a006 100644 --- a/experiments/Heart-GatedCT_To_USD/3-transform_dynamic_and_static_contours.ipynb +++ b/experiments/Heart-GatedCT_To_USD/3-transform_dynamic_and_static_contours.ipynb @@ -1,140 +1,163 @@ { - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "id": "3ce61753-11ad-4ade-9afe-6ad1bc748e25", - "metadata": {}, - "outputs": [], - "source": [ - "import os\n", - "\n", - "import itk\n", - "import pyvista as pv\n", - "\n", - "from physiomotion4d.contour_tools import ContourTools\n", - "from physiomotion4d.convert_vtk_4d_to_usd import ConvertVTK4DToUSD\n", - "from physiomotion4d.segment_chest_total_segmentator import SegmentChestTotalSegmentator\n", - "from physiomotion4d.usd_anatomy_tools import USDAnatomyTools\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "240f1d14", - "metadata": {}, - "outputs": [], - "source": [ - "data_dir = \"../../data\"\n", - "output_dir = \"results\"\n", - "\n", - "base_name = \"slice_fixed\"\n", - "#base_name = \"slice_max.reg_dynamic_anatomy\"\n", - "\n", - "project_name = \"Slicer_CardiacGatedCT\"\n", - "\n", - "do_transform_contours = True" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "ef6002e8", - "metadata": {}, - "outputs": [], - "source": [ - "def transform_contours(contours, transform_filenames, base_name, output_dir, project_name):\n", - " con = ContourTools()\n", - " for i, transform_filename in enumerate(transform_filenames):\n", - " phi_MF = itk.transformread(transform_filename)[0]\n", - " print(f\"Applying transform {transform_filename} to {base_name}\")\n", - "\n", - " new_contours = con.transform_contours(contours, phi_MF, with_deformation_magnitude=True)\n", - " new_contours.save(\n", - " os.path.join(output_dir, f\"slice_{i:03d}.reg_{base_name}_inv.{base_name}_mask.vtp\")\n", - " )\n", - "\n", - " files = [\n", - " f\"{output_dir}/slice_{i:03d}.reg_{base_name}_inv.{base_name}_mask.vtp\"\n", - " for i in range(21)\n", - " ]\n", - " seg = SegmentChestTotalSegmentator()\n", - " all_mask_ids = seg.all_mask_ids\n", - "\n", - " polydata = [pv.read(f) for f in files]\n", - "\n", - " print(\"Converting vtp models to USD\")\n", - " converter = ConvertVTK4DToUSD(\n", - " project_name,\n", - " polydata,\n", - " all_mask_ids,\n", - " )\n", - " stage = converter.convert(\n", - " os.path.join(output_dir, f\"{project_name}.{base_name}.usd\"),\n", - " )\n", - "\n", - " painter = USDAnatomyTools(stage)\n", - " painter.enhance_meshes(seg)\n", - " if os.path.exists(os.path.join(output_dir, f\"{project_name}.{base_name}_painted.usd\")):\n", - " os.remove(os.path.join(output_dir, f\"{project_name}.{base_name}_painted.usd\"))\n", - " stage.Export(os.path.join(output_dir, f\"{project_name}.{base_name}_painted.usd\"))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "90c291bb", - "metadata": {}, - "outputs": [], - "source": [ - "dynamic_transform_filenames = [\n", - " os.path.join(output_dir, f\"slice_{i:03d}.reg_dynamic_anatomy.phi_MF.hdf\")\n", - " for i in range(21)\n", - "]\n", - "static_transform_filenames = [\n", - " os.path.join(output_dir, f\"slice_{i:03d}.reg_static_anatomy.phi_MF.hdf\")\n", - " for i in range(21)\n", - "]\n", - "all_transform_filenames = [\n", - " os.path.join(output_dir, f\"slice_{i:03d}.reg_all.phi_MF.hdf\")\n", - " for i in range(21)\n", - "]\n", - "\n", - "dynamic_anatomy_contours = pv.read(\n", - " os.path.join(output_dir, f\"{base_name}.dynamic_anatomy_mask.vtp\")\n", - ")\n", - "static_anatomy_contours = pv.read(\n", - " os.path.join(output_dir, f\"{base_name}.static_anatomy_mask.vtp\")\n", - ")\n", - "all_contours = pv.read(\n", - " os.path.join(output_dir, f\"{base_name}.all_mask.vtp\")\n", - ")\n", - "\n", - "transform_contours(all_contours, all_transform_filenames, \"all\", output_dir, project_name)\n", - "transform_contours(dynamic_anatomy_contours, dynamic_transform_filenames, \"dynamic_anatomy\", output_dir, project_name)\n", - "transform_contours(static_anatomy_contours, static_transform_filenames, \"static_anatomy\", output_dir, project_name)\n" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "venv", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.11" - } - }, - "nbformat": 4, - "nbformat_minor": 5 + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "3ce61753-11ad-4ade-9afe-6ad1bc748e25", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "import itk\n", + "import pyvista as pv\n", + "\n", + "from physiomotion4d.contour_tools import ContourTools\n", + "from physiomotion4d.convert_vtk_4d_to_usd import ConvertVTK4DToUSD\n", + "from physiomotion4d.segment_chest_total_segmentator import SegmentChestTotalSegmentator\n", + "from physiomotion4d.usd_anatomy_tools import USDAnatomyTools\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "240f1d14", + "metadata": {}, + "outputs": [], + "source": [ + "data_dir = \"../../data\"\n", + "output_dir = \"results\"\n", + "\n", + "base_name = \"slice_fixed\"\n", + "#base_name = \"slice_max.reg_dynamic_anatomy\"\n", + "\n", + "project_name = \"Slicer_CardiacGatedCT\"\n", + "\n", + "do_transform_contours = True" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ef6002e8", + "metadata": {}, + "outputs": [], + "source": [ + "def transform_contours(contours, transform_filenames, base_name, output_dir, project_name):\n", + " con = ContourTools()\n", + " for i, transform_filename in enumerate(transform_filenames):\n", + " forward_transform = itk.transformread(transform_filename)[0]\n", + " print(f\"Applying transform {transform_filename} to {base_name}\")\n", + "\n", + " new_contours = con.transform_contours(contours, forward_transform, with_deformation_magnitude=True)\n", + " new_contours.save(\n", + " os.path.join(output_dir, f\"slice_{i:03d}.reg_{base_name}_inv.{base_name}_mask.vtp\")\n", + " )\n", + "\n", + "def convert_contours(contours, transform_filenames, base_name, output_dir, project_name, compute_normals=False):\n", + " files = [\n", + " f\"{output_dir}/slice_{i:03d}.reg_{base_name}_inv.{base_name}_mask.vtp\"\n", + " for i in range(21)\n", + " ]\n", + " seg = SegmentChestTotalSegmentator()\n", + " all_mask_ids = seg.all_mask_ids\n", + "\n", + " polydata = [pv.read(f) for f in files]\n", + "\n", + " converter = ConvertVTK4DToUSD(\n", + " project_name,\n", + " polydata,\n", + " all_mask_ids,\n", + " compute_normals=compute_normals,\n", + " )\n", + " stage = converter.convert(\n", + " os.path.join(output_dir, f\"{project_name}.{base_name}.usd\"),\n", + " )\n", + "\n", + " painter = USDAnatomyTools(stage)\n", + " painter.enhance_meshes(seg)\n", + " if os.path.exists(os.path.join(output_dir, f\"{project_name}.{base_name}_painted.usd\")):\n", + " os.remove(os.path.join(output_dir, f\"{project_name}.{base_name}_painted.usd\"))\n", + " stage.Export(os.path.join(output_dir, f\"{project_name}.{base_name}_painted.usd\"))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "90c291bb", + "metadata": {}, + "outputs": [], + "source": [ + "dynamic_transform_filenames = [\n", + " os.path.join(output_dir, f\"slice_{i:03d}.reg_dynamic_anatomy.forward.hdf\")\n", + " for i in range(21)\n", + "]\n", + "static_transform_filenames = [\n", + " os.path.join(output_dir, f\"slice_{i:03d}.reg_static_anatomy.forward.hdf\")\n", + " for i in range(21)\n", + "]\n", + "all_transform_filenames = [\n", + " os.path.join(output_dir, f\"slice_{i:03d}.reg_all.forward.hdf\")\n", + " for i in range(21)\n", + "]\n", + "\n", + "dynamic_anatomy_contours = pv.read(\n", + " os.path.join(output_dir, f\"{base_name}.dynamic_anatomy_mask.vtp\")\n", + ")\n", + "static_anatomy_contours = pv.read(\n", + " os.path.join(output_dir, f\"{base_name}.static_anatomy_mask.vtp\")\n", + ")\n", + "all_contours = pv.read(\n", + " os.path.join(output_dir, f\"{base_name}.all_mask.vtp\")\n", + ")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e3d48ddc", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "if False:\n", + " transform_contours(all_contours, all_transform_filenames, \"all\", output_dir, project_name)\n", + " transform_contours(dynamic_anatomy_contours, dynamic_transform_filenames, \"dynamic_anatomy\", output_dir, project_name)\n", + " transform_contours(static_anatomy_contours, static_transform_filenames, \"static_anatomy\", output_dir, project_name)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "06d06123", + "metadata": {}, + "outputs": [], + "source": [ + "convert_contours(all_contours, all_transform_filenames, \"all\", output_dir, project_name, compute_normals=True)\n", + "convert_contours(dynamic_anatomy_contours, dynamic_transform_filenames, \"dynamic_anatomy\", output_dir, project_name, compute_normals=True)\n", + "convert_contours(static_anatomy_contours, static_transform_filenames, \"static_anatomy\", output_dir, project_name, compute_normals=True)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.11" + } + }, + "nbformat": 4, + "nbformat_minor": 5 } diff --git a/experiments/Heart-GatedCT_To_USD/test_vista3d_class.ipynb b/experiments/Heart-GatedCT_To_USD/test_vista3d_class.ipynb index f7615b1..0da013e 100644 --- a/experiments/Heart-GatedCT_To_USD/test_vista3d_class.ipynb +++ b/experiments/Heart-GatedCT_To_USD/test_vista3d_class.ipynb @@ -7,13 +7,7 @@ "metadata": {}, "outputs": [], "source": [ - "import itk\n", - "import os\n", - "\n", - "from physiomotion4d.segment_chest_vista_3d import SegmentChestVista3D\n", - "\n", - "output_dir = \"./results\"\n", - "max_image = itk.imread(os.path.join(output_dir, \"slice_max.reg_dynamic_anatomy.mha\"))" + "import itk\nimport os\n\nfrom physiomotion4d.segment_chest_vista_3d import SegmentChestVista3D\n\noutput_dir = \"./results\"\nmax_image = itk.imread(os.path.join(output_dir, \"slice_max.reg_dynamic_anatomy.mha\"))" ] }, { @@ -62,21 +56,7 @@ } ], "source": [ - "seg = SegmentChestVista3D()\n", - "result = seg.segment(max_image, contrast_enhanced_study=True)\n", - "labelmap_image = result[\"labelmap\"]\n", - "lung_mask = result[\"lung\"]\n", - "heart_mask = result[\"heart\"]\n", - "major_vessels_mask = result[\"major_vessels\"]\n", - "bone_mask = result[\"bone\"]\n", - "soft_tissue_mask = result[\"soft_tissue\"]\n", - "other_mask = result[\"other\"]\n", - "contrast_mask = result[\"contrast\"]\n", - "itk.imwrite(\n", - " labelmap_image,\n", - " os.path.join(output_dir, \"slice_max.reg_dynamic_anatomy.all_mask-test.mha\"),\n", - " compression=True,\n", - ")" + "seg = SegmentChestVista3D()\nresult = seg.segment(max_image, contrast_enhanced_study=True)\nlabelmap_image = result[\"labelmap\"]\nlung_mask = result[\"lung\"]\nheart_mask = result[\"heart\"]\nmajor_vessels_mask = result[\"major_vessels\"]\nbone_mask = result[\"bone\"]\nsoft_tissue_mask = result[\"soft_tissue\"]\nother_mask = result[\"other\"]\ncontrast_mask = result[\"contrast\"]\nitk.imwrite(\n labelmap_image,\n os.path.join(output_dir, \"slice_max.reg_dynamic_anatomy.all_mask-test.mha\"),\n compression=True,\n)" ] } ], @@ -101,4 +81,4 @@ }, "nbformat": 4, "nbformat_minor": 5 -} +} \ No newline at end of file diff --git a/experiments/Heart-GatedCT_To_USD/test_vista3d_inMem.ipynb b/experiments/Heart-GatedCT_To_USD/test_vista3d_inMem.ipynb index fc79e0b..6f0fda6 100644 --- a/experiments/Heart-GatedCT_To_USD/test_vista3d_inMem.ipynb +++ b/experiments/Heart-GatedCT_To_USD/test_vista3d_inMem.ipynb @@ -1,263 +1,263 @@ { - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "id": "30066e92", - "metadata": {}, - "outputs": [], - "source": [ - "import numpy as np\n", - "import torch\n", - "\n", - "\n", - "def vista3d_inference_from_itk(\n", - " itk_image,\n", - " label_prompt=None,\n", - " points=None,\n", - " point_labels=None,\n", - " device=None,\n", - " bundle_path=None,\n", - " model_cache_dir=None,\n", - "):\n", - " # 1. Import dependencies\n", - " import itk\n", - " from monai.bundle import download\n", - " from monai.data.itk_torch_bridge import itk_image_to_metatensor\n", - " from monai.transforms import (\n", - " EnsureChannelFirst,\n", - " Spacing,\n", - " ScaleIntensityRange,\n", - " CropForeground,\n", - " EnsureType,\n", - " )\n", - " from monai.inferers import sliding_window_inference\n", - " from monai.networks.nets import vista3d132\n", - " from monai.utils import set_determinism\n", - "\n", - " set_determinism(seed=42)\n", - " if device is None:\n", - " device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", - "\n", - " # 2. Handle \"no prompts\" case: segment all classes\n", - " if label_prompt is None and points is None:\n", - " everything_labels = list(\n", - " set([i + 1 for i in range(132)]) - set([2, 16, 18, 20, 21, 23, 24, 25, 26])\n", - " )\n", - " label_prompt = everything_labels\n", - " print(f\"No prompt provided. Using everything_labels for {len(everything_labels)} classes.\")\n", - "\n", - " if points is not None and point_labels is None:\n", - " raise ValueError(\"point_labels must be provided when points are specified\")\n", - "\n", - " # 3. Download model bundle if needed\n", - " if bundle_path is None:\n", - " import tempfile\n", - "\n", - " if model_cache_dir is None:\n", - " model_cache_dir = tempfile.mkdtemp()\n", - " try:\n", - " download(name=\"vista3d\", bundle_dir=model_cache_dir, source=\"monaihosting\")\n", - " except Exception:\n", - " download(name=\"vista3d\", bundle_dir=model_cache_dir, source=\"github\")\n", - " bundle_path = f\"{model_cache_dir}/vista3d\"\n", - "\n", - " # 4. ITK->MetaTensor (in memory)\n", - " meta_tensor = itk_image_to_metatensor(itk_image, channel_dim=None, dtype=torch.float32)\n", - "\n", - " input_size = itk_image.GetLargestPossibleRegion().GetSize()\n", - "\n", - " # 5. Preprocessing pipeline\n", - " processed = meta_tensor\n", - " processed = EnsureChannelFirst(channel_dim=None)(processed)\n", - " processed = EnsureType(dtype=torch.float32)(processed)\n", - " processed = Spacing(pixdim=[1.5, 1.5, 1.5], mode=\"bilinear\")(processed)\n", - " processed = ScaleIntensityRange(a_min=-1024, a_max=1024, b_min=0.0, b_max=1.0, clip=True)(\n", - " processed\n", - " )\n", - " processed = CropForeground()(processed)\n", - "\n", - " # 6. Load VISTA3D\n", - " model = vista3d132(encoder_embed_dim=48, in_channels=1)\n", - " model_path = f\"{bundle_path}/models/model.pt\"\n", - " checkpoint = torch.load(model_path, map_location=device)\n", - " model.load_state_dict(checkpoint)\n", - " model.eval()\n", - " model.to(device)\n", - "\n", - " # 7. Prepare input tensor\n", - " input_tensor = processed\n", - " if not isinstance(input_tensor, torch.Tensor):\n", - " input_tensor = torch.tensor(np.asarray(input_tensor), dtype=torch.float32)\n", - " if input_tensor.dim() == 3:\n", - " input_tensor = input_tensor.unsqueeze(0)\n", - " if input_tensor.dim() == 4:\n", - " input_tensor = input_tensor.unsqueeze(0)\n", - " input_tensor = input_tensor.to(device)\n", - "\n", - " # 8. Prepare model inputs\n", - " model_inputs = {\"image\": input_tensor}\n", - " if label_prompt is not None:\n", - " label_prompt_tensor = torch.tensor(label_prompt, dtype=torch.long, device=device)\n", - " model_inputs[\"label_prompt\"] = label_prompt_tensor\n", - " print('label_prompt_tensor shape', label_prompt_tensor.shape)\n", - " if points is not None:\n", - " point_coords = torch.tensor(points, dtype=torch.float32, device=device).unsqueeze(0)\n", - " point_labels_tensor = torch.tensor(\n", - " point_labels, dtype=torch.float32, device=device\n", - " ).unsqueeze(0)\n", - " model_inputs[\"points\"] = point_coords\n", - " model_inputs[\"point_labels\"] = point_labels_tensor\n", - " print('point_coords shape', point_coords.shape)\n", - "\n", - " # 9. Sliding window inference for large images\n", - " def predictor_fn(x):\n", - " args = {k: v for k, v in model_inputs.items() if k != \"image\"}\n", - " print(x.shape)\n", - " return model(x, **args)\n", - "\n", - " with torch.no_grad():\n", - " if any(dim > 128 for dim in input_tensor.shape[2:]):\n", - " print(\"Sliding window inference\")\n", - " output = sliding_window_inference(\n", - " input_tensor,\n", - " roi_size=[128, 128, 128],\n", - " sw_batch_size=1,\n", - " predictor=predictor_fn,\n", - " overlap=0.5,\n", - " mode=\"gaussian\",\n", - " device=device,\n", - " )\n", - " else:\n", - " print(\"Single window inference\")\n", - " output = model(input_tensor, **{k: v for k, v in model_inputs.items() if k != \"image\"})\n", - "\n", - " print('output shape', output.shape)\n", - " # 10. Postprocess: multi-class to label map\n", - " output = output.cpu()\n", - " if hasattr(output, \"detach\"):\n", - " output = output.detach()\n", - " if isinstance(output, dict):\n", - " if \"pred\" in output:\n", - " output = output[\"pred\"]\n", - " else:\n", - " output = list(output.values())[0]\n", - "\n", - " if output.shape[1] > 1:\n", - " label_map = torch.argmax(output, dim=1).squeeze(0).numpy().astype(np.uint16)\n", - " else:\n", - " label_map = (output > 0.5).squeeze(0).cpu().numpy().astype(np.uint8)\n", - "\n", - " # Ensure output is zyx order for ITK\n", - " if label_map.shape != tuple(reversed(input_size)):\n", - " # Some transforms may flip axes; reorder as needed.\n", - " label_map_for_itk = np.transpose(label_map, axes=range(label_map.ndim)[::-1])\n", - " else:\n", - " label_map_for_itk = label_map\n", - "\n", - " # ITK expects z,y,x ordering for GetImageFromArray\n", - " output_itk = itk.GetImageFromArray(label_map_for_itk)\n", - " itk.imwrite(output_itk, 'output_itk.mha')\n", - "\n", - " # Return output in ITK format matching the input (size, spacing, origin, direction, type)\n", - " return output_itk" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "c0e5a477", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "No prompt provided. Using everything_labels for 123 classes.\n", - "2025-07-17 13:09:51,844 - INFO - --- input summary of monai.bundle.scripts.download ---\n", - "2025-07-17 13:09:51,844 - INFO - > name: 'vista3d'\n", - "2025-07-17 13:09:51,845 - INFO - > bundle_dir: './network_weights'\n", - "2025-07-17 13:09:51,845 - INFO - > source: 'monaihosting'\n", - "2025-07-17 13:09:51,846 - INFO - > remove_prefix: 'monai_'\n", - "2025-07-17 13:09:51,846 - INFO - > progress: True\n", - "2025-07-17 13:09:51,847 - INFO - ---\n", - "\n", - "\n" - ] + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "30066e92", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import torch\n", + "\n", + "\n", + "def vista3d_inference_from_itk(\n", + " itk_image,\n", + " label_prompt=None,\n", + " points=None,\n", + " point_labels=None,\n", + " device=None,\n", + " bundle_path=None,\n", + " model_cache_dir=None,\n", + "):\n", + " # 1. Import dependencies\n", + " import itk\n", + " from monai.bundle import download\n", + " from monai.data.itk_torch_bridge import itk_image_to_metatensor\n", + " from monai.transforms import (\n", + " EnsureChannelFirst,\n", + " Spacing,\n", + " ScaleIntensityRange,\n", + " CropForeground,\n", + " EnsureType,\n", + " )\n", + " from monai.inferers import sliding_window_inference\n", + " from monai.networks.nets import vista3d132\n", + " from monai.utils import set_determinism\n", + "\n", + " set_determinism(seed=42)\n", + " if device is None:\n", + " device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "\n", + " # 2. Handle \"no prompts\" case: segment all classes\n", + " if label_prompt is None and points is None:\n", + " everything_labels = list(\n", + " set([i + 1 for i in range(132)]) - set([2, 16, 18, 20, 21, 23, 24, 25, 26])\n", + " )\n", + " label_prompt = everything_labels\n", + " print(f\"No prompt provided. Using everything_labels for {len(everything_labels)} classes.\")\n", + "\n", + " if points is not None and point_labels is None:\n", + " raise ValueError(\"point_labels must be provided when points are specified\")\n", + "\n", + " # 3. Download model bundle if needed\n", + " if bundle_path is None:\n", + " import tempfile\n", + "\n", + " if model_cache_dir is None:\n", + " model_cache_dir = tempfile.mkdtemp()\n", + " try:\n", + " download(name=\"vista3d\", bundle_dir=model_cache_dir, source=\"monaihosting\")\n", + " except Exception:\n", + " download(name=\"vista3d\", bundle_dir=model_cache_dir, source=\"github\")\n", + " bundle_path = f\"{model_cache_dir}/vista3d\"\n", + "\n", + " # 4. ITK->MetaTensor (in memory)\n", + " meta_tensor = itk_image_to_metatensor(itk_image, channel_dim=None, dtype=torch.float32)\n", + "\n", + " input_size = itk_image.GetLargestPossibleRegion().GetSize()\n", + "\n", + " # 5. Preprocessing pipeline\n", + " processed = meta_tensor\n", + " processed = EnsureChannelFirst(channel_dim=None)(processed)\n", + " processed = EnsureType(dtype=torch.float32)(processed)\n", + " processed = Spacing(pixdim=[1.5, 1.5, 1.5], mode=\"bilinear\")(processed)\n", + " processed = ScaleIntensityRange(a_min=-1024, a_max=1024, b_min=0.0, b_max=1.0, clip=True)(\n", + " processed\n", + " )\n", + " processed = CropForeground()(processed)\n", + "\n", + " # 6. Load VISTA3D\n", + " model = vista3d132(encoder_embed_dim=48, in_channels=1)\n", + " model_path = f\"{bundle_path}/models/model.pt\"\n", + " checkpoint = torch.load(model_path, map_location=device)\n", + " model.load_state_dict(checkpoint)\n", + " model.eval()\n", + " model.to(device)\n", + "\n", + " # 7. Prepare input tensor\n", + " input_tensor = processed\n", + " if not isinstance(input_tensor, torch.Tensor):\n", + " input_tensor = torch.tensor(np.asarray(input_tensor), dtype=torch.float32)\n", + " if input_tensor.dim() == 3:\n", + " input_tensor = input_tensor.unsqueeze(0)\n", + " if input_tensor.dim() == 4:\n", + " input_tensor = input_tensor.unsqueeze(0)\n", + " input_tensor = input_tensor.to(device)\n", + "\n", + " # 8. Prepare model inputs\n", + " model_inputs = {\"image\": input_tensor}\n", + " if label_prompt is not None:\n", + " label_prompt_tensor = torch.tensor(label_prompt, dtype=torch.long, device=device)\n", + " model_inputs[\"label_prompt\"] = label_prompt_tensor\n", + " print('label_prompt_tensor shape', label_prompt_tensor.shape)\n", + " if points is not None:\n", + " point_coords = torch.tensor(points, dtype=torch.float32, device=device).unsqueeze(0)\n", + " point_labels_tensor = torch.tensor(\n", + " point_labels, dtype=torch.float32, device=device\n", + " ).unsqueeze(0)\n", + " model_inputs[\"points\"] = point_coords\n", + " model_inputs[\"point_labels\"] = point_labels_tensor\n", + " print('point_coords shape', point_coords.shape)\n", + "\n", + " # 9. Sliding window inference for large images\n", + " def predictor_fn(x):\n", + " args = {k: v for k, v in model_inputs.items() if k != \"image\"}\n", + " print(x.shape)\n", + " return model(x, **args)\n", + "\n", + " with torch.no_grad():\n", + " if any(dim > 128 for dim in input_tensor.shape[2:]):\n", + " print(\"Sliding window inference\")\n", + " output = sliding_window_inference(\n", + " input_tensor,\n", + " roi_size=[128, 128, 128],\n", + " sw_batch_size=1,\n", + " predictor=predictor_fn,\n", + " overlap=0.5,\n", + " mode=\"gaussian\",\n", + " device=device,\n", + " )\n", + " else:\n", + " print(\"Single window inference\")\n", + " output = model(input_tensor, **{k: v for k, v in model_inputs.items() if k != \"image\"})\n", + "\n", + " print('output shape', output.shape)\n", + " # 10. Postprocess: multi-class to label map\n", + " output = output.cpu()\n", + " if hasattr(output, \"detach\"):\n", + " output = output.detach()\n", + " if isinstance(output, dict):\n", + " if \"pred\" in output:\n", + " output = output[\"pred\"]\n", + " else:\n", + " output = list(output.values())[0]\n", + "\n", + " if output.shape[1] > 1:\n", + " label_map = torch.argmax(output, dim=1).squeeze(0).numpy().astype(np.uint16)\n", + " else:\n", + " label_map = (output > 0.5).squeeze(0).cpu().numpy().astype(np.uint8)\n", + "\n", + " # Ensure output is zyx order for ITK\n", + " if label_map.shape != tuple(reversed(input_size)):\n", + " # Some transforms may flip axes; reorder as needed.\n", + " label_map_for_itk = np.transpose(label_map, axes=range(label_map.ndim)[::-1])\n", + " else:\n", + " label_map_for_itk = label_map\n", + "\n", + " # ITK expects z,y,x ordering for GetImageFromArray\n", + " output_itk = itk.GetImageFromArray(label_map_for_itk)\n", + " itk.imwrite(output_itk, 'output_itk.mha')\n", + "\n", + " # Return output in ITK format matching the input (size, spacing, origin, direction, type)\n", + " return output_itk" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "c0e5a477", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "No prompt provided. Using everything_labels for 123 classes.\n", + "2025-07-17 13:09:51,844 - INFO - --- input summary of monai.bundle.scripts.download ---\n", + "2025-07-17 13:09:51,844 - INFO - > name: 'vista3d'\n", + "2025-07-17 13:09:51,845 - INFO - > bundle_dir: './network_weights'\n", + "2025-07-17 13:09:51,845 - INFO - > source: 'monaihosting'\n", + "2025-07-17 13:09:51,846 - INFO - > remove_prefix: 'monai_'\n", + "2025-07-17 13:09:51,846 - INFO - > progress: True\n", + "2025-07-17 13:09:51,847 - INFO - ---\n", + "\n", + "\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "c5b02e2bf4d94e769d2eaae5942b3804", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Fetching 22 files: 0%| | 0/22 [00:00 0).astype(np.uint8)\n", - "binary_mask_image = itk.GetImageFromArray(binary_mask_arr)\n", - "binary_mask_image.CopyInformation(heart_mask_image)\n", - "\n", - "edge_filter = itk.BinaryContourImageFilter.New(Input=binary_mask_image)\n", - "edge_filter.SetForegroundValue(1)\n", - "edge_filter.SetBackgroundValue(0)\n", - "edge_filter.SetFullyConnected(False)\n", - "edge_filter.Update()\n", - "edge_mask_image = edge_filter.GetOutput()\n", - "\n", - "# Compute signed distance map (positive inside, negative outside)\n", - "print(\" Computing signed distance map...\")\n", - "distance_filter = itk.SignedMaurerDistanceMapImageFilter.New(Input=edge_mask_image)\n", - "distance_filter.SetSquaredDistance(False)\n", - "distance_filter.SetUseImageSpacing(True)\n", - "distance_filter.SetInsideIsPositive(False)\n", - "distance_filter.Update()\n", - "distance_image = distance_filter.GetOutput()\n", - "\n", - "# Normalize to [0, 1000] range for better optimization\n", - "print(\" Normalizing intensity values...\")\n", - "dist_arr = itk.GetArrayFromImage(distance_image)\n", - "min_val = dist_arr.min()\n", - "max_val = dist_arr.max()\n", - "normalized_arr = ((1.0 - (dist_arr - min_val) / (max_val - min_val)) * 100.0).astype(np.float32)\n", - "target_image = itk.GetImageFromArray(normalized_arr)\n", - "target_image.CopyInformation(distance_image)\n", - "\n", - "# Save intermediate and final images\n", - "itk.imwrite(binary_mask_image, str(output_dir / 'binary_mask.mha'), compression=True)\n", - "itk.imwrite(distance_image, str(output_dir / 'distance_map.mha'), compression=True)\n", - "itk.imwrite(target_image, str(output_dir / 'target_intensity_image.mha'), compression=True)\n", - "\n", - "print(f\"✓ Target intensity image created\")\n", - "print(f\" Min intensity: {normalized_arr.min():.2f}\")\n", - "print(f\" Max intensity: {normalized_arr.max():.2f}\")\n", - "print(f\" Mean intensity: {normalized_arr.mean():.2f}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Load PCA Heart Model" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Load the average model\n", - "print(\"Loading PCA heart model...\")\n", - "average_mesh_original = pv.read(str(average_model_path))\n", - "print(f\" Average model: {average_mesh_original.n_points} points, {average_mesh_original.n_cells} cells\")\n", - "print(f\" Model bounds: {average_mesh_original.bounds}\")\n", - "print(f\" Model center: {average_mesh_original.center}\")\n", - "\n", - "# Save a copy for reference\n", - "average_mesh_original.save(str(output_dir / 'average_model_original.vtk'))\n", - "\n", - "# Load PCA data from JSON\n", - "print(f\"\\nLoading PCA data from JSON...\")\n", - "with open(str(pca_json_path), 'r') as f:\n", - " pca_data = json.load(f)\n", - "\n", - "# Extract PCA group data\n", - "group_data = pca_data[pca_group_key]\n", - "\n", - "# Extract eigenvalues and convert to standard deviations\n", - "eigenvalues = np.array(group_data['eigenvalues'])\n", - "std_deviations = np.sqrt(eigenvalues)\n", - "print(f\" Loaded {len(std_deviations)} eigenvalues (converted to std deviations)\")\n", - "\n", - "# Extract eigenvector components\n", - "eigenvectors = np.array(group_data['components'], dtype=np.float64)\n", - "print(f\" Loaded eigenvectors with shape {eigenvectors.shape}\")\n", - "print(f\" ✓ PCA data loaded successfully!\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Perform Initial ICP Affine Registration\n", - "\n", - "Use ICP (Iterative Closest Point) with affine mode to align the model surface to the patient surface extracted from the segmentation mask. This provides a good initial alignment for the PCA-based registration.\n", - "\n", - "The ICP registration pipeline:\n", - "1. Centroid alignment (automatic)\n", - "2. Rigid ICP alignment\n", - "3. Affine ICP alignment (allows scaling and shearing)\n", - "\n", - "The PCA registration will then refine this initial alignment with shape model constraints." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Perform initial ICP-based affine registration to patient surface\n", - "print(\"Performing initial ICP affine registration...\")\n", - "print(\"=\"*70)\n", - "\n", - "# Extract surface from patient mask for ICP target\n", - "\n", - "contour_tools = ContourTools()\n", - "patient_surface = contour_tools.extract_contours(heart_mask_image)\n", - "print(f\" Extracted patient surface: {patient_surface.n_points} points\")\n", - "\n", - "# Extract surface from average model for ICP source\n", - "model_surface = average_mesh_original.extract_surface()\n", - "print(f\" Extracted model surface: {model_surface.n_points} points\")\n", - "\n", - "# Perform ICP affine registration\n", - "icp_registrar = RegisterModelToModelICP(\n", - " moving_mesh=model_surface,\n", - " fixed_mesh=patient_surface\n", - ")\n", - "\n", - "icp_result = icp_registrar.register(mode='rigid', max_iterations=200)\n", - "\n", - "# Get the aligned mesh and transform\n", - "aligned_model_surface = icp_result['moving_mesh']\n", - "phi_FM = icp_result['phi_FM']\n", - "\n", - "print(\"\\n✓ ICP affine registration complete\")\n", - "print(\" Initial alignment obtained using centroid + rigid + ICP\")\n", - "\n", - "# Save aligned model for visualization\n", - "aligned_model_surface.save(str(output_dir / 'icp_aligned_model_surface.vtp'))\n", - "print(\" Saved ICP-aligned model surface\")\n", - "\n", - "# Apply ICP transform to the full average mesh (not just surface)\n", - "# This gets the volumetric mesh into patient space for PCA registration\n", - "transform_tools = TransformTools()\n", - "average_mesh_icp_aligned = transform_tools.transform_pvcontour(\n", - " average_mesh_original,\n", - " phi_FM\n", - ")\n", - "print(\"\\n✓ Applied ICP transform to full average mesh\")\n", - "print(f\" Aligned mesh center: {average_mesh_icp_aligned.center}\")\n", - "\n", - "# Save ICP-aligned full mesh\n", - "average_mesh_icp_aligned.save(str(output_dir / 'average_model_icp_aligned.vtk'))" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Initialize PCA Registration" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "print(\"Initializing RegisterModelToImagePCA...\")\n", - "print(\"=\"*70)\n", - "\n", - "# Use the ICP-aligned mesh as the starting point for PCA registration\n", - "pca_registrar = RegisterModelToImagePCA(\n", - " average_mesh=average_mesh_icp_aligned,\n", - " eigenvectors=eigenvectors,\n", - " std_deviations=std_deviations,\n", - " reference_image=target_image\n", - ")\n", - "\n", - "print(\"✓ PCA registrar initialized\")\n", - "print(\" Using ICP-aligned mesh as starting point\")\n", - "print(f\" Number of points: {len(pca_registrar.average_mesh.points)}\")\n", - "print(f\" Number of PCA modes: {pca_registrar.n_pca_modes}\")\n", - "print(f\" Reference image size: {itk.size(target_image)}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Run PCA-Based Shape Optimization\n", - "\n", - "Now that we have a good initial alignment from ICP affine registration, we run the PCA-based registration to optimize the shape parameters while allowing small rigid refinements.\n", - "\n", - "The PCA registration pipeline includes:\n", - "1. **Stage 1**: Minor rigid refinement (starting from ICP-aligned position)\n", - "2. **Stage 2**: Joint optimization of rigid parameters + PCA shape coefficients\n", - "\n", - "Since the ICP already provided good alignment, we use reduced bounds for rigid refinement and focus on PCA shape optimization.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "print(\"\\n\" + \"=\"*70)\n", - "print(\"PCA-BASED SHAPE OPTIMIZATION\")\n", - "print(\"=\"*70)\n", - "print(\"\\nRunning complete PCA registration pipeline...\")\n", - "print(\" (Starting from ICP-aligned mesh, so skipping rigid stage)\")\n", - "print(\" Using identity transform as initial guess\")\n", - "\n", - "# Run complete registration\n", - "# Since we already have good alignment from ICP, we can focus on PCA shape fitting\n", - "result = pca_registrar.register(\n", - " n_pca_modes=10, # Use first 10 PCA modes\n", - " stage1_max_iterations=10, # Fewer iterations for rigid since already aligned\n", - " stage2_max_iterations=200, # More iterations for PCA optimization\n", - " pca_coefficient_bounds=3.0, # ±3 std deviations per mode\n", - " rigid_refinement_bounds={'versor': 0.1, 'translation_mm': 10.0} # Small refinements only\n", - ")\n", - "\n", - "print(\"\\n✓ PCA registration complete\")\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Display Registration Results\n", - "\n", - "Review the optimization results from the PCA registration pipeline.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "print(\"\\n\" + \"=\"*70)\n", - "print(\"REGISTRATION RESULTS\")\n", - "print(\"=\"*70)\n", - "\n", - "# Display results\n", - "print(f\"\\nFinal Registration Metrics:\")\n", - "print(f\" Final mean intensity: {result['intensity']:.4f}\")\n", - "\n", - "print(f\"\\nOptimized Rigid Coefficients:\")\n", - "print(result['pre_phi_FM'])\n", - "print(f\"\\nOptimized PCA Coefficients (in units of std deviations):\")\n", - "for i, coef in enumerate(result['pca_coefficients_FM']):\n", - " print(f\" Mode {i+1:2d}: {coef:7.4f}\")\n", - "\n", - "# Get the final registered mesh\n", - "registered_mesh = result['registered_mesh']\n", - "print(f\"\\nRegistered Mesh Properties:\")\n", - "print(f\" Number of points: {registered_mesh.n_points}\")\n", - "print(f\" Number of cells: {registered_mesh.n_cells}\")\n", - "print(f\" Center: {registered_mesh.center}\")\n", - "print(f\" Bounds: {registered_mesh.bounds}\")\n", - "\n", - "print(\"\\n✓ Registration pipeline complete!\")\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Save Registration Results\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "print(\"\\nSaving results...\")\n", - "\n", - "# Save final PCA-registered mesh\n", - "registered_mesh.save(str(output_dir / 'registered_mesh_pca.vtk'))\n", - "print(f\" Saved final PCA-registered mesh\")\n", - "\n", - "# Save ICP-aligned mesh for comparison\n", - "average_mesh_icp_aligned.save(str(output_dir / 'registered_mesh_icp_only.vtk'))\n", - "print(f\" Saved ICP-only aligned mesh\")\n", - "\n", - "# Save patient surface\n", - "patient_surface.save(str(output_dir / 'patient_surface.vtp'))\n", - "print(f\" Saved patient surface\")\n", - "\n", - "# Save transforms\n", - "itk.transformwrite([phi_FM], str(output_dir / 'icp_rigid_transform.hdf'), compression=True)\n", - "itk.transformwrite([result['pre_phi_FM']], str(output_dir / 'pca_pre_rigid_transform.hdf'), compression=True)\n", - "print(f\" Saved transforms\")\n", - "\n", - "# Save PCA coefficients\n", - "np.savetxt(\n", - " str(output_dir / 'pca_coefficients.txt'),\n", - " result['pca_coefficients_FM'],\n", - " header=f\"PCA coefficients for {len(result['pca_coefficients_FM'])} modes\"\n", - ")\n", - "print(f\" Saved PCA coefficients\")\n", - "\n", - "print(f\"\\n✓ All results saved to: {output_dir}\")\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Visualize Results\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Extract surface from patient mask\n", - "patient_surface = ContourTools().extract_contours(heart_mask_image)\n", - "patient_surface.save(str(output_dir / 'patient_surface.vtp'))\n", - "\n", - "# Create side-by-side comparison\n", - "plotter = pv.Plotter(shape=(1, 2), window_size=[1000, 600])\n", - "\n", - "plotter.subplot(0, 0)\n", - "plotter.add_mesh(patient_surface, color='red', opacity=1.0, label='Patient')\n", - "plotter.add_mesh(average_mesh_icp_aligned, color='green', opacity=0.6, label='ICP Registered')\n", - "plotter.add_title('After PCA Shape Fitting')\n", - "plotter.add_axes()\n", - "\n", - "# After PCA shape fitting\n", - "plotter.subplot(0, 1)\n", - "plotter.add_mesh(patient_surface, color='red', opacity=1.0, label='Patient')\n", - "plotter.add_mesh(registered_mesh, color='green', opacity=0.6, label='PCA Registered')\n", - "plotter.add_title('After PCA Shape Fitting')\n", - "plotter.add_axes()\n", - "\n", - "plotter.link_views()\n", - "plotter.show()\n" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "venv", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.11" - } - }, - "nbformat": 4, - "nbformat_minor": 4 + "nbformat": 4, + "nbformat_minor": 4 } diff --git a/experiments/Heart-Model_To_Patient/heart_model_to_patient.ipynb b/experiments/Heart-Model_To_Patient/heart_model_to_patient.ipynb new file mode 100644 index 0000000..c49cb9c --- /dev/null +++ b/experiments/Heart-Model_To_Patient/heart_model_to_patient.ipynb @@ -0,0 +1,419 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setup and Imports" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "from pathlib import Path\n", + "\n", + "import itk\n", + "import numpy as np\n", + "import pyvista as pv\n", + "\n", + "# Import from PhysioMotion4D package\n", + "from physiomotion4d import (\n", + " ContourTools,\n", + " HeartModelToPatientWorkflow,\n", + " SegmentChestTotalSegmentator,\n", + ")\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Define File Paths" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Patient CT image (defines coordinate frame)\n", + "patient_data_dir = Path.cwd().parent / '..' / 'data' / 'Slicer-Heart-CT'\n", + "patient_ct_path = patient_data_dir / 'patient_img.mha'\n", + "patient_ct_heart_mask_path = patient_data_dir / 'patient_heart_wall_mask.nii.gz'\n", + "\n", + "# Atlas template model (moving)\n", + "atlas_data_dir = Path.cwd().parent / '..' / 'data' / 'KCL-Heart-Model'\n", + "atlas_vtu_path = atlas_data_dir / 'average_mesh.vtk'\n", + "atlas_labelmap_path = atlas_data_dir / 'average_labelmap_bkg.mha'\n", + "\n", + "pca_data_dir = Path.cwd().parent / '..' / 'data' / 'KCL-Heart-Model' / 'pca'\n", + "pca_json_path = pca_data_dir / 'pca.json'\n", + "pca_group_key = 'All'\n", + "pca_n_modes = 10\n", + "\n", + "# Output directory\n", + "output_dir = Path.cwd() / 'results'\n", + "\n", + "os.makedirs(output_dir, exist_ok=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "patient_image = itk.imread(str(patient_ct_path))\n", + "itk.imwrite(patient_image, str(output_dir / 'patient_image.mha'), compression=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "if False:\n", + " segmentator = SegmentChestTotalSegmentator()\n", + " segmentator.contrast_threshold = 500\n", + " patient_segmentation_data = segmentator.segment(patient_image, contrast_enhanced_study=False)\n", + " labelmap = patient_segmentation_data[\"labelmap\"]\n", + " lung_mask = patient_segmentation_data[\"lung\"]\n", + " heart_mask = patient_segmentation_data[\"heart\"]\n", + " major_vessels_mask = patient_segmentation_data[\"major_vessels\"]\n", + " bone_mask = patient_segmentation_data[\"bone\"]\n", + " soft_tissue_mask = patient_segmentation_data[\"soft_tissue\"]\n", + " other_mask = patient_segmentation_data[\"other\"]\n", + " contrast_mask = patient_segmentation_data[\"contrast\"]\n", + "\n", + "\n", + " itk.imwrite(labelmap, str(output_dir / 'patient_labelmap.mha'), compression=True)\n", + "\n", + " heart_arr = itk.GetArrayFromImage(heart_mask)\n", + " #contrast_arr = itk.GetArrayFromImage(contrast_mask)\n", + " mask_arr = (heart_arr > 0).astype(np.uint8) #((heart_arr + contrast_arr) > 0).astype(np.uint8)\n", + " patient_mask = itk.GetImageFromArray(mask_arr)\n", + " patient_mask.CopyInformation(patient_image)\n", + "\n", + " itk.imwrite(patient_mask, str(output_dir / 'patient_heart_mask_draft.mha'), compression=True)\n", + "\n", + " # hand edit fixed_mask to make patient_heart_wall_mask.nii.gz that is saved in patient_data_dir\n", + "else:\n", + " patient_mask = itk.imread(str(patient_ct_heart_mask_path))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "flip0 = np.array(patient_mask.GetDirection())[0,0] < 0\n", + "flip1 = np.array(patient_mask.GetDirection())[1,1] < 0\n", + "flip2 = np.array(patient_mask.GetDirection())[2,2] < 0\n", + "if flip0 or flip1 or flip2:\n", + " print(\"Flipping patient image...\")\n", + " print(flip0, flip1, flip2)\n", + " flip_filter = itk.FlipImageFilter.New(Input=patient_image)\n", + " flip_filter.SetFlipAxes([int(flip0), int(flip1), int(flip2)])\n", + " flip_filter.SetFlipAboutOrigin(True)\n", + " flip_filter.Update()\n", + " patient_image = flip_filter.GetOutput()\n", + " id_mat = itk.Matrix[itk.D, 3, 3]()\n", + " id_mat.SetIdentity()\n", + " patient_image.SetDirection(id_mat)\n", + " itk.imwrite(patient_image, str(output_dir / 'patient_image.mha'), compression=True)\n", + " print(\"Flipping patient mask image...\")\n", + " flip_filter = itk.FlipImageFilter.New(Input=patient_mask)\n", + " flip_filter.SetFlipAxes([int(flip0), int(flip1), int(flip2)])\n", + " flip_filter.SetFlipAboutOrigin(True)\n", + " flip_filter.Update()\n", + " patient_mask = flip_filter.GetOutput()\n", + " patient_mask.SetDirection(id_mat)\n", + " itk.imwrite(patient_mask, str(output_dir / 'patient_mask.mha'), compression=True)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "patient_model = ContourTools().extract_contours(patient_mask)\n", + "patient_model.save(str(output_dir / 'patient_mesh.vtp'))\n", + "patient_model = pv.read(str(output_dir / 'patient_mesh.vtp'))\n", + "\n", + "template_model = pv.read(str(atlas_vtu_path))\n", + "template_model_surface = template_model.extract_surface()\n", + "template_model_surface.save(str(output_dir / 'model_surface.vtp'))\n", + "template_model_surface = pv.read(str(output_dir / 'model_surface.vtp'))\n", + "template_labelmap = itk.imread(str(atlas_labelmap_path))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "registrar = HeartModelToPatientWorkflow(\n", + " template_model=template_model,\n", + " template_labelmap=template_labelmap,\n", + " template_labelmap_heart_muscle_ids=[1],\n", + " template_labelmap_chamber_ids=[2, 3, 4, 5],\n", + " template_labelmap_background_ids=[6],\n", + " patient_image=patient_image,\n", + " patient_models=[patient_model],\n", + " pca_json_filename=pca_json_path,\n", + " pca_group_key=pca_group_key,\n", + " pca_number_of_modes=pca_n_modes,\n", + ")\n", + "\n", + "registrar.set_mask_dilation_mm(0)\n", + "registrar.set_roi_dilation_mm(25)\n", + "\n", + "patient_image = registrar.patient_image\n", + "itk.imwrite(\n", + " patient_image,\n", + " str(output_dir / 'patient_image_preprocessed.mha'),\n", + " compression=True\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Rough alignment using ICP\n", + "icp_results = registrar.register_model_to_model_icp()\n", + "icp_inverse_point_transform = icp_results['inverse_point_transform']\n", + "icp_forward_point_transform = icp_results['forward_point_transform']\n", + "icp_model_surface = icp_results['registered_template_model_surface']\n", + "icp_labelmap = icp_results['registered_template_labelmap']\n", + "\n", + "icp_model_surface.save(str(output_dir / \"icp_model_surface.vtp\"))\n", + "itk.imwrite(icp_labelmap, str(output_dir / \"icp_labelmap.mha\"), compression=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "pca_results = registrar.register_model_to_model_pca()\n", + "pca_coefficients = pca_results['pca_coefficients']\n", + "pca_model_surface = pca_results['registered_template_model_surface']\n", + "pca_labelmap = pca_results['registered_template_labelmap']\n", + "\n", + "pca_model_surface.save(str(output_dir / \"pca_model_surface.vtp\"))\n", + "itk.imwrite(pca_labelmap, str(output_dir / \"pca_labelmap.mha\"), compression=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Mask Alignment" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Perform deformable registration\n", + "print(\"Starting deformable mask-to-mask registration...\")\n", + "\n", + "m2m_results = registrar.register_mask_to_mask(use_icon_refinement=False)\n", + "m2m_inverse_transform = m2m_results['inverse_transform']\n", + "m2m_forward_transform = m2m_results['forward_transform']\n", + "m2m_model_surface = m2m_results['registered_template_model_surface']\n", + "m2m_labelmap = m2m_results['registered_template_labelmap']\n", + "\n", + "print(\"Registration complete!\")\n", + "\n", + "m2m_model_surface.save(str(output_dir / \"m2m_model_surface.vtp\"))\n", + "itk.imwrite(m2m_labelmap, str(output_dir / \"m2m_labelmap.mha\"), compression=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(\"Starting deformable registration...\")\n", + "print(\"This may take several minutes depending on GPU availability.\")\n", + "\n", + "m2i_results = registrar.register_labelmap_to_image()\n", + "m2i_inverse_transform = m2i_results['inverse_transform']\n", + "m2i_forward_transform = m2i_results['forward_transform']\n", + "m2i_surface = m2i_results['registered_template_model_surface']\n", + "m2i_labelmap = m2i_results['registered_template_labelmap']\n", + "print(\"\\nRegistration complete!\")\n", + "\n", + "# Save registration results to output folder\n", + "m2i_surface.save(str(output_dir / \"m2i_model_surface.vtp\"))\n", + "itk.imwrite(m2i_labelmap, str(output_dir / \"m2i_labelmap.mha\"), compression=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import time\n", + "\n", + "tmp_p = itk.Point[itk.D, 3]()\n", + "point = registrar.template_model.points[0]\n", + "tmp_p[0] = float(point[0])\n", + "tmp_p[1] = float(point[1])\n", + "tmp_p[2] = float(point[2])\n", + "\n", + "start_time = time.time()\n", + "# Don't save the results since ICP transform is applied as a post-PCA transform\n", + "_ = registrar.icp_registrar.forward_point_transform.TransformPoint(tmp_p)\n", + "print(f\"--- ICP forward transform time: {time.time() - start_time} seconds\", flush=True)\n", + "\n", + "start_time = time.time()\n", + "# Don't apply the post PCA transform since this is just for setup\n", + "_ = registrar.pca_registrar.transform_point(tmp_p, include_post_pca_transform=False)\n", + "print(f\"--- PCA setup time: {time.time() - start_time} seconds\", flush=True)\n", + "start_time = time.time()\n", + "# Apply the post PCA transform since this is the actual transform\n", + "tmp_p = registrar.pca_registrar.transform_point(tmp_p, include_post_pca_transform=True)\n", + "print(f\"PCA + ICP transform time: {time.time() - start_time} seconds\", flush=True)\n", + "\n", + "start_time = time.time()\n", + "tmp_p = registrar.m2m_inverse_transform.TransformPoint(tmp_p)\n", + "print(f\"M2M inverse transform time: {time.time() - start_time} seconds\", flush=True)\n", + "\n", + "start_time = time.time()\n", + "tmp_p = registrar.m2i_inverse_transform.TransformPoint(tmp_p)\n", + "print(f\"M2I inverse transform time: {time.time() - start_time} seconds\", flush=True)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Verify registration using the transform member function\n", + "surface_transformed = registrar.m2i_template_model_surface\n", + "surface_transformed.save(str(output_dir / \"registered_template_surface.vtp\"))\n", + "\n", + "model_transformed = registrar.transform_model()\n", + "model_transformed.save(str(output_dir / \"registered_template.vtu\"))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Visualize Final Results" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Load meshes from registrar member variables\n", + "patient_surface = registrar.patient_model_surface\n", + "registered_surface = registrar.registered_template_model_surface\n", + "icp_surface = registrar.icp_template_model_surface\n", + "pca_surface = registrar.pca_template_model_surface\n", + "m2m_surface = registrar.m2m_template_model_surface\n", + "m2i_surface = registrar.m2i_template_model_surface\n", + "\n", + "# Create side-by-side comparison\n", + "plotter = pv.Plotter(shape=(1, 2))\n", + "\n", + "# After rough alignment\n", + "plotter.subplot(0, 0)\n", + "plotter.add_mesh(patient_surface, color='red', opacity=0.5, label='Patient')\n", + "plotter.add_mesh(pca_surface, color='green', opacity=1.0, label='After ICP')\n", + "plotter.add_title('PCA Alignment')\n", + "\n", + "# After deformable registration\n", + "plotter.subplot(0, 1)\n", + "plotter.add_mesh(patient_surface, color='red', opacity=0.5, label='Patient')\n", + "plotter.add_mesh(m2i_surface, color='blue', opacity=1.0, label='Registered')\n", + "plotter.add_title('Final Registration')\n", + "\n", + "plotter.link_views()\n", + "plotter.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Visualize Deformation Magnitude" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# The transformed mesh has deformation magnitude stored as point data\n", + "if 'DeformationMagnitude' in registered_surface.point_data:\n", + " plotter = pv.Plotter()\n", + " plotter.add_mesh(\n", + " registered_surface,\n", + " scalars='DeformationMagnitude',\n", + " cmap='jet',\n", + " show_scalar_bar=True,\n", + " scalar_bar_args={'title': 'Deformation (mm)'}\n", + " )\n", + " plotter.add_title('Deformation Magnitude')\n", + " plotter.show()\n", + "\n", + " # Print statistics\n", + " deformation = registered_surface['DeformationMagnitude']\n", + " print(\"Deformation statistics:\")\n", + " print(f\" Min: {deformation.min():.2f} mm\")\n", + " print(f\" Max: {deformation.max():.2f} mm\")\n", + " print(f\" Mean: {deformation.mean():.2f} mm\")\n", + " print(f\" Std: {deformation.std():.2f} mm\")\n", + "else:\n", + " print(\"DeformationMagnitude not found in mesh point data\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.11" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/experiments/Heart-Model_To_Patient/heart_model_to_patient_wip.ipynb b/experiments/Heart-Model_To_Patient/heart_model_to_patient_wip.ipynb deleted file mode 100644 index a7082b7..0000000 --- a/experiments/Heart-Model_To_Patient/heart_model_to_patient_wip.ipynb +++ /dev/null @@ -1,435 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Setup and Imports" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import os\n", - "from pathlib import Path\n", - "\n", - "import itk\n", - "import numpy as np\n", - "import pyvista as pv\n", - "from itk import TubeTK as ttk\n", - "\n", - "# Import from PhysioMotion4D package\n", - "from physiomotion4d import (\n", - " ContourTools,\n", - " HeartModelToPatientWorkflow,\n", - " SegmentChestTotalSegmentator,\n", - ")\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Define File Paths" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Patient CT image (defines coordinate frame)\n", - "patient_data_dir = Path.cwd().parent / '..' / 'data' / 'Slicer-Heart-CT'\n", - "patient_ct_path = patient_data_dir / 'slice_007.mha'\n", - "patient_ct_heart_mask_path = patient_data_dir / 'slice_007_heart_mask3.nii.gz'\n", - "\n", - "# Atlas template model (moving)\n", - "atlas_data_dir = Path.cwd().parent / '..' / 'data' / 'KCL-Heart-Model'\n", - "atlas_vtu_path = atlas_data_dir / 'average_mesh.vtk'\n", - "\n", - "pca_data_dir = Path.cwd().parent / '..' / 'data' / 'KCL-Heart-Model' / 'pca'\n", - "pca_json_path = pca_data_dir / 'pca.json'\n", - "pca_group_key = 'All'\n", - "pca_n_modes = 10\n", - "\n", - "# Output directory\n", - "output_dir = Path.cwd() / 'results'\n", - "\n", - "os.makedirs(output_dir, exist_ok=True)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "fixed_image = itk.imread(str(patient_ct_path))\n", - "itk.imwrite(fixed_image, str(output_dir / 'patient_image.mha'), compression=True)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "if False:\n", - " segmentator = SegmentChestTotalSegmentator()\n", - " segmentator.contrast_threshold = 500\n", - " fixed_data = segmentator.segment(fixed_image, contrast_enhanced_study=False)\n", - " labelmap_image = fixed_data[\"labelmap\"]\n", - " lung_mask_image = fixed_data[\"lung\"]\n", - " heart_mask_image = fixed_data[\"heart\"]\n", - " major_vessels_mask_image = fixed_data[\"major_vessels\"]\n", - " bone_mask_image = fixed_data[\"bone\"]\n", - " soft_tissue_mask_image = fixed_data[\"soft_tissue\"]\n", - " other_mask_image = fixed_data[\"other\"]\n", - " contrast_mask_image = fixed_data[\"contrast\"]\n", - "\n", - "\n", - " itk.imwrite(labelmap_image, str(output_dir / 'fixed_labelmap.mha'), compression=True)\n", - "\n", - " heart_arr = itk.GetArrayFromImage(heart_mask_image)\n", - " #contrast_arr = itk.GetArrayFromImage(contrast_mask_image)\n", - " mask_arr = (heart_arr > 0).astype(np.uint8) #((heart_arr + contrast_arr) > 0).astype(np.uint8)\n", - " fixed_mask_image = itk.GetImageFromArray(mask_arr)\n", - " fixed_mask_image.CopyInformation(fixed_image)\n", - "\n", - " itk.imwrite(fixed_mask_image, str(output_dir / 'fixed_mask_draft.mha'), compression=True)\n", - "\n", - " # hand edit fixed_mask to make slice_007_heart_mask.nii.gz that is saved in patient_data_dir\n", - "else:\n", - " fixed_mask_image = itk.imread(str(patient_ct_heart_mask_path))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "flip0 = np.array(fixed_mask_image.GetDirection())[0,0] < 0\n", - "flip1 = np.array(fixed_mask_image.GetDirection())[1,1] < 0\n", - "flip2 = np.array(fixed_mask_image.GetDirection())[2,2] < 0\n", - "if flip0 or flip1 or flip2:\n", - " print(\"Flipping fixed image...\")\n", - " print(flip0, flip1, flip2)\n", - " flip_filter = itk.FlipImageFilter.New(Input=fixed_image)\n", - " flip_filter.SetFlipAxes([int(flip0), int(flip1), int(flip2)])\n", - " flip_filter.SetFlipAboutOrigin(True)\n", - " flip_filter.Update()\n", - " fixed_image = flip_filter.GetOutput()\n", - " id_mat = itk.Matrix[itk.D, 3, 3]()\n", - " id_mat.SetIdentity()\n", - " fixed_image.SetDirection(id_mat)\n", - " itk.imwrite(fixed_image, str(output_dir / 'fixed_image.mha'), compression=True)\n", - " print(\"Flipping fixed mask image...\")\n", - " flip_filter = itk.FlipImageFilter.New(Input=fixed_mask_image)\n", - " flip_filter.SetFlipAxes([int(flip0), int(flip1), int(flip2)])\n", - " flip_filter.SetFlipAboutOrigin(True)\n", - " flip_filter.Update()\n", - " fixed_mask_image = flip_filter.GetOutput()\n", - " fixed_mask_image.SetDirection(id_mat)\n", - " itk.imwrite(fixed_mask_image, str(output_dir / 'fixed_mask.mha'), compression=True)\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "\n", - "fixed_mesh = ContourTools().extract_contours(fixed_mask_image)\n", - "fixed_mesh.save(str(output_dir / 'fixed_mesh.vtp'))\n", - "fixed_mesh = pv.read(str(output_dir / 'fixed_mesh.vtp'))\n", - "\n", - "moving_original_mesh = pv.read(str(atlas_vtu_path))\n", - "moving_mesh = moving_original_mesh.extract_surface()\n", - "moving_mesh.save(str(output_dir / 'moving_mesh.vtp'))\n", - "moving_mesh = pv.read(str(output_dir / 'moving_mesh.vtp'))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "registrar = HeartModelToPatientWorkflow(\n", - " moving_mesh=moving_mesh,\n", - " fixed_image=fixed_image,\n", - " fixed_meshes=[fixed_mesh],\n", - ")\n", - "\n", - "registrar.set_masks(\n", - " moving_mask_image=None,\n", - " fixed_mask_image=fixed_mask_image,\n", - ")\n", - "\n", - "registrar.set_mask_dilation_mm(5)\n", - "registrar.set_roi_dilation_mm(25)\n", - "\n", - "registrar.set_pca_data_from_slicersalt(\n", - " json_filename=pca_json_path,\n", - " group_key=pca_group_key,\n", - " n_modes=pca_n_modes,\n", - ")\n", - "\n", - "fixed_image = registrar.fixed_image\n", - "itk.imwrite(fixed_image, str(output_dir / 'fixed_image.mha'), compression=True)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Rough alignment using ICP\n", - "reg_results = registrar.register_mesh_to_mesh_icp()\n", - "icp_phi_FM = reg_results['phi_FM']\n", - "icp_phi_MF = reg_results['phi_MF']\n", - "moving_icp_mesh = reg_results['moving_mesh']\n", - "moving_icp_mask_image = reg_results['moving_mask_image']\n", - "moving_icp_mask_roi_image = reg_results['moving_mask_roi_image']\n", - "\n", - "fixed_roi_image = registrar.fixed_mask_roi_image\n", - "moving_mask_image = registrar.moving_mask_image\n", - "moving_roi_image = registrar.moving_mask_roi_image\n", - "\n", - "# Save masks for inspection\n", - "itk.imwrite(moving_mask_image, str(output_dir / 'moving_mask.mha'), compression=True)\n", - "itk.imwrite(moving_roi_image, str(output_dir / 'moving_roi.mha'), compression=True)\n", - "itk.imwrite(fixed_roi_image, str(output_dir / 'fixed_roi.mha'), compression=True)\n", - "\n", - "print(\"New center =\", moving_icp_mesh.center)\n", - "print(\" Rough alignment using ICP completed.\")\n", - "itk.imwrite(moving_icp_mask_image, str(output_dir / \"moving_icp_mask.nii.gz\"), compression=True)\n", - "itk.imwrite(moving_icp_mask_roi_image, str(output_dir / \"moving_icp_mask_roi.nii.gz\"), compression=True)\n", - "moving_icp_mesh.save(str(output_dir / \"moving_icp_mesh.vtp\"))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "#import logging\n", - "#from physiomotion4d import PhysioMotion4DBase\n", - "#\n", - "#PhysioMotion4DBase.set_log_level(logging.DEBUG)\n", - "\n", - "reg_results = registrar.register_mesh_to_mesh_pca()\n", - "pca_rigid_transform = reg_results['pre_phi_FM']\n", - "pca_coefficients = reg_results['pca_coefficients_FM']\n", - "moving_pca_mesh = reg_results['moving_mesh']\n", - "moving_pca_mask_image = reg_results[\"moving_mask_image\"]\n", - "moving_pca_mask_roi_image = reg_results[\"moving_mask_roi_image\"]\n", - "\n", - "itk.imwrite(moving_pca_mask_image, str(output_dir / \"moving_pca_mask.nii.gz\"), compression=True)\n", - "itk.imwrite(moving_pca_mask_roi_image, str(output_dir / \"moving_pca_mask_roi.nii.gz\"), compression=True)\n", - "moving_pca_mesh.save(str(output_dir / \"moving_pca_mesh.vtp\"))\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Mask Alignment" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Perform deformable registration\n", - "print(\"Starting deformable mask-to-mask registration...\")\n", - "\n", - "reg_results = registrar.register_mask_to_mask()\n", - "m2m_phi_FM = reg_results['phi_FM']\n", - "m2m_phi_MF = reg_results['phi_MF']\n", - "moving_m2m_mesh = reg_results['moving_mesh']\n", - "moving_m2m_mask_image = reg_results['moving_mask_image']\n", - "moving_m2m_mask_roi_image = reg_results['moving_mask_roi_image']\n", - "\n", - "print(\"Registration complete!\")\n", - "\n", - "# Save registration results to output folder\n", - "itk.transformwrite([m2m_phi_FM], str(output_dir / \"m2m_phi_FM.hdf\"), compression=True)\n", - "itk.transformwrite([m2m_phi_MF], str(output_dir / \"m2m_phi_MF.hdf\"), compression=True)\n", - "itk.imwrite(moving_m2m_mask_image, str(output_dir / \"moving_m2m_mask.nii.gz\"), compression=True)\n", - "itk.imwrite(moving_m2m_mask_roi_image, str(output_dir / \"moving_m2m_mask_roi.nii.gz\"), compression=True)\n", - "moving_m2m_mesh.save(str(output_dir / \"moving_m2m_mesh.vtp\"))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Perform Icon-based deformable registration\n", - "# This is the most computationally intensive step (requires GPU)\n", - "print(\"Starting deformable registration...\")\n", - "print(\"This may take several minutes depending on GPU availability.\")\n", - "\n", - "reg_results = registrar.register_mask_to_image()\n", - "m2i_phi_FM = reg_results['phi_FM']\n", - "m2i_phi_MF = reg_results['phi_MF']\n", - "moving_m2i_mesh = reg_results['moving_mesh']\n", - "moving_m2i_mask_image = reg_results['moving_mask_image']\n", - "moving_m2i_mask_roi_image = reg_results['moving_mask_roi_image']\n", - "\n", - "print(\"\\nRegistration complete!\")\n", - "\n", - "# Save registration results to output folder\n", - "itk.transformwrite([m2i_phi_FM], str(output_dir / \"m2i_phi_FM.hdf\"), compression=True)\n", - "itk.transformwrite([m2i_phi_MF], str(output_dir / \"m2i_phi_MF.hdf\"), compression=True)\n", - "itk.imwrite(moving_m2i_mask_image, str(output_dir / \"moving_m2i_mask.nii.gz\"), compression=True)\n", - "itk.imwrite(moving_m2i_mask_roi_image, str(output_dir / \"moving_m2i_mask_roi.nii.gz\"), compression=True)\n", - "moving_m2i_mesh.save(str(output_dir / \"moving_m2i_mesh.vtp\"))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "moving_registered_surface_mesh = registrar.moving_mesh.copy(deep=True)\n", - "new_points = moving_registered_surface_mesh.points\n", - "for i in range(new_points.shape[0]):\n", - " p = itk.Point[itk.D, 3]()\n", - " new_p = itk.Point[itk.D, 3]()\n", - " p[0], p[1], p[2] = float(new_points[i, 0]), float(new_points[i, 1]), float(new_points[i, 2])\n", - " tmp_p = registrar.icp_phi_FM.TransformPoint(p)\n", - " new_p[0], new_p[1], new_p[2] = tmp_p[0], tmp_p[1], tmp_p[2]\n", - " tmp_p = registrar.registrar_pca.transform_point(new_p)\n", - " new_p[0], new_p[1], new_p[2] = tmp_p[0], tmp_p[1], tmp_p[2]\n", - " tmp_p = registrar.m2m_phi_FM.TransformPoint(new_p)\n", - " new_p[0], new_p[1], new_p[2] = tmp_p[0], tmp_p[1], tmp_p[2]\n", - " tmp_p = registrar.m2i_phi_FM.TransformPoint(new_p)\n", - " new_p[0], new_p[1], new_p[2] = tmp_p[0], tmp_p[1], tmp_p[2]\n", - " new_points[i, 0], new_points[i, 1], new_points[i, 2] = new_p[0], new_p[1], new_p[2]\n", - "\n", - "moving_registered_surface_mesh.points = new_points\n", - "moving_registered_surface_mesh.save(str(output_dir / \"moving_registered_surface_mesh.vtp\"))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "new_original_mesh = registrar.apply_transforms_to_original_mesh(include_m2i=True)\n", - "new_original_mesh.save(str(output_dir / \"moving_registered_original_mesh.vtu\"))" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Visualize Final Results" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Load meshes from registrar member variables\n", - "moving_mesh = registrar.moving_original_mesh\n", - "aligned_mesh = registrar.moving_icp_mesh\n", - "registered_mesh = registrar.moving_m2i_mesh\n", - "fixed_mesh = registrar.fixed_mesh\n", - "\n", - "# Create side-by-side comparison\n", - "plotter = pv.Plotter(shape=(1, 2))\n", - "\n", - "# After rough alignment\n", - "plotter.subplot(0, 0)\n", - "plotter.add_mesh(fixed_mesh, color='red', opacity=1.0, label='Patient')\n", - "plotter.add_mesh(aligned_mesh, color='green', opacity=0.6, label='After ICP')\n", - "plotter.add_title('Rough Alignment')\n", - "\n", - "# After deformable registration\n", - "plotter.subplot(0, 1)\n", - "plotter.add_mesh(fixed_mesh, color='red', opacity=0.6, label='Patient')\n", - "plotter.add_mesh(registered_mesh, color='blue', opacity=0.6, label='Registered')\n", - "plotter.add_title('Final Registration')\n", - "\n", - "plotter.link_views()\n", - "plotter.show()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Visualize Deformation Magnitude" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# The transformed mesh has deformation magnitude stored as point data\n", - "if 'DeformationMagnitude' in moving_m2i_mesh.point_data:\n", - " plotter = pv.Plotter()\n", - " plotter.add_mesh(\n", - " moving_m2i_mesh,\n", - " scalars='DeformationMagnitude',\n", - " cmap='jet',\n", - " show_scalar_bar=True,\n", - " scalar_bar_args={'title': 'Deformation (mm)'}\n", - " )\n", - " plotter.add_title('Deformation Magnitude')\n", - " plotter.show()\n", - "\n", - " # Print statistics\n", - " deformation = registered_mesh['DeformationMagnitude']\n", - " print(f\"Deformation statistics:\")\n", - " print(f\" Min: {deformation.min():.2f} mm\")\n", - " print(f\" Max: {deformation.max():.2f} mm\")\n", - " print(f\" Mean: {deformation.mean():.2f} mm\")\n", - " print(f\" Std: {deformation.std():.2f} mm\")\n", - "else:\n", - " print(\"DeformationMagnitude not found in mesh point data\")" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "venv", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.11" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} diff --git a/experiments/Heart-VTKSeries_To_USD/1-heart_vtkseries_to_usd.ipynb b/experiments/Heart-VTKSeries_To_USD/1-heart_vtkseries_to_usd.ipynb index 765d51f..cd96882 100644 --- a/experiments/Heart-VTKSeries_To_USD/1-heart_vtkseries_to_usd.ipynb +++ b/experiments/Heart-VTKSeries_To_USD/1-heart_vtkseries_to_usd.ipynb @@ -9,9 +9,10 @@ "source": [ "import glob\n", "import os\n", + "\n", "import pyvista as pv\n", "\n", - "from physiomotion4d.convert_vtk_4d_to_usd import ConvertVTK4DToUSD" + "from physiomotion4d.convert_vtk_4d_to_usd import ConvertVTK4DToUSD\n" ] }, { @@ -26,8 +27,11 @@ "):\n", " # Segment chest from CT images to generate vtk files\n", " import itk\n", + "\n", " from physiomotion4d.contour_tools import ContourTools\n", - " from physiomotion4d.segment_chest_total_segmentator import SegmentChestTotalSegmentator\n", + " from physiomotion4d.segment_chest_total_segmentator import (\n", + " SegmentChestTotalSegmentator,\n", + " )\n", "\n", " input_images = sorted(\n", " glob.glob(\n", diff --git a/experiments/Lung-GatedCT_To_USD/0-register_dirlab_4dct.ipynb b/experiments/Lung-GatedCT_To_USD/0-register_dirlab_4dct.ipynb index b1318a4..1e0cbac 100644 --- a/experiments/Lung-GatedCT_To_USD/0-register_dirlab_4dct.ipynb +++ b/experiments/Lung-GatedCT_To_USD/0-register_dirlab_4dct.ipynb @@ -8,12 +8,11 @@ "source": [ "import os\n", "\n", + "import numpy as np\n", "import itk\n", + "from data_dirlab_4d_ct import DataDirLab4DCT\n", "from itk import TubeTK as tube\n", "\n", - "import numpy as np\n", - "\n", - "from data_dirlab_4d_ct import DataDirLab4DCT\n", "from physiomotion4d.register_images_icon import RegisterImagesICON\n", "from physiomotion4d.segment_chest_total_segmentator import SegmentChestTotalSegmentator\n", "from physiomotion4d.transform_tools import TransformTools\n", @@ -77,13 +76,13 @@ " if fixed_mask is not None:\n", " fixed_mask_d = dilate_mask(fixed_mask, 5)\n", " moving_mask_d = dilate_mask(moving_mask, 5)\n", - " reg_images.set_fixed_image_mask(fixed_mask_d)\n", + " reg_images.set_fixed_mask(fixed_mask_d)\n", " results = reg_images.register(moving_image, moving_mask_d)\n", - " phi_FM = results[\"phi_FM\"]\n", - " phi_MF = results[\"phi_MF\"]\n", + " inverse_transform = results[\"inverse_transform\"]\n", + " forward_transform = results[\"forward_transform\"]\n", " print(\"Registering image...Done!\")\n", " moving_image_reg = TransformTools().transform_image(\n", - " moving_image, phi_MF, fixed_image, \"sinc\"\n", + " moving_image, forward_transform, fixed_image, \"sinc\"\n", " ) # Final resampling with sinc\n", " itk.imwrite(\n", " moving_image_reg,\n", @@ -92,14 +91,14 @@ " )\n", "\n", " itk.transformwrite(\n", - " [phi_FM],\n", - " f\"{output_dir}/{case_name}_T{image_num * 10:02d}_{mask_name}_phi_FM.hdf\",\n", + " [forward_transform],\n", + " f\"{output_dir}/{case_name}_T{image_num * 10:02d}_{mask_name}_forward.hdf\",\n", " compression=True,\n", " )\n", "\n", " itk.transformwrite(\n", - " [phi_MF],\n", - " f\"{output_dir}/{case_name}_T{image_num * 10:02d}_{mask_name}_phi_MF.hdf\",\n", + " [inverse_transform],\n", + " f\"{output_dir}/{case_name}_T{image_num * 10:02d}_{mask_name}_inverse.hdf\",\n", " compression=True,\n", " )" ] @@ -289,13 +288,13 @@ "\n", " itk.transformwrite(\n", " [composite_transform],\n", - " f\"{output_dir}/{case_name}_T{image_num * 10:02d}_{mask_name}_phi_FM.hdf\",\n", + " f\"{output_dir}/{case_name}_T{image_num * 10:02d}_{mask_name}_forward.hdf\",\n", " compression=True,\n", " )\n", "\n", " itk.transformwrite(\n", " [composite_transform],\n", - " f\"{output_dir}/{case_name}_T{image_num * 10:02d}_{mask_name}_phi_MF.hdf\",\n", + " f\"{output_dir}/{case_name}_T{image_num * 10:02d}_{mask_name}_inverse.hdf\",\n", " compression=True,\n", " )\n", "\n", diff --git a/experiments/Lung-GatedCT_To_USD/1-make_dirlab_models.ipynb b/experiments/Lung-GatedCT_To_USD/1-make_dirlab_models.ipynb index f7b71ab..0a4c401 100644 --- a/experiments/Lung-GatedCT_To_USD/1-make_dirlab_models.ipynb +++ b/experiments/Lung-GatedCT_To_USD/1-make_dirlab_models.ipynb @@ -10,11 +10,11 @@ "import itk\n", "import numpy as np\n", "import pyvista as pv\n", + "from data_dirlab_4d_ct import DataDirLab4DCT\n", "\n", "from physiomotion4d.contour_tools import ContourTools\n", "from physiomotion4d.convert_vtk_4d_to_usd import ConvertVTK4DToUSD\n", "from physiomotion4d.segment_chest_total_segmentator import SegmentChestTotalSegmentator\n", - "from data_dirlab_4d_ct import DataDirLab4DCT\n", "\n", "case_names = DataDirLab4DCT().case_names\n", "case_names = [case_names[0]]\n", @@ -43,12 +43,12 @@ " con_tools = ContourTools()\n", " new_contours = []\n", " for i in range(10):\n", - " phi_FM = itk.transformread(\n", - " f\"{output_dir}/{case_name}_T{i * 10:02d}_{mask_name}_phi_FM.hdf\"\n", + " inverse_transform = itk.transformread(\n", + " f\"{output_dir}/{case_name}_T{i * 10:02d}_{mask_name}_inverse.hdf\"\n", " )[0]\n", "\n", " print(f\"Transforming {case_name} - {mask_name} - T{i * 10:02d}\")\n", - " new_contours.append(con_tools.transform_contours(contours, phi_FM))\n", + " new_contours.append(con_tools.transform_contours(contours, inverse_transform))\n", "\n", " return new_contours" ] diff --git a/experiments/Lung-GatedCT_To_USD/2-paint_dirlab_models.ipynb b/experiments/Lung-GatedCT_To_USD/2-paint_dirlab_models.ipynb index f9d3de8..3692ff9 100644 --- a/experiments/Lung-GatedCT_To_USD/2-paint_dirlab_models.ipynb +++ b/experiments/Lung-GatedCT_To_USD/2-paint_dirlab_models.ipynb @@ -2,16 +2,17 @@ "cells": [ { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "id": "3ce61753-11ad-4ade-9afe-6ad1bc748e25", "metadata": {}, "outputs": [], "source": [ "from data_dirlab_4d_ct import DataDirLab4DCT\n", - "from physiomotion4d.usd_anatomy_tools import USDAnatomyTools\n", - "from physiomotion4d.segment_chest_total_segmentator import SegmentChestTotalSegmentator\n", "from pxr import Usd\n", "\n", + "from physiomotion4d.segment_chest_total_segmentator import SegmentChestTotalSegmentator\n", + "from physiomotion4d.usd_anatomy_tools import USDAnatomyTools\n", + "\n", "case_names = DataDirLab4DCT().case_names\n", "\n", "case_names = [case_names[0]]\n", @@ -21,7 +22,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "id": "3cc90c5c", "metadata": {}, "outputs": [], diff --git a/experiments/Lung-GatedCT_To_USD/Experiment_ArrangeOnStage.ipynb b/experiments/Lung-GatedCT_To_USD/Experiment_ArrangeOnStage.ipynb new file mode 100644 index 0000000..bf60f5d --- /dev/null +++ b/experiments/Lung-GatedCT_To_USD/Experiment_ArrangeOnStage.ipynb @@ -0,0 +1,52 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n\nfrom physiomotion4d.data_dirlab_4d_ct import DataDirLab4DCT\nfrom physiomotion4d.usd_tools import USDTools" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "os.makedirs(\"Results_ArrangeOnStage\", exist_ok=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "case_names = DataDirLab4DCT().get_case_names()\n\nusd_tools = USDTools()\n\nfor label in [\"all\", \"lung\"]:\n usd_file_names = [\n f\"results/{case_name}_{label}_lungGated_painted.usd\"\n for case_name in case_names\n ]\n new_stage = usd_tools.save_usd_file_arrangement(f\"stage-{label}.usd\", usd_file_names)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.11" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} \ No newline at end of file diff --git a/experiments/Lung-GatedCT_To_USD/Experiment_CombineModels.ipynb b/experiments/Lung-GatedCT_To_USD/Experiment_CombineModels.ipynb new file mode 100644 index 0000000..f3ecac3 --- /dev/null +++ b/experiments/Lung-GatedCT_To_USD/Experiment_CombineModels.ipynb @@ -0,0 +1,107 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import itk\n", + "\n", + "from physiomotion4d.transform_tools import TransformTools" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "os.makedirs(\"results_CombineModels\", exist_ok=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "all_mask_t00 = itk.imread(\"results/Case1Pack_T00_mask_org.mha\")\n", + "lung_mask_t30 = itk.imread(\"results/Case1Pack_T30_lung_mask_org.mha\")\n", + "other_mask_t30 = itk.imread(\"results/Case1Pack_T30_other_mask_org.mha\")\n", + "\n", + "img_tfm_lung_t00 = itk.transformread(\"results/Case1Pack_T00_dynamic_forward.hdf\")\n", + "img_tfm_other_t00 = itk.transformread(\"results/Case1Pack_T00_static_forward.hdf\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "TfmTools = TransformTools()\n", + "\n", + "lung_mask_t00 = TfmTools.transform_image(lung_mask_t30, img_tfm_lung_t00, all_mask_t00)\n", + "other_mask_t00 = TfmTools.transform_image(other_mask_t30, img_tfm_other_t00, all_mask_t00)\n", + "\n", + "itk.imwrite(lung_mask_t00, \"results_CombineModels/lung_mask_t00.mha\", compression=True)\n", + "itk.imwrite(other_mask_t00, \"results_CombineModels/other_mask_t00.mha\", compression=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "import pyvista as pv" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "lung_model_t30 = pv.read(\"TransformedModels/DirLab-4DCT/Case1Pack_T30_lung_lungGatedBase.vtp\")\n", + "other_model_t30 = pv.read(\"TransformedModels/DirLab-4DCT/Case1Pack_T30_other_lungGatedBase.vtp\")\n", + "\n", + "model_tfm_lung_t00 = itk.transformread(\n", + " \"TransformedImages/DirLab-4DCT/Case1Pack_T00_dynamic_inverse.hdf\"\n", + ")\n", + "model_tfm_other_t00 = itk.transformread(\n", + " \"TransformedImages/DirLab-4DCT/Case1Pack_T00_static_inverse.hdf\"\n", + ")\n", + "\n", + "lung_model_t00 = TfmTools.transform_pvcontour(lung_model_t30, model_tfm_lung_t00)\n", + "other_model_t00 = TfmTools.transform_pvcontour(other_model_t30, model_tfm_other_t00)\n", + "\n", + "lung_model_t00.save(\"Experiment_CombineModels/lung_model_t00.vtp\")\n", + "other_model_t00.save(\"Experiment_CombineModels/other_model_t00.vtp\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.11" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/experiments/Lung-GatedCT_To_USD/Experiment_SegReg.ipynb b/experiments/Lung-GatedCT_To_USD/Experiment_SegReg.ipynb new file mode 100644 index 0000000..b1ed29b --- /dev/null +++ b/experiments/Lung-GatedCT_To_USD/Experiment_SegReg.ipynb @@ -0,0 +1,70 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import itk\n\nimport os\n\nfrom physiomotion4d.data_dirlab_4d_ct import DataDirLab4DCT\nfrom physiomotion4d.register_images import RegisterImages\nfrom physiomotion4d.segment_chest_total_segmentator import SegmentChestTotalSegmentator\nfrom physiomotion4d.segment_chest_vista_3d import SegmentChestVista3D" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "fixed_image = DataDirLab4DCT().fix_image(itk.imread(\"../../data/DirLab-4DCT/Case1Pack_T30.mhd\"))\nmoving_image = DataDirLab4DCT().fix_image(itk.imread(\"../../data/DirLab-4DCT/Case1Pack_T00.mhd\"))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Register images\nreg_images = RegisterImages()\nreg_images.set_fixed_image(fixed_image)\nimg, phi_ab, phi_ba = reg_images.register(moving_image)\nos.makedirs(\"results_SegReg\", exist_ok=True)\nitk.imwrite(img, \"results_SegReg/Experiment_reg.mha\", compression=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "img = itk.imread(\"results_SegReg/Experiment_reg.mha\")\ntot_seg = SegmentChestTotalSegmentator()\nseg_image = tot_seg.segment(img, contrast_enhanced_study=False)\nitk.imwrite(seg_image[0], \"results_SegReg/Experiment_totseg.mha\", compression=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "# This section requires the Vista3D container to be running\n\nvista3d_running = False\nif vista3d_running:\n img = itk.imread(\"Experiment_SegReg/Experiment_reg.mha\")\n\n tot_seg = SegmentChestVista3D()\n\n seg_image = tot_seg.segment(img, contrast_enhanced_study=False)\n\n itk.imwrite(seg_image[0], \"Experiment_SegReg/Experiment_vista3d.mha\", compression=True)\n itk.imwrite(seg_image[1], \"Experiment_SegReg/Experiment_vista3d_lung.mha\", compression=True)\n itk.imwrite(seg_image[2], \"Experiment_SegReg/Experiment_vista3d_heart.mha\", compression=True)\n itk.imwrite(seg_image[3], \"Experiment_SegReg/Experiment_vista3d_bone.mha\", compression=True)\n itk.imwrite(\n seg_image[4], \"Experiment_SegReg/Experiment_vista3d_soft_tissue.mha\", compression=True\n )\n itk.imwrite(seg_image[5], \"Experiment_SegReg/Experiment_vista3d_other.mha\", compression=True)\n itk.imwrite(seg_image[6], \"Experiment_SegReg/Experiment_vista3d_contrast.mha\", compression=True)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.11" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} \ No newline at end of file diff --git a/experiments/Lung-GatedCT_To_USD/Experiment_SubSurfaceScatter.ipynb b/experiments/Lung-GatedCT_To_USD/Experiment_SubSurfaceScatter.ipynb new file mode 100644 index 0000000..dc315ff --- /dev/null +++ b/experiments/Lung-GatedCT_To_USD/Experiment_SubSurfaceScatter.ipynb @@ -0,0 +1,52 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from pxr import Usd, UsdGeom, Sdf, UsdShade, Gf\n\nimport omni\n\nimport os" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "os.makedirs(\"results_SubsurfaceScatter\", exist_ok=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Initialize stage and define material/shader paths\nstage = Usd.Stage.CreateNew(\"results_SubsurfaceScatter/Experiment_SubSurfaceScatter.usda\")\n\nscope = UsdGeom.Scope.Define(stage, \"/World\")\n\nsphere1 = UsdGeom.Sphere.Define(stage, \"/World/Sphere\")\nsphere1.CreateRadiusAttr().Set(10)\n\nmtl_path = Sdf.Path(\"/World/Looks/OmniSurface_Subsurface\")\nshader_path = mtl_path.AppendPath(\"Shader\")\nvista3d_running = False\n# Create material and shader prims\nmtl = UsdShade.Material.Define(stage, mtl_path)\nshader = UsdShade.Shader.Define(stage, shader_path)\n\n# Configure MDL source\nshader.CreateImplementationSourceAttr(UsdShade.Tokens.sourceAsset)\nshader.SetSourceAsset(\"OmniSurface.mdl\", \"mdl\") # Use subsurface-capable MDL\nshader.SetSourceAssetSubIdentifier(\"OmniSurface\", \"mdl\")\n\n# Enable and configure subsurface scattering\nshader.CreateInput(\"enable_diffuse_transmission\", Sdf.ValueTypeNames.Bool).Set(True)\nshader.CreateInput(\"subsurface_weight\", Sdf.ValueTypeNames.Float).Set(0.8) # Intensity (0-1)\nshader.CreateInput(\"subsurface_scattering_color\", Sdf.ValueTypeNames.Color3f).Set(\n (0.8, 0.2, 0.1)\n) # RGB\nshader.CreateInput(\"subsurface_scale\", Sdf.ValueTypeNames.Float).Set(1.5) # Scattering depth\n\n# Connect shader outputs to material\nmtl.CreateSurfaceOutput(\"mdl\").ConnectToSource(shader.ConnectableAPI(), \"out\")\nmtl.CreateDisplacementOutput(\"mdl\").ConnectToSource(shader.ConnectableAPI(), \"out\")\n# Get target prim and bind material\nbinding_api = UsdShade.MaterialBindingAPI.Apply(sphere1.GetPrim())\nbinding_api.Bind(mtl)\n\nstage.GetRootLayer().Save()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.11" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} \ No newline at end of file diff --git a/experiments/Lung-VesselsAirways/0-GenData.ipynb b/experiments/Lung-VesselsAirways/0-GenData.ipynb index 0f2a101..e764b8a 100644 --- a/experiments/Lung-VesselsAirways/0-GenData.ipynb +++ b/experiments/Lung-VesselsAirways/0-GenData.ipynb @@ -6,19 +6,7 @@ "metadata": {}, "outputs": [], "source": [ - "import itk\n", - "from itk import TubeTK as tube\n", - "\n", - "import monai\n", - "\n", - "import numpy as np\n", - "\n", - "from scipy.interpolate import UnivariateSpline\n", - "\n", - "import torch\n", - "from torch.utils.data import Dataset, DataLoader\n", - "\n", - "import matplotlib.pyplot as plt" + "import itk\nfrom itk import TubeTK as tube\n\nimport monai\n\nimport numpy as np\n\nfrom scipy.interpolate import UnivariateSpline\n\nimport torch\nfrom torch.utils.data import Dataset, DataLoader\n\nimport matplotlib.pyplot as plt" ] }, { @@ -27,9 +15,7 @@ "metadata": {}, "outputs": [], "source": [ - "# globals\n", - "patch_size = 16\n", - "patch_depth = 7" + "# globals\npatch_size = 16\npatch_depth = 7" ] }, { @@ -38,56 +24,7 @@ "metadata": {}, "outputs": [], "source": [ - "def add_tube_to_patch(patch, patch_size, tube_i_range, tube_r_range):\n", - " tube_start_x = np.random.uniform(0.1, 0.9) * patch_size[0]\n", - " tube_start_y = np.random.uniform(0.1, 0.9) * patch_size[1]\n", - " tube_start_r = np.random.uniform(tube_r_range[0], tube_r_range[1])\n", - " tube_mid_x = np.random.uniform(0.3, 0.7) * patch_size[0]\n", - " tube_mid_y = np.random.uniform(0.3, 0.7) * patch_size[1]\n", - " tube_mid_r = np.random.uniform(tube_r_range[0], tube_r_range[1])\n", - " tube_end_x = np.random.uniform(0.1, 0.9) * patch_size[0]\n", - " tube_end_y = np.random.uniform(0.1, 0.9) * patch_size[1]\n", - " tube_end_r = np.random.uniform(tube_r_range[0], tube_r_range[1])\n", - " tube_extent = np.random.uniform(0, 1)\n", - " tube_terminate = tube_extent < 0.1\n", - " if not tube_terminate:\n", - " tube_extent = 1.0\n", - " else:\n", - " tube_extent = 0.51 + 0.4 * np.random.uniform(0, 1)\n", - "\n", - " tube_i = (\n", - " np.random.uniform(tube_i_range[0], tube_i_range[1])\n", - " )\n", - "\n", - " if np.random.uniform(0, 1) < 0.5:\n", - " tube_ct = True\n", - " else:\n", - " tube_ct = False\n", - "\n", - " tube_x_spline = UnivariateSpline([0.0, 0.5, 1.0], [tube_start_x, tube_mid_x, tube_end_x], k=2)\n", - " tube_y_spline = UnivariateSpline([0.0, 0.5, 1.0], [tube_start_y, tube_mid_y, tube_end_y], k=2)\n", - " tube_r_spline = UnivariateSpline([0.0, 0.5, 1.0], [tube_start_r, tube_mid_r, tube_end_r], k=2)\n", - " for i,t in enumerate(np.linspace(0.0, 1.0*tube_extent, patch_size[2])):\n", - " tube_x = tube_x_spline(t)\n", - " tube_y = tube_y_spline(t)\n", - " tube_r = tube_r_spline(t)\n", - " if i >= patch_size[2] // 2 and i <= (patch_size[2] // 2 + 1):\n", - " tube_center_x = tube_x\n", - " tube_center_y = tube_y\n", - " tube_center_r = tube_r\n", - " tube_center_i = tube_i\n", - " z = np.clip(int(t * patch_size[2]), 0, patch_size[2] - 1)\n", - " for rx in range(int(-tube_r - 0.5), int(tube_r + 0.5)):\n", - " for ry in range(int(-tube_r - 0.5), int(tube_r + 0.5)):\n", - " if rx**2 + ry**2 < tube_r**2:\n", - " x = np.clip(int(tube_x + rx), 0, patch_size[0] - 1)\n", - " y = np.clip(int(tube_y + ry), 0, patch_size[1] - 1)\n", - " if tube_ct:\n", - " patch[z, y, x] = tube_i\n", - " else:\n", - " patch[z, y, x] = tube_i * (1 - (rx**2 + ry**2) / (1.5 * tube_r) ** 2)\n", - "\n", - " return tube_center_x, tube_center_y, tube_center_r, tube_center_i, tube_terminate" + "def add_tube_to_patch(patch, patch_size, tube_i_range, tube_r_range):\n tube_start_x = np.random.uniform(0.1, 0.9) * patch_size[0]\n tube_start_y = np.random.uniform(0.1, 0.9) * patch_size[1]\n tube_start_r = np.random.uniform(tube_r_range[0], tube_r_range[1])\n tube_mid_x = np.random.uniform(0.3, 0.7) * patch_size[0]\n tube_mid_y = np.random.uniform(0.3, 0.7) * patch_size[1]\n tube_mid_r = np.random.uniform(tube_r_range[0], tube_r_range[1])\n tube_end_x = np.random.uniform(0.1, 0.9) * patch_size[0]\n tube_end_y = np.random.uniform(0.1, 0.9) * patch_size[1]\n tube_end_r = np.random.uniform(tube_r_range[0], tube_r_range[1])\n tube_extent = np.random.uniform(0, 1)\n tube_terminate = tube_extent < 0.1\n if not tube_terminate:\n tube_extent = 1.0\n else:\n tube_extent = 0.51 + 0.4 * np.random.uniform(0, 1)\n\n tube_i = (\n np.random.uniform(tube_i_range[0], tube_i_range[1])\n )\n\n if np.random.uniform(0, 1) < 0.5:\n tube_ct = True\n else:\n tube_ct = False\n\n tube_x_spline = UnivariateSpline([0.0, 0.5, 1.0], [tube_start_x, tube_mid_x, tube_end_x], k=2)\n tube_y_spline = UnivariateSpline([0.0, 0.5, 1.0], [tube_start_y, tube_mid_y, tube_end_y], k=2)\n tube_r_spline = UnivariateSpline([0.0, 0.5, 1.0], [tube_start_r, tube_mid_r, tube_end_r], k=2)\n for i,t in enumerate(np.linspace(0.0, 1.0*tube_extent, patch_size[2])):\n tube_x = tube_x_spline(t)\n tube_y = tube_y_spline(t)\n tube_r = tube_r_spline(t)\n if i >= patch_size[2] // 2 and i <= (patch_size[2] // 2 + 1):\n tube_center_x = tube_x\n tube_center_y = tube_y\n tube_center_r = tube_r\n tube_center_i = tube_i\n z = np.clip(int(t * patch_size[2]), 0, patch_size[2] - 1)\n for rx in range(int(-tube_r - 0.5), int(tube_r + 0.5)):\n for ry in range(int(-tube_r - 0.5), int(tube_r + 0.5)):\n if rx**2 + ry**2 < tube_r**2:\n x = np.clip(int(tube_x + rx), 0, patch_size[0] - 1)\n y = np.clip(int(tube_y + ry), 0, patch_size[1] - 1)\n if tube_ct:\n patch[z, y, x] = tube_i\n else:\n patch[z, y, x] = tube_i * (1 - (rx**2 + ry**2) / (1.5 * tube_r) ** 2)\n\n return tube_center_x, tube_center_y, tube_center_r, tube_center_i, tube_terminate" ] }, { @@ -96,71 +33,7 @@ "metadata": {}, "outputs": [], "source": [ - "def add_noise_to_patch(\n", - " patch,\n", - " patch_size,\n", - " noise_point_mean,\n", - " noise_point_stddev,\n", - " noise_edge_intensity_range,\n", - " noise_edge_spread_range,\n", - " noise_slope_intensity_range,\n", - "):\n", - " noise_point_mean = np.random.uniform(0, noise_point_mean)\n", - " noise_point_stddev = np.random.uniform(0.01, noise_point_stddev)\n", - "\n", - " noise_edge_intensity = (\n", - " np.random.uniform(noise_edge_intensity_range[0], noise_edge_intensity_range[1])\n", - " )\n", - " noise_edge_spread = np.random.uniform(noise_edge_spread_range[0], noise_edge_spread_range[1])\n", - " noise_edge_position = np.empty(3)\n", - " noise_edge_position[0] = np.random.uniform(0, 1) * patch_size[0]\n", - " noise_edge_position[1] = np.random.uniform(0, 1) * patch_size[1]\n", - " noise_edge_position[2] = np.random.uniform(0, 1) * patch_size[2]\n", - " noise_edge_orientation = np.empty(3)\n", - " noise_edge_orientation[0] = np.random.uniform(-1, 1)\n", - " noise_edge_orientation[1] = np.random.uniform(-1, 1)\n", - " noise_edge_orientation[2] = np.random.uniform(-1, 1)\n", - " noise_edge_orientation = noise_edge_orientation / np.linalg.norm(noise_edge_orientation)\n", - "\n", - " noise_slope_intensity = (\n", - " np.random.uniform(noise_slope_intensity_range[0], noise_slope_intensity_range[1])\n", - " )\n", - " noise_slope_position = np.empty(3)\n", - " noise_slope_position[0] = np.random.uniform(0, 1) * patch_size[0]\n", - " noise_slope_position[1] = np.random.uniform(0, 1) * patch_size[1]\n", - " noise_slope_position[2] = np.random.uniform(0, 1) * patch_size[2]\n", - " noise_slope_orientation = np.empty(3)\n", - " noise_slope_orientation[0] = np.random.uniform(-1, 1)\n", - " noise_slope_orientation[1] = np.random.uniform(-1, 1)\n", - " noise_slope_orientation[2] = np.random.uniform(-1, 1)\n", - " noise_slope_orientation = noise_slope_orientation / np.linalg.norm(noise_slope_orientation)\n", - "\n", - " for x in range(0, patch_size[0]):\n", - " for y in range(0, patch_size[1]):\n", - " for z in range(0, patch_size[2]):\n", - " point_noise = np.random.normal(noise_point_mean, noise_point_stddev)\n", - " dist_from_edge = (\n", - " (x - noise_edge_position[0]) * noise_edge_orientation[0]\n", - " + (y - noise_edge_position[1]) * noise_edge_orientation[1]\n", - " + (z - noise_edge_position[2]) * noise_edge_orientation[2]\n", - " )\n", - " edge_noise = noise_edge_intensity / (\n", - " 1 + np.exp(-dist_from_edge / noise_edge_spread)\n", - " )\n", - " dist_along_slope = (\n", - " (x - noise_slope_position[0]) * noise_slope_orientation[0]\n", - " + (y - noise_slope_position[1]) * noise_slope_orientation[1]\n", - " + (z - noise_slope_position[2]) * noise_slope_orientation[2]\n", - " )\n", - " slope_noise = (\n", - " noise_slope_intensity\n", - " * dist_along_slope\n", - " / (patch_size[0] + patch_size[1] + patch_size[2])\n", - " )\n", - " patch[z, y, x] += point_noise + edge_noise + slope_noise\n", - "\n", - " patch_max = np.max(patch)\n", - " patch = patch / patch_max" + "def add_noise_to_patch(\n patch,\n patch_size,\n noise_point_mean,\n noise_point_stddev,\n noise_edge_intensity_range,\n noise_edge_spread_range,\n noise_slope_intensity_range,\n):\n noise_point_mean = np.random.uniform(0, noise_point_mean)\n noise_point_stddev = np.random.uniform(0.01, noise_point_stddev)\n\n noise_edge_intensity = (\n np.random.uniform(noise_edge_intensity_range[0], noise_edge_intensity_range[1])\n )\n noise_edge_spread = np.random.uniform(noise_edge_spread_range[0], noise_edge_spread_range[1])\n noise_edge_position = np.empty(3)\n noise_edge_position[0] = np.random.uniform(0, 1) * patch_size[0]\n noise_edge_position[1] = np.random.uniform(0, 1) * patch_size[1]\n noise_edge_position[2] = np.random.uniform(0, 1) * patch_size[2]\n noise_edge_orientation = np.empty(3)\n noise_edge_orientation[0] = np.random.uniform(-1, 1)\n noise_edge_orientation[1] = np.random.uniform(-1, 1)\n noise_edge_orientation[2] = np.random.uniform(-1, 1)\n noise_edge_orientation = noise_edge_orientation / np.linalg.norm(noise_edge_orientation)\n\n noise_slope_intensity = (\n np.random.uniform(noise_slope_intensity_range[0], noise_slope_intensity_range[1])\n )\n noise_slope_position = np.empty(3)\n noise_slope_position[0] = np.random.uniform(0, 1) * patch_size[0]\n noise_slope_position[1] = np.random.uniform(0, 1) * patch_size[1]\n noise_slope_position[2] = np.random.uniform(0, 1) * patch_size[2]\n noise_slope_orientation = np.empty(3)\n noise_slope_orientation[0] = np.random.uniform(-1, 1)\n noise_slope_orientation[1] = np.random.uniform(-1, 1)\n noise_slope_orientation[2] = np.random.uniform(-1, 1)\n noise_slope_orientation = noise_slope_orientation / np.linalg.norm(noise_slope_orientation)\n\n for x in range(0, patch_size[0]):\n for y in range(0, patch_size[1]):\n for z in range(0, patch_size[2]):\n point_noise = np.random.normal(noise_point_mean, noise_point_stddev)\n dist_from_edge = (\n (x - noise_edge_position[0]) * noise_edge_orientation[0]\n + (y - noise_edge_position[1]) * noise_edge_orientation[1]\n + (z - noise_edge_position[2]) * noise_edge_orientation[2]\n )\n edge_noise = noise_edge_intensity / (\n 1 + np.exp(-dist_from_edge / noise_edge_spread)\n )\n dist_along_slope = (\n (x - noise_slope_position[0]) * noise_slope_orientation[0]\n + (y - noise_slope_position[1]) * noise_slope_orientation[1]\n + (z - noise_slope_position[2]) * noise_slope_orientation[2]\n )\n slope_noise = (\n noise_slope_intensity\n * dist_along_slope\n / (patch_size[0] + patch_size[1] + patch_size[2])\n )\n patch[z, y, x] += point_noise + edge_noise + slope_noise\n\n patch_max = np.max(patch)\n patch = patch / patch_max" ] }, { @@ -169,39 +42,7 @@ "metadata": {}, "outputs": [], "source": [ - "def gen_training_patch(p_size, p_depth):\n", - " patch_size = [p_size, p_size, p_depth]\n", - " tube_intensity_range = [0.2, 0.9]\n", - " tube_r_range = [p_size * 0.1, p_size * 0.4]\n", - " noise_point_mean = 0.5\n", - " noise_point_stddev = 0.4\n", - " noise_edge_intensity_range = [0.0, 0.25]\n", - " noise_edge_spread_range = [0.1, 4]\n", - " noise_slope_intensity_range = [0.0, 0.25]\n", - "\n", - " patch = np.zeros(patch_size[::-1], dtype=np.float32)\n", - "\n", - " tube_branch = np.random.uniform(0, 1) > 0.95\n", - " if tube_branch:\n", - " add_tube_to_patch(patch, patch_size, tube_intensity_range, tube_r_range)\n", - " patch *= 0.25\n", - "\n", - " tube_center = [0, 0]\n", - " tube_center[0], tube_center[1], tube_radius, tube_intensity, tube_terminate = add_tube_to_patch(\n", - " patch, patch_size, tube_intensity_range, tube_r_range\n", - " )\n", - "\n", - " add_noise_to_patch(\n", - " patch,\n", - " patch_size,\n", - " noise_point_mean,\n", - " noise_point_stddev,\n", - " noise_edge_intensity_range,\n", - " noise_edge_spread_range,\n", - " noise_slope_intensity_range,\n", - " )\n", - "\n", - " return patch, tube_center, tube_terminate, tube_radius, tube_intensity, tube_branch" + "def gen_training_patch(p_size, p_depth):\n patch_size = [p_size, p_size, p_depth]\n tube_intensity_range = [0.2, 0.9]\n tube_r_range = [p_size * 0.1, p_size * 0.4]\n noise_point_mean = 0.5\n noise_point_stddev = 0.4\n noise_edge_intensity_range = [0.0, 0.25]\n noise_edge_spread_range = [0.1, 4]\n noise_slope_intensity_range = [0.0, 0.25]\n\n patch = np.zeros(patch_size[::-1], dtype=np.float32)\n\n tube_branch = np.random.uniform(0, 1) > 0.95\n if tube_branch:\n add_tube_to_patch(patch, patch_size, tube_intensity_range, tube_r_range)\n patch *= 0.25\n\n tube_center = [0, 0]\n tube_center[0], tube_center[1], tube_radius, tube_intensity, tube_terminate = add_tube_to_patch(\n patch, patch_size, tube_intensity_range, tube_r_range\n )\n\n add_noise_to_patch(\n patch,\n patch_size,\n noise_point_mean,\n noise_point_stddev,\n noise_edge_intensity_range,\n noise_edge_spread_range,\n noise_slope_intensity_range,\n )\n\n return patch, tube_center, tube_terminate, tube_radius, tube_intensity, tube_branch" ] }, { @@ -221,31 +62,7 @@ } ], "source": [ - "plt.figure(0, [10, 40])\n", - "num_patches = 20\n", - "for i in range(num_patches):\n", - " data = gen_training_patch(patch_size, patch_depth)\n", - " for s in range(patch_depth):\n", - " plt.subplot(num_patches, patch_size, i * patch_size + s + 1)\n", - " tmp = data[0][s, :, :]\n", - " tmp[0, 0] = 0\n", - " tmp[1, 1] = 1\n", - " plt.imshow(data[0][s, :, :], cmap='gray')\n", - " if s == 0:\n", - " plt.title(\n", - " \"Center: \"\n", - " + str(data[1])\n", - " + \"\\n Terminate: \"\n", - " + str(data[2])\n", - " + \"\\n Radius: \"\n", - " + str(data[3])\n", - " + \"\\n Intensity: \"\n", - " + str(data[4])\n", - " + \"\\n Branch: \"\n", - " + str(data[5])\n", - " )\n", - " plt.axis('off')\n", - "plt.show()" + "plt.figure(0, [10, 40])\nnum_patches = 20\nfor i in range(num_patches):\n data = gen_training_patch(patch_size, patch_depth)\n for s in range(patch_depth):\n plt.subplot(num_patches, patch_size, i * patch_size + s + 1)\n tmp = data[0][s, :, :]\n tmp[0, 0] = 0\n tmp[1, 1] = 1\n plt.imshow(data[0][s, :, :], cmap='gray')\n if s == 0:\n plt.title(\n \"Center: \"\n + str(data[1])\n + \"\\n Terminate: \"\n + str(data[2])\n + \"\\n Radius: \"\n + str(data[3])\n + \"\\n Intensity: \"\n + str(data[4])\n + \"\\n Branch: \"\n + str(data[5])\n )\n plt.axis('off')\nplt.show()" ] }, { @@ -254,60 +71,7 @@ "metadata": {}, "outputs": [], "source": [ - "def get_training_data(patch_size, patch_depth):\n", - " data = gen_training_patch(patch_size, patch_depth)\n", - "\n", - " patch = data[0]\n", - " center = data[1]\n", - " terminate = data[2]\n", - " radius = data[3]\n", - " intensity = data[4]\n", - "\n", - " # patch = np.expand_dims(patch, axis=0) # uncomment to make a 3D patch (rather than multi-channel 2D)\n", - " patch = patch.astype(np.float32)\n", - " patch = torch.from_numpy(patch)\n", - "\n", - " center = np.array(center)\n", - " center = center.astype(np.float32)\n", - " center = torch.from_numpy(center)\n", - "\n", - " terminate = np.array([terminate])\n", - " terminate = terminate.astype(np.float32)\n", - " terminate = torch.from_numpy(terminate)\n", - "\n", - " radius = np.array([radius])\n", - " radius = radius.astype(np.float32)\n", - " radius = torch.from_numpy(radius)\n", - "\n", - " intensity = np.array([intensity])\n", - " intensity = intensity.astype(np.float32)\n", - " intensity = torch.from_numpy(intensity)\n", - "\n", - " return patch, center, terminate, radius, intensity\n", - "\n", - "\n", - "class TubeDataset(Dataset):\n", - " def __init__(self, patch_size, patch_depth):\n", - " self.patch_size = patch_size\n", - " self.patch_depth = patch_depth\n", - "\n", - " def __len__(self):\n", - " return 1000\n", - "\n", - " def __getitem__(self, _):\n", - " return get_training_data(self.patch_size, self.patch_depth)\n", - "\n", - "\n", - "dataset = TubeDataset(patch_size, patch_depth)\n", - "dataloader = DataLoader(dataset, batch_size=4, shuffle=True)\n", - "# Check if GPU is available and move the model to GPU\n", - "\n", - "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", - "model = monai.networks.nets.resnet18(spatial_dims=2, n_input_channels=patch_depth, num_classes=5).to(device)\n", - "\n", - "# Define the loss function and optimizer\n", - "criterion = torch.nn.MSELoss()\n", - "optimizer = torch.optim.Adam(model.parameters(), 1e-5)\n" + "def get_training_data(patch_size, patch_depth):\n data = gen_training_patch(patch_size, patch_depth)\n\n patch = data[0]\n center = data[1]\n terminate = data[2]\n radius = data[3]\n intensity = data[4]\n\n # patch = np.expand_dims(patch, axis=0) # uncomment to make a 3D patch (rather than multi-channel 2D)\n patch = patch.astype(np.float32)\n patch = torch.from_numpy(patch)\n\n center = np.array(center)\n center = center.astype(np.float32)\n center = torch.from_numpy(center)\n\n terminate = np.array([terminate])\n terminate = terminate.astype(np.float32)\n terminate = torch.from_numpy(terminate)\n\n radius = np.array([radius])\n radius = radius.astype(np.float32)\n radius = torch.from_numpy(radius)\n\n intensity = np.array([intensity])\n intensity = intensity.astype(np.float32)\n intensity = torch.from_numpy(intensity)\n\n return patch, center, terminate, radius, intensity\n\n\nclass TubeDataset(Dataset):\n def __init__(self, patch_size, patch_depth):\n self.patch_size = patch_size\n self.patch_depth = patch_depth\n\n def __len__(self):\n return 1000\n\n def __getitem__(self, _):\n return get_training_data(self.patch_size, self.patch_depth)\n\n\ndataset = TubeDataset(patch_size, patch_depth)\ndataloader = DataLoader(dataset, batch_size=4, shuffle=True)\n# Check if GPU is available and move the model to GPU\n\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\nmodel = monai.networks.nets.resnet18(spatial_dims=2, n_input_channels=patch_depth, num_classes=5).to(device)\n\n# Define the loss function and optimizer\ncriterion = torch.nn.MSELoss()\noptimizer = torch.optim.Adam(model.parameters(), 1e-5)\n" ] }, { @@ -316,30 +80,7 @@ "metadata": {}, "outputs": [], "source": [ - "\n", - "if False:\n", - " # Training loop with GPU support\n", - " num_epochs = 100\n", - " for epoch in range(num_epochs):\n", - " model.train()\n", - " running_loss = 0.0\n", - " for patches, centers, terminates, radii, intensities in dataloader:\n", - " patches, centers, terminates, radii, intensities = (\n", - " patches.to(device),\n", - " centers.to(device),\n", - " terminates.to(device),\n", - " radii.to(device),\n", - " intensities.to(device),\n", - " )\n", - " optimizer.zero_grad()\n", - " outputs = model(patches)\n", - " targets = torch.cat((centers, terminates, radii, intensities), dim=1)\n", - " loss = criterion(outputs, targets)\n", - " loss.backward()\n", - " optimizer.step()\n", - " running_loss += loss.item()\n", - " print(f\"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss/len(dataloader)}\")\n", - " torch.save(model, 'vessel_seg_resnet18_2025.03.04.pth')" + "\nif False:\n # Training loop with GPU support\n num_epochs = 100\n for epoch in range(num_epochs):\n model.train()\n running_loss = 0.0\n for patches, centers, terminates, radii, intensities in dataloader:\n patches, centers, terminates, radii, intensities = (\n patches.to(device),\n centers.to(device),\n terminates.to(device),\n radii.to(device),\n intensities.to(device),\n )\n optimizer.zero_grad()\n outputs = model(patches)\n targets = torch.cat((centers, terminates, radii, intensities), dim=1)\n loss = criterion(outputs, targets)\n loss.backward()\n optimizer.step()\n running_loss += loss.item()\n print(f\"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss/len(dataloader)}\")\n torch.save(model, 'vessel_seg_resnet18_2025.03.04.pth')" ] }, { @@ -364,79 +105,7 @@ "metadata": {}, "outputs": [], "source": [ - "def extract_patch(image, x, r, t, patch_size=16, patch_depth=7):\n", - " n1 = np.array([t[1], -t[0], t[2]])\n", - " if n1[0] == 0 and n1[1] == 0:\n", - " n1 = np.array([t[2], -t[0], t[1]])\n", - " n1 = n1 / np.linalg.norm(n1)\n", - " n2 = np.cross(t, n1)\n", - " n2 = n2 / np.linalg.norm(n2)\n", - "\n", - " image_spacing = image.GetSpacing()[0]\n", - "\n", - " patch_r = 3\n", - " patch_spacing = r / (patch_r * image_spacing)\n", - " if patch_spacing < image_spacing:\n", - " patch_spacing = image_spacing\n", - "\n", - " print(\"patch_r =\", patch_r)\n", - " print(\"image_spacing =\", image_spacing)\n", - " print(\"patch_spacing =\", patch_spacing)\n", - "\n", - " interp = itk.itkBSplineInterpolateImageFunctionPython(image)\n", - " patch = np.empty((patch_depth, patch_size, patch_size)).astype(np.float32)\n", - " for iz in range(patch_depth):\n", - " pz = (iz - patch_depth / 2) * patch_spacing\n", - " for iy in range(patch_size):\n", - " py = (iy - patch_size / 2) * patch_spacing\n", - " for ix in range(patch_size):\n", - " px = (ix - patch_size / 2) * patch_spacing\n", - " pnt = x + np.multiply(px, n1) + np.multiply(py, n2) + np.multiply(pz, t)\n", - " cindx = image.GetIndexFromPhysicalSpace(pnt)\n", - " v = interp.GetPixel(cindx)\n", - " patch[iz, iy, ix] = v\n", - "\n", - " return patch, patch_spacing, n1, n2\n", - "\n", - " # t_step = patch_spacing * patch_size / 4\n", - " # new_x = x + t*t_step\n", - " # new_t = (new_x - x) / np.linalg.norm(new_x - x)\n", - "\n", - "\n", - "def step(image, model, x, r, t):\n", - " patch_size = 16\n", - " patch_depth = 7\n", - " patch, patch_spacing, n1, n2 = extract_patch(image, x, r, t, patch_size, patch_depth)\n", - " patch_img = itk.GetImageFromArray(patch)\n", - " itk.imwrite(patch_img, f\"patch{x[0]}.mha\")\n", - "\n", - " patch = np.expand_dims(patch, axis=0)\n", - " patch = torch.from_numpy(patch).to('cuda:0')\n", - "\n", - " model.eval()\n", - " with torch.no_grad():\n", - " output = model(patch)\n", - "\n", - " t_step = patch_spacing\n", - " n1_step = (output[0][0].item() - 0.5) * patch_size * patch_spacing\n", - " n2_step = (output[0][1].item() - 0.5) * patch_size * patch_spacing\n", - "\n", - " x_step = t * t_step + n1 * n1_step + n2 * n2_step\n", - " print(\"x_step =\", x_step)\n", - " new_x = x + x_step\n", - "\n", - " term = output[0][2].item() > 0.5\n", - "\n", - " new_r = output[0][3].item() * patch_size * patch_spacing\n", - " print(\"new_r =\", output)\n", - " new_r = (r + new_r) / 2\n", - "\n", - " new_t = (t + x_step / np.linalg.norm(x_step)) / 2\n", - " new_t = new_t / np.linalg.norm(new_t)\n", - "\n", - " prob = output[0][4].item()\n", - "\n", - " return new_x, new_r, new_t, term, prob, patch.to('cpu').detach().numpy()" + "def extract_patch(image, x, r, t, patch_size=16, patch_depth=7):\n n1 = np.array([t[1], -t[0], t[2]])\n if n1[0] == 0 and n1[1] == 0:\n n1 = np.array([t[2], -t[0], t[1]])\n n1 = n1 / np.linalg.norm(n1)\n n2 = np.cross(t, n1)\n n2 = n2 / np.linalg.norm(n2)\n\n image_spacing = image.GetSpacing()[0]\n\n patch_r = 3\n patch_spacing = r / (patch_r * image_spacing)\n if patch_spacing < image_spacing:\n patch_spacing = image_spacing\n\n print(\"patch_r =\", patch_r)\n print(\"image_spacing =\", image_spacing)\n print(\"patch_spacing =\", patch_spacing)\n\n interp = itk.itkBSplineInterpolateImageFunctionPython(image)\n patch = np.empty((patch_depth, patch_size, patch_size)).astype(np.float32)\n for iz in range(patch_depth):\n pz = (iz - patch_depth / 2) * patch_spacing\n for iy in range(patch_size):\n py = (iy - patch_size / 2) * patch_spacing\n for ix in range(patch_size):\n px = (ix - patch_size / 2) * patch_spacing\n pnt = x + np.multiply(px, n1) + np.multiply(py, n2) + np.multiply(pz, t)\n cindx = image.GetIndexFromPhysicalSpace(pnt)\n v = interp.GetPixel(cindx)\n patch[iz, iy, ix] = v\n\n return patch, patch_spacing, n1, n2\n\n # t_step = patch_spacing * patch_size / 4\n # new_x = x + t*t_step\n # new_t = (new_x - x) / np.linalg.norm(new_x - x)\n\n\ndef step(image, model, x, r, t):\n patch_size = 16\n patch_depth = 7\n patch, patch_spacing, n1, n2 = extract_patch(image, x, r, t, patch_size, patch_depth)\n patch_img = itk.GetImageFromArray(patch)\n itk.imwrite(patch_img, f\"patch{x[0]}.mha\")\n\n patch = np.expand_dims(patch, axis=0)\n patch = torch.from_numpy(patch).to('cuda:0')\n\n model.eval()\n with torch.no_grad():\n output = model(patch)\n\n t_step = patch_spacing\n n1_step = (output[0][0].item() - 0.5) * patch_size * patch_spacing\n n2_step = (output[0][1].item() - 0.5) * patch_size * patch_spacing\n\n x_step = t * t_step + n1 * n1_step + n2 * n2_step\n print(\"x_step =\", x_step)\n new_x = x + x_step\n\n term = output[0][2].item() > 0.5\n\n new_r = output[0][3].item() * patch_size * patch_spacing\n print(\"new_r =\", output)\n new_r = (r + new_r) / 2\n\n new_t = (t + x_step / np.linalg.norm(x_step)) / 2\n new_t = new_t / np.linalg.norm(new_t)\n\n prob = output[0][4].item()\n\n return new_x, new_r, new_t, term, prob, patch.to('cpu').detach().numpy()" ] }, { @@ -445,13 +114,7 @@ "metadata": {}, "outputs": [], "source": [ - "img = itk.imread('Branch.n020.mha')\n", - "\n", - "img_arr = itk.GetArrayFromImage(img).astype(np.float32)\n", - "img_arr /= np.max(img_arr)\n", - "\n", - "img_norm = itk.GetImageFromArray(img_arr)\n", - "img_norm.CopyInformation(img)" + "img = itk.imread('Branch.n020.mha')\n\nimg_arr = itk.GetArrayFromImage(img).astype(np.float32)\nimg_arr /= np.max(img_arr)\n\nimg_norm = itk.GetImageFromArray(img_arr)\nimg_norm.CopyInformation(img)" ] }, { @@ -503,8 +166,7 @@ } ], "source": [ - "plt.imshow(patch[0][2, :, :], cmap='gray')\n", - "plt.show()" + "plt.imshow(patch[0][2, :, :], cmap='gray')\nplt.show()" ] }, { @@ -548,23 +210,7 @@ } ], "source": [ - "old_x = np.array([48, 50, 50])\n", - "old_t = np.array([1, 0, 0])\n", - "old_r = 0.5\n", - "\n", - "patch_list = []\n", - "for i in range(20):\n", - " new_x, new_r, new_t, term, prob, patch = step(img_norm, model, old_x, old_r, old_t)\n", - " print(new_x, new_r, new_t)\n", - " old_x = new_x\n", - " old_r = new_r\n", - " old_t = new_t\n", - " patch_list.append(patch)\n", - "\n", - "for i in range(len(patch_list)):\n", - " plt.figure(i)\n", - " plt.imshow(patch_list[i][0, 1, :, :], cmap='gray')\n", - " plt.axis('off')" + "old_x = np.array([48, 50, 50])\nold_t = np.array([1, 0, 0])\nold_r = 0.5\n\npatch_list = []\nfor i in range(20):\n new_x, new_r, new_t, term, prob, patch = step(img_norm, model, old_x, old_r, old_t)\n print(new_x, new_r, new_t)\n old_x = new_x\n old_r = new_r\n old_t = new_t\n patch_list.append(patch)\n\nfor i in range(len(patch_list)):\n plt.figure(i)\n plt.imshow(patch_list[i][0, 1, :, :], cmap='gray')\n plt.axis('off')" ] } ], @@ -589,4 +235,4 @@ }, "nbformat": 4, "nbformat_minor": 2 -} +} \ No newline at end of file diff --git a/experiments/Reconstruct4DCT/reconstruct_4d_ct.ipynb b/experiments/Reconstruct4DCT/reconstruct_4d_ct.ipynb index 758ebd2..0eede84 100644 --- a/experiments/Reconstruct4DCT/reconstruct_4d_ct.ipynb +++ b/experiments/Reconstruct4DCT/reconstruct_4d_ct.ipynb @@ -12,7 +12,7 @@ "import itk\n", "import numpy as np\n", "\n", - "from physiomotion4d import RegisterImagesANTs, RegisterImagesICON, TransformTools" + "from physiomotion4d import RegisterImagesANTs, TransformTools\n" ] }, { @@ -87,11 +87,11 @@ " tfm_tools = TransformTools()\n", "\n", " img = images[reference_image_num]\n", - " phi_MF = None\n", - " phi_FM = None\n", + " forward_transform = None\n", + " inverse_transform = None\n", " results = None\n", " reg_image = None\n", - " prior_phi_MF = None\n", + " prior_forward_transform = None\n", "\n", " reference_image = images[reference_image_num]\n", " reference_image_indx = files_indx[reference_image_num]\n", @@ -104,21 +104,21 @@ "\n", " if reference_image_reg_use_identity:\n", " print(f\"Registering reference slice {reference_image_indx} using identify transform\")\n", - " phi_MF = identity_tfm\n", - " phi_FM = identity_tfm\n", + " forward_transform = identity_tfm\n", + " inverse_transform = identity_tfm\n", " if portion_of_prior_to_use > 0.0:\n", - " prior_phi_MF = identity_tfm\n", + " prior_forward_transform = identity_tfm\n", " reg_image = img\n", " reg_image_inv = fixed_image\n", " else:\n", " print(f\"Registering reference slice {reference_image_indx} to reference image.\")\n", " results = reg_tool.register(img)\n", - " phi_MF = results[\"phi_MF\"]\n", - " phi_FM = results[\"phi_FM\"]\n", + " forward_transform = results[\"forward_transform\"]\n", + " inverse_transform = results[\"inverse_transform\"]\n", " if portion_of_prior_to_use > 0.0:\n", - " prior_phi_MF = tfm_tools.combine_displacement_field_transforms(\n", + " prior_forward_transform = tfm_tools.combine_displacement_field_transforms(\n", " identity_tfm,\n", - " phi_MF,\n", + " forward_transform,\n", " reference_image,\n", " tfm1_weight=1.0,\n", " tfm2_weight=portion_of_prior_to_use,\n", @@ -126,47 +126,47 @@ " tfm2_blur_sigma=0.5,\n", " mode=\"add\"\n", " )\n", - " reg_image = tfm_tools.transform_image(img, phi_MF, fixed_image)\n", - " reg_image_inv = tfm_tools.transform_image(fixed_image, phi_FM, img)\n", + " reg_image = tfm_tools.transform_image(img, forward_transform, fixed_image)\n", + " reg_image_inv = tfm_tools.transform_image(fixed_image, inverse_transform, img)\n", "\n", " num_images = len(images)\n", "\n", - " phi_MF_arr = [itk.Transform[itk.D, 3].New() for _ in range(num_images)]\n", - " phi_FM_arr = [itk.Transform[itk.D, 3].New() for _ in range(num_images)]\n", - " phi_MF_arr[reference_image_num] = phi_MF\n", - " phi_FM_arr[reference_image_num] = phi_FM\n", + " forward_transform_arr = [itk.Transform[itk.D, 3].New() for _ in range(num_images)]\n", + " inverse_transform_arr = [itk.Transform[itk.D, 3].New() for _ in range(num_images)]\n", + " forward_transform_arr[reference_image_num] = forward_transform\n", + " inverse_transform_arr[reference_image_num] = inverse_transform\n", "\n", " debug_mode = True\n", "\n", " if debug_mode:\n", " out_file = os.path.join(\n", - " \"results\", f\"slice_{reg_tool_name}_MF_{reference_image_indx:03d}.mha\"\n", + " \"results\", f\"slice_{reg_tool_name}_forward_{reference_image_indx:03d}.mha\"\n", " )\n", " itk.imwrite(reg_image, out_file, compression=True)\n", "\n", " out_file = os.path.join(\n", - " \"results\", f\"slice_fixed_{reg_tool_name}_FM_{reference_image_indx:03d}.mha\"\n", + " \"results\", f\"slice_fixed_{reg_tool_name}_inverse_{reference_image_indx:03d}.mha\"\n", " )\n", " itk.imwrite(reg_image_inv, out_file, compression=True)\n", "\n", " itk.transformwrite(\n", - " phi_MF,\n", + " forward_transform,\n", " os.path.join(\n", " \"results\",\n", - " f\"slice_{reg_tool_name}_MF_{reference_image_indx:03d}.hdf\"\n", + " f\"slice_{reg_tool_name}_forward_{reference_image_indx:03d}.hdf\"\n", " ),\n", " compression=True\n", " )\n", " itk.transformwrite(\n", - " phi_FM,\n", + " inverse_transform,\n", " os.path.join(\n", " \"results\",\n", - " f\"slice_{reg_tool_name}_FM_{reference_image_indx:03d}.hdf\"\n", + " f\"slice_{reg_tool_name}_inverse_{reference_image_indx:03d}.hdf\"\n", " ),\n", " compression=True\n", " )\n", "\n", - " prior_phi_MF_ref = prior_phi_MF\n", + " prior_forward_transform_ref = prior_forward_transform\n", "\n", " for step_i in [1, -1]:\n", " start_i = 0\n", @@ -178,7 +178,7 @@ " start_i = reference_image_num+1\n", " end_i = num_files\n", "\n", - " prior_phi_MF = prior_phi_MF_ref\n", + " prior_forward_transform = prior_forward_transform_ref\n", "\n", " print(f\"registering: from {files_indx[start_i]} to {files_indx[end_i-step_i]} step {step_i}\")\n", " for img_indx in range(start_i, end_i, step_i):\n", @@ -190,10 +190,10 @@ " print(\" Trying init with identity.\")\n", " results_init_identity = reg_tool.register(\n", " img,\n", - " initial_phi_MF=None\n", + " initial_forward_transform=None\n", " )\n", - " phi_FM_init_identity = results_init_identity[\"phi_FM\"]\n", - " phi_MF_init_identity = results_init_identity[\"phi_MF\"]\n", + " inverse_tranform_init_identity = results_init_identity[\"inverse_transform\"]\n", + " forward_transform_init_identity = results_init_identity[\"forward_transform\"]\n", " loss_init_identity = results_init_identity[\"loss\"]\n", " print(\" Final loss:\", results_init_identity[\"loss\"])\n", "\n", @@ -202,26 +202,26 @@ " print(\" Trying with init prior.\")\n", " results_init_prior = reg_tool.register(\n", " img,\n", - " initial_phi_MF=prior_phi_MF\n", + " initial_forward_transform=prior_forward_transform\n", " )\n", - " phi_FM_init_prior = results_init_prior[\"phi_FM\"]\n", - " phi_MF_init_prior = results_init_prior[\"phi_MF\"]\n", + " inverse_transform_init_prior = results_init_prior[\"inverse_transform\"]\n", + " forward_transform_init_prior = results_init_prior[\"forward_transform\"]\n", " loss_init_prior = results_init_prior[\"loss\"]\n", " print(\" Final loss:\", results_init_prior[\"loss\"])\n", "\n", " if loss_init_identity < loss_init_prior:\n", " print(\" Using identity.\")\n", - " prior_phi_MF = identity_tfm\n", - " phi_FM = phi_FM_init_identity\n", - " phi_MF = phi_MF_init_identity\n", + " prior_forward_transform = identity_tfm\n", + " inverse_transform = inverse_tranform_init_identity\n", + " forward_transform = forward_transform_init_identity\n", " else:\n", " print(\" Using prior.\")\n", - " phi_FM = phi_FM_init_prior\n", - " phi_MF = phi_MF_init_prior\n", - " \n", - " prior_phi_MF = tfm_tools.combine_displacement_field_transforms(\n", + " inverse_transform = inverse_transform_init_prior\n", + " forward_transform = forward_transform_init_prior\n", + "\n", + " prior_forward_transform = tfm_tools.combine_displacement_field_transforms(\n", " identity_tfm,\n", - " phi_MF,\n", + " forward_transform,\n", " reference_image,\n", " tfm1_weight=1.0,\n", " tfm2_weight=portion_of_prior_to_use,\n", @@ -230,43 +230,43 @@ " mode=\"add\"\n", " )\n", " else:\n", - " phi_FM = phi_FM_init_identity\n", - " phi_MF = phi_MF_init_identity\n", + " inverse_transform = inverse_tranform_init_identity\n", + " forward_transform = forward_transform_init_identity\n", "\n", - " phi_MF_arr[img_indx] = phi_MF\n", - " phi_FM_arr[img_indx] = phi_FM\n", + " forward_transform_arr[img_indx] = forward_transform\n", + " inverse_transform_arr[img_indx] = inverse_transform\n", "\n", " if debug_mode:\n", - " reg_image = tfm_tools.transform_image(img, phi_MF, fixed_image)\n", + " reg_image = tfm_tools.transform_image(img, forward_transform, fixed_image)\n", " out_file = os.path.join(\n", - " \"results\", f\"slice_{reg_tool_name}_MF_{img_file_indx:03d}.mha\"\n", + " \"results\", f\"slice_{reg_tool_name}_forward_{img_file_indx:03d}.mha\"\n", " )\n", " itk.imwrite(reg_image, out_file, compression=True)\n", "\n", - " reg_image = tfm_tools.transform_image(fixed_image, phi_FM, img)\n", + " reg_image = tfm_tools.transform_image(fixed_image, inverse_transform, img)\n", " out_file = os.path.join(\n", - " \"results\", f\"slice_fixed_{reg_tool_name}_FM_{img_file_indx:03d}.mha\"\n", + " \"results\", f\"slice_fixed_{reg_tool_name}_inverse_{img_file_indx:03d}.mha\"\n", " )\n", " itk.imwrite(reg_image, out_file, compression=True)\n", "\n", " itk.transformwrite(\n", - " phi_MF,\n", + " forward_transform,\n", " os.path.join(\n", " \"results\",\n", - " f\"slice_{reg_tool_name}_MF_{img_file_indx:03d}.hdf\"\n", + " f\"slice_{reg_tool_name}_forward_{img_file_indx:03d}.hdf\"\n", " ),\n", " compression=True\n", " )\n", " itk.transformwrite(\n", - " phi_FM,\n", + " inverse_transform,\n", " os.path.join(\n", " \"results\",\n", - " f\"slice_{reg_tool_name}_FM_{img_file_indx:03d}.hdf\"\n", + " f\"slice_{reg_tool_name}_inverse_{img_file_indx:03d}.hdf\"\n", " ),\n", " compression=True\n", " )\n", "\n", - " return { \"phi_MF\": phi_MF_arr, \"phi_FM\": phi_FM_arr }" + " return { \"forward_transforms\": forward_transform_arr, \"inverse_transforms\": inverse_transform_arr }" ] }, { @@ -276,7 +276,8 @@ "metadata": {}, "outputs": [], "source": [ - "phi_MF_arr = None\n", + "forward_transform_arr = None\n", + "inverse_transform_arr = None\n", "for reg_tool_name, reg_tool, num_iterations in reg_method_data:\n", " reg_tool.set_fixed_image(fixed_image)\n", " reg_tool.set_number_of_iterations(num_iterations)\n", @@ -290,8 +291,8 @@ " reference_image_reg_use_identity,\n", " portion_of_prior_to_use=0.0\n", " )\n", - " phi_MF_arr = results[\"phi_MF\"]\n", - " phi_FM_arr = results[\"phi_FM\"]\n" + " forward_transform_arr = results[\"forward_transforms\"]\n", + " inverse_transform_arr = results[\"inverse_transforms\"]\n" ] }, { @@ -304,9 +305,6 @@ "import os\n", "\n", "import itk\n", - "import numpy as np\n", - "\n", - "from physiomotion4d.transform_tools import TransformTools\n", "\n", "tfm_tool = TransformTools()\n", "\n", @@ -317,7 +315,7 @@ " files = []\n", " files_indx = []\n", " for f in sorted(os.listdir(data_dir)):\n", - " if f.endswith(\".hdf\") and f.startswith(\"slice_ANTs_MF_\"):\n", + " if f.endswith(\".hdf\") and f.startswith(\"slice_ANTs_forward_\"):\n", " files.append(os.path.join(data_dir, f))\n", " files_indx.append(int(f.split(\"_\")[3].split(\".\")[0]))\n", "\n", @@ -331,29 +329,29 @@ "\n", "for i in range(num_files):\n", " print(files_indx[i])\n", - " phi = itk.transformread(\n", - " os.path.join(\"results\", f\"slice_ANTs_FM_{files_indx[i]:03d}.hdf\")\n", + " inverse_transform = itk.transformread(\n", + " os.path.join(\"results\", f\"slice_ANTs_inverse_{files_indx[i]:03d}.hdf\")\n", " )[0]\n", "\n", - " phi_image = tfm_tool.convert_transform_to_displacement_field(\n", - " phi,\n", + " inverse_image = tfm_tool.convert_transform_to_displacement_field(\n", + " inverse_transform,\n", " fixed_image,\n", " np_component_type=np.float32,\n", " )\n", " itk.imwrite(\n", - " phi_image,\n", - " os.path.join(\"results\", f\"slice_ANTs_FM_{files_indx[i]:03d}_hdf.mha\"),\n", + " inverse_image,\n", + " os.path.join(\"results\", f\"slice_ANTs_inverse_{files_indx[i]:03d}_hdf.mha\"),\n", " compression=True\n", " )\n", "\n", - " phi_grid_image = tfm_tool.transform_image(\n", + " inverse_grid_image = tfm_tool.transform_image(\n", " grid_image,\n", - " phi,\n", + " inverse_transform,\n", " fixed_image,\n", " )\n", " itk.imwrite(\n", - " phi_grid_image,\n", - " os.path.join(\"results\", f\"slice_fixed_ANTs_FM_grid_{files_indx[i]:03d}.mha\"),\n", + " inverse_grid_image,\n", + " os.path.join(\"results\", f\"slice_fixed_ANTs_inverse_grid_{files_indx[i]:03d}.mha\"),\n", " compression=True\n", " )\n" ] diff --git a/experiments/Reconstruct4DCT/reconstruct_4d_ct_class.ipynb b/experiments/Reconstruct4DCT/reconstruct_4d_ct_class.ipynb index e333e06..b45c2bf 100644 --- a/experiments/Reconstruct4DCT/reconstruct_4d_ct_class.ipynb +++ b/experiments/Reconstruct4DCT/reconstruct_4d_ct_class.ipynb @@ -1,354 +1,353 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# 4D CT Reconstruction Using RegisterTimeSeriesImages Class\n", - "\n", - "This notebook demonstrates the use of the `RegisterTimeSeriesImages` class to register a time series of CT images to a common reference frame.\n", - "\n", - "This is a refactored version of `reconstruct_4d_ct.ipynb` that uses the new class-based approach.\n" - ] + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 4D CT Reconstruction Using RegisterTimeSeriesImages Class\n", + "\n", + "This notebook demonstrates the use of the `RegisterTimeSeriesImages` class to register a time series of CT images to a common reference frame.\n", + "\n", + "This is a refactored version of `reconstruct_4d_ct.ipynb` that uses the new class-based approach.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "import itk\n", + "import numpy as np\n", + "\n", + "from physiomotion4d import RegisterTimeSeriesImages, TransformTools\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Load Data and Set Parameters\n", + "\n", + "Set `quick_run = True` for a fast test with fewer images, or `quick_run = False` for full processing.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Load image files\n", + "data_dir = os.path.join(\"..\", \"..\", \"data\", \"Slicer-Heart-CT\")\n", + "files = [\n", + " os.path.join(data_dir, f)\n", + " for f in sorted(os.listdir(data_dir))\n", + " if f.endswith(\".mha\") and f.startswith(\"slice_\")\n", + "]\n", + "\n", + "print(f\"Found {len(files)} slice files\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Configuration\n", + "quick_run = True # Set to True for quick testing\n", + "\n", + "# Select files and parameters based on mode\n", + "if quick_run:\n", + " print(\"=== QUICK RUN MODE ===\")\n", + " total_num_files = len(files)\n", + " target_num_files = 5\n", + " file_step = total_num_files // target_num_files\n", + " files = files[0:total_num_files:file_step]\n", + " files_indx = list(range(0, total_num_files, file_step))\n", + " num_files = len(files)\n", + " reference_image_num = num_files // 2\n", + "\n", + " # Registration parameters - only ANTs for quick run\n", + " registration_methods = [\"ants\", \"icon\", \"ants_icon\"]\n", + " number_of_iterations_list = [[8, 4, 1], 5, [[8, 4, 1], 5]] # For ANTs and ICON\n", + "else:\n", + " print(\"=== FULL RUN MODE ===\")\n", + " num_files = len(files)\n", + " files_indx = list(range(0, num_files))\n", + " reference_image_num = 7\n", + "\n", + " # Registration parameters - both ANTs and ICON for full run\n", + " registration_methods = [\"ants\", \"icon\", \"ants_icon\"]\n", + " number_of_iterations_list = [\n", + " [30, 15, 7, 3], # For ANTs\n", + " 20, # For ICON\n", + " [[30, 15, 7, 3], 20] # For ants_icon\n", + " ]\n", + "\n", + "# Common parameters\n", + "reference_image_file = os.path.join(data_dir, f\"slice_{files_indx[reference_image_num]:03d}.mha\")\n", + "register_start_to_reference = False\n", + "portion_of_prior_transform_to_init_next_transform = 0.0\n", + "\n", + "print(f\"Number of files: {num_files}\")\n", + "print(f\"Reference image: slice_{files_indx[reference_image_num]:03d}.mha\")\n", + "print(f\"Registration methods: {registration_methods}\")\n", + "print(f\"Number of iterations: {number_of_iterations_list}\")\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Load Images\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Load fixed/reference image\n", + "fixed_image = itk.imread(reference_image_file, pixel_type=itk.F)\n", + "print(f\"Fixed image size: {itk.size(fixed_image)}\")\n", + "print(f\"Fixed image spacing: {itk.spacing(fixed_image)}\")\n", + "\n", + "# Save fixed image for reference\n", + "os.makedirs(\"results\", exist_ok=True)\n", + "out_file = os.path.join(\"results\", f\"slice_fixed.mha\")\n", + "itk.imwrite(fixed_image, out_file)\n", + "print(f\"Saved fixed image to: {out_file}\")\n", + "\n", + "images = []\n", + "for file in files:\n", + " img = itk.imread(file, pixel_type=itk.F)\n", + " images.append(img)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# This cell will be run for each registration method in the loop below\n", + "print(f\"Registration methods to run: {registration_methods}\")\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Perform Time Series Registration\n", + "\n", + "Loop through each registration method and perform registration.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Store results for each method\n", + "all_results = {}\n", + "\n", + "# Loop through each registration method\n", + "for method_idx, registration_method in enumerate(registration_methods):\n", + " number_of_iterations = number_of_iterations_list[method_idx]\n", + "\n", + " print(\"\\n\" + \"=\"*70)\n", + " print(f\"Starting registration with {registration_method.upper()}\")\n", + " print(\"=\"*70)\n", + " print(f\" Starting index: {reference_image_num}\")\n", + " print(f\" Register start to reference: {register_start_to_reference}\")\n", + " print(f\" Prior transform weight: {portion_of_prior_transform_to_init_next_transform}\")\n", + " print(f\" Number of iterations: {number_of_iterations}\")\n", + "\n", + " # Create registrar for this method\n", + " registrar = RegisterTimeSeriesImages(registration_method=registration_method)\n", + " registrar.set_modality('ct')\n", + " registrar.set_fixed_image(fixed_image)\n", + " registrar.set_number_of_iterations(number_of_iterations)\n", + "\n", + " # Perform registration\n", + " result = registrar.register_time_series(\n", + " moving_images=images,\n", + " reference_frame=reference_image_num,\n", + " register_reference=register_start_to_reference,\n", + " prior_weight=portion_of_prior_transform_to_init_next_transform,\n", + " )\n", + "\n", + " # Store results\n", + " all_results[registration_method] = result\n", + "\n", + " forward_transforms = result[\"forward_transforms\"]\n", + " inverse_transforms = result[\"inverse_transforms\"]\n", + " losses = result[\"losses\"]\n", + "\n", + " print(f\"\\n{registration_method.upper()} registration complete!\")\n", + " print(f\" Average loss: {np.mean(losses):.6f}\")\n", + " print(f\" Min loss: {np.min(losses):.6f}\")\n", + " print(f\" Max loss: {np.max(losses):.6f}\")\n", + "\n", + "print(\"\\n\" + \"=\"*70)\n", + "print(\"All registrations complete!\")\n", + "print(\"=\"*70)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Save registered images and transforms for each method\n", + "tfm_tools = TransformTools()\n", + "\n", + "for registration_method in registration_methods:\n", + " result = all_results[registration_method]\n", + " forward_transforms = result[\"forward_transforms\"]\n", + " inverse_transforms = result[\"inverse_transforms\"]\n", + "\n", + " print(f\"Saving {registration_method.upper()} results...\")\n", + " for i, img_indx in enumerate(files_indx):\n", + " print(f\" Saving slice {img_indx:03d}...\")\n", + "\n", + " # Apply transform and save registered image (moving to fixed)\n", + " reg_image = tfm_tools.transform_image(images[i], forward_transforms[i], fixed_image)\n", + " out_file = os.path.join(\n", + " \"results\", f\"slice_{registration_method}_forward_{img_indx:03d}.mha\"\n", + " )\n", + " itk.imwrite(reg_image, out_file, compression=True)\n", + "\n", + " # Apply inverse transform and save (fixed to moving)\n", + " reg_image_inv = tfm_tools.transform_image(fixed_image, inverse_transforms[i], images[i])\n", + " out_file = os.path.join(\n", + " \"results\", f\"slice_fixed_{registration_method}_inverse_{img_indx:03d}.mha\"\n", + " )\n", + " itk.imwrite(reg_image_inv, out_file, compression=True)\n", + "\n", + " # Save transforms\n", + " itk.transformwrite(\n", + " forward_transforms[i],\n", + " os.path.join(\n", + " \"results\",\n", + " f\"slice_{registration_method}_forward_{img_indx:03d}.hdf\"\n", + " ),\n", + " compression=True\n", + " )\n", + " itk.transformwrite(\n", + " inverse_transforms[i],\n", + " os.path.join(\n", + " \"results\",\n", + " f\"slice_{registration_method}_inverse_{img_indx:03d}.hdf\"\n", + " ),\n", + " compression=True\n", + " )\n", + "\n", + "print(\"✓ Results saved to results/ directory\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Print registration losses for each method\n", + "for registration_method in registration_methods:\n", + " result = all_results[registration_method]\n", + " losses = result[\"losses\"]\n", + "\n", + " print(f\"{registration_method.upper()} Registration Losses:\")\n", + " print(\"=\"*50)\n", + " for i, img_indx in enumerate(files_indx):\n", + " status = \"(reference)\" if i == reference_image_num else \"\"\n", + " print(f\" Slice {img_indx:03d}: {losses[i]:.6f} {status}\")\n", + "\n", + " print(f\"{registration_method.upper()} Statistics:\")\n", + " print(f\" Mean loss: {np.mean(losses):.6f}\")\n", + " print(f\" Std loss: {np.std(losses):.6f}\")\n", + " print(f\" Min loss: {np.min(losses):.6f}\")\n", + " print(f\" Max loss: {np.max(losses):.6f}\")\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Visualize Registration Quality\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Generate grid image for visualization\n", + "grid_image = tfm_tools.generate_grid_image(fixed_image, 30, 1)\n", + "\n", + "for registration_method in registration_methods:\n", + " result = all_results[registration_method]\n", + " inverse_transforms = result[\"inverse_transforms\"]\n", + "\n", + " print(f\"Generating {registration_method.upper()} grid visualizations...\")\n", + " for i, img_indx in enumerate(files_indx):\n", + " print(f\" Generating grid for slice {img_indx:03d}...\")\n", + "\n", + " # Transform grid with inverse transform (FM)\n", + " inverse_grid_image = tfm_tools.transform_image(\n", + " grid_image,\n", + " inverse_transforms[i],\n", + " fixed_image,\n", + " )\n", + " itk.imwrite(\n", + " inverse_grid_image,\n", + " os.path.join(\"results\", f\"slice_fixed_{registration_method}_inverse_grid_{img_indx:03d}.mha\"),\n", + " compression=True\n", + " )\n", + "\n", + " # Save displacement field as image\n", + " inverse_transform_image = tfm_tools.convert_transform_to_displacement_field(\n", + " inverse_transforms[i],\n", + " fixed_image,\n", + " np_component_type=np.float32,\n", + " )\n", + " itk.imwrite(\n", + " inverse_transform_image,\n", + " os.path.join(\"results\", f\"slice_{registration_method}_inverse_{img_indx:03d}_field.mha\"),\n", + " compression=True\n", + " )\n", + "\n", + "print(\"✓ Grid visualizations saved\")\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.11" + } }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import os\n", - "\n", - "import itk\n", - "import numpy as np\n", - "\n", - "from physiomotion4d import RegisterTimeSeriesImages, TransformTools\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Load Data and Set Parameters\n", - "\n", - "Set `quick_run = True` for a fast test with fewer images, or `quick_run = False` for full processing.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Load image files\n", - "data_dir = os.path.join(\"..\", \"..\", \"data\", \"Slicer-Heart-CT\")\n", - "files = [\n", - " os.path.join(data_dir, f)\n", - " for f in sorted(os.listdir(data_dir))\n", - " if f.endswith(\".mha\") and f.startswith(\"slice_\")\n", - "]\n", - "\n", - "print(f\"Found {len(files)} slice files\")\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Configuration\n", - "quick_run = True # Set to True for quick testing\n", - "\n", - "# Select files and parameters based on mode\n", - "if quick_run:\n", - " print(\"=== QUICK RUN MODE ===\")\n", - " total_num_files = len(files)\n", - " target_num_files = 5\n", - " file_step = total_num_files // target_num_files\n", - " files = files[0:total_num_files:file_step]\n", - " files_indx = list(range(0, total_num_files, file_step))\n", - " num_files = len(files)\n", - " reference_image_num = num_files // 2\n", - "\n", - " # Registration parameters - only ANTs for quick run\n", - " registration_methods = [\"ants\", \"icon\", \"ants_icon\"]\n", - " number_of_iterations_list = [[8, 4, 1], 5, [[8, 4, 1], 5]] # For ANTs and ICON\n", - "else:\n", - " print(\"=== FULL RUN MODE ===\")\n", - " num_files = len(files)\n", - " files_indx = list(range(0, num_files))\n", - " reference_image_num = 7\n", - "\n", - " # Registration parameters - both ANTs and ICON for full run\n", - " registration_methods = [\"ants\", \"icon\", \"ants_icon\"]\n", - " number_of_iterations_list = [\n", - " [30, 15, 7, 3], # For ANTs\n", - " 20, # For ICON\n", - " [[30, 15, 7, 3], 20] # For ants_icon\n", - " ]\n", - "\n", - "# Common parameters\n", - "reference_image_file = os.path.join(data_dir, f\"slice_{files_indx[reference_image_num]:03d}.mha\")\n", - "register_start_to_reference = False\n", - "portion_of_prior_transform_to_init_next_transform = 0.0\n", - "\n", - "print(f\"Number of files: {num_files}\")\n", - "print(f\"Reference image: slice_{files_indx[reference_image_num]:03d}.mha\")\n", - "print(f\"Registration methods: {registration_methods}\")\n", - "print(f\"Number of iterations: {number_of_iterations_list}\")\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Load Images\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Load fixed/reference image\n", - "fixed_image = itk.imread(reference_image_file, pixel_type=itk.F)\n", - "print(f\"Fixed image size: {itk.size(fixed_image)}\")\n", - "print(f\"Fixed image spacing: {itk.spacing(fixed_image)}\")\n", - "\n", - "# Save fixed image for reference\n", - "os.makedirs(\"results\", exist_ok=True)\n", - "out_file = os.path.join(\"results\", f\"slice_fixed.mha\")\n", - "itk.imwrite(fixed_image, out_file)\n", - "print(f\"Saved fixed image to: {out_file}\")\n", - "\n", - "images = []\n", - "for file in files:\n", - " img = itk.imread(file, pixel_type=itk.F)\n", - " images.append(img)\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# This cell will be run for each registration method in the loop below\n", - "print(f\"Registration methods to run: {registration_methods}\")\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Perform Time Series Registration\n", - "\n", - "Loop through each registration method and perform registration.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Store results for each method\n", - "all_results = {}\n", - "\n", - "# Loop through each registration method\n", - "for method_idx, registration_method in enumerate(registration_methods):\n", - " number_of_iterations = number_of_iterations_list[method_idx]\n", - "\n", - " print(\"\\n\" + \"=\"*70)\n", - " print(f\"Starting registration with {registration_method.upper()}\")\n", - " print(\"=\"*70)\n", - " print(f\" Starting index: {reference_image_num}\")\n", - " print(f\" Register start to reference: {register_start_to_reference}\")\n", - " print(f\" Prior transform weight: {portion_of_prior_transform_to_init_next_transform}\")\n", - " print(f\" Number of iterations: {number_of_iterations}\")\n", - "\n", - " # Create registrar for this method\n", - " registrar = RegisterTimeSeriesImages(registration_method=registration_method)\n", - " registrar.set_modality('ct')\n", - " registrar.set_fixed_image(fixed_image)\n", - " registrar.set_number_of_iterations(number_of_iterations)\n", - "\n", - " # Perform registration\n", - " result = registrar.register_time_series(\n", - " moving_images=images,\n", - " starting_index=reference_image_num,\n", - " register_start_to_fixed_image=register_start_to_reference,\n", - " portion_of_prior_transform_to_init_next_transform=portion_of_prior_transform_to_init_next_transform,\n", - " )\n", - "\n", - " # Store results\n", - " all_results[registration_method] = result\n", - "\n", - " phi_MF_list = result[\"phi_MF_list\"]\n", - " phi_FM_list = result[\"phi_FM_list\"]\n", - " losses = result[\"losses\"]\n", - "\n", - " print(f\"\\n{registration_method.upper()} registration complete!\")\n", - " print(f\" Transforms generated: {len(phi_MF_list)}\")\n", - " print(f\" Average loss: {np.mean(losses):.6f}\")\n", - " print(f\" Min loss: {np.min(losses):.6f}\")\n", - " print(f\" Max loss: {np.max(losses):.6f}\")\n", - "\n", - "print(\"\\n\" + \"=\"*70)\n", - "print(\"All registrations complete!\")\n", - "print(\"=\"*70)\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Save registered images and transforms for each method\n", - "tfm_tools = TransformTools()\n", - "\n", - "for registration_method in registration_methods:\n", - " result = all_results[registration_method]\n", - " phi_MF_list = result[\"phi_MF_list\"]\n", - " phi_FM_list = result[\"phi_FM_list\"]\n", - "\n", - " print(f\"Saving {registration_method.upper()} results...\")\n", - " for i, img_indx in enumerate(files_indx):\n", - " print(f\" Saving slice {img_indx:03d}...\")\n", - "\n", - " # Apply transform and save registered image (moving to fixed)\n", - " reg_image = tfm_tools.transform_image(images[i], phi_MF_list[i], fixed_image)\n", - " out_file = os.path.join(\n", - " \"results\", f\"slice_{registration_method}_MF_{img_indx:03d}.mha\"\n", - " )\n", - " itk.imwrite(reg_image, out_file, compression=True)\n", - "\n", - " # Apply inverse transform and save (fixed to moving)\n", - " reg_image_inv = tfm_tools.transform_image(fixed_image, phi_FM_list[i], images[i])\n", - " out_file = os.path.join(\n", - " \"results\", f\"slice_fixed_{registration_method}_FM_{img_indx:03d}.mha\"\n", - " )\n", - " itk.imwrite(reg_image_inv, out_file, compression=True)\n", - "\n", - " # Save transforms\n", - " itk.transformwrite(\n", - " phi_MF_list[i],\n", - " os.path.join(\n", - " \"results\",\n", - " f\"slice_{registration_method}_MF_{img_indx:03d}.hdf\"\n", - " ),\n", - " compression=True\n", - " )\n", - " itk.transformwrite(\n", - " phi_FM_list[i],\n", - " os.path.join(\n", - " \"results\",\n", - " f\"slice_{registration_method}_FM_{img_indx:03d}.hdf\"\n", - " ),\n", - " compression=True\n", - " )\n", - "\n", - "print(\"✓ Results saved to results/ directory\")\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Print registration losses for each method\n", - "for registration_method in registration_methods:\n", - " result = all_results[registration_method]\n", - " losses = result[\"losses\"]\n", - "\n", - " print(f\"{registration_method.upper()} Registration Losses:\")\n", - " print(\"=\"*50)\n", - " for i, img_indx in enumerate(files_indx):\n", - " status = \"(reference)\" if i == reference_image_num else \"\"\n", - " print(f\" Slice {img_indx:03d}: {losses[i]:.6f} {status}\")\n", - "\n", - " print(f\"{registration_method.upper()} Statistics:\")\n", - " print(f\" Mean loss: {np.mean(losses):.6f}\")\n", - " print(f\" Std loss: {np.std(losses):.6f}\")\n", - " print(f\" Min loss: {np.min(losses):.6f}\")\n", - " print(f\" Max loss: {np.max(losses):.6f}\")\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Visualize Registration Quality\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Generate grid image for visualization\n", - "grid_image = tfm_tools.generate_grid_image(fixed_image, 30, 1)\n", - "\n", - "for registration_method in registration_methods:\n", - " result = all_results[registration_method]\n", - " phi_FM_list = result[\"phi_FM_list\"]\n", - "\n", - " print(f\"Generating {registration_method.upper()} grid visualizations...\")\n", - " for i, img_indx in enumerate(files_indx):\n", - " print(f\" Generating grid for slice {img_indx:03d}...\")\n", - "\n", - " # Transform grid with inverse transform (FM)\n", - " phi_grid_image = tfm_tools.transform_image(\n", - " grid_image,\n", - " phi_FM_list[i],\n", - " fixed_image,\n", - " )\n", - " itk.imwrite(\n", - " phi_grid_image,\n", - " os.path.join(\"results\", f\"slice_fixed_{registration_method}_FM_grid_{img_indx:03d}.mha\"),\n", - " compression=True\n", - " )\n", - "\n", - " # Save displacement field as image\n", - " phi_image = tfm_tools.convert_transform_to_displacement_field(\n", - " phi_FM_list[i],\n", - " fixed_image,\n", - " np_component_type=np.float32,\n", - " )\n", - " itk.imwrite(\n", - " phi_image,\n", - " os.path.join(\"results\", f\"slice_{registration_method}_FM_{img_indx:03d}_field.mha\"),\n", - " compression=True\n", - " )\n", - "\n", - "print(\"✓ Grid visualizations saved\")\n" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "venv", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.11" - } - }, - "nbformat": 4, - "nbformat_minor": 2 + "nbformat": 4, + "nbformat_minor": 2 } diff --git a/pyproject.toml b/pyproject.toml index 6388c14..56038e6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -82,6 +82,8 @@ dependencies = [ "matplotlib>=3.5.0", "jupyterlab>=4.0.0", "typing-extensions>=4.0.0", + "cupy-cuda12x>=13.6.0", + ] [tool.uv.sources] @@ -121,7 +123,8 @@ dev = [ "ruff>=0.1.0", ] docs = [ - "linkify>=1.4", + "linkify-it-py>=2.0.0", + "uc-micro-py>=1.0.1", "sphinx>=7.0.0", "sphinx-rtd-theme>=2.0.0", "sphinx-autodoc-typehints>=1.25.0", @@ -283,14 +286,14 @@ addopts = [ "--cov=physiomotion4d", "--cov-report=term-missing", "--cov-report=html", - "--cov-report=xml", - "--timeout=900" + "--cov-report=xml" ] testpaths = ["tests"] markers = [ + "unit: marks tests as unit tests (fast, isolated)", + "integration: marks tests as integration tests (slower, multiple components)", "slow: marks tests as slow (deselect with '-m \"not slow\"')", - "integration: marks tests as integration tests", - "unit: marks tests as unit tests", + "requires_gpu: marks tests that require GPU/CUDA support", "requires_data: marks tests that require external data download" ] timeout = 900 diff --git a/scripts/README.md b/scripts/README.md index 985b66b..3388146 100644 --- a/scripts/README.md +++ b/scripts/README.md @@ -192,8 +192,8 @@ The workflow executes the following steps automatically: #### Intermediate Files (in output directory) - `slice_*.mha` - Individual 3D images for each time point - `slice_*.labelmap.mha` - Segmentation masks -- `slice_*.reg_*.phi_FM.hdf` - Forward transformation files -- `slice_*.reg_*.phi_MF.hdf` - Backward transformation files +- `slice_*.reg_*.inverse_transform.hdf` - Backward transformation files +- `slice_*.reg_*.forward_transform.hdf` - Forward transformation files - `slice_max.reg_*.mha` - Maximum intensity projection images - `*.vtk` - VTK mesh files for contours - `*_4d.vtk` - Time series VTK files diff --git a/src/physiomotion4d/__init__.py b/src/physiomotion4d/__init__.py index 0805919..7b4ec52 100644 --- a/src/physiomotion4d/__init__.py +++ b/src/physiomotion4d/__init__.py @@ -16,9 +16,6 @@ __version__ = "2025.05.0" -# Base classes -from .physiomotion4d_base import PhysioMotion4DBase - from .contour_tools import ContourTools # Data processing utilities @@ -34,15 +31,18 @@ # Utility classes from .image_tools import ImageTools -from .usd_anatomy_tools import USDAnatomyTools + +# Base classes +from .physiomotion4d_base import PhysioMotion4DBase from .register_images_ants import RegisterImagesANTs # Registration classes from .register_images_base import RegisterImagesBase from .register_images_icon import RegisterImagesICON -from .register_model_to_image_pca import RegisterModelToImagePCA -from .register_model_to_model_icp import RegisterModelToModelICP -from .register_model_to_model_masks import RegisterModelToModelMasks +from .register_models_distance_maps import RegisterModelsDistanceMaps +from .register_models_icp import RegisterModelsICP +from .register_models_icp_itk import RegisterModelsICPITK +from .register_models_pca import RegisterModelsPCA from .register_time_series_images import RegisterTimeSeriesImages # Segmentation classes @@ -52,6 +52,7 @@ from .segment_chest_vista_3d import SegmentChestVista3D from .segment_chest_vista_3d_nim import SegmentChestVista3DNIM from .transform_tools import TransformTools +from .usd_anatomy_tools import USDAnatomyTools from .usd_tools import USDTools __all__ = [ @@ -69,9 +70,9 @@ "RegisterImagesICON", "RegisterImagesANTs", "RegisterTimeSeriesImages", - "RegisterModelToImagePCA", - "RegisterModelToModelICP", - "RegisterModelToModelMasks", + "RegisterModelsPCA", + "RegisterModelsICP", + "RegisterModelsDistanceMaps", # Base classes "PhysioMotion4DBase", # Utility classes diff --git a/src/physiomotion4d/contour_tools.py b/src/physiomotion4d/contour_tools.py index e49e6e2..6b105fc 100644 --- a/src/physiomotion4d/contour_tools.py +++ b/src/physiomotion4d/contour_tools.py @@ -2,21 +2,30 @@ Tools for creating and manipulating contours. """ +import logging + import itk import numpy as np import pyvista as pv import trimesh +from physiomotion4d.image_tools import ImageTools +from physiomotion4d.physiomotion4d_base import PhysioMotion4DBase from physiomotion4d.transform_tools import TransformTools -class ContourTools: +class ContourTools(PhysioMotion4DBase): """ Tools for creating and manipulating contours. """ - def __init__(self): - pass + def __init__(self, log_level: int | str = logging.INFO): + """Initialize ContourTools. + + Args: + log_level: Logging level (default: logging.INFO) + """ + super().__init__(class_name=self.__class__.__name__, log_level=log_level) def extract_contours( self, @@ -87,7 +96,7 @@ def merge_meshes(self, meshes): pv.PolyData Merged mesh """ - print("Merging meshes...") + self.log_info("Merging meshes...") if hasattr(meshes[0], 'n_faces_strict'): meshes = [ trimesh.Trimesh( @@ -118,11 +127,39 @@ def merge_meshes(self, meshes): return merged_mesh, pv_meshes + def create_reference_image( + self, + mesh, + spatial_resolution: float = 0.5, + buffer_factor: float = 0.25, + ptype: type = itk.F, + ) -> itk.Image: + """ + Create a reference image from a mesh. + """ + points = np.array(mesh.points) + min_bounds = points.min(axis=0) + max_bounds = points.max(axis=0) + min_bounds = min_bounds - buffer_factor * (max_bounds - min_bounds) + max_bounds = max_bounds + buffer_factor * (max_bounds - min_bounds) + region = ( + ((max_bounds - min_bounds) / spatial_resolution + 1) + .astype(np.int32) + .tolist() + ) + itk_region = itk.ImageRegion[3]() + itk_region.SetSize(region) + reference_image = itk.Image[ptype, 3].New() + reference_image.SetRegions(itk_region) + reference_image.SetSpacing([spatial_resolution] * 3) + reference_image.SetOrigin(min_bounds.tolist()) + reference_image.Allocate() + return reference_image + def create_mask_from_mesh( self, mesh, reference_image, - resample_to_reference=True, ): ref_spacing = np.array(reference_image.GetSpacing()) @@ -177,7 +214,8 @@ def create_mask_from_mesh( binary_image.SetSpacing([voxel_pitch] * 3) # Direction: use identity for now (axis-aligned), will be handled by resampling - ref_dir = np.array(reference_image.GetDirection()) + # Flip Z axis to match ITK convention + ref_dir = np.array(binary_image.GetDirection()) ref_dir[2, 2] = -ref_dir[2, 2] binary_image.SetDirection(ref_dir) @@ -189,67 +227,156 @@ def create_mask_from_mesh( fill_filter.Update() mask_image = fill_filter.GetOutput() - if resample_to_reference: - resampler = itk.ResampleImageFilter.New(Input=mask_image) - resampler.SetReferenceImage(reference_image) - resampler.SetUseReferenceImage(True) - resampler.SetInterpolator( - itk.NearestNeighborInterpolateImageFunction.New(mask_image) - ) - resampler.SetDefaultPixelValue(0) - resampler.Update() - mask_image = resampler.GetOutput() + resampler = itk.ResampleImageFilter.New(Input=mask_image) + resampler.SetReferenceImage(reference_image) + resampler.SetUseReferenceImage(True) + resampler.SetInterpolator( + itk.NearestNeighborInterpolateImageFunction.New(mask_image) + ) + resampler.SetDefaultPixelValue(0) + resampler.Update() + mask_image = resampler.GetOutput() return mask_image - def create_contour_distance_map_from_mesh( + def create_distance_map( self, mesh, reference_image, - max_distance: float = 100.0, + squared_distance: bool = False, + max_distance: float = 0.0, invert_distance_map: bool = False, - ): + create_point_map: bool = False, + ) -> itk.Image: + self.log_info("Computing signed distance map...") + # Convert mask to binary - mesh_mask = self.create_mask_from_mesh(mesh, reference_image) - mask_arr = itk.GetArrayFromImage(mesh_mask) - binary_mask_arr = (mask_arr > 0).astype(np.uint8) - binary_mask_image = itk.GetImageFromArray(binary_mask_arr) - binary_mask_image.CopyInformation(reference_image) - - edge_filter = itk.BinaryContourImageFilter.New(Input=binary_mask_image) - edge_filter.SetForegroundValue(1) - edge_filter.SetBackgroundValue(0) - edge_filter.SetFullyConnected(False) - edge_filter.Update() - edge_mask_image = edge_filter.GetOutput() - - # Compute signed distance map (positive inside, negative outside) - print(" Computing signed distance map...") - distance_filter = itk.SignedMaurerDistanceMapImageFilter.New( - Input=edge_mask_image - ) - distance_filter.SetSquaredDistance(False) + points = mesh.points + + size = reference_image.GetLargestPossibleRegion().GetSize() + size = (size[2], size[1], size[0]) + + tmp_arr = np.zeros(size, dtype=np.int32) + itk_point = itk.Point[itk.D, 3]() + for i, point in enumerate(points): + itk_point[0] = float(point[0]) + itk_point[1] = float(point[1]) + itk_point[2] = float(point[2]) + indx = reference_image.TransformPhysicalPointToIndex(itk_point) + tmp_arr[indx[2], indx[1], indx[0]] = i + tmp_binary_arr = (tmp_arr > 0).astype(np.float32) + tmp_binary_image = itk.GetImageFromArray(tmp_binary_arr) + tmp_binary_image.CopyInformation(reference_image) + + distance_filter = itk.DanielssonDistanceMapImageFilter.New( + Input=tmp_binary_image + ) + distance_filter.SetSquaredDistance(squared_distance) distance_filter.SetUseImageSpacing(True) - distance_filter.SetInsideIsPositive(False) + distance_filter.SetInputIsBinary(True) distance_filter.Update() distance_image = distance_filter.GetOutput() - distance_arr = itk.GetArrayFromImage(distance_image) - min_val = distance_arr.min() - if max_distance is None: + distance_arr = itk.GetArrayFromImage(distance_image).astype(np.float32) + if max_distance == 0.0: max_val = distance_arr.max() else: max_val = max_distance - distance_arr = np.clip(distance_arr, min_val, max_val) + distance_arr = np.clip(distance_arr, 0.0, max_val) if invert_distance_map: - distance_arr = ( - (1.0 - (distance_arr - min_val) / (max_val - min_val)) * max_distance - ).astype(np.float32) - else: - distance_arr = ( - ((distance_arr - min_val) / (max_val - min_val)) * max_distance - ).astype(np.float32) + distance_arr = max_distance - distance_arr distance_image = itk.GetImageFromArray(distance_arr) distance_image.CopyInformation(reference_image) return distance_image + + def create_deformation_field( + self, + points: np.ndarray, + point_displacements: np.ndarray, + reference_image: itk.Image, + blur_sigma: float = 2.5, + ptype=itk.D, + ) -> itk.Image: + """ + Create a displacement map from model points and displacements. + """ + size = reference_image.GetLargestPossibleRegion().GetSize() + norm_map = np.zeros((size[2], size[1], size[0])).astype(np.float32) + displacement_map_x = np.zeros((size[2], size[1], size[0])).astype(np.float32) + displacement_map_y = np.zeros((size[2], size[1], size[0])).astype(np.float32) + displacement_map_z = np.zeros((size[2], size[1], size[0])).astype(np.float32) + itk_point = itk.Point[itk.D, 3]() + for i, point in enumerate(points): + itk_point[0] = float(point[0]) + itk_point[1] = float(point[1]) + itk_point[2] = float(point[2]) + indx = reference_image.TransformPhysicalPointToIndex(itk_point) + displacement_map_x[int(indx[2]), int(indx[1]), int(indx[0])] = ( + point_displacements[i, 0] + ) + displacement_map_y[int(indx[2]), int(indx[1]), int(indx[0])] = ( + point_displacements[i, 1] + ) + displacement_map_z[int(indx[2]), int(indx[1]), int(indx[0])] = ( + point_displacements[i, 2] + ) + norm_map[int(indx[2]), int(indx[1]), int(indx[0])] = 1 + + norm_img = itk.GetImageFromArray(norm_map) + norm_img.CopyInformation(reference_image) + + blurred_norm = itk.SmoothingRecursiveGaussianImageFilter( + Input=norm_img, Sigma=blur_sigma + ) + blurred_norm_arr = itk.GetArrayFromImage(blurred_norm) + blurred_norm_arr = np.where(blurred_norm_arr < 1.0e-4, 1.0e-4, blurred_norm_arr) + + deformation_field_x_img = itk.GetImageFromArray(displacement_map_x) + deformation_field_x_img.CopyInformation(reference_image) + deformation_field_x_img = itk.SmoothingRecursiveGaussianImageFilter( + Input=deformation_field_x_img, Sigma=blur_sigma + ) + + deformation_field_y_img = itk.GetImageFromArray(displacement_map_y) + deformation_field_y_img.CopyInformation(reference_image) + deformation_field_y_img = itk.SmoothingRecursiveGaussianImageFilter( + Input=deformation_field_y_img, Sigma=blur_sigma + ) + + deformation_field_z_img = itk.GetImageFromArray(displacement_map_z) + deformation_field_z_img.CopyInformation(reference_image) + deformation_field_z_img = itk.SmoothingRecursiveGaussianImageFilter( + Input=deformation_field_z_img, Sigma=blur_sigma + ) + + deformation_field_x = ( + itk.GetArrayFromImage(deformation_field_x_img) / blurred_norm_arr + ) + deformation_field_y = ( + itk.GetArrayFromImage(deformation_field_y_img) / blurred_norm_arr + ) + deformation_field_z = ( + itk.GetArrayFromImage(deformation_field_z_img) / blurred_norm_arr + ) + + deformation_field_x = np.where( + blurred_norm_arr > 1.0e-3, deformation_field_x, 0.0 + ) + deformation_field_y = np.where( + blurred_norm_arr > 1.0e-3, deformation_field_y, 0.0 + ) + deformation_field_z = np.where( + blurred_norm_arr > 1.0e-3, deformation_field_z, 0.0 + ) + + deformation_field = np.stack( + [deformation_field_x, deformation_field_y, deformation_field_z], axis=-1 + ) + + image_tools = ImageTools() + deformation_field_img = image_tools.convert_array_to_image_of_vectors( + deformation_field, reference_image, ptype=ptype + ) + + return deformation_field_img diff --git a/src/physiomotion4d/convert_nrrd_4d_to_3d.py b/src/physiomotion4d/convert_nrrd_4d_to_3d.py index 37f87ea..5c11f68 100644 --- a/src/physiomotion4d/convert_nrrd_4d_to_3d.py +++ b/src/physiomotion4d/convert_nrrd_4d_to_3d.py @@ -1,12 +1,21 @@ +import logging import os import itk import nrrd import numpy as np +from physiomotion4d.physiomotion4d_base import PhysioMotion4DBase -class ConvertNRRD4DTo3D: - def __init__(self): + +class ConvertNRRD4DTo3D(PhysioMotion4DBase): + def __init__(self, log_level: int | str = logging.INFO): + """Initialize the NRRD 4D to 3D converter. + + Args: + log_level: Logging level (default: logging.INFO) + """ + super().__init__(class_name=self.__class__.__name__, log_level=log_level) self.nrrd_4d = None self.img_3d = [] diff --git a/src/physiomotion4d/convert_vtk_4d_to_usd.py b/src/physiomotion4d/convert_vtk_4d_to_usd.py index 0cddc57..23048f7 100644 --- a/src/physiomotion4d/convert_vtk_4d_to_usd.py +++ b/src/physiomotion4d/convert_vtk_4d_to_usd.py @@ -1,14 +1,17 @@ """Unified facade for VTK to USD conversion supporting both PolyData and UnstructuredGrid.""" +import logging + import pyvista as pv import vtk from pxr import Usd from .convert_vtk_4d_to_usd_polymesh import ConvertVTK4DToUSDPolyMesh from .convert_vtk_4d_to_usd_tetmesh import ConvertVTK4DToUSDTetMesh +from .physiomotion4d_base import PhysioMotion4DBase -class ConvertVTK4DToUSD: +class ConvertVTK4DToUSD(PhysioMotion4DBase): """ Unified converter supporting both PolyData and UnstructuredGrid. @@ -54,7 +57,15 @@ class ConvertVTK4DToUSD: >>> stage = converter.convert("output.usd") """ - def __init__(self, data_basename, input_polydata, mask_ids=None): + def __init__( + self, + data_basename, + input_polydata, + mask_ids=None, + compute_normals=False, + convert_to_surface=False, + log_level: int | str = logging.INFO, + ): """ Initialize converter and store parameters for later routing. @@ -66,11 +77,17 @@ def __init__(self, data_basename, input_polydata, mask_ids=None): mask_ids (dict or None): Optional mapping of label IDs to label names for organizing meshes by anatomical regions. Default: None + log_level: Logging level (default: logging.INFO) """ + super().__init__(class_name=self.__class__.__name__, log_level=log_level) + self.data_basename = data_basename self.input_polydata = input_polydata self.mask_ids = mask_ids + self.compute_normals = compute_normals + self.convert_to_surface = convert_to_surface + # Colormap settings (will be applied to specialized converter) self.color_by_array = None self.colormap = 'plasma' @@ -125,7 +142,9 @@ def set_colormap( self.intensity_range = intensity_range return self - def convert(self, output_usd_file, convert_to_surface=False) -> Usd.Stage: + def convert( + self, output_usd_file, convert_to_surface=None, compute_normals=None + ) -> Usd.Stage: """ Convert meshes to USD, automatically routing by mesh type. @@ -149,7 +168,10 @@ def convert(self, output_usd_file, convert_to_surface=False) -> Usd.Stage: NotImplementedError: If mixed mesh types are detected ValueError: If no valid mesh data found """ - self.convert_to_surface = convert_to_surface + if convert_to_surface is not None: + self.convert_to_surface = convert_to_surface + if compute_normals is not None: + self.compute_normals = compute_normals # Analyze mesh types in input has_polydata = False @@ -166,25 +188,35 @@ def convert(self, output_usd_file, convert_to_surface=False) -> Usd.Stage: # Case 1: Only PolyData (or surface-converted UGrid) if has_polydata and not has_ugrid: - print("Routing to PolyMesh converter (surface meshes)") + self.log_info("Routing to PolyMesh converter (surface meshes)") converter = ConvertVTK4DToUSDPolyMesh( - self.data_basename, self.input_polydata, self.mask_ids + self.data_basename, + self.input_polydata, + self.mask_ids, + convert_to_surface=self.convert_to_surface, + compute_normals=self.compute_normals, + log_level=self.log_level, ) converter.set_colormap( self.color_by_array, self.colormap, self.intensity_range ) - return converter.convert(output_usd_file, convert_to_surface) + return converter.convert(output_usd_file) # Case 2: Only UnstructuredGrid (tetmesh) elif has_ugrid and not has_polydata: - print("Routing to TetMesh converter (volumetric meshes)") + self.log_info("Routing to TetMesh converter (volumetric meshes)") converter = ConvertVTK4DToUSDTetMesh( - self.data_basename, self.input_polydata, self.mask_ids + self.data_basename, + self.input_polydata, + self.mask_ids, + convert_to_surface=self.convert_to_surface, + compute_normals=self.compute_normals, + log_level=self.log_level, ) converter.set_colormap( self.color_by_array, self.colormap, self.intensity_range ) - return converter.convert(output_usd_file, convert_to_surface) + return converter.convert(output_usd_file) # Case 3: Mixed - need custom handling elif has_polydata and has_ugrid: diff --git a/src/physiomotion4d/convert_vtk_4d_to_usd_base.py b/src/physiomotion4d/convert_vtk_4d_to_usd_base.py index 62a75d8..c147590 100644 --- a/src/physiomotion4d/convert_vtk_4d_to_usd_base.py +++ b/src/physiomotion4d/convert_vtk_4d_to_usd_base.py @@ -1,13 +1,18 @@ """Abstract base class for converting 4D VTK data to animated USD meshes.""" +import logging import os +import time from abc import ABC, abstractmethod +import cupy as cp import matplotlib.cm as cm import numpy as np import pyvista as pv import vtk -from pxr import Gf, Sdf, Usd, UsdGeom +from pxr import Gf, Usd, UsdGeom, Vt + +from physiomotion4d.physiomotion4d_base import PhysioMotion4DBase # VTK Cell Type Constants VTK_TRIANGLE = 5 @@ -18,7 +23,7 @@ VTK_PYRAMID = 14 -class ConvertVTK4DToUSDBase(ABC): +class ConvertVTK4DToUSDBase(PhysioMotion4DBase, ABC): """ Abstract base class for VTK to USD conversion. @@ -27,7 +32,15 @@ class ConvertVTK4DToUSDBase(ABC): mesh-specific processing and USD creation methods. """ - def __init__(self, data_basename, input_polydata, mask_ids=None, convert_to_surface=False): + def __init__( + self, + data_basename, + input_polydata, + mask_ids=None, + convert_to_surface=False, + compute_normals=False, + log_level: int | str = logging.INFO, + ): """ Initialize VTK to USD converter. @@ -41,13 +54,18 @@ def __init__(self, data_basename, input_polydata, mask_ids=None, convert_to_surf convert_to_surface (bool): If True, convert UnstructuredGrid meshes to surface PolyData before processing. Only applicable for PolyMesh converter. Default: False + log_level: Logging level (default: logging.INFO) """ + super().__init__(class_name=self.__class__.__name__, log_level=log_level) + self.data_basename = data_basename self.input_polydata = input_polydata self.mask_ids = mask_ids self.convert_to_surface = convert_to_surface + self.compute_normals = compute_normals + # Colormap settings (set via set_colormap()) self.color_by_array = None self.colormap = 'plasma' @@ -90,7 +108,9 @@ def list_available_arrays(self): if array_name not in available_arrays: array_data = mesh.point_data[array_name] available_arrays[array_name] = { - 'n_components': (array_data.shape[1] if len(array_data.shape) > 1 else 1), + 'n_components': ( + array_data.shape[1] if len(array_data.shape) > 1 else 1 + ), 'dtype': str(array_data.dtype), 'range': (float(np.min(array_data)), float(np.max(array_data))), 'present_in_steps': [time_idx], @@ -107,7 +127,9 @@ def list_available_arrays(self): return available_arrays - def set_colormap(self, color_by_array=None, colormap='plasma', intensity_range=None): + def set_colormap( + self, color_by_array=None, colormap='plasma', intensity_range=None + ): """ Configure colormap settings for vertex coloring. @@ -130,20 +152,6 @@ def set_colormap(self, color_by_array=None, colormap='plasma', intensity_range=N self.colormap = colormap self.intensity_range = intensity_range - # Validate colormap choice - self._validate_colormap() - - # Initialize random seed for reproducible random colormap - if self.colormap == 'random': - np.random.seed(42) - - return self - - def _ras_to_usd(self, point): - """Convert RAS coordinates to USD's right-handed Y-up system""" - return Gf.Vec3f(float(point[0]), float(point[2]), float(-point[1])) - - def _validate_colormap(self): """Validate that the chosen colormap is supported""" supported_colormaps = [ 'plasma', @@ -160,6 +168,16 @@ def _validate_colormap(self): f"Choose from: {', '.join(supported_colormaps)}" ) + # Initialize random seed for reproducible random colormap + if self.colormap == 'random': + np.random.seed(42) + + return self + + def _ras_to_usd(self, point): + """Convert RAS coordinates to USD's right-handed Y-up system""" + return Gf.Vec3f(float(point[0]), float(point[2]), float(-point[1])) + def _get_matplotlib_colormap(self, colormap_name): """ Get matplotlib colormap object, with custom implementations for special cases. @@ -299,7 +317,7 @@ def _extract_color_array(self, mesh): color_values.append(scalar_value) return np.array(color_values) else: - print(f"Warning: Array '{self.color_by_array}' not found in point data") + self.log_warning("Array '%s' not found in point data", self.color_by_array) return None def _check_topology_changes(self, mesh_time_data): @@ -347,79 +365,81 @@ def _check_topology_changes(self, mesh_time_data): if mesh_type == 'polymesh': curr_num_points = len(curr_data['points']) curr_num_faces = len(curr_data['face_vertex_counts']) - if curr_num_points != ref_num_points or curr_num_faces != ref_num_faces: + if ( + curr_num_points != ref_num_points + or curr_num_faces != ref_num_faces + ): has_change = True break elif mesh_type == 'tetmesh': curr_num_points = len(curr_data['points']) curr_num_tets = len(curr_data['tet_indices']) - if curr_num_points != ref_num_points or curr_num_tets != ref_num_tets: + if ( + curr_num_points != ref_num_points + or curr_num_tets != ref_num_tets + ): has_change = True break topology_changes[label] = has_change if has_change: - print( - f"Detected topology changes for label '{label}' - " - f"will use time-varying mesh approach" + self.log_info( + "Detected topology changes for label '%s' - will use time-varying mesh approach", + label, ) return topology_changes - def _is_unstructured_grid(self, mesh) -> bool: - """Check if mesh is an UnstructuredGrid""" - return isinstance(mesh, (pv.UnstructuredGrid, vtk.vtkUnstructuredGrid)) + def _compute_facevarying_normals_tri( + self, points_vt, faceCounts_vt, faceIndices_vt + ): + """ + Vectorized face-varying normals for a triangulated mesh. - def _is_polydata(self, mesh) -> bool: - """Check if mesh is a PolyData""" - return isinstance(mesh, (pv.PolyData, vtk.vtkPolyData)) + points_vt: Vt.Vec3fArray + faceCounts_vt: Vt.IntArray (all must be 3) + faceIndices_vt: Vt.IntArray (len == 3 * numFaces) - def _compute_vertex_normals(self, points, face_vertex_counts, face_vertex_indices): + Returns: Vt.Vec3fArray of length len(faceIndices_vt), one normal per corner. """ - Compute per-vertex normals by averaging face normals. - Required for IndeX renderer compatibility. - """ - num_vertices = len(points) - vertex_normals = np.zeros((num_vertices, 3), dtype=np.float32) - - # Iterate through faces and compute face normals - idx = 0 - for count in face_vertex_counts: - # Get vertex indices for this face - face_verts = [face_vertex_indices[idx + i] for i in range(count)] - - # Get vertex positions - v0 = np.array(points[face_verts[0]]) - v1 = np.array(points[face_verts[1]]) - v2 = np.array(points[face_verts[2]]) - - # Compute face normal using cross product - edge1 = v1 - v0 - edge2 = v2 - v0 - face_normal = np.cross(edge1, edge2) - - # Normalize (handle zero-length normals) - length = np.linalg.norm(face_normal) - if length > 1e-10: - face_normal = face_normal / length - - # Add face normal to all vertices of this face - for vert_idx in face_verts: - vertex_normals[vert_idx] += face_normal - - idx += count - - # Normalize vertex normals - for i in range(num_vertices): - length = np.linalg.norm(vertex_normals[i]) - if length > 1e-10: - vertex_normals[i] = vertex_normals[i] / length - else: - # Default to (0, 1, 0) if normal is zero - vertex_normals[i] = np.array([0.0, 1.0, 0.0]) - return vertex_normals.tolist() + # Convert Vt arrays to NumPy + points = np.array(points_vt).astype(np.float32) # (N, 3) + counts = np.array(faceCounts_vt).astype(np.int32) # (F,) + indices = np.array(faceIndices_vt).astype(np.int32) # (3F,) + + # Sanity: assume triangulated mesh + if not np.all(counts == 3): + raise ValueError( + "Mesh must be fully triangulated (all faceVertexCounts == 3)" + ) + + # Reshape indices into (F, 3) + faces = indices.reshape(-1, 3) # (F, 3) + + # Gather per-face vertex positions (F, 3, 3) + tris = points[faces] # (F, 3, 3) + + # Compute normals via vectorized cross product + v1 = tris[:, 1] - tris[:, 0] # (F, 3) + v2 = tris[:, 2] - tris[:, 0] # (F, 3) + v1 = cp.array(v1) + v2 = cp.array(v2) + n = cp.cross(v1, v2) # (F, 3) + + # Normalize + lengths = cp.linalg.norm(n, axis=1, keepdims=True) # (F, 1) + mask = lengths[:, 0] > 0 + n[mask] /= lengths[mask] + + # Broadcast each face normal to 3 corners -> (F, 3, 3), then flatten + n_fv = cp.repeat(n[:, cp.newaxis, :], 3, axis=1).reshape(-1, 3) # (3F, 3) + + # Convert back to Vt.Vec3fArray + fv_vt = Vt.Vec3fArray.FromNumpy(n_fv.get()) + + return fv_vt # Abstract methods that subclasses must implement @@ -465,7 +485,9 @@ def _create_usd_mesh( """ pass - def convert(self, output_usd_file, convert_to_surface=None) -> Usd.Stage: + def convert( + self, output_usd_file, convert_to_surface=None, compute_normals=None + ) -> Usd.Stage: """ Convert VTK meshes to USD format. @@ -482,6 +504,9 @@ def convert(self, output_usd_file, convert_to_surface=None) -> Usd.Stage: if convert_to_surface is not None: self.convert_to_surface = convert_to_surface + if compute_normals is not None: + self.compute_normals = compute_normals + # Remove existing file if it exists to avoid USD layer conflicts if os.path.exists(output_usd_file): os.remove(output_usd_file) @@ -497,7 +522,7 @@ def convert(self, output_usd_file, convert_to_surface=None) -> Usd.Stage: UsdGeom.Xform.Define(self.stage, root_path) basename = os.path.basename(output_usd_file).split(".")[0] - print(f"Converting {basename}") + self.log_info("Converting %s", basename) root_path = f"{root_path}/Transform_{basename}" UsdGeom.Xform.Define(self.stage, root_path) @@ -506,38 +531,49 @@ def convert(self, output_usd_file, convert_to_surface=None) -> Usd.Stage: # Collect the label data from each time point polydata_time_data = {} - for fnum in range(len(self.input_polydata)): - polydata_time_data[fnum] = self._process_mesh_data(self.input_polydata[fnum]) - print("Processed time point", fnum) + for fnum, mesh_data in enumerate(self.input_polydata): + self.log_progress( + fnum + 1, len(self.input_polydata), prefix="Processing time point" + ) + polydata_time_data[fnum] = self._process_mesh_data(mesh_data) # Check for topology changes across time steps - print("\nChecking for topology changes...") topology_changes = self._check_topology_changes(polydata_time_data) # Assign a unique color to each label label_colors = {} for fnum in range(len(polydata_time_data)): - for label, data in polydata_time_data[fnum].items(): + for label, _ in polydata_time_data[fnum].items(): if label not in label_colors: - label_colors[label] = self.colors[len(label_colors) % len(self.colors)] + label_colors[label] = self.colors[ + len(label_colors) % len(self.colors) + ] # Process first polydata to get label groups first_data = polydata_time_data[0] # Create a mesh prim for each label group + num_labels = len(first_data.items()) for idx, (label, data) in enumerate(first_data.items()): - print(f"Processing {label} {idx}") # Create a transform for each mesh transform_path = f"{root_path}/Transform_{label}" - print(f"Transform path: {transform_path}") UsdGeom.Xform.Define(self.stage, transform_path) # Determine if topology changes for this label has_topology_change = topology_changes.get(label, False) # Call subclass-specific USD mesh creation + start_time = time.time() self._create_usd_mesh( - transform_path, label, polydata_time_data, label_colors, has_topology_change + transform_path, + label, + polydata_time_data, + label_colors, + has_topology_change, + ) + end_time = time.time() + self.log_info( + "Time taken to create USD mesh: %s seconds", end_time - start_time ) # Set time range for the stage @@ -546,5 +582,4 @@ def convert(self, output_usd_file, convert_to_surface=None) -> Usd.Stage: self.stage.SetTimeCodesPerSecond(1.0) self.stage.Save() - return self.stage diff --git a/src/physiomotion4d/convert_vtk_4d_to_usd_polymesh.py b/src/physiomotion4d/convert_vtk_4d_to_usd_polymesh.py index d09cddb..4ddefe9 100644 --- a/src/physiomotion4d/convert_vtk_4d_to_usd_polymesh.py +++ b/src/physiomotion4d/convert_vtk_4d_to_usd_polymesh.py @@ -1,5 +1,8 @@ """Converter for VTK PolyData to USD Mesh with surface meshes.""" +import itertools +import time + import numpy as np import pyvista as pv import vtk @@ -42,9 +45,12 @@ def supports_mesh_type(self, mesh) -> bool: Returns: bool: True if mesh is PolyData or can be converted to surface """ - if self._is_polydata(mesh): + if isinstance(mesh, (pv.PolyData, vtk.vtkPolyData)): return True - if self._is_unstructured_grid(mesh) and self.convert_to_surface: + if ( + isinstance(mesh, (pv.UnstructuredGrid, vtk.vtkUnstructuredGrid)) + and self.convert_to_surface + ): return True return False @@ -58,7 +64,7 @@ def _process_mesh_data(self, mesh) -> dict: Returns: dict: Processed mesh data organized by labels or 'default' """ - if self._is_unstructured_grid(mesh): + if isinstance(mesh, (pv.UnstructuredGrid, vtk.vtkUnstructuredGrid)): if self.convert_to_surface: # Convert UnstructuredGrid to surface PolyData first surface_mesh = self._convert_ugrid_to_surface(mesh) @@ -68,7 +74,7 @@ def _process_mesh_data(self, mesh) -> dict: "UnstructuredGrid not supported by PolyMesh converter. " "Use convert_to_surface=True or TetMesh converter." ) - elif self._is_polydata(mesh): + elif isinstance(mesh, (pv.PolyData, vtk.vtkPolyData)): return self._process_polydata(mesh) else: raise TypeError(f"Unsupported mesh type: {type(mesh)}") @@ -91,15 +97,15 @@ def _create_usd_mesh( has_topology_change: Whether topology varies over time """ if has_topology_change: - print( - f"Creating time-varying UsdGeomMesh for label: {label} " - f"(topology changes detected)" + self.log_info( + "Creating time-varying UsdGeomMesh for label: %s (topology changes detected)", + label, ) self._create_usd_polymesh_varying( transform_path, label, mesh_time_data, label_colors ) else: - print(f"Creating UsdGeomMesh for label: {label}") + self.log_info("Creating UsdGeomMesh for label: %s", label) self._create_usd_polymesh( transform_path, label, mesh_time_data, label_colors ) @@ -150,12 +156,12 @@ def _process_polydata(self, polydata) -> dict: ): label_array = polydata.GetCellData().GetArray("boundary_labels") # Get all unique labels from both components - boundary_labels = set() - for i in range(label_array.GetNumberOfTuples()): - tuple_values = label_array.GetTuple(i) - for value in tuple_values: - if int(value) != 0: - boundary_labels.add(value) + tuple_values = [ + label_array.GetTuple(i) for i in range(label_array.GetNumberOfTuples()) + ] + tuple_values_flattened = list(itertools.chain.from_iterable(tuple_values)) + boundary_labels = set(tuple_values_flattened) + boundary_labels.discard(0) # Get deformation magnitude if it exists def_mag = None @@ -172,17 +178,23 @@ def _process_polydata(self, polydata) -> dict: offsets = faces.GetOffsetsArray() # Process face data - face_vertex_counts = [] + num_faces = offsets.GetNumberOfValues() - 1 + start_idx = [offsets.GetValue(i) for i in range(num_faces)] + end_idx = [offsets.GetValue(i + 1) for i in range(num_faces)] + face_vertex_counts = [end_idx[i] - start_idx[i] for i in range(num_faces)] face_vertex_indices = [] + for i in range(num_faces): + face_vertex_indices.extend( + [connectivity.GetValue(j) for j in range(start_idx[i], end_idx[i])] + ) - for i in range(offsets.GetNumberOfValues() - 1): - start_idx = offsets.GetValue(i) - end_idx = offsets.GetValue(i + 1) - num_vertices = end_idx - start_idx - face_vertex_counts.append(num_vertices) - - for j in range(start_idx, end_idx): - face_vertex_indices.append(connectivity.GetValue(j)) + # for i in range(offsets.GetNumberOfValues() - 1): + # start_idx = offsets.GetValue(i) + # end_idx = offsets.GetValue(i + 1) + # num_vertices = end_idx - start_idx + # face_vertex_counts.append(num_vertices) + # for j in range(start_idx, end_idx): + # face_vertex_indices.append(connectivity.GetValue(j)) # Create objects for each cell based on its labels if boundary_labels: @@ -208,50 +220,66 @@ def _process_polydata(self, polydata) -> dict: label_array = polydata.GetCellData().GetArray("boundary_labels") # Process each cell + start_time = time.time() for cell_id in range(polydata.GetNumberOfCells()): + if cell_id % 1000000 == 0 or cell_id == polydata.GetNumberOfCells() - 1: + self.log_progress( + cell_id + 1, + polydata.GetNumberOfCells(), + prefix="Processing cells", + ) + cell = polydata.GetCell(cell_id) cell_labels = set() # Get all labels for this cell if label_array: tuple_values = label_array.GetTuple(cell_id) - for label_id in tuple_values: - if int(label_id) != 0: - label = self.mask_ids[int(label_id)] - cell_labels.add(label) - - # For each label of this cell, create a copy of the cell - for label_str in cell_labels: - obj = label_objects[label_str] + labels = [ + self.mask_ids[int(label_id)] + for label_id in tuple_values + if int(label_id) != 0 + ] + cell_labels.update(labels) # Get the points of this cell - cell_point_indices = [] - for i in range(cell.GetNumberOfPoints()): - point_id = cell.GetPointId(i) - point = points.GetPoint(point_id) - usd_point = self._ras_to_usd(point) - - # Check if we've already added this point - if point_id not in obj['point_mapping']: - obj['point_mapping'][point_id] = len(obj['points']) - obj['points'].append(usd_point) - if def_mag: - obj['deformation_magnitude'].append(def_mag[point_id]) - if color_array is not None: - obj['color_array'].append(color_array[point_id]) - - cell_point_indices.append(obj['point_mapping'][point_id]) - - # Add the face to this label's object - obj['face_vertex_counts'].append(len(cell_point_indices)) - obj['face_vertex_indices'].extend(cell_point_indices) + n_points = cell.GetNumberOfPoints() + point_ids = [int(cell.GetPointId(i)) for i in range(n_points)] + usd_points = [ + self._ras_to_usd(points.GetPoint(pnt_id)) + for pnt_id in point_ids + ] + + # For each label of this cell, create a copy of the cell + for label_str in cell_labels: + obj = label_objects[label_str] + cell_point_indices = [] + for pnt_num, pnt_id in enumerate(point_ids): + indx = obj['point_mapping'].get(pnt_id, None) + if indx is None: + indx = len(obj['points']) + obj['points'].append(usd_points[pnt_num]) + obj['point_mapping'][pnt_id] = indx + if def_mag: + obj['deformation_magnitude'].append(def_mag[pnt_id]) + if color_array is not None: + obj['color_array'].append(color_array[pnt_id]) + cell_point_indices.append(indx) + + obj['face_vertex_counts'].append(len(cell_point_indices)) + obj['face_vertex_indices'].extend(cell_point_indices) + + end_time = time.time() + self.log_info( + "Time taken to process cells %d: %s seconds", + polydata.GetNumberOfCells(), + end_time - start_time, + ) # Convert color_array lists to numpy arrays - for label in label_objects: - if label_objects[label]['color_array'] is not None: - label_objects[label]['color_array'] = np.array( - label_objects[label]['color_array'] - ) + # for label, obj in label_objects.items(): + # if obj['color_array'] is not None: + # obj['color_array'] = np.array(obj['color_array']) return label_objects else: @@ -299,10 +327,11 @@ def _create_usd_polymesh(self, transform_path, label, mesh_time_data, label_colo mesh.CreateSubdivisionSchemeAttr("none") # Prevent unwanted subdivision mesh.CreateDoubleSidedAttr(True) # Ensure visibility from both sides - # Create normals attribute (REQUIRED for IndeX renderer) + # Create normals attribute # Normals will be computed per timestep since mesh deforms - normals_attr = mesh.CreateNormalsAttr() - normals_attr.SetMetadata('interpolation', UsdGeom.Tokens.vertex) + if self.compute_normals: + normals_attr = mesh.CreateNormalsAttr() + normals_attr.SetMetadata('interpolation', UsdGeom.Tokens.vertex) # Set display color - either per-vertex from color array or fixed label color use_color_array = self.color_by_array is not None and any( @@ -351,23 +380,30 @@ def _create_usd_polymesh(self, transform_path, label, mesh_time_data, label_colo global_vmin = float('inf') global_vmax = float('-inf') + num_times = len(self.times) for time_idx, time_code in enumerate(self.times): - print(f"Processing time sample: {time_code} for label: {label}") + if time_idx % 10 == 0 or time_idx == num_times - 1: + self.log_progress( + time_idx + 1, + num_times, + prefix=f"Processing time samples for {label}", + ) time_data = mesh_time_data[time_idx][label] - # Compute per-vertex normals for this timestep (REQUIRED for IndeX renderer) - vertex_normals = self._compute_vertex_normals( - time_data['points'], - time_data['face_vertex_counts'], - time_data['face_vertex_indices'], - ) + if self.compute_normals: + vertex_normals = self._compute_facevarying_normals_tri( + time_data['points'], + time_data['face_vertex_counts'], + time_data['face_vertex_indices'], + ) # Set points first time_samples[time_code] = { 'points': time_data['points'], 'extent': UsdGeom.Mesh.ComputeExtent(time_data['points']), - 'normals': vertex_normals, } + if self.compute_normals: + time_samples[time_code]['normals'] = vertex_normals # Compute per-vertex colors if using color array if use_color_array and time_data.get('color_array') is not None: @@ -387,7 +423,8 @@ def _create_usd_polymesh(self, transform_path, label, mesh_time_data, label_colo for t_code, time_data_dict in time_samples.items(): points_attr.Set(time_data_dict['points'], t_code) extent_attr.Set(time_data_dict['extent'], t_code) - normals_attr.Set(time_data_dict['normals'], t_code) + if self.compute_normals: + normals_attr.Set(time_data_dict['normals'], t_code) if use_color_array and 'vertex_colors' in time_data_dict: display_color_primvar.Set(time_data_dict['vertex_colors'], t_code) # Set raw scalar values for colormap control @@ -406,7 +443,8 @@ def _create_usd_polymesh(self, transform_path, label, mesh_time_data, label_colo # Set initial values (non-timewarped) points_attr.Set(time_samples[self.times[0]]['points']) extent_attr.Set(time_samples[self.times[0]]['extent']) - normals_attr.Set(time_samples[self.times[0]]['normals']) + if self.compute_normals: + normals_attr.Set(time_samples[self.times[0]]['normals']) if use_color_array and 'vertex_colors' in time_samples[self.times[0]]: display_color_primvar.Set(time_samples[self.times[0]]['vertex_colors']) scalar_values_0 = time_samples[self.times[0]]['scalar_values'] @@ -471,8 +509,12 @@ def _create_usd_polymesh_varying( UsdGeom.Xform.Define(self.stage, parent_path) # Create separate mesh for each time step + num_times = len(self.times) for time_idx, time_code in enumerate(self.times): - print(f"Creating UsdGeomMesh for label: {label} at time: {time_code}") + if time_idx % 10 == 0 or time_idx == num_times - 1: + self.log_progress( + time_idx + 1, num_times, prefix=f"Creating meshes for {label}" + ) # Skip if label doesn't exist at this timestep if label not in mesh_time_data[time_idx]: continue @@ -493,14 +535,15 @@ def _create_usd_polymesh_varying( mesh.CreateDoubleSidedAttr(True) # Compute and set normals - vertex_normals = self._compute_vertex_normals( - time_data['points'], - time_data['face_vertex_counts'], - time_data['face_vertex_indices'], - ) - normals_attr = mesh.CreateNormalsAttr() - normals_attr.SetMetadata('interpolation', UsdGeom.Tokens.vertex) - normals_attr.Set(vertex_normals) + if self.compute_normals: + vertex_normals = self.compute_facevarying_normals_tri( + time_data['points'], + time_data['face_vertex_counts'], + time_data['face_vertex_indices'], + ) + normals_attr = mesh.CreateNormalsAttr() + normals_attr.SetMetadata('interpolation', UsdGeom.Tokens.vertex) + normals_attr.Set(vertex_normals) # Set extent extent_attr = mesh.CreateExtentAttr() diff --git a/src/physiomotion4d/convert_vtk_4d_to_usd_tetmesh.py b/src/physiomotion4d/convert_vtk_4d_to_usd_tetmesh.py index 4b4cbaf..eac9b80 100644 --- a/src/physiomotion4d/convert_vtk_4d_to_usd_tetmesh.py +++ b/src/physiomotion4d/convert_vtk_4d_to_usd_tetmesh.py @@ -45,7 +45,10 @@ def supports_mesh_type(self, mesh) -> bool: Returns: bool: True if mesh is UnstructuredGrid and not surface mode """ - return self._is_unstructured_grid(mesh) and not self.convert_to_surface + return ( + isinstance(mesh, (pv.UnstructuredGrid, vtk.vtkUnstructuredGrid)) + and not self.convert_to_surface + ) def _process_mesh_data(self, mesh) -> dict: """ @@ -57,7 +60,7 @@ def _process_mesh_data(self, mesh) -> dict: Returns: dict: Processed mesh data with tetrahedral or surface cell information """ - if not self._is_unstructured_grid(mesh): + if not isinstance(mesh, (pv.UnstructuredGrid, vtk.vtkUnstructuredGrid)): raise TypeError( f"TetMesh converter only supports UnstructuredGrid. " f"Got: {type(mesh)}" @@ -87,15 +90,15 @@ def _create_usd_mesh( if mesh_type == 'tetmesh': if has_topology_change: - print( - f"Creating time-varying UsdGeomTetMesh for label: {label} " - f"(topology changes detected)" + self.log_info( + "Creating time-varying UsdGeomTetMesh for label: %s (topology changes detected)", + label, ) self._create_usd_tetmesh_varying( transform_path, label, mesh_time_data, label_colors ) else: - print(f"Creating UsdGeomTetMesh for label: {label}") + self.log_info("Creating UsdGeomTetMesh for label: %s", label) self._create_usd_tetmesh( transform_path, label, mesh_time_data, label_colors ) @@ -323,10 +326,11 @@ def _create_usd_tetmesh(self, transform_path, label, mesh_time_data, label_color # Set mesh attributes for Index renderer compatibility tetmesh.CreateDoubleSidedAttr(True) # Ensure visibility from both sides - # Create normals attribute (REQUIRED for IndeX renderer) + # Create normals attribute # For tetrahedral meshes, we need normals for the surface vertices - normals_attr = tetmesh.CreateNormalsAttr() - normals_attr.SetMetadata('interpolation', UsdGeom.Tokens.vertex) + if self.compute_normals: + normals_attr = tetmesh.CreateNormalsAttr() + normals_attr.SetMetadata('interpolation', UsdGeom.Tokens.vertex) # Assign a unique color to the mesh with proper primvar display_color = label_colors[label] @@ -340,37 +344,45 @@ def _create_usd_tetmesh(self, transform_path, label, mesh_time_data, label_color extent_attr = tetmesh.CreateExtentAttr() time_samples = {} + num_times = len(self.times) for time_idx, time_code in enumerate(self.times): - print(f"Processing time step {time_code} for label: {label}") + if time_idx % 10 == 0 or time_idx == num_times - 1: + self.log_progress( + time_idx + 1, num_times, prefix=f"Processing time steps for {label}" + ) time_data = mesh_time_data[time_idx][label] - # Compute per-vertex normals for surface faces (REQUIRED for IndeX renderer) + # Compute per-vertex normals for surface faces # For tetrahedral meshes, compute normals based on surface triangulation # First, need to convert surface face indices to face vertex counts surface_indices = time_data['surface_face_indices'] # Surface faces are triangular, so each face has 3 vertices face_vertex_counts = [3] * (len(surface_indices) // 3) - vertex_normals = self._compute_vertex_normals( - time_data['points'], face_vertex_counts, surface_indices - ) + if self.compute_normals: + vertex_normals = self._compute_vertex_normals( + time_data['points'], face_vertex_counts, surface_indices + ) # Set points first time_samples[time_code] = { 'points': time_data['points'], 'extent': UsdGeom.TetMesh.ComputeExtent(time_data['points']), - 'normals': vertex_normals, } + if self.compute_normals: + time_samples[time_code]['normals'] = vertex_normals # Set points, extents, and normals with explicit time codes for t_code, time_data_dict in time_samples.items(): points_attr.Set(time_data_dict['points'], t_code) extent_attr.Set(time_data_dict['extent'], t_code) - normals_attr.Set(time_data_dict['normals'], t_code) + if self.compute_normals: + normals_attr.Set(time_data_dict['normals'], t_code) # Set initial values (non-timewarped) points_attr.Set(time_samples[self.times[0]]['points']) extent_attr.Set(time_samples[self.times[0]]['extent']) - normals_attr.Set(time_samples[self.times[0]]['normals']) + if self.compute_normals: + normals_attr.Set(time_samples[self.times[0]]['normals']) # Set deformation magnitude if it exists if any( @@ -406,8 +418,12 @@ def _create_usd_tetmesh_varying( UsdGeom.Xform.Define(self.stage, parent_path) # Create separate tetmesh for each time step + num_times = len(self.times) for time_idx, time_code in enumerate(self.times): - print(f"Creating UsdGeomTetMesh for label: {label} at time: {time_code}") + if time_idx % 10 == 0 or time_idx == num_times - 1: + self.log_progress( + time_idx + 1, num_times, prefix=f"Creating tetmeshes for {label}" + ) # Skip if label doesn't exist at this timestep if label not in mesh_time_data[time_idx]: continue @@ -431,12 +447,13 @@ def _create_usd_tetmesh_varying( # Compute and set normals surface_indices = time_data['surface_face_indices'] face_vertex_counts = [3] * (len(surface_indices) // 3) - vertex_normals = self._compute_vertex_normals( - time_data['points'], face_vertex_counts, surface_indices - ) - normals_attr = tetmesh.CreateNormalsAttr() - normals_attr.SetMetadata('interpolation', UsdGeom.Tokens.vertex) - normals_attr.Set(vertex_normals) + if self.compute_normals: + vertex_normals = self._compute_vertex_normals( + time_data['points'], face_vertex_counts, surface_indices + ) + normals_attr = tetmesh.CreateNormalsAttr() + normals_attr.SetMetadata('interpolation', UsdGeom.Tokens.vertex) + normals_attr.Set(vertex_normals) # Set extent extent_attr = tetmesh.CreateExtentAttr() diff --git a/src/physiomotion4d/heart_gated_ct_to_usd_workflow.py b/src/physiomotion4d/heart_gated_ct_to_usd_workflow.py index f1723f9..43c250a 100644 --- a/src/physiomotion4d/heart_gated_ct_to_usd_workflow.py +++ b/src/physiomotion4d/heart_gated_ct_to_usd_workflow.py @@ -5,8 +5,9 @@ as demonstrated in the Heart-GatedCT experiment notebooks. """ +import logging import os -from typing import List, Optional +from typing import Optional import itk import numpy as np @@ -16,6 +17,7 @@ from physiomotion4d.contour_tools import ContourTools from physiomotion4d.convert_nrrd_4d_to_3d import ConvertNRRD4DTo3D from physiomotion4d.convert_vtk_4d_to_usd import ConvertVTK4DToUSD +from physiomotion4d.physiomotion4d_base import PhysioMotion4DBase from physiomotion4d.register_images_ants import RegisterImagesANTs from physiomotion4d.register_images_icon import RegisterImagesICON from physiomotion4d.segment_chest_total_segmentator import SegmentChestTotalSegmentator @@ -23,7 +25,7 @@ from physiomotion4d.usd_anatomy_tools import USDAnatomyTools -class HeartGatedCTToUSDWorkflow: +class HeartGatedCTToUSDWorkflow(PhysioMotion4DBase): """ Complete workflow for Heart-gated CT images to dynamic USD models. @@ -33,19 +35,20 @@ class HeartGatedCTToUSDWorkflow: def __init__( self, - input_filenames: List[str], + input_filenames: list, contrast_enhanced: bool, output_directory: str, project_name: str, reference_image_filename: Optional[str] = None, number_of_registration_iterations: Optional[int] = 1, registration_method: str = 'icon', + log_level: int | str = logging.INFO, ): """ Initialize the Heart-gated CT to USD workflow. Args: - input_filenames (List[str]): List of paths to the 3D NRRD files containing cardiac CT data. + input_filenames (List): List of paths to the 3D NRRD files containing cardiac CT data. If there is only one file, it will be used as the 4D NRRD file. contrast_enhanced (bool): Whether the study uses contrast enhancement output_directory (str): Directory path where output files will be stored @@ -53,7 +56,10 @@ def __init__( reference_image_filename (Optional[str]): Path to reference image file number_of_registration_iterations (Optional[int]): Number of registration iterations registration_method (str): Registration method to use: 'ants' or 'icon' (default: 'icon') + log_level: Logging level (default: logging.INFO) """ + super().__init__(class_name=self.__class__.__name__, log_level=log_level) + self.input_filenames = input_filenames self.contrast_enhanced = contrast_enhanced self.output_directory = output_directory @@ -73,14 +79,14 @@ def __init__( os.makedirs(output_directory, exist_ok=True) # Initialize processing components - self.converter = ConvertNRRD4DTo3D() - self.segmenter = SegmentChestTotalSegmentator() + self.converter = ConvertNRRD4DTo3D(log_level=log_level) + self.segmenter = SegmentChestTotalSegmentator(log_level=log_level) self.segmenter.contrast_threshold = 500 # Initialize registration method if self.registration_method == 'ants': - print(f"Initializing ANTs registration...") - self.registrar = RegisterImagesANTs() + self.log_info("Initializing ANTs registration...") + self.registrar = RegisterImagesANTs(log_level=log_level) self.registrar.set_modality('ct') self.registrar.set_transform_type('SyN') if ( @@ -95,8 +101,8 @@ def __init__( ) ) else: # icon (default) - print(f"Initializing ICON registration...") - self.registrar = RegisterImagesICON() + self.log_info("Initializing ICON registration...") + self.registrar = RegisterImagesICON(log_level=log_level) self.registrar.set_modality('ct') if ( number_of_registration_iterations is not None @@ -124,7 +130,7 @@ def process(self) -> str: Returns: str: Path to the final dynamic anatomy USD file """ - print("Starting Heart-gated CT processing pipeline...") + self.log_section("Heart-gated CT Processing Pipeline") # Load and convert data self._load_time_series() @@ -141,12 +147,12 @@ def process(self) -> str: # Create USD files self._create_usd_files() - print("Processing pipeline completed successfully") + self.log_info("Processing pipeline completed successfully") return f"{self.project_name}.dynamic_anatomy_painted.usd" def _load_time_series(self): """Load and convert 4D data to time series images.""" - print("Loading time series data...") + self.log_info("Loading time series data...") if len(self.input_filenames) == 1: self.converter.load_nrrd_4d(self.input_filenames[0]) @@ -156,7 +162,7 @@ def _load_time_series(self): ) ) else: - print(f"Loading {len(self.input_filenames)} 3D NRRD files") + self.log_info("Loading %d 3D NRRD files", len(self.input_filenames)) self.converter.load_nrrd_3d(self.input_filenames) self._num_time_points = self.converter.get_number_of_3d_images() @@ -178,14 +184,14 @@ def _load_time_series(self): compression=True, ) - print(f"Loaded {self._num_time_points} time points") + self.log_info("Loaded %d time points", self._num_time_points) def _segment_and_register_frames(self): """Segment each frame and register to reference image.""" - print("Segmenting and registering frames...") + self.log_info("Segmenting and registering frames...") # Segment reference image - print("Segmenting reference image...") + self.log_info("Segmenting reference image...") self._fixed_segmentation = self.segmenter.segment( self._fixed_image, contrast_enhanced_study=self.contrast_enhanced ) @@ -231,29 +237,29 @@ def _segment_and_register_frames(self): # Process each time point self._time_series_transforms = [] for i in range(self._num_time_points): - print(f"Processing frame {i+1}/{self._num_time_points}") + self.log_progress(i + 1, self._num_time_points, prefix="Processing frames") moving_image = self._time_series_images[i] # Register without mask first self.registrar.set_fixed_image_mask(None) result_all = self.registrar.register(moving_image) - phi_FM_all = result_all["phi_FM"] - phi_MF_all = result_all["phi_MF"] + inverse_transform_all = result_all["inverse_transform"] + forward_transform_all = result_all["forward_transform"] itk.transformwrite( - phi_FM_all, + inverse_transform_all, os.path.join(self.output_directory, f"slice_{i:03d}_all_AB.hdf"), compression=True, ) itk.transformwrite( - phi_MF_all, + forward_transform_all, os.path.join(self.output_directory, f"slice_{i:03d}_all_BA.hdf"), compression=True, ) # Estimate the moving dynamic mask by the inverse transform of the fixed dynamic mask moving_dynamic_mask = TransformTools().transform_image( - fixed_dynamic_mask, phi_FM_all, moving_image, "nearest" + fixed_dynamic_mask, inverse_transform_all, moving_image, "nearest" ) itk.imwrite( moving_dynamic_mask, @@ -262,12 +268,12 @@ def _segment_and_register_frames(self): ) self.registrar.set_fixed_image_mask(fixed_dynamic_mask) result_dynamic = self.registrar.register(moving_image, moving_dynamic_mask) - phi_FM_dynamic = result_dynamic["phi_FM"] - phi_MF_dynamic = result_dynamic["phi_MF"] + inverse_transform_dynamic = result_dynamic["inverse_transform"] + forward_transform_dynamic = result_dynamic["forward_transform"] # Estimate the moving static mask by the inverse transform of the fixed static mask moving_static_mask = TransformTools().transform_image( - fixed_static_mask, phi_FM_all, moving_image, "nearest" + fixed_static_mask, inverse_transform_all, moving_image, "nearest" ) itk.imwrite( moving_static_mask, @@ -276,32 +282,41 @@ def _segment_and_register_frames(self): ) self.registrar.set_fixed_image_mask(fixed_static_mask) result_static = self.registrar.register(moving_image, moving_static_mask) - phi_FM_static = result_static["phi_FM"] - phi_MF_static = result_static["phi_MF"] + inverse_transform_static = result_static["inverse_transform"] + forward_transform_static = result_static["forward_transform"] # Store transforms transforms = { - 'dynamic': {'phi_FM': phi_FM_dynamic, 'phi_MF': phi_MF_dynamic}, - 'static': {'phi_FM': phi_FM_static, 'phi_MF': phi_MF_static}, - 'all': {'phi_FM': phi_FM_all, 'phi_MF': phi_MF_all}, + 'dynamic': { + 'inverse_transform': inverse_transform_dynamic, + 'forward_transform': forward_transform_dynamic, + }, + 'static': { + 'inverse_transform': inverse_transform_static, + 'forward_transform': forward_transform_static, + }, + 'all': { + 'inverse_transform': inverse_transform_all, + 'forward_transform': forward_transform_all, + }, } itk.transformwrite( - phi_FM_dynamic, + inverse_transform_dynamic, os.path.join(self.output_directory, f"slice_{i:03d}_dynamic_AB.hdf"), compression=True, ) itk.transformwrite( - phi_MF_dynamic, + forward_transform_dynamic, os.path.join(self.output_directory, f"slice_{i:03d}_dynamic_BA.hdf"), compression=True, ) itk.transformwrite( - phi_FM_static, + inverse_transform_static, os.path.join(self.output_directory, f"slice_{i:03d}_static_AB.hdf"), compression=True, ) itk.transformwrite( - phi_MF_static, + forward_transform_static, os.path.join(self.output_directory, f"slice_{i:03d}_static_BA.hdf"), compression=True, ) @@ -309,7 +324,7 @@ def _segment_and_register_frames(self): def _generate_reference_contours(self): """Generate contour meshes from reference segmentation.""" - print("Generating reference contours...") + self.log_info("Generating reference contours...") ( labelmap_image, @@ -363,22 +378,26 @@ def _generate_reference_contours(self): def _transform_all_contours(self): """Transform contours for all time points using registration transforms.""" - print("Transforming contours for all time points...") + self.log_info("Transforming contours for all time points...") self._transformed_contours = {'all': [], 'dynamic': [], 'static': []} for i in range(self._num_time_points): - print(f"Transforming contours for frame {i+1}/{self._num_time_points}") + self.log_progress( + i + 1, self._num_time_points, prefix="Transforming contours" + ) frame_contours = {} for anatomy_type in ['all', 'dynamic', 'static']: - # Get the inverse transform for this anatomy type and frame - phi_MF = self._time_series_transforms[i][anatomy_type]['phi_MF'] + # Get the forward transform for this anatomy type and frame + forward_transform = self._time_series_transforms[i][anatomy_type][ + 'forward_transform' + ] # Transform the reference contours transformed_contours = self.contour_tools.transform_contours( self._reference_contours[anatomy_type], - phi_MF, + forward_transform, with_deformation_magnitude=False, ) @@ -387,25 +406,26 @@ def _transform_all_contours(self): def _create_usd_files(self): """Create painted USD files for all anatomy types.""" - print("Creating USD files...") + self.log_info("Creating USD files...") # Create USD for each anatomy type for anatomy_type in ['all', 'dynamic', 'static']: - print(f"Creating {anatomy_type} anatomy USD...") + self.log_info("Creating %s anatomy USD...", anatomy_type) # Convert VTK contours to USD converter = ConvertVTK4DToUSD( self.project_name, self._transformed_contours[anatomy_type], self.segmenter.all_mask_ids, - os.path.join( - self.output_directory, f"{self.project_name}.{anatomy_type}.usd" - ), + log_level=self.log_level, + ) + usd_file = os.path.join( + self.output_directory, f"{self.project_name}.{anatomy_type}.usd" ) - stage = converter.convert() + stage = converter.convert(usd_file) # Paint the USD file - print(f"Painting {anatomy_type} anatomy USD...") + self.log_info("Painting %s anatomy USD...", anatomy_type) output_filename = os.path.join( self.output_directory, f"{self.project_name}.{anatomy_type}_painted.usd" ) diff --git a/src/physiomotion4d/heart_model_to_patient_workflow.py b/src/physiomotion4d/heart_model_to_patient_workflow.py index adb1b1b..691f1c0 100644 --- a/src/physiomotion4d/heart_model_to_patient_workflow.py +++ b/src/physiomotion4d/heart_model_to_patient_workflow.py @@ -1,11 +1,11 @@ """Model-to-image and model-to-model registration for anatomical models. This module provides the ModelToPatientWorkflow class for registering generic -anatomical models to patient-specific imaging data and surface meshes. +anatomical models to patient-specific imaging data and surface models. The workflow includes: -1. Rough alignment using ICP (RegisterModelToModelICP) -1.5. Optional PCA-based registration (RegisterModelToImagePCA) if PCA data provided -2. Mask-based deformable registration (RegisterModelToModelMasks) +1. Rough alignment using ICP (RegisterModelsICP) +1.5. Optional PCA-based registration (RegisterModelsPCA) if PCA data provided +2. Mask-based deformable registration (RegisterModelsDistanceMaps) 3. Optional final mask-to-image refinement using Icon The registration is particularly useful for cardiac modeling where a generic heart model @@ -13,44 +13,14 @@ Key Features: - Automatic mask generation if not provided by user - - Modular design using RegisterModelToModelICP, RegisterModelToImagePCA, and RegisterModelToModelMasks + - Modular design using RegisterModelsICP, RegisterModelsPCA, and RegisterModelsDistanceMaps - Multi-stage registration pipeline: ICP → (optional PCA) → mask-to-mask → mask-to-image - Optional PCA-based shape fitting with SlicerSALT format support - Support for multi-label anatomical structures - Optional Icon-based final refinement - -Example: - >>> import itk - >>> import pyvista as pv - >>> from physiomotion4d import HeartModelToPatientWorkflow - >>> - >>> # Load patient data - >>> patient_surfaces = [pv.read("lv.stl"), pv.read("mc.stl"), pv.read("rv.stl")] - >>> reference_image = itk.imread("patient_ct.nii.gz") - >>> - >>> # For PCA-based workflow, use a dummy mesh initially (will be replaced) - >>> dummy_mesh = patient_surfaces[0] # Placeholder - >>> - >>> # Initialize registration - >>> registrar = HeartModelToPatientWorkflow( - ... moving_mesh=dummy_mesh, - ... fixed_meshes=patient_surfaces, - ... fixed_image=reference_image, - ... ) - >>> - >>> # Load PCA model from SlicerSALT format (replaces moving mesh) - >>> registrar.set_pca_data_from_slicersalt( - ... json_filename='path/to/pca.json', - ... group_key='All', - ... n_pca_modes=10 - ... ) - >>> - >>> # Run complete workflow - >>> registered_mesh = registrar.run_workflow() """ import logging -from typing import Optional import itk import numpy as np @@ -61,9 +31,9 @@ from physiomotion4d.physiomotion4d_base import PhysioMotion4DBase from physiomotion4d.register_images_ants import RegisterImagesANTs from physiomotion4d.register_images_icon import RegisterImagesICON -from physiomotion4d.register_model_to_image_pca import RegisterModelToImagePCA -from physiomotion4d.register_model_to_model_icp import RegisterModelToModelICP -from physiomotion4d.register_model_to_model_masks import RegisterModelToModelMasks +from physiomotion4d.register_models_distance_maps import RegisterModelsDistanceMaps +from physiomotion4d.register_models_icp import RegisterModelsICP +from physiomotion4d.register_models_pca import RegisterModelsPCA from physiomotion4d.transform_tools import TransformTools @@ -71,137 +41,126 @@ class HeartModelToPatientWorkflow(PhysioMotion4DBase): """Register anatomical models using multi-stage ICP, mask-based, and image-based registration. This class provides a flexible workflow for registering generic anatomical models - (e.g., cardiac meshes) to patient-specific surface meshes and images. The + (e.g., cardiac models) to patient-specific surface models and images. The registration pipeline combines: - - Initial mesh alignment using RegisterModelToModelICP (centroid + affine ICP) - - Mask-based deformable registration using RegisterModelToModelMasks (ANTs) + - Initial model alignment using RegisterModelsICP (centroid + affine ICP) + - Mask-based deformable registration using RegisterModelsDistanceMaps (ANTs) - Optional final mask-to-image refinement using Icon registration **Registration Pipeline:** - 1. **ICP Alignment**: Rough affine alignment using RegisterModelToModelICP - 1.5. **PCA Registration** (optional): If PCA data is provided via set_pca_data(), - performs PCA-based shape fitting using RegisterModelToImagePCA - 2. **Mask-to-Mask**: Deformable registration using RegisterModelToModelMasks - 3. **Mask-to-Image** (optional): Final refinement using Icon registration + 1. **ICP Alignment**: Rough affine alignment using RegisterModelsICP + 2. **PCA Registration**: Performs PCA-based shape fitting using RegisterModelsPCA + 3. **Mask-to-Mask**: Deformable registration using RegisterModelsDistanceMaps + 4. **Mask-to-Image**: Final refinement **Mask Configuration:** - Masks are automatically generated from meshes if not provided by the user + Masks are automatically generated from models if not provided by the user via set_masks(). Auto-generated masks use mask_dilation_mm parameter. Attributes: - moving_original_mesh (pv.UnstructuredGrid): Generic anatomical model to be registered - moving_mesh (pv.PolyData): Surface extracted from moving_mesh - fixed_meshes (list of pv.PolyData): Patient-specific surface meshes - fixed_mesh (pv.PolyData): Primary fixed mesh (first in list) - fixed_image (itk.Image): Reference image providing coordinate frame - moving_mask_image (itk.Image): Binary/multi-label mask for moving model - fixed_mask_image (itk.Image): Binary/multi-label mask for fixed model - moving_mask_roi_image (itk.Image): ROI mask for moving model - fixed_mask_roi_image (itk.Image): ROI mask for fixed model + template_model (pv.UnstructuredGrid): Generic anatomical model to be registered + template_model_surface (pv.PolyData): Surface extracted from template_model_surface + template_model_mask (itk.Image): Binary/multi-label mask for model model + template_model_roi (itk.Image): ROI mask for model model + patient_models (list of pv.PolyData): Patient-specific surface models + patient_model_surface (pv.PolyData): Primary patient model surface (first in list) + patient_image (itk.Image): Reference image providing coordinate frame + patient_mask (itk.Image): Binary/multi-label mask for patient model + patient_roi (itk.Image): ROI mask for patient model mask_dilation_mm (float): Dilation for mask generation roi_dilation_mm (float): Dilation for ROI mask transform_tools (TransformTools): Transform utilities registrar_icon (RegisterImagesICON): ICON registration instance registrar_ants (RegisterImagesANTs): ANTs registration instance - pca_eigenvectors (np.ndarray): PCA eigenvectors (optional) - pca_std_deviations (np.ndarray): PCA standard deviations (optional) - n_pca_modes (int): Number of PCA modes to use - use_pca (bool): Whether PCA registration is enabled - icp_phi_FM, icp_phi_MF: ICP transforms - pca_rigid_transform: PCA rigid transform (if PCA used) + pca_json_filename (str): PCA JSON filename (optional) + pca_group_key (str): PCA group key (optional) + pca_number_of_modes (int): Number of PCA modes to use + icp_forward_point_transform : ICP transforms + icp_inverse_point_transform : ICP inverse transforms + icp_template_model_surface: template model surface after ICP alignment pca_coefficients: PCA shape coefficients (if PCA used) - moving_pca_mesh: Mesh after PCA registration (if PCA used) - m2m_phi_FM, m2m_phi_MF: Mask-to-mask transforms - m2i_phi_FM, m2i_phi_MF: Mask-to-image transforms - moving_icp_mesh: Mesh after ICP alignment - moving_m2m_mesh: Mesh after mask-to-mask registration - moving_m2i_mesh: Mesh after mask-to-image registration - moving_registered_mesh: Final registered mesh + pca_template_model_surface: template model surface after PCA registration (if PCA used) + m2m_forward_transform: Mask-to-mask forward transform + m2m_inverse_transform: Mask-to-mask inverse transform + m2m_template_model_surface: template model surface after mask-to-mask registration + m2i_forward_transform: Mask-to-image forward transform + m2i_inverse_transform: Mask-to-image inverse transform + m2i_template_model_surface: template model surface after mask-to-image registration + m2i_template_labelmap: template labelmap after mask-to-image registration + registered_template_model: Final registered model + registered_template_model_surface: Final registered model surface Example: >>> # Initialize with minimal parameters >>> registrar = HeartModelToPatientWorkflow( - ... moving_mesh=generic_heart, - ... fixed_meshes=[lv_mesh, mc_mesh, rv_mesh], - ... fixed_image=patient_ct, + ... template_model=heart_model, + ... patient_models=[lv_model, mc_model, rv_model], + ... patient_image=patient_ct, + ... pca_json_filename='path/to/pca.json', + ... pca_group_key='All', + ... pca_number_of_modes=10 ... ) >>> >>> # Optional: Configure parameters (masks auto-generated if not set) >>> registrar.set_roi_dilation_mm(20) >>> - >>> # Optional: Enable PCA registration (Method 1 - from SlicerSALT file) - >>> registrar.set_pca_data_from_slicersalt( - ... json_filename='path/to/pca.json', - ... group_key='All', - ... n_pca_modes=10 - ... ) - >>> - >>> # Alternative: Enable PCA registration (Method 2 - from arrays) - >>> # registrar.set_pca_data( - >>> # eigenvectors=pca_components, - >>> # eigenvalues=pca_eigenvalues, - >>> # n_pca_modes=10 - >>> # ) - >>> >>> # Run registration - >>> final_mesh = registrar.run_workflow() + >>> patient_model = registrar.run_workflow() """ def __init__( self, - moving_mesh: pv.UnstructuredGrid, - fixed_meshes: list, - fixed_image: itk.Image, + template_model: pv.UnstructuredGrid, + template_labelmap: itk.Image, + template_labelmap_heart_muscle_ids: list[int], + template_labelmap_chamber_ids: list[int], + template_labelmap_background_ids: list[int], + patient_models: list, + patient_image: itk.Image, + pca_json_filename: str | None = None, + pca_group_key: str = 'All', + pca_number_of_modes: int = 0, log_level: int | str = logging.INFO, ): """Initialize the model-to-image-and-model registration pipeline. Args: - moving_mesh: Generic anatomical model mesh to be registered - fixed_meshes: List of patient-specific surface meshes extracted from imaging - data. Typically 3 meshes for cardiac applications: LV, myocardium, RV. - fixed_image: Patient image data providing the target coordinate frame + template_model: Generic anatomical model to be registered + patient_models: List of patient-specific models extracted from imaging + data. Typically 3 models for cardiac applications: LV, myocardium, RV. + patient_image: Patient image data providing the target coordinate frame (origin, spacing, direction). Used as reference for registration. log_level: Logging level (logging.DEBUG, logging.INFO, logging.WARNING). Default: logging.INFO - - Note: - The fixed image is median-filtered (radius=2) to reduce noise. - Masks are auto-generated if not provided via set_masks(). """ # Initialize base class with logging super().__init__(class_name="HeartModelToPatientWorkflow", log_level=log_level) - self.moving_original_mesh = moving_mesh - self.moving_mesh = moving_mesh.extract_surface() + self.template_model = template_model + self.template_model_surface = template_model.extract_surface() + self.template_labelmap = template_labelmap + self.template_labelmap_heart_muscle_ids = template_labelmap_heart_muscle_ids + self.template_labelmap_chamber_ids = template_labelmap_chamber_ids + self.template_labelmap_background_ids = template_labelmap_background_ids - self.fixed_meshes = fixed_meshes - fixed_meshes_surfaces = [mesh.extract_surface() for mesh in fixed_meshes] - combined_fixed_mesh = pv.merge(fixed_meshes_surfaces) - self.fixed_mesh = combined_fixed_mesh.extract_surface() + self.patient_models = patient_models + patient_models_surfaces = [model.extract_surface() for model in patient_models] + combined_patient_model = pv.merge(patient_models_surfaces) + self.patient_model_surface = combined_patient_model.extract_surface() - # Apply median filter to fixed image - median_filter = itk.MedianImageFilter.New(Input=fixed_image) - median_filter.SetRadius(2) - median_filter.Update() - self.fixed_image = median_filter.GetOutput() + self.patient_image = patient_image - resampler = ttk.ResampleImage.New(Input=self.fixed_image) + resampler = ttk.ResampleImage.New(Input=self.patient_image) resampler.SetMakeHighResIso(True) resampler.Update() - self.fixed_image = resampler.GetOutput() + self.patient_image = resampler.GetOutput() # Utilities self.transform_tools = TransformTools() self.contour_tools = ContourTools() - self.registrar_pca = None - self.n_pca_modes = -1 - self.use_pca = False - self.registrar_ants = RegisterImagesANTs() self.registrar_ants.set_number_of_iterations([5, 2, 5]) - # Icon registration for final mask-to-image step self.registrar_icon = RegisterImagesICON() self.registrar_icon.set_modality('ct') @@ -209,125 +168,105 @@ def __init__( self.registrar_icon.set_multi_modality(True) self.registrar_icon.set_number_of_iterations(50) - # Mask configuration (to be set by user or auto-generated) - self.moving_mask_image = None - self.fixed_mask_image = None - self.moving_mask_roi_image = None - self.fixed_mask_roi_image = None + # Mask configuration (auto-generated) + self.template_model_mask = None + self.patient_mask = None + self.template_model_roi = None + self.patient_roi = None # Parameters for mask generation and processing self.mask_dilation_mm = 5 # For auto-generated mask dilation self.roi_dilation_mm = 20 # For ROI mask generation # Stage 1: ICP alignment results - self.icp_phi_FM = None - self.icp_phi_MF = None - self.moving_icp_mesh = None - self.moving_icp_mask_image = None - self.moving_icp_mask_roi_image = None + self.icp_registrar = None + self.icp_inverse_point_transform = None + self.icp_forward_point_transform = None + self.icp_template_model_surface = None + self.icp_template_labelmap = None # Stage 1.5: PCA registration results (optional) - self.pca_rigid_transform = None + self.pca_registrar = None + self.pca_forward_point_transform = None + self.pca_inverse_point_transform = None + self.pca_json_filename = pca_json_filename + self.pca_number_of_modes = pca_number_of_modes + self.pca_group_key = pca_group_key self.pca_coefficients = None - self.moving_pca_mesh = None + self.pca_template_model_surface = None + self.pca_template_labelmap = None # Stage 2: Mask-to-mask registration results - self.m2m_phi_FM = None - self.m2m_phi_MF = None - self.moving_m2m_mesh = None - self.moving_m2m_mask_image = None - self.moving_m2m_mask_roi_image = None + self.use_m2m_registration = True + self.m2m_inverse_transform = None + self.m2m_forward_transform = None + self.m2m_template_model_surface = None + self.m2m_template_labelmap = None # Stage 3: Mask-to-image registration results - self.m2i_phi_FM = None - self.m2i_phi_MF = None - self.moving_m2i_mesh = None - self.moving_m2i_mask_image = None - self.moving_m2i_mask_roi_image = None - - # Final result - self.moving_registered_mesh = None - - def set_masks( - self, - moving_mask_image: Optional[itk.Image] = None, - fixed_mask_image: Optional[itk.Image] = None, - ): - """Set user-provided masks for registration. + self.use_m2i_registration = True + self.m2i_inverse_transform = None + self.m2i_forward_transform = None + self.m2i_template_model_surface = None + self.m2i_template_labelmap = None - Args: - moving_mask_image: Binary or multi-label mask for moving model - fixed_mask_image: Binary or multi-label mask for fixed model - """ - self.moving_mask_image = moving_mask_image - self.fixed_mask_image = fixed_mask_image + self.use_icon_registration_refinement = False - self.log_info("User-provided masks configured.") + # Final result + self.registered_template_model = None + self.registered_template_model_surface = None - def _auto_generate_masks(self, meshes: list[pv.UnstructuredGrid]) -> itk.Image: - """Auto-generate binary masks from meshes. + def _auto_generate_mask( + self, models: list[pv.UnstructuredGrid], dilate_mm: float | None = None + ) -> itk.Image: + """Auto-generate binary masks from models. - Creates binary masks from moving_mesh and fixed_meshes, with dilation + Creates binary masks from list of models, with dilation according to mask_dilation_mm parameter. """ self.log_info( - f"Auto-generating masks from meshes (dilation: {self.mask_dilation_mm}mm)..." + f"Auto-generating masks from models (dilation: {self.mask_dilation_mm}mm)..." ) - # Generate fixed mask (single mesh or multi-label) - if len(meshes) == 1: - mask_image = self.contour_tools.create_mask_from_mesh( - meshes[0], - self.fixed_image, - resample_to_reference=True, + # Generate patient mask (single model or multi-label) + if len(models) == 1: + mask = self.contour_tools.create_mask_from_mesh( + models[0], + self.patient_image, ) else: # Create multi-label mask mask_arr = None - for i, mesh in enumerate(meshes): + for i, model in enumerate(models): mask = self.contour_tools.create_mask_from_mesh( - mesh, - self.fixed_image, + model, + self.patient_image, ) mask_arr = itk.GetArrayFromImage(mask).astype(np.uint8) if i == 0: mask_arr = mask_arr * (i + 1) # Label 1, 2, 3, ... else: mask_arr = np.where(mask_arr > 0, (i + 1) * mask_arr, 0) - mask_image = itk.GetImageFromArray(mask_arr.astype(np.uint8)) - mask_image.CopyInformation(self.fixed_image) + mask = itk.GetImageFromArray(mask_arr.astype(np.uint8)) + mask.CopyInformation(self.patient_image) # Apply dilation if requested - if self.mask_dilation_mm > 0: - imMath = ttk.ImageMath.New(mask_image) - dilation_voxels = int( - self.mask_dilation_mm / self.fixed_image.GetSpacing()[0] - ) + if dilate_mm is None: + dilate_mm = self.mask_dilation_mm + if dilate_mm > 0: + imMath = ttk.ImageMath.New(mask) + dilation_voxels = int(dilate_mm / self.patient_image.GetSpacing()[0]) imMath.Dilate(dilation_voxels, 1, 0) - mask_image = imMath.GetOutputUChar() + mask = imMath.GetOutputUChar() self.log_info("Masks auto-generated successfully.") - return mask_image + return mask - def set_roi_masks( - self, - moving_mask_roi_image: itk.Image, - fixed_mask_roi_image: itk.Image, - ): - """Set user-provided ROI masks. - - Args: - moving_mask_roi_image: Binary ROI mask for moving model - fixed_mask_roi_image: Binary ROI mask for fixed model - """ - self.moving_mask_roi_image = moving_mask_roi_image - self.fixed_mask_roi_image = fixed_mask_roi_image - - self.log_info("User-provided ROI masks configured.") - - def _auto_generate_roi_masks(self, mask_image: itk.Image) -> itk.Image: - """Auto-generate ROI masks from existing masks with dilation. + def _auto_generate_roi_mask( + self, mask: itk.Image, dilate_mm: float | None = None + ) -> itk.Image: + """Auto-generate ROI mask from existing masks with dilation. Uses self.roi_dilation_mm for dilation amount. @@ -338,14 +277,21 @@ def _auto_generate_roi_masks(self, mask_image: itk.Image) -> itk.Image: f"Auto-generating ROI masks (dilation: {self.roi_dilation_mm}mm)..." ) - # Generate moving ROI mask - imMath = ttk.ImageMath.New(mask_image) - dilation_voxels = int(self.roi_dilation_mm / mask_image.GetSpacing()[0]) - imMath.Dilate(dilation_voxels, 1, 0) - mask_roi_image = imMath.GetOutputUChar() + if dilate_mm is None: + dilate_mm = self.roi_dilation_mm + + # Generate model ROI mask + roi = None + if dilate_mm > 0: + imMath = ttk.ImageMath.New(mask) + dilation_voxels = int(dilate_mm / mask.GetSpacing()[0]) + imMath.Dilate(dilation_voxels, 1, 0) + roi = imMath.GetOutputUChar() + else: + roi = mask self.log_info("ROI masks auto-generated successfully.") - return mask_roi_image + return roi def set_mask_dilation_mm(self, mask_dilation_mm: float): """Set mask dilation amount for auto-generated masks. @@ -365,534 +311,415 @@ def set_roi_dilation_mm(self, roi_dilation_mm: float): """ self.roi_dilation_mm = roi_dilation_mm - def set_pca_data( - self, - eigenvectors: np.ndarray, - eigenvalues: np.ndarray, - n_modes: int = -1, - ): - """Set PCA eigenvalues and eigenvectors for PCA-based registration. - - When this method is called, the workflow will include PCA registration - after ICP alignment and before mask-to-mask registration. + def set_use_mask_to_mask_registration(self, use_mask_to_mask_registration: bool): + """Set whether to use mask-to-mask registration. Args: - eigenvectors: PCA eigenvectors/components array. Shape: (n_pca_modes, n_points*3) - Each row is a flattened eigenmode with 3D displacements: [x1,y1,z1, x2,y2,z2, ...] - eigenvalues: PCA eigenvalues array. Shape: (n_pca_modes,) - These will be converted to standard deviations internally. - n_pca_modes: Number of PCA modes to use in registration. Default: 10 - - Example: - >>> registrar = HeartModelToPatientWorkflow(...) - >>> registrar.set_pca_data( - ... eigenvectors=pca_components, - ... eigenvalues=pca_eigenvalues, - ... n_pca_modes=10 - ... ) + use_mask_to_mask_registration: Whether to use mask-to-mask registration. Default: True """ - self.log_info("Creating PCA registrar.") - self.log_info(f" Average mesh n_points: {len(self.moving_mesh.points)}") - self.log_info(f" Number of PCA modes: {n_modes}") - self.log_info(f" Number of eigenvalues: {len(eigenvalues)}") - self.log_info(f" Number of eigenvectors: {eigenvectors.shape[0]}") - self.log_info(f" Number of std deviations: {len(np.sqrt(eigenvalues))}") - self.log_info(f" Reference image shape: {self.fixed_image.shape}") - self.registrar_pca = RegisterModelToImagePCA( - average_mesh=self.moving_mesh, - eigenvectors=eigenvectors, - std_deviations=np.sqrt(eigenvalues), - reference_image=self.fixed_image, - n_modes=n_modes, - ) + self.use_m2m_registration = use_mask_to_mask_registration - self.n_pca_modes = self.registrar_pca.n_pca_modes - self.use_pca = True - - self.log_info( - f"PCA data configured: {len(eigenvalues)} modes available, " - f"using {n_pca_modes} modes for registration." - ) - - def set_pca_data_from_slicersalt( - self, - json_filename: str, - group_key: str = 'All', - n_modes: int = -1, - ): - """Load PCA data from SlicerSALT format and configure for registration. - - This convenience method loads PCA statistical shape model data from a - SlicerSALT JSON file and automatically updates the moving mesh to use - the PCA mean mesh. The workflow will then include PCA registration - after ICP alignment and before mask-to-mask registration. + def set_use_mask_to_image_registration(self, use_mask_to_image_registration: bool): + """Set whether to use mask-to-image registration. Args: - json_filename: Path to the SlicerSALT PCA JSON file (e.g., 'pca.json') - group_key: Key for the PCA group to extract from JSON. Default: 'All' - n_pca_modes: Number of PCA modes to use in registration. Default: 10 - - Raises: - FileNotFoundError: If JSON or VTK mesh file not found - KeyError: If group_key not found in JSON - ValueError: If data format is invalid - - Example: - >>> registrar = HeartModelToPatientWorkflow( - ... moving_mesh=generic_heart, - ... fixed_meshes=[lv_mesh, mc_mesh, rv_mesh], - ... fixed_image=patient_ct, - ... ) - >>> registrar.set_pca_data_from_slicersalt( - ... json_filename='path/to/pca.json', - ... group_key='All', - ... n_pca_modes=10 - ... ) - >>> final_mesh = registrar.run_workflow() - - Note: - This method expects the SlicerSALT file structure: - - A JSON file containing eigenvalues and components - - A corresponding VTK mesh file following the pattern: - {json_stem}_{group_key}_mean.vtk + use_m2i: Whether to use mask-to-image registration. Default: True """ - self.log_section("Loading PCA Data from SlicerSALT Format", width=70) - - self.log_info("Creating PCA registrar from SlicerSALT data.") - self.log_info(f" Average mesh n_points: {len(self.moving_mesh.points)}") - self.log_info(f" Number of PCA modes: {n_modes}") - self.log_info(f" Reference image shape: {self.fixed_image.shape}") - # Load PCA data using RegisterModelToImagePCA - self.registrar_pca = RegisterModelToImagePCA.from_slicersalt( - average_mesh=self.moving_mesh, - json_filename=json_filename, - group_key=group_key, - reference_image=self.fixed_image, - n_modes=n_modes, - ) - self.use_pca = True + self.use_m2i_registration = use_mask_to_image_registration - self.n_pca_modes = self.registrar_pca.n_pca_modes + def register_model_to_model_icp(self): + """Perform ICP alignment of template model to patient model. - self.log_section("SlicerSALT PCA Data Loaded and Configured", width=70) - - def register_mesh_to_mesh_icp(self): - """Perform ICP alignment of moving mesh to fixed patient mesh. - - Uses RegisterModelToModelICP class for affine ICP alignment. + Uses RegisterModelsICP class for ICP alignment. Returns: dict: Dictionary containing: - - 'phi_FM': Forward transform (fixed to moving) - - 'phi_MF': Reverse transform (moving to fixed) - - 'moving_mesh': Transformed moving mesh - - 'moving_mask_image': Transformed moving mask image - - 'moving_mask_roi_image': Transformed moving ROI mask image + - 'forward_transform': used to warp an image from model to patient space + - 'inverse_transform': used to warp an image from patient to model space + - 'registered_template_model_surface': Transformed model model surface """ - self.log_section("Stage 1: ICP Alignment (RegisterModelToModelICP)", width=70) + self.log_section("Stage 1: ICP Alignment (RegisterModelsICP)", width=70) # Create ICP registrar - icp_registrar = RegisterModelToModelICP( - moving_mesh=self.moving_mesh, - fixed_mesh=self.fixed_mesh, + self.icp_registrar = RegisterModelsICP( + moving_model=self.template_model_surface, + fixed_model=self.patient_model_surface, ) - # Run affine ICP registration - icp_result = icp_registrar.register(mode='affine', max_iterations=2000) + # Run rigid ICP registration + icp_result = self.icp_registrar.register(mode='rigid', max_iterations=2000) # Store results - self.icp_phi_MF = icp_result['phi_MF'] - self.icp_phi_FM = icp_result['phi_FM'] - self.moving_icp_mesh = icp_result['moving_mesh'] - - # Ensure masks exist (auto-generate if needed) - if self.moving_mask_image is None: - self.moving_mask_image = self._auto_generate_masks([self.moving_mesh]) - if self.fixed_mask_image is None: - self.fixed_mask_image = self._auto_generate_masks(self.fixed_meshes) - - if self.moving_mask_roi_image is None: - self.moving_mask_roi_image = self._auto_generate_roi_masks( - self.moving_mask_image - ) - if self.fixed_mask_roi_image is None: - self.fixed_mask_roi_image = self._auto_generate_roi_masks( - self.fixed_mask_image - ) - - # Transform moving mask images to fixed space - self.moving_icp_mask_image = self.transform_tools.transform_image( - self.moving_mask_image, - self.icp_phi_MF, - self.fixed_image, + # Note: Point transforms are in opposite direction from image transforms + self.icp_forward_point_transform = icp_result['forward_point_transform'] + self.icp_inverse_point_transform = icp_result['inverse_point_transform'] + self.icp_template_model_surface = icp_result['registered_model'] + + self.icp_template_labelmap = self.transform_tools.transform_image( + self.template_labelmap, + self.icp_inverse_point_transform, + self.patient_image, interpolation_method="nearest", ) - # Now that the moving mask is in the fixed space, we should regenerate the ROI mask - self.moving_icp_mask_roi_image = self._auto_generate_roi_masks( - self.moving_icp_mask_image, - ) - self.log_info("Stage 1 complete: ICP alignment finished.") + self.registered_template_model_surface = self.icp_template_model_surface + return { - 'phi_FM': self.icp_phi_FM, - 'phi_MF': self.icp_phi_MF, - 'moving_mesh': self.moving_icp_mesh, - 'moving_mask_image': self.moving_icp_mask_image, - 'moving_mask_roi_image': self.moving_icp_mask_roi_image, + 'inverse_point_transform': self.icp_inverse_point_transform, + 'forward_point_transform': self.icp_forward_point_transform, + 'registered_template_model_surface': self.icp_template_model_surface, + 'registered_template_labelmap': self.icp_template_labelmap, } - def register_mesh_to_mesh_pca(self): + def register_model_to_model_pca(self): """Perform PCA-based registration after ICP alignment. - Uses RegisterModelToImagePCA class for intensity-based PCA registration. + Uses RegisterModelsPCA class for intensity-based PCA registration. This method requires PCA data to be set via set_pca_data(). - Creates a contour distance map from the fixed mesh, clips it to 100, - and inverts the intensities for use as the target image in PCA registration. - Returns: dict: Dictionary containing: - - 'pca_rigid_transform': Rigid transform from PCA registration + - 'forward_point_transform': Rigid transform from PCA registration - 'pca_coefficients': PCA shape coefficients - - 'moving_mesh': PCA-registered mesh + - 'registered_template_model_surface': PCA-registered model surface Raises: - ValueError: If PCA data has not been set via set_pca_data() + ValueError: If PCA data has not been set """ self.log_section( - "Stage 1.5: PCA-Based Registration (RegisterModelToImagePCA)", width=70 + "Stage 2: PCA-Based Registration (RegisterModelsPCA)", + width=70, ) - if not self.use_pca or self.registrar_pca is None: - raise ValueError( - "PCA data not set. Call set_pca_data() before using PCA registration." - ) - - # Create contour distance map from fixed mesh - self.log_info("Creating contour distance map from fixed mesh...") - fixed_mesh_contour_map = ( - self.contour_tools.create_contour_distance_map_from_mesh( - mesh=self.fixed_mesh, - reference_image=self.fixed_image, - max_distance=100.0, - invert_distance_map=True, - ) + if self.pca_json_filename is None: + self.pca_template_model_surface = self.icp_template_model_surface + return { + 'pca_coefficients': None, + 'registered_model_surface': self.pca_template_model_surface, + } + + self.pca_registrar = RegisterModelsPCA.from_slicersalt( + pca_template_model=self.template_model_surface, + pca_json_filename=self.pca_json_filename, + pca_group_key=self.pca_group_key, + pca_number_of_modes=self.pca_number_of_modes, + post_pca_transform=self.icp_forward_point_transform, + fixed_model=self.patient_model_surface, + reference_image=self.patient_image, ) - self.registrar_pca.set_reference_image(fixed_mesh_contour_map) - self.registrar_pca.set_average_mesh(self.moving_icp_mesh) - - # Create initial transform (identity, since ICP already aligned the mesh) - initial_transform = itk.VersorRigid3DTransform[itk.D].New() - initial_transform.SetIdentity() - # Run complete PCA registration - result = self.registrar_pca.register( - initial_transform=initial_transform, - n_pca_modes=self.n_pca_modes, - stage1_max_iterations=10, - stage2_max_iterations=200, - pca_coefficient_bounds=3.0, - rigid_refinement_bounds={'versor': 0.1, 'translation_mm': 10.0}, + result = self.pca_registrar.register() + self.pca_coefficients = result['pca_coefficients'] + self.pca_template_model_surface = result['registered_model'] + + pca_transforms = self.pca_registrar.compute_pca_transforms( + reference_image=self.template_labelmap, ) + self.pca_forward_point_transform = pca_transforms['forward_point_transform'] + self.pca_inverse_point_transform = pca_transforms['inverse_point_transform'] # Store results - self.pca_rigid_transform = result['pre_phi_FM'] - self.pca_coefficients = result['pca_coefficients_FM'] - self.moving_pca_mesh = result['registered_mesh'] - self.moving_pca_mask_image = self._auto_generate_masks([self.moving_pca_mesh]) - self.moving_pca_mask_roi_image = self._auto_generate_roi_masks( - self.moving_pca_mask_image + self.registered_template_model_surface = self.pca_template_model_surface + + self.pca_template_labelmap = self.transform_tools.transform_image( + self.template_labelmap, + self.pca_inverse_point_transform, + self.template_labelmap, + ) + self.pca_template_labelmap = self.transform_tools.transform_image( + self.pca_template_labelmap, + self.icp_inverse_point_transform, + self.patient_image, + interpolation_method="nearest", ) - self.log_info("Stage 1.5 complete: PCA registration finished.") + self.log_info("Stage 2 complete: PCA registration finished.") return { - 'pre_phi_FM': self.pca_rigid_transform, - 'pca_coefficients_FM': self.pca_coefficients, - 'moving_mesh': self.moving_pca_mesh, - 'moving_mask_image': self.moving_pca_mask_image, - 'moving_mask_roi_image': self.moving_pca_mask_roi_image, + 'pca_coefficients': self.pca_coefficients, + 'forward_point_transform': self.pca_forward_point_transform, + 'inverse_point_transform': self.pca_inverse_point_transform, + 'registered_template_model_surface': self.pca_template_model_surface, + 'registered_template_labelmap': self.pca_template_labelmap, } - def register_mask_to_mask(self): - """Perform mask-based deformable registration of moving to fixed mesh. + def register_mask_to_mask(self, use_icon_refinement: bool = False) -> dict | None: + """Perform mask-based deformable registration of model to patient model. - Uses RegisterModelToModelMasks class for ANTs deformable registration. - If PCA registration was performed, uses the PCA-registered mesh as input. + Uses RegisterModelsDistanceMaps class for ANTs deformable registration. Returns: dict: Dictionary containing: - - 'phi_FM': Forward transform (fixed to moving) - - 'phi_MF': Reverse transform (moving to fixed) - - 'moving_mesh': Transformed moving mesh - - 'moving_mask_image': Transformed moving mask image - - 'moving_mask_roi_image': Transformed moving ROI mask image + - 'forward_transform': model to patient space transform + - 'inverse_transform': patient to model space transform + - 'registered_template_model_surface': Transformed model model + - 'registered_template_labelmap': Transformed model labelmap """ self.log_section( - "Stage 2: Mask-to-Mask Deformable Registration (RegisterModelToModelMasks)", + "Stage 3: Mask-to-Mask Deformable Registration (RegisterModelsDistanceMaps)", width=70, ) - # Use PCA mesh if available, otherwise use ICP mesh - if self.use_pca and self.moving_pca_mesh is not None: - input_mesh = self.moving_pca_mesh - input_mask_image = self.moving_pca_mask_image - input_mask_roi_image = self.moving_pca_mask_roi_image - else: - input_mesh = self.moving_icp_mesh - input_mask_image = self.moving_icp_mask_image - input_mask_roi_image = self.moving_icp_mask_roi_image + if not self.use_m2m_registration: + self.log_info("Mask-to-mask registration is not enabled.") + return None # Create mask-based registrar - mask_registrar = RegisterModelToModelMasks( - moving_mesh=input_mesh, - fixed_mesh=self.fixed_mesh, - reference_image=self.fixed_image, + mask_registrar = RegisterModelsDistanceMaps( + moving_model=self.pca_template_model_surface, + fixed_model=self.patient_model_surface, + reference_image=self.patient_image, + roi_dilation_mm=self.roi_dilation_mm, ) # Run deformable registration mask_result = mask_registrar.register( - mode='deformable', - use_icon=False, # No Icon refinement in this stage + transform_type='Deformable', + use_icon=use_icon_refinement, ) # Store results - self.m2m_phi_MF = mask_result['phi_MF'] - self.m2m_phi_FM = mask_result['phi_FM'] - self.moving_m2m_mesh = mask_result['moving_mesh'] - - # Transform mask images to fixed space - self.moving_m2m_mask_image = self.transform_tools.transform_image( - input_mask_image, - self.m2m_phi_MF, - self.fixed_image, - interpolation_method="nearest", - ) - self.moving_m2m_mask_roi_image = self.transform_tools.transform_image( - input_mask_roi_image, - self.m2m_phi_MF, - self.fixed_image, + self.m2m_forward_transform = mask_result['forward_transform'] + self.m2m_inverse_transform = mask_result['inverse_transform'] + self.m2m_template_model_surface = mask_result['registered_model'] + + self.registered_template_model_surface = self.m2m_template_model_surface + + self.m2m_template_labelmap = self.transform_tools.transform_image( + self.pca_template_labelmap, + self.m2m_forward_transform, + self.patient_image, interpolation_method="nearest", ) - self.log_info("Stage 2 complete: Mask-to-mask registration finished.") + self.log_info("Stage 3 complete: Mask-to-mask registration finished.") return { - 'phi_FM': self.m2m_phi_FM, - 'phi_MF': self.m2m_phi_MF, - 'moving_mesh': self.moving_m2m_mesh, - 'moving_mask_image': self.moving_m2m_mask_image, - 'moving_mask_roi_image': self.moving_m2m_mask_roi_image, + 'forward_transform': self.m2m_forward_transform, + 'inverse_transform': self.m2m_inverse_transform, + 'registered_template_model_surface': self.m2m_template_model_surface, + 'registered_template_labelmap': self.m2m_template_labelmap, } - def register_mask_to_image(self): - """Perform final mask-to-image refinement using Icon registration. + def register_labelmap_to_image( + self, use_icon_refinement: bool = False + ) -> dict | None: + """Perform labelmap-to-image refinement. - Uses Icon registration to align mask to actual image intensities. + Uses registration to align labelmap to actual image intensities. Returns: dict: Dictionary containing: - - 'phi_FM': Forward transform (fixed to moving) - - 'phi_MF': Reverse transform (moving to fixed) - - 'moving_mesh': Transformed moving mesh - - 'moving_mask_image': Transformed moving mask image - - 'moving_mask_roi_image': Transformed moving ROI mask image + - 'inverse_transform': patient to model space transform + - 'forward_transform': model to patient space transform + - 'registered_template_model_surface': Transformed model model + - 'registered_template_labelmap': Transformed model labelmap """ self.log_section( - "Stage 3: Mask-to-Image Refinement (Icon Registration)", width=70 + "Stage 4: Labelmap-to-Image Refinement (Icon Registration)", width=70 ) - if self.moving_m2m_mask_image is None: - raise ValueError( - "Moving mask image not available for mask-to-image registration. " - "Ensure mask-to-mask registration is run first." - ) - - # Prepare moving mask image for Icon registration (scale to intensity 100) - mmi_arr = ( - itk.GetArrayFromImage(self.moving_m2m_mask_image).astype(np.float32) * 100 + labelmap_arr = itk.GetArrayFromImage(self.m2m_template_labelmap).astype( + np.uint16 ) - if mmi_arr.min() == mmi_arr.max(): - raise ValueError( - "Moving mask image is empty. Ensure mask-to-mask registration is run first." - ) - - mmi = itk.GetImageFromArray(mmi_arr) - mmi.CopyInformation(self.moving_m2m_mask_image) - - self.registrar_ants.set_fixed_image(self.fixed_image) - if self.fixed_mask_roi_image is not None: - self.registrar_ants.set_fixed_image_mask(self.fixed_mask_roi_image) - result = self.registrar_ants.register( - moving_image=mmi, moving_image_mask=self.moving_m2m_mask_roi_image - ) - phi_FM_ants = result["phi_FM"] - phi_MF_ants = result["phi_MF"] - - # Configure Icon registration - self.registrar_icon.set_fixed_image(self.fixed_image) - if self.fixed_mask_roi_image is not None: - self.registrar_icon.set_fixed_image_mask(self.fixed_mask_roi_image) - - # Perform Icon registration - result = self.registrar_icon.register( - moving_image=mmi, moving_image_mask=self.moving_m2m_mask_roi_image + labelmap_arr = np.where( + np.isin(labelmap_arr, self.template_labelmap_background_ids), + 0, + labelmap_arr, ) - phi_FM_icon = result["phi_FM"] - phi_MF_icon = result["phi_MF"] - - # Compose ANTS and Icon transforms - phi_FM = self.transform_tools.combine_displacement_field_transforms( - phi_FM_ants, phi_FM_icon, self.fixed_image + labelmap_arr = np.where( + np.isin(labelmap_arr, self.template_labelmap_heart_muscle_ids), + 0, + labelmap_arr, ) - phi_MF = self.transform_tools.combine_displacement_field_transforms( - phi_MF_icon, phi_MF_ants, self.fixed_image + labelmap_arr = np.where( + np.isin(labelmap_arr, self.template_labelmap_chamber_ids), 1, labelmap_arr ) + labelmap = itk.GetImageFromArray(labelmap_arr) + labelmap.CopyInformation(self.m2m_template_labelmap) - self.m2i_phi_FM = phi_FM - self.m2i_phi_MF = phi_MF + labelmap_roi = self._auto_generate_roi_mask(labelmap) - # Transform mesh with Icon result - self.moving_m2i_mesh = self.transform_tools.transform_pvcontour( - self.moving_m2m_mesh, self.m2i_phi_FM, with_deformation_magnitude=True + patient_mask = self._auto_generate_mask( + [self.patient_model_surface], dilate_mm=0 ) + patient_roi = self._auto_generate_roi_mask(patient_mask) - # Transform mask images to fixed space - self.moving_m2i_mask_image = self.transform_tools.transform_image( - self.moving_m2m_mask_image, - self.m2i_phi_MF, - self.fixed_image, - interpolation_method="nearest", + self.registrar_ants.set_fixed_image(self.patient_image) + self.registrar_ants.set_fixed_mask(patient_roi) + + result = self.registrar_ants.register( + moving_image=labelmap, moving_mask=labelmap_roi + ) + self.m2i_inverse_transform = result["inverse_transform"] + self.m2i_forward_transform = result["forward_transform"] + + if use_icon_refinement: + # Configure Icon registration + self.registrar_icon.set_fixed_image(self.patient_image) + self.registrar_icon.set_fixed_mask(patient_roi) + + # Perform Icon registration + result = self.registrar_icon.register( + initial_forward_transform=self.m2i_forward_transform, + moving_image=labelmap, + moving_mask=labelmap_roi, + ) + self.m2i_inverse_transform = result["inverse_transform"] + self.m2i_forward_transform = result["forward_transform"] + + # Transform model with Icon result + self.m2i_template_model_surface = self.transform_tools.transform_pvcontour( + self.m2m_template_model_surface, + self.m2i_inverse_transform, + with_deformation_magnitude=True, ) - self.moving_m2i_mask_roi_image = self.transform_tools.transform_image( - self.moving_m2m_mask_roi_image, - self.m2i_phi_MF, - self.fixed_image, + + self.m2i_template_labelmap = self.transform_tools.transform_image( + self.template_labelmap, + self.m2i_forward_transform, + self.patient_image, interpolation_method="nearest", ) - self.log_info("Stage 3 complete: Mask-to-image registration finished.") + self.log_info("Stage 4 complete: Mask-to-image registration finished.") + + self.registered_template_model_surface = self.m2i_template_model_surface return { - 'phi_FM': self.m2i_phi_FM, - 'phi_MF': self.m2i_phi_MF, - 'moving_mesh': self.moving_m2i_mesh, - 'moving_mask_image': self.moving_m2i_mask_image, - 'moving_mask_roi_image': self.moving_m2i_mask_roi_image, + 'inverse_transform': self.m2i_inverse_transform, + 'forward_transform': self.m2i_forward_transform, + 'registered_template_model_surface': self.m2i_template_model_surface, + 'registered_template_labelmap': self.m2i_template_labelmap, } - def apply_transforms_to_original_mesh(self, include_m2i: bool = True): - """Apply registration transforms to the original mesh. + def transform_model(self, base_model=None) -> pv.UnstructuredGrid | None: + """Apply registration transforms to the model. - Transforms the original mesh through all registration stages. - Note: If PCA registration was used, the PCA transform is already - baked into the mesh, so we don't apply it separately. + Transforms the model through all registration stages. Args: - include_m2i: Whether to include mask-to-image transform. Default: True + base_model: Base model for generating the new model. + If None, the template model is used. Returns: - pv.UnstructuredGrid: Registered mesh + pv.UnstructuredGrid: Registered model """ - self.log_info("Applying transforms to original mesh...") + self.log_info("Applying transforms to model...") - self.moving_registered_mesh = self.moving_original_mesh.copy(deep=True) - new_points = self.moving_registered_mesh.points + new_model = None + if base_model is None: + self.registered_template_model = self.template_model.copy(deep=True) + new_points = self.registered_template_model.points + else: + new_model = base_model.copy(deep=True) + new_points = new_model.points n_points = new_points.shape[0] progress_interval = max(1, n_points // 10) # Report progress every 10% # Transform each point through the complete pipeline - for i in range(n_points): + p = itk.Point[itk.D, 3]() + for i, point in enumerate(new_points): # Report progress if i % progress_interval == 0 or i == n_points - 1: - self.log_progress(i + 1, n_points, prefix="Transforming mesh points") + self.log_progress(i + 1, n_points, prefix="Transforming model points") - p = itk.Point[itk.D, 3]() - p[0], p[1], p[2] = ( - float(new_points[i, 0]), - float(new_points[i, 1]), - float(new_points[i, 2]), - ) - - # Apply ICP transform - new_p = self.icp_phi_FM.TransformPoint(p) + p[0] = float(point[0]) + p[1] = float(point[1]) + p[2] = float(point[2]) - # Apply PCA rigid transform (if PCA was used) - if self.use_pca and self.pca_rigid_transform is not None: - new_p = self.pca_rigid_transform.TransformPoint(new_p) + # Apply PCA and ICP transforms + if self.pca_coefficients is not None: + p = self.pca_registrar.transform_point( + p, + include_post_pca_transform=True, + ) # Apply mask-to-mask transform - new_p = self.m2m_phi_FM.TransformPoint(new_p) + if self.use_m2m_registration and self.m2m_inverse_transform is not None: + p = self.m2m_inverse_transform.TransformPoint(p) - # Apply mask-to-image transform (if available and requested) - if include_m2i and self.m2i_phi_FM is not None: - new_p = self.m2i_phi_FM.TransformPoint(new_p) + # Apply mask-to-image transform + if self.use_m2i_registration and self.m2i_inverse_transform is not None: + p = self.m2i_inverse_transform.TransformPoint(p) - new_points[i, 0], new_points[i, 1], new_points[i, 2] = ( - new_p[0], - new_p[1], - new_p[2], - ) - - self.moving_registered_mesh.points = new_points + new_points[i, 0] = p[0] + new_points[i, 1] = p[1] + new_points[i, 2] = p[2] self.log_info("Transform application complete.") - return self.moving_registered_mesh + if base_model is None: + self.registered_template_model.points = new_points + return self.registered_template_model + else: + new_model.points = new_points + return new_model - def run_workflow(self, include_mask_to_image: bool = True): + def run_workflow( + self, + use_mask_to_image_registration: bool = True, + use_mask_to_mask_registration: bool = True, + use_icon_registration_refinement: bool = False, + ) -> dict: """Execute the complete multi-stage registration workflow. Runs registration stages in sequence: - 1. ICP alignment (RegisterModelToModelICP) - 1.5. Optional PCA registration (if set_pca_data() was called) - 2. Mask-to-mask deformable registration (RegisterModelToModelMasks) - 3. Optional mask-to-image refinement (Icon) - - Masks are automatically generated if not provided via set_masks(). + 1. ICP alignment (RegisterModelsICP) + 2. PCA registration (PCA data was provided) + 3. Mask-to-mask deformable registration (RegisterModelsDistanceMaps) + 4. Optional mask-to-image refinement (Icon) Args: - include_mask_to_image: Whether to include mask-to-image registration stage. + use_mask_to_image_registration: Whether to include mask-to-image registration stage. + Default: True + use_mask_to_mask_registration: Whether to include mask-to-mask registration stage. Default: True + use_icon_registration_refinement: Whether to include icon registration refinement stage. + Default: False Returns: - pv.PolyData: Final registered surface mesh - - Note: - - Masks are auto-generated from meshes if not provided via set_masks(). - - PCA registration is only performed if set_pca_data() was called. + pv.UnstructuredGrid: Final registered model """ self.log_section( "STARTING COMPLETE MODEL-TO-IMAGE-AND-MODEL REGISTRATION WORKFLOW", width=70 ) + self.use_m2m_registration = use_mask_to_mask_registration + self.use_m2i_registration = use_mask_to_image_registration + self.use_icon_registration_refinement = use_icon_registration_refinement + # Stage 1: ICP alignment - self.register_mesh_to_mesh_icp() + self.register_model_to_model_icp() - # Stage 1.5: Optional PCA registration (if PCA data was set) - if self.use_pca: - self.register_pca() + # Stage 2: Optional PCA registration (if PCA data was set) + self.register_model_to_model_pca() - final_mesh = self.moving_pca_mesh + # Stage 3: Mask-to-mask deformable registration + if self.use_m2m_registration: + self.register_mask_to_mask( + use_icon_refinement=use_icon_registration_refinement + ) - # Stage 2: Mask-to-mask deformable registration - # self.register_mask_to_mask() + # Stage 4: Optional mask-to-image refinement + if self.use_m2i_registration: + self.register_mask_to_image( + use_icon_refinement=use_icon_registration_refinement + ) - # Stage 3: Optional mask-to-image refinement - # if include_mask_to_image: - # self.register_mask_to_image() - # final_mesh = self.moving_m2i_mesh - # else: - # final_mesh = self.moving_m2m_mesh + _ = self.transform_model() self.log_section("REGISTRATION WORKFLOW COMPLETE", width=70) - self.log_info(f"Final registered mesh: {final_mesh.n_points} points") - if self.use_pca: - self.log_info("PCA registration was applied in this workflow.") + self.log_info( + f"Final registered patient model surface: {self.registered_template_model_surface.n_points} points" + ) - return final_mesh + return { + 'registered_template_model': self.registered_template_model, + 'registered_template_model_surface': self.registered_template_model_surface, + } diff --git a/src/physiomotion4d/image_tools.py b/src/physiomotion4d/image_tools.py index e4769fd..4eab033 100644 --- a/src/physiomotion4d/image_tools.py +++ b/src/physiomotion4d/image_tools.py @@ -5,12 +5,16 @@ and performing image processing operations. """ +import logging + import itk import numpy as np import SimpleITK as sitk +from physiomotion4d.physiomotion4d_base import PhysioMotion4DBase + -class ImageTools: +class ImageTools(PhysioMotion4DBase): """ Utilities for medical image format conversions and processing. @@ -26,9 +30,63 @@ class ImageTools: >>> itk_image_back = tools.convert_sitk_image_to_itk(sitk_image) """ - def __init__(self): - """Initialize ImageTools.""" - pass + def __init__(self, log_level: int | str = logging.INFO): + """Initialize ImageTools. + + Args: + log_level: Logging level (default: logging.INFO) + """ + super().__init__(class_name=self.__class__.__name__, log_level=log_level) + + def imreadVD3(self, filename: str) -> itk.Image: + """Read an ITK vector image with double precision vectors. + + ITK's imread is not wrapped for itk.Image[itk.Vector[itk.D,3],3], + so this method reads as itk.Image[itk.Vector[itk.F,3],3] and converts + to double precision. + + Args: + filename (str): Path to the image file to read + + Returns: + itk.Image[itk.Vector[itk.D,3],3]: Vector image with double precision + + Example: + >>> displacement_field = ImageTools().imreadVD3("deformation.mha") + """ + # Read as float precision vector image + image = itk.imread(filename) + if "VD" in str(type(image)): + return image + + image_arr = itk.array_from_image(image) + image_double = self.convert_array_to_image_of_vectors(image_arr, image, itk.D) + + return image_double + + def imwriteVD3(self, image: itk.Image, filename: str, compression: bool = True): + """Write an ITK vector image with double precision vectors. + + ITK's imwrite is not wrapped for itk.Image[itk.Vector[itk.D,3],3], + so this method converts to itk.Image[itk.Vector[itk.F,3],3] and writes. + + Args: + image (itk.Image[itk.Vector[itk.D,3],3]): Vector image to write + filename (str): Path to the output file + compression (bool): Whether to use compression (default: True) + + Example: + >>> ImageTools().imwriteVD3(displacement_field, "deformation.mha") + """ + # Convert to float precision for writing + if "VD" not in str(type(image)): + raise ValueError("Image must be a vector image with double precision") + + image_arr = itk.array_from_image(image) + image_float = self.convert_array_to_image_of_vectors(image_arr, image, itk.F) + + # Write the float image + itk.imwrite(image_float, filename, compression=compression) def convert_itk_image_to_sitk(self, itk_image: itk.Image) -> sitk.Image: """ @@ -78,8 +136,8 @@ def convert_itk_image_to_sitk(self, itk_image: itk.Image) -> sitk.Image: # Set metadata # Convert origin and spacing to tuples (reverse order for SimpleITK: x, y, z) - sitk_image.SetOrigin(tuple(reversed(origin))) - sitk_image.SetSpacing(tuple(reversed(spacing))) + sitk_image.SetOrigin(tuple(origin)) + sitk_image.SetSpacing(tuple(spacing)) # Direction matrix needs to be flattened and reversed appropriately # ITK and SimpleITK use the same direction convention, but we need to handle @@ -161,7 +219,10 @@ def convert_sitk_image_to_itk(self, sitk_image: sitk.Image) -> itk.Image: return itk_image def convert_array_to_image_of_vectors( - self, arr_data: np.array, ptype: itk.D, reference_image: itk.Image + self, + arr_data: np.array, + reference_image: itk.Image, + ptype=itk.D, ) -> itk.Image: """ Convert a numpy array to an ITK image of vector type. diff --git a/src/physiomotion4d/physiomotion4d_base.py b/src/physiomotion4d/physiomotion4d_base.py index 6304ee6..1179d4d 100644 --- a/src/physiomotion4d/physiomotion4d_base.py +++ b/src/physiomotion4d/physiomotion4d_base.py @@ -215,11 +215,11 @@ def set_log_classes(cls, class_names: list[str]) -> None: Args: class_names: List of class names to show logs from. - Example: ["RegisterModelToImagePCA", "HeartModelToPatientWorkflow"] + Example: ["RegisterModelsPCA", "HeartModelToPatientWorkflow"] Example: - >>> PhysioMotion4DBase.set_log_classes(["RegisterModelToImagePCA"]) - >>> # Now only RegisterModelToImagePCA logs will be shown + >>> PhysioMotion4DBase.set_log_classes(["RegisterModelsPCA"]) + >>> # Now only RegisterModelsPCA logs will be shown """ if cls._class_filter is not None: cls._class_filter.enabled = True @@ -250,7 +250,7 @@ def get_log_classes(cls) -> list[str]: Example: >>> classes = PhysioMotion4DBase.get_log_classes() >>> print(classes) - ['RegisterModelToImagePCA', 'HeartModelToPatientWorkflow'] + ['RegisterModelsPCA', 'HeartModelToPatientWorkflow'] """ if cls._class_filter is not None and cls._class_filter.enabled: return sorted(cls._class_filter.allowed_classes) diff --git a/src/physiomotion4d/register_images_ants.py b/src/physiomotion4d/register_images_ants.py index 254717e..293ef17 100644 --- a/src/physiomotion4d/register_images_ants.py +++ b/src/physiomotion4d/register_images_ants.py @@ -10,11 +10,12 @@ """ import argparse +import logging +import os import ants import itk import numpy as np -from itk import TubeTK as ttk from physiomotion4d.register_images_base import RegisterImagesBase from physiomotion4d.transform_tools import TransformTools @@ -60,18 +61,34 @@ class RegisterImagesANTs(RegisterImagesBase): >>> registrar.set_modality('ct') >>> registrar.set_fixed_image(reference_image) >>> result = registrar.register(moving_image) - >>> phi_FM = result["phi_FM"] + >>> inverse_transform = result["inverse_transform"] """ - def __init__(self): + def __init__(self, log_level: int | str = logging.INFO): """Initialize the ANTs image registration class. Calls the parent RegisterImagesBase constructor to set up common parameters. Default ANTs registration parameters are set to work well for medical images. + + Args: + log_level: Logging level (default: logging.INFO) """ - super().__init__() + super().__init__(log_level=log_level) self.number_of_iterations = [40, 20, 10] + self.transform_type = "Deformable" + + def set_transform_type(self, transform_type): + """Set the type of transform to use for registration. + + Args: + transform_type (str): Type of transform to use for registration. + Options: 'Deformable', 'Affine', 'Rigid' + """ + self.transform_type = transform_type + if transform_type not in ['Deformable', 'Affine', 'Rigid']: + self.log_error("Invalid transform type: %s", transform_type) + raise ValueError(f"Invalid transform type: {transform_type}") def _ants_to_itk_image(self, ants_image): """Convert ANTs image back to ITK format. @@ -247,26 +264,130 @@ def _antsfile_to_itk_displacement_field_transform( self._itk_to_ants_image(ref_image, dtype='float'), ) - disp_field_itk = self._ants_to_itk_image(disp_field_ants) + disp_field_itk_raw = self._ants_to_itk_image(disp_field_ants) + # Convert to the correct Image[Vector[D, 3], 3] type for DisplacementFieldTransform + # Use ImageTools helper to convert array to vector image with correct type + from physiomotion4d.image_tools import ImageTools + + image_tools = ImageTools() + + disp_array = itk.array_from_image(disp_field_itk_raw) + disp_field_itk = image_tools.convert_array_to_image_of_vectors( + disp_array, ref_image, itk.D + ) + + # Create displacement field transform disp_tfm = itk.DisplacementFieldTransform[itk.D, 3].New() disp_tfm.SetDisplacementField(disp_field_itk) return disp_tfm - def itk_transform_to_ants_transform( - self, itk_tfm: itk.Transform, reference_image: itk.Image + def itk_affine_transform_to_ants_transform(self, itk_tfm): + """Convert ITK affine/rigid transform to ANTs affine transform. + + Converts an ITK MatrixOffsetTransformBase-derived transform (such as + AffineTransform or Rigid3DTransform) to an ANTs affine transform object. + + The conversion extracts: + - 3x3 transformation matrix (converted to row-major order for ANTs) + - Translation vector + - Center of rotation (fixed parameters) + + Args: + itk_tfm (itk.Transform): ITK affine or rigid transform derived from + itkMatrixOffsetTransformBase (e.g., itk.AffineTransform[itk.D, 3], + itk.Rigid3DTransform, etc.) + + Returns: + ants.ANTsTransform: ANTs affine transform object + + Raises: + ValueError: If transform dimension is not 3D + + Example: + >>> # Create ITK affine transform + >>> affine_itk = itk.AffineTransform[itk.D, 3].New() + >>> affine_itk.SetIdentity() + >>> # Convert to ANTs + >>> affine_ants = registrar.itk_affine_transform_to_ants_transform(affine_itk) + >>> # Use in ANTs operations + >>> result = ants.apply_ants_transform(affine_ants, moving_image) + """ + # Get dimension of the transform + dimension = itk_tfm.GetInputSpaceDimension() + if dimension != 3: + raise ValueError( + f"Only 3D transforms are supported, got dimension: {dimension}" + ) + + # Check if transform has a matrix (Translation transforms don't) + if hasattr(itk_tfm, 'GetMatrix'): + # Extract matrix (ITK matrix is row-major) + matrix_itk = np.asarray(itk_tfm.GetMatrix()).reshape(3, 3) + else: + # For transforms without matrix (e.g., TranslationTransform), use identity matrix + matrix_itk = np.eye(3, dtype=np.float64) + + # Extract translation and center based on transform type: + # - MatrixOffsetTransformBase (Affine, Rigid): use GetTranslation() WITH GetCenter() + # - TranslationTransform: use GetOffset() WITHOUT GetCenter() + if hasattr(itk_tfm, 'GetTranslation'): + # MatrixOffsetTransformBase-derived transforms + translation_itk = np.asarray(itk_tfm.GetTranslation()) + center_itk = np.asarray(itk_tfm.GetCenter()) + elif hasattr(itk_tfm, 'GetOffset'): + # TranslationTransform - use GetOffset() WITHOUT GetCenter() + translation_itk = np.asarray(itk_tfm.GetOffset()) + center_itk = np.zeros(3, dtype=np.float64) # No center for translation + else: + # Fallback for unknown transform types + translation_itk = np.zeros(3, dtype=np.float64) + center_itk = np.zeros(3, dtype=np.float64) + + # ANTs affine transform parameters structure: + # For 3D: 12 parameters + # parameters[0:9]: 3x3 matrix in row-major order + # parameters[9:12]: translation vector + # fixed_parameters[0:3]: center of rotation + + # Flatten matrix to row-major order for ANTs + params = np.zeros(12, dtype=np.float64) + params[0:9] = matrix_itk.flatten() # Already row-major + params[9:12] = translation_itk + + # Ensure fixed_params is also float64 + fixed_params = center_itk.astype(np.float64) + + # Create ANTs affine transform + # Note: dimension must be integer 3, not float + ants_tfm = ants.new_ants_transform( + precision='double', + transform_type='AffineTransform', + dimension=int(3), + parameters=params.tolist(), # Convert to list to ensure proper type + fixed_parameters=fixed_params.tolist(), + ) + + return ants_tfm + + def itk_transform_to_antsfile( + self, + itk_tfm: itk.Transform, + reference_image: itk.Image, + output_filename: str, ): - """Convert ITK transform to ANTsPy transform object. + """Convert ITK transform to ANTs transform file. This method converts any ITK transform (Affine, Rigid, DisplacementField, etc.) - to an ANTsPy transform object that can be used as initial_transform in + to an ANTs transform file that can be used as initial_transform in ants.registration() or ants.label_image_registration(). The conversion process: 1. Uses TransformTools to convert the ITK transform to a displacement field 2. Converts the displacement field image from ITK to ANTs format 3. Creates an ANTsPy transform object from the displacement field + 4. Writes the ANTs transform to a file Args: itk_tfm (itk.Transform): Input ITK transform to convert. Can be any @@ -274,42 +395,56 @@ def itk_transform_to_ants_transform( CompositeTransform, etc.) reference_image (itk.Image): Reference image that defines the spatial domain for the displacement field (spacing, size, origin, direction) + output_filename (str): Path where the ANTs transform file will be written. + Typically should have .mat extension for ANTs transforms. Returns: - ants.core.ANTsTransform: ANTsPy transform object suitable for use - as initial_transform parameter in ANTs registration functions + list[str]: List containing the path to the written ANTs transform file Example: - >>> # Convert ITK affine transform to ANTs + >>> # Convert ITK affine transform to ANTs file >>> affine_itk = itk.AffineTransform[itk.D, 3].New() >>> affine_itk.SetIdentity() - >>> affine_ants = registrar.itk_transform_to_ants_transform( - ... affine_itk, reference_image + >>> transform_files = registrar.itk_transform_to_antsfile( + ... affine_itk, reference_image, "initial_transform.mat" ... ) >>> >>> # Use in registration >>> result = ants.registration( ... fixed=fixed_ants, ... moving=moving_ants, - ... initial_transform=affine_ants + ... initial_transform=transform_files ... ) """ - # Use TransformTools to convert any ITK transform to displacement field - transform_tools = TransformTools() - disp_field_itk = transform_tools.convert_transform_to_displacement_field( - tfm=itk_tfm, - reference_image=reference_image, - np_component_type=np.float64, - use_reference_image_as_mask=False, - ) + if isinstance(itk_tfm, itk.DisplacementFieldTransform) or isinstance( + itk_tfm, itk.CompositeTransform + ): + transform_tools = TransformTools() + disp_field_itk = transform_tools.convert_transform_to_displacement_field( + tfm=itk_tfm, + reference_image=reference_image, + np_component_type=np.float32, # Use float32 for compatibility with ANTs + use_reference_image_as_mask=False, + ) - # Convert ITK displacement field to ANTs image format - disp_field_ants = self._itk_to_ants_image(disp_field_itk, dtype='double') + if "nii.gz" not in output_filename: + output_filename = os.path.splitext(output_filename)[0] + '.nii.gz' - # Create ANTs transform object from displacement field - ants_tfm = ants.transform_from_displacement_field(disp_field_ants) + # Write displacement field directly as nifti (ANTs can read this) + itk.imwrite(disp_field_itk, output_filename, compression=True) + self.log_info("Wrote ANTs displacement field to: %s", output_filename) - return ants_tfm + return [output_filename] + else: + ants_tfm = self.itk_affine_transform_to_ants_transform(itk_tfm) + if ".mat" not in output_filename: + output_filename = os.path.splitext(output_filename)[0] + '.mat' + + # Write transform to file + ants.write_transform(ants_tfm, output_filename) + self.log_info("Wrote ANTs transform to: %s", output_filename) + + return [output_filename] def _antsfiles_to_itk_transforms( self, @@ -342,10 +477,9 @@ def _antsfiles_to_itk_transforms( def registration_method( self, moving_image, - moving_image_mask=None, + moving_mask=None, moving_image_pre=None, - images_are_labelmaps=False, - initial_phi_MF=None, + initial_forward_transform=None, ): """Register moving image to fixed image using ANTs registration algorithm. @@ -354,17 +488,12 @@ def registration_method( fixed image using ANTs SyN or other specified algorithms. Args: - moving_image (itk.image): The 3D image to be registered/aligned. When - images_are_labelmaps=True, this should be a label image for label-based - registration - moving_image_mask (itk.image, optional): Binary mask defining the + moving_image (itk.image): The 3D image to be registered/aligned. + moving_mask (itk.image, optional): Binary mask defining the region of interest in the moving image moving_image_pre (ants.core.ANTsImage, optional): Pre-processed moving image in ANTs format. If None, preprocessing is performed automatically - images_are_labelmaps (bool, optional): If True, use label-based registration - instead of intensity-based registration. In this mode, fixed_image and - moving_image are treated as label images - initial_phi_MF (itk.Transform, optional): Initial transform from moving + initial_forward_transform (itk.Transform, optional): Initial transform from moving to fixed space. Can be any ITK transform type (Affine, Rigid, DisplacementField, Composite, etc.). Will be converted to ANTs format automatically. The returned transforms will include this @@ -372,8 +501,8 @@ def registration_method( Returns: dict: Dictionary containing: - - "phi_FM": Forward transformation (fixed to moving) - - "phi_MF": Backward transformation (moving to fixed) + - "forward_transform": Transformation from moving to fixed + - "inverse_transform": Transformation from fixed to moving - "loss": Loss value from the registration Note: @@ -384,7 +513,7 @@ def registration_method( IMPORTANT: ANTs registration does NOT include the initial_transform in its output fwdtransforms/invtransforms. This method automatically composes the initial transform with the registration result, so the - returned phi_MF and phi_FM include both the initial alignment and + returned transforms include both the initial alignment and the registration refinement. Implementation details: @@ -397,19 +526,19 @@ def registration_method( Example: >>> # Basic registration >>> result = registrar.register(moving_image) - >>> phi_FM = result["phi_FM"] - >>> phi_MF = result["phi_MF"] + >>> inverse_transform = result["inverse_transform"] + >>> forward_transform = result["forward_transform"] >>> >>> # Masked registration for cardiac structures - >>> registrar.set_fixed_image_mask(heart_mask_fixed) + >>> registrar.set_fixed_mask(heart_mask_fixed) >>> result = registrar.register( - ... moving_image, moving_image_mask=heart_mask_moving + ... moving_image, moving_mask=heart_mask_moving ... ) >>> >>> # Registration with initial transform >>> initial_tfm = itk.AffineTransform[itk.D, 3].New() >>> result = registrar.register( - ... moving_image, initial_phi_MF=initial_tfm + ... moving_image, initial_forward_transform=initial_tfm ... ) """ if moving_image is not None: @@ -423,60 +552,65 @@ def registration_method( self.moving_image, self.modality ) - if moving_image_mask is not None: - self.moving_image_mask = moving_image_mask + if moving_mask is not None: + self.moving_mask = moving_mask if self.fixed_image_pre is None: self.fixed_image_pre = self.preprocess(self.fixed_image, self.modality) # Convert initial ITK transform to ANTs format if provided initial_transform = "identity" - if initial_phi_MF is not None: - print("Converting initial ITK transform to ANTs format...") - initial_transform = self.itk_transform_to_ants_transform( - itk_tfm=initial_phi_MF, reference_image=self.fixed_image + if initial_forward_transform is not None: + self.log_info("Converting initial ITK transform to ANTs format...") + initial_transform = self.itk_transform_to_antsfile( + itk_tfm=initial_forward_transform, + reference_image=self.fixed_image, + output_filename="initial_transform_temp.mat", ) - print("✓ Initial transform converted successfully") - - if images_are_labelmaps: - registration_result = ants.label_image_registration( - fixed_label_images=self._itk_to_ants_image(self.fixed_image), - moving_label_images=self._itk_to_ants_image(self.moving_image), - initial_transform=initial_transform, + self.log_info("Initial transform converted successfully") + + transform_type = None + if self.transform_type == "Deformable": + transform_type = "antsRegistrationSyNQuick[so]" + elif self.transform_type == "Affine": + transform_type = "antsRegistrationAffineQuick[so]" + elif self.transform_type == "Rigid": + transform_type = "antsRegistrationRigidQuick[so]" + else: + self.log_error("Invalid transform type: %s", self.transform_type) + raise ValueError(f"Invalid transform type: {self.transform_type}") + + if self.fixed_mask is not None and self.moving_mask is not None: + registration_result = ants.registration( + fixed=self._itk_to_ants_image(self.fixed_image_pre), + mask=self._itk_to_ants_image(self.fixed_mask), + moving=self._itk_to_ants_image(self.moving_image_pre), + moving_mask=self._itk_to_ants_image(self.moving_mask), + initial_transform=[initial_transform], + type_of_transform=transform_type, + use_histogram_matching=False, + mask_all_stages=True, verbose=True, + reg_iterations=self.number_of_iterations, ) else: - if self.fixed_image_mask is not None and self.moving_image_mask is not None: - registration_result = ants.registration( - fixed=self._itk_to_ants_image(self.fixed_image_pre), - mask=self._itk_to_ants_image(self.fixed_image_mask), - moving=self._itk_to_ants_image(self.moving_image_pre), - moving_mask=self._itk_to_ants_image(self.moving_image_mask), - initial_transform=initial_transform, - type_of_transform="antsRegistrationSyNQuick[so]", - use_histogram_matching=False, - mask_all_stages=True, - verbose=True, - reg_iterations=self.number_of_iterations, - ) - else: - registration_result = ants.registration( - fixed=self._itk_to_ants_image(self.fixed_image_pre), - moving=self._itk_to_ants_image(self.moving_image_pre), - initial_transform=initial_transform, - type_of_transform="antsRegistrationSyNQuick[so]", - use_histogram_matching=False, - verbose=True, - reg_iterations=self.number_of_iterations, - ) + registration_result = ants.registration( + fixed=self._itk_to_ants_image(self.fixed_image_pre), + moving=self._itk_to_ants_image(self.moving_image_pre), + initial_transform=[initial_transform], + type_of_transform=transform_type, + use_histogram_matching=False, + verbose=True, + reg_iterations=self.number_of_iterations, + ) # Convert ANTs transforms to ITK - phi_MF_reg = self._antsfiles_to_itk_transforms( + forward_reg = self._antsfiles_to_itk_transforms( registration_result['fwdtransforms'], inverse=False, reference_image=self.fixed_image, ) - phi_FM_reg = self._antsfiles_to_itk_transforms( + inverse_reg = self._antsfiles_to_itk_transforms( registration_result['invtransforms'], inverse=True, reference_image=self.moving_image, @@ -484,51 +618,55 @@ def registration_method( # Important: ANTs does NOT include the initial_transform in the output transforms # We need to manually compose them - if initial_phi_MF is not None: - print("Composing initial transform with registration result...") + if initial_forward_transform is not None: + self.log_info("Composing initial transform with registration result...") - # For phi_MF (Moving -> Fixed): Apply initial_phi_MF first, then registration - # Transform order: point -> initial_phi_MF -> phi_MF_reg - phi_MF = itk.CompositeTransform[itk.D, 3].New() - phi_MF.AddTransform(initial_phi_MF) + # For forward_transform (Moving -> Fixed): Apply initial_forward_transform first, then registration + # Transform order: point -> initial_forward_transform -> forward_reg + forward_transform = itk.CompositeTransform[itk.D, 3].New() + forward_transform.AddTransform(initial_forward_transform) # Add transforms from registration result (may be composite) - if isinstance(phi_MF_reg, itk.CompositeTransform[itk.D, 3]): - for i in range(phi_MF_reg.GetNumberOfTransforms()): - phi_MF.AddTransform(phi_MF_reg.GetNthTransform(i)) + if isinstance(forward_reg, itk.CompositeTransform[itk.D, 3]): + for i in range(forward_reg.GetNumberOfTransforms()): + forward_transform.AddTransform(forward_reg.GetNthTransform(i)) else: - phi_MF.AddTransform(phi_MF_reg) + forward_transform.AddTransform(forward_reg) - # For phi_FM (Fixed -> Moving): Apply registration inverse first, then initial inverse - # Transform order: point -> phi_FM_reg -> initial_phi_MF^(-1) - phi_FM = itk.CompositeTransform[itk.D, 3].New() + # For inverse_transform (Fixed -> Moving): Apply registration inverse first, then initial inverse + # Transform order: point -> inverse_reg -> initial_forward_transform^(-1) + inverse_transform = itk.CompositeTransform[itk.D, 3].New() # Add registration inverse transforms - if isinstance(phi_FM_reg, itk.CompositeTransform[itk.D, 3]): - for i in range(phi_FM_reg.GetNumberOfTransforms()): - phi_FM.AddTransform(phi_FM_reg.GetNthTransform(i)) + if isinstance(inverse_reg, itk.CompositeTransform[itk.D, 3]): + for i in range(inverse_reg.GetNumberOfTransforms()): + inverse_transform.AddTransform(inverse_reg.GetNthTransform(i)) else: - phi_FM.AddTransform(phi_FM_reg) + inverse_transform.AddTransform(inverse_reg) # Invert and add initial transform # For displacement field transforms, we need to invert properly transform_tools = TransformTools() - initial_phi_FM = transform_tools.invert_displacement_field_transform( + initial_inverse = transform_tools.invert_displacement_field_transform( transform_tools.convert_transform_to_displacement_field_transform( - initial_phi_MF, self.moving_image + initial_forward_transform, self.moving_image ) ) - phi_FM.AddTransform(initial_phi_FM) + inverse_transform.AddTransform(initial_inverse) - print("✓ Transforms composed successfully") + self.log_info("Transforms composed successfully") else: # No initial transform, use registration results directly - phi_MF = phi_MF_reg - phi_FM = phi_FM_reg + forward_transform = forward_reg + inverse_transform = inverse_reg moving_image_reg = registration_result['warpedmovout'] loss = ants.image_similarity( self._itk_to_ants_image(self.fixed_image), moving_image_reg, ) - return {"phi_FM": phi_FM, "phi_MF": phi_MF, "loss": loss} + return { + "forward_transform": forward_transform, + "inverse_transform": inverse_transform, + "loss": loss, + } def parse_args(): @@ -569,18 +707,19 @@ def parse_args(): registrar.set_fixed_image(itk.imread(args.fixed_image)) moving_image = itk.imread(args.moving_image) result = registrar.register(moving_image=moving_image) - res_phi_FM = result["phi_FM"] - res_phi_MF = result["phi_MF"] + inverse_transform = result["inverse_transform"] + forward_transform = result["forward_transform"] # Apply transform using ANTs moving_image_ants = registrar.preprocess(moving_image, args.modality) - # res_phi_MF contains the forward transform files (moving to fixed) + # forward_transform contains the forward transform (moving to fixed) moving_image_reg_ants = ants.apply_transforms( fixed=registrar.fixed_image_pre, moving=moving_image_ants, - transformlist=res_phi_MF, + transformlist=forward_transform, ) # Convert back to ITK and save moving_image_reg = registrar._ants_to_itk_image(moving_image_reg_ants) itk.imwrite(moving_image_reg, args.output_image, compression=True) + itk.imwrite(moving_image_reg, args.output_image, compression=True) diff --git a/src/physiomotion4d/register_images_base.py b/src/physiomotion4d/register_images_base.py index dc4a9fc..ef1da2d 100644 --- a/src/physiomotion4d/register_images_base.py +++ b/src/physiomotion4d/register_images_base.py @@ -15,14 +15,17 @@ the register() method with their specific algorithm (e.g., Icon, ANTs, etc.). """ +import logging + import itk import numpy as np from itk import TubeTK as ttk +from physiomotion4d.physiomotion4d_base import PhysioMotion4DBase from physiomotion4d.transform_tools import TransformTools -class RegisterImagesBase: +class RegisterImagesBase(PhysioMotion4DBase): """Base class for deformable image registration algorithms. This class provides a common interface and shared functionality for @@ -47,41 +50,50 @@ class and implement the register() method. modality (str): Image modality ('ct', 'mri', etc.) for parameter optimization fixed_image (itk.image): The target/reference image fixed_image_pre (itk.image): Preprocessed fixed image - fixed_image_mask (itk.image): Binary mask for fixed image ROI + fixed_mask (itk.image): Binary mask for fixed image ROI mask_dilation_mm (float): Mask dilation amount in millimeters Example: >>> class MyRegistration(RegisterImagesBase): ... def registration_method(self, moving_image, **kwargs): ... # Implement specific registration algorithm - ... return {"phi_FM": tfm_forward, "phi_MF": tfm_backward} + ... return { + ... "forward_transform": tfm_forward, # Moving → Fixed + ... "inverse_transform": tfm_inverse, # Fixed → Moving + ... "loss": 0.0 + ... } >>> >>> registrar = MyRegistration() >>> registrar.set_modality('ct') >>> registrar.set_fixed_image(reference_image) >>> result = registrar.register(moving_image) - >>> phi_FM = result["phi_FM"] - >>> phi_MF = result["phi_MF"] + >>> forward_tfm = result["forward_transform"] # Moving → Fixed + >>> inverse_tfm = result["inverse_transform"] # Fixed → Moving """ - def __init__(self): + def __init__(self, log_level: int | str = logging.INFO): """Initialize the base image registration class. Sets up the common registration parameters with default values. Algorithm-specific components (like neural networks or optimization objects) should be initialized in the concrete implementation to avoid unnecessary resource allocation. + + Args: + log_level: Logging level (default: logging.INFO) """ + super().__init__(class_name=self.__class__.__name__, log_level=log_level) + self.net = None self.modality = 'ct' self.fixed_image = None self.fixed_image_pre = None - self.fixed_image_mask = None + self.fixed_mask = None self.moving_image = None self.moving_image_pre = None - self.moving_image_mask = None + self.moving_mask = None self.mask_dilation_mm = 5 @@ -136,7 +148,7 @@ def set_mask_dilation(self, mask_dilation_mm): """ self.mask_dilation_mm = mask_dilation_mm - def set_fixed_image_mask(self, fixed_image_mask): + def set_fixed_mask(self, fixed_mask): """Set a binary mask for the fixed image region of interest. The mask constrains registration to focus on specific anatomical @@ -145,30 +157,30 @@ def set_fixed_image_mask(self, fixed_image_mask): the mask is dilated by the specified amount. Args: - fixed_image_mask (itk.image): Binary or label mask defining the + fixed_mask (itk.image): Binary or label mask defining the region of interest in the fixed image. Non-zero values are treated as foreground Example: >>> # Use heart mask to focus registration on cardiac structures - >>> registrar.set_fixed_image_mask(heart_mask) + >>> registrar.set_fixed_mask(heart_mask) """ self.fixed_image_pre = None - if fixed_image_mask is None: - self.fixed_image_mask = None + if fixed_mask is None: + self.fixed_mask = None return - mask_arr = itk.GetArrayFromImage(fixed_image_mask) + mask_arr = itk.GetArrayFromImage(fixed_mask) mask_arr = np.where(mask_arr > 0, 1, 0) - self.fixed_image_mask = itk.GetImageFromArray(mask_arr.astype(np.uint8)) - self.fixed_image_mask.CopyInformation(self.fixed_image) + self.fixed_mask = itk.GetImageFromArray(mask_arr.astype(np.uint8)) + self.fixed_mask.CopyInformation(self.fixed_image) if self.mask_dilation_mm > 0: - imMath = ttk.ImageMath.New(self.fixed_image_mask) + imMath = ttk.ImageMath.New(self.fixed_mask) imMath.Dilate( int(self.fixed_image.GetSpacing()[0] / self.mask_dilation_mm), 1, 0 ) - self.fixed_image_mask = imMath.GetOutputUChar() + self.fixed_mask = imMath.GetOutputUChar() def preprocess(self, image, modality='ct'): """Preprocess the image based on modality-specific requirements. @@ -194,28 +206,31 @@ def preprocess(self, image, modality='ct'): def registration_method( self, moving_image, - moving_image_mask=None, + moving_mask=None, moving_image_pre=None, - images_are_labelmaps=False, - initial_phi_MF=None, + initial_forward_transform=None, ) -> dict: """Main registration method to align moving image to fixed image. This method serves as the primary interface for performing image registration. It takes a moving image and optional mask and - preprocessed image, and returns the registered image along with - forward and backward transformations. + preprocessed image, and returns the forward and backward transformations. + + Note: This is an internal method that should be implemented by subclasses. + The public API is register() which wraps this method. Args: moving_image (itk.image): The 3D image to be registered to the fixed image - moving_image_mask (itk.image, optional): Binary mask for moving image ROI + moving_mask (itk.image, optional): Binary mask for moving image ROI moving_image_pre (itk.image, optional): Preprocessed moving image - images_are_labelmaps (bool, optional): Whether the images are labelmaps - initial_phi_MF (itk.Transform, optional): Initial transformation from moving to fixed + initial_forward_transform (itk.Transform, optional): Initial transformation from moving to fixed + Returns: dict: Dictionary containing: - - "phi_FM": Used to warp fixed image into moving space - - "phi_MF": Used to warp moving image into fixed space + - "forward_transform": Transform that warps moving image into fixed space + - "inverse_transform": Transform that warps fixed image into moving space + - "loss": Registration loss/metric value + Raises: ValueError: If fixed image is not set """ @@ -224,10 +239,9 @@ def registration_method( def register( self, moving_image, - moving_image_mask=None, + moving_mask=None, moving_image_pre=None, - images_are_labelmaps=False, - initial_phi_MF=None, + initial_forward_transform=None, ) -> dict: """Register a moving image to the fixed image. @@ -237,15 +251,19 @@ def register( Args: moving_image (itk.image): The 3D image to be registered to the fixed image - moving_image_mask (itk.image, optional): Binary mask for moving image ROI + moving_mask (itk.image, optional): Binary mask for moving image ROI moving_image_pre (itk.image, optional): Preprocessed moving image - images_are_labelmaps (bool, optional): Whether the images are labelmaps for label-based registration - initial_phi_MF (itk.Transform, optional): Initial transformation from fixed to moving + initial_forward_transform (itk.Transform, optional): Initial transformation from moving to fixed Returns: - dict: Dictionary containing: - - "phi_FM": Forward transformation from moving to fixed - - "phi_MF": Backward transformation from fixed to moving + dict: Dictionary containing transformation results: + - "forward_transform": Transforms moving image to fixed space (warps moving → fixed) + - "inverse_transform": Transforms fixed image to moving space (warps fixed → moving) + - "loss": Registration loss/metric value + + Note: + - forward_transform: Use this to warp the moving image to match the fixed image + - inverse_transform: Use this to warp the fixed image to match the moving image Raises: NotImplementedError: This method must be implemented by subclasses @@ -262,37 +280,36 @@ def register( modality=self.modality, ) - new_moving_image_mask = moving_image_mask - if moving_image_mask is not None: - mask_arr = itk.GetArrayFromImage(moving_image_mask) + new_moving_mask = moving_mask + if moving_mask is not None: + mask_arr = itk.GetArrayFromImage(moving_mask) mask_arr = np.where(mask_arr > 0, 1, 0) - new_moving_image_mask = itk.GetImageFromArray(mask_arr.astype(np.uint8)) - new_moving_image_mask.CopyInformation(moving_image) + new_moving_mask = itk.GetImageFromArray(mask_arr.astype(np.uint8)) + new_moving_mask.CopyInformation(moving_image) if self.mask_dilation_mm > 0: - imMath = ttk.ImageMath.New(new_moving_image_mask) + imMath = ttk.ImageMath.New(new_moving_mask) imMath.Dilate( int(moving_image.GetSpacing()[0] / self.mask_dilation_mm), 1, 0 ) - new_moving_image_mask = imMath.GetOutputUChar() + new_moving_mask = imMath.GetOutputUChar() self.moving_image = moving_image self.moving_image_pre = moving_image_pre - self.moving_image_mask = new_moving_image_mask + self.moving_mask = new_moving_mask result = self.registration_method( moving_image, - moving_image_mask=new_moving_image_mask, + moving_mask=new_moving_mask, moving_image_pre=moving_image_pre, - images_are_labelmaps=images_are_labelmaps, - initial_phi_MF=initial_phi_MF, + initial_forward_transform=initial_forward_transform, ) - phi_FM = result["phi_FM"] - phi_MF = result["phi_MF"] + forward_transform = result["forward_transform"] + inverse_transform = result["inverse_transform"] loss = result["loss"] return { - "phi_FM": phi_FM, - "phi_MF": phi_MF, + "forward_transform": forward_transform, # Warps moving → fixed + "inverse_transform": inverse_transform, # Warps fixed → moving "loss": loss, } diff --git a/src/physiomotion4d/register_images_icon.py b/src/physiomotion4d/register_images_icon.py index 487645d..80bd0ff 100644 --- a/src/physiomotion4d/register_images_icon.py +++ b/src/physiomotion4d/register_images_icon.py @@ -10,6 +10,7 @@ """ import argparse +import logging import icon_registration as icon import icon_registration.itk_wrapper @@ -54,17 +55,20 @@ class RegisterImagesICON(RegisterImagesBase): >>> registrar.set_modality('ct') >>> registrar.set_fixed_image(reference_image) >>> result = registrar.register(moving_image) - >>> phi_FM = result["phi_FM"] + >>> forward_transform = result["forward_transform"] """ - def __init__(self): + def __init__(self, log_level: int | str = logging.INFO): """Initialize the ICON image registration class. Calls the parent RegisterImagesBase constructor to set up common parameters. The ICON deep learning network is initialized lazily on first use to avoid unnecessary GPU memory allocation. + + Args: + log_level: Logging level (default: logging.INFO) """ - super().__init__() + super().__init__(log_level=log_level) self.net = None self.use_multi_modality = False @@ -127,10 +131,9 @@ def preprocess(self, image, modality): def registration_method( self, moving_image, - moving_image_mask=None, + moving_mask=None, moving_image_pre=None, - images_are_labelmaps=False, - initial_phi_MF=None, + initial_forward_transform=None, ): """Register moving image to fixed image using ICON registration algorithm. @@ -141,31 +144,26 @@ def registration_method( Args: moving_image (itk.image): The 3D image to be registered/aligned - moving_image_mask (itk.image, optional): Binary mask defining the + moving_mask (itk.image, optional): Binary mask defining the region of interest in the moving image. If provided along with - fixed_image_mask, enables mask-constrained registration + fixed_mask, enables mask-constrained registration moving_image_pre (itk.image, optional): Pre-processed moving image. If None, preprocessing is performed automatically - images_are_labelmaps (bool, optional): Whether the images are labelmaps. - Currently not used by ICON implementation - initial_phi_MF (itk.Transform, optional): Initial transformation from fixed - to moving. If provided, it is used to transform the moving image before + initial_forward_transform (itk.Transform, optional): Initial transformation from moving + to fixed. If provided, it is used to transform the moving image before registration. Returns: dict: Dictionary containing: - - "phi_FM": transform fixed image to moving space - - transform goes from moving space into fixed space - - "phi_MF": transform moving image into fixed space - - transform goes from fixed space into moving space + - "forward_transform": transform moving image into fixed space + - "inverse_transform": transform fixed image to moving space - "loss": Loss value from the registration Note: - The transformations phi_FM and phi_MF - are inverse consistent, meaning phi_MF ≈ inverse( - phi_FM). - The phi_FM transform is used to warp the fixed image - to the moving image space. The phi_MF transform is used + The transformations are inverse consistent, meaning + forward_transform ≈ inverse(inverse_transform). + The inverse_transform is used to warp the fixed image + to the moving image space. The forward_transform is used to warp the moving image to the fixed image space. Implementation details: @@ -177,13 +175,13 @@ def registration_method( Example: >>> # Basic registration >>> result = registrar.register(moving_image) - >>> phi_FM = result["phi_FM"] - >>> phi_MF = result["phi_MF"] + >>> forward_transform = result["forward_transform"] + >>> inverse_transform = result["inverse_transform"] >>> >>> # Masked registration for cardiac structures - >>> registrar.set_fixed_image_mask(heart_mask_fixed) + >>> registrar.set_fixed_mask(heart_mask_fixed) >>> result = registrar.register( - ... moving_image, moving_image_mask=heart_mask_moving + ... moving_image, moving_mask=heart_mask_moving ... ) """ @@ -193,10 +191,10 @@ def registration_method( moving_image_pre = self.preprocess(moving_image, self.modality) new_moving_image_pre = moving_image_pre - if initial_phi_MF is not None: + if initial_forward_transform is not None: new_moving_image_pre = tfm_tools.transform_image( moving_image_pre, - initial_phi_MF, + initial_forward_transform, self.fixed_image, ) @@ -213,23 +211,23 @@ def registration_method( apply_intensity_conservation_loss=self.use_mass_preservation, ) - phi_FM = None - phi_MF = None + inverse_transform = None + forward_transform = None loss_artifacts = None - if self.fixed_image_mask is not None and moving_image_mask is not None: - phi_FM, phi_MF, loss_artifacts = ( + if self.fixed_mask is not None and moving_mask is not None: + inverse_transform, forward_transform, loss_artifacts = ( icon_registration.itk_wrapper.register_pair_with_mask( self.net, self.fixed_image_pre, new_moving_image_pre, - self.fixed_image_mask, - moving_image_mask, + self.fixed_mask, + moving_mask, finetune_steps=self.number_of_iterations, return_artifacts=True, ) ) else: - phi_FM, phi_MF, loss_artifacts = ( + inverse_transform, forward_transform, loss_artifacts = ( icon_registration.itk_wrapper.register_pair( self.net, self.fixed_image_pre, @@ -241,10 +239,10 @@ def registration_method( loss = loss_artifacts[0] - if initial_phi_MF is not None: - phi_MF = tfm_tools.combine_displacement_field_transforms( - initial_phi_MF, - phi_MF, + if initial_forward_transform is not None: + forward_transform = tfm_tools.combine_displacement_field_transforms( + initial_forward_transform, + forward_transform, self.fixed_image, tfm1_weight=1.0, tfm2_weight=1.0, @@ -252,14 +250,14 @@ def registration_method( ) dftfm = tfm_tools.convert_transform_to_displacement_field_transform( - phi_MF, + forward_transform, self.fixed_image, ) - phi_FM = tfm_tools.invert_displacement_field_transform(dftfm) + inverse_transform = tfm_tools.invert_displacement_field_transform(dftfm) return { - "phi_FM": phi_FM, - "phi_MF": phi_MF, + "forward_transform": forward_transform, + "inverse_transform": inverse_transform, "loss": loss, } @@ -298,9 +296,9 @@ def parse_args(): registrar.set_fixed_image(itk.imread(args.fixed_image)) moving_image = itk.imread(args.moving_image) result = registrar.register(moving_image=moving_image) - phi_FM = result["phi_FM"] - phi_MF = result["phi_MF"] + forward_transform = result["forward_transform"] + inverse_transform = result["inverse_transform"] moving_image_reg = TransformTools().transform_image( - moving_image, phi_MF, registrar.fixed_image, "sinc" + moving_image, forward_transform, registrar.fixed_image, "sinc" ) # Final resampling with sinc itk.imwrite(moving_image_reg, args.output_image, compression=True) diff --git a/src/physiomotion4d/register_model_to_image_pca.py b/src/physiomotion4d/register_model_to_image_pca.py deleted file mode 100644 index 52e97d8..0000000 --- a/src/physiomotion4d/register_model_to_image_pca.py +++ /dev/null @@ -1,1211 +0,0 @@ -"""PCA-based model-to-image registration for cardiac anatomical models. - -This module provides the RegisterModelToImagePCA class for registering -parametric anatomical models (VTK format with PCA shape variation) to patient-specific -imaging data. The workflow includes: - -1. Stage 1: Rigid alignment (rotation + translation) to establish initial pose -2. Stage 2: Joint optimization of rigid parameters + PCA coefficients - - Allows small refinements to rigid transform - - Optimizes shape coefficients to maximize mean intensity at model points - -The registration is particularly useful for cardiac modeling where a statistical -shape model (mean + PCA modes) needs to be fitted to contrast-enhanced CT images. - -Key Features: - - Two-stage optimization (coarse rigid then joint rigid+shape) - - PCA-based shape model with eigenmode variation - - Intensity-based metric (maximize mean intensity at model points) - - ITK linear interpolation for continuous intensity sampling - - Quaternion-based rigid transform (VersorRigid3DTransform) avoids gimbal lock - - Support for VTK unstructured grids and surface meshes - -Example: - >>> import itk - >>> import numpy as np - >>> from physiomotion4d import RegisterModelToImagePCA - >>> - >>> # Load patient image - >>> image = itk.imread("patient_ct.nrrd") - >>> - >>> # Load PCA model data from SlicerSALT format - >>> average_mesh, eigenvalues, eigenvectors = ( - ... RegisterModelToImagePCA.pca_read_slicersalt("pca.json", group_key='All') - ... ) - >>> std_deviations = np.sqrt(eigenvalues) - >>> - >>> # Create initial transform - >>> initial_transform = itk.VersorRigid3DTransform[itk.D].New() - >>> initial_transform.SetIdentity() - >>> - >>> # Initialize PCA-based registration - >>> registrar = RegisterModelToImagePCA( - ... average_mesh=average_mesh, - ... eigenvectors=eigenvectors, - ... std_deviations=std_deviations, - ... reference_image=image - ... ) - >>> - >>> # Run complete two-stage registration - >>> result = registrar.register(initial_transform=initial_transform) - >>> - >>> # Access results - >>> registered_mesh = result['registered_mesh'] - >>> pca_coefficients = result['pca_coefficients'] - >>> final_intensity = result['final_intensity'] -""" - -import json -import logging -from pathlib import Path -from typing import Optional - -try: - from typing import Self -except ImportError: - from typing_extensions import Self - -import itk -import numpy as np -import pyvista as pv -from scipy.optimize import minimize -from scipy.spatial import KDTree - -from physiomotion4d.physiomotion4d_base import PhysioMotion4DBase -from physiomotion4d.transform_tools import TransformTools - - -class RegisterModelToImagePCA(PhysioMotion4DBase): - """Register PCA-based shape models to medical images using intensity optimization. - - This class implements a two-stage registration pipeline for fitting statistical - shape models to patient-specific medical images: - - **Stage 1: Coarse Rigid Alignment** - - Optimizes 6 DOF rigid transform (ITK VersorRigid3DTransform) - - Uses quaternion representation to avoid gimbal lock - - Establishes initial pose of the model in image coordinate system - - Uses Nelder-Mead optimization to maximize mean intensity - - **Stage 2: Joint Rigid + PCA Deformable Registration** - - Simultaneously optimizes rigid parameters AND PCA coefficients - - Rigid parameters are allowed small refinements from Stage 1 - - Model equation: P = rigid_transform(mean + Σ(b_i * std_i * eigenvector_i)) - - Maximizes mean intensity at deformed model points - - **Optimization Objective:** - Maximize the mean intensity of the image sampled at model points using - ITK's LinearInterpolateImageFunction. This aligns the model with bright - regions in contrast-enhanced images (e.g., blood pool in cardiac CT). - - Attributes: - average_mesh (pv.UnstructuredGrid): Mean shape model - eigenvectors (np.ndarray): PCA eigenvectors/components (modes × n_points*3) - std_deviations (np.ndarray): Standard deviations per mode (modes,) - reference_image (itk.Image): Patient image providing coordinate frame and intensity data - n_points (int): Number of points in the mesh - n_pca_modes (int): Number of PCA modes available - rigid_transform (itk.VersorRigid3DTransform): Optimized rigid transformation - pca_coefficients (np.ndarray): Optimized PCA coefficients - registered_mesh (pv.UnstructuredGrid): Final registered and deformed mesh - - Example: - >>> # Load PCA model data - >>> average_mesh = pv.read("pca_All_mean.vtk") - >>> with open("pca.json", 'r') as f: - ... pca_data = json.load(f) - >>> group_data = pca_data['All'] - >>> std_deviations = np.sqrt(np.array(group_data['eigenvalues'])) - >>> eigenvectors = np.array(group_data['components']) - >>> - >>> # Initialize registrar with loaded data - >>> registrar = RegisterModelToImagePCA( - ... average_mesh=average_mesh, - ... eigenvectors=eigenvectors, - ... std_deviations=std_deviations, - ... reference_image=patient_ct_image - ... ) - >>> - >>> # Create initial transform - >>> initial_transform = itk.VersorRigid3DTransform[itk.D].New() - >>> initial_transform.SetIdentity() - >>> - >>> # Run full registration pipeline - >>> result = registrar.register( - ... initial_transform=initial_transform, - ... n_pca_modes=10 - ... ) - >>> - >>> # Save registered mesh - >>> result['registered_mesh'].save("registered_heart.vtk") - >>> - >>> # Print optimization results - >>> print(f"Final intensity: {result['final_intensity']:.2f}") - >>> print(f"PCA coefficients: {result['pca_coefficients']}") - """ - - def __init__( - self, - average_mesh: pv.UnstructuredGrid, - eigenvectors: np.ndarray, - std_deviations: np.ndarray, - reference_image: Optional[itk.Image] = None, - n_modes: int = -1, - log_level: int | str = logging.INFO, - point_subsample_step: int = 4, - ): - """Initialize the PCA-based model-to-image registration. - - Args: - average_mesh: PyVista mesh containing the mean 3D shape model - (unstructured grid or polydata) - eigenvectors: Numpy array of PCA eigenvectors/components. Shape: (modes, n_points*3) - Each row is a flattened eigenmode with 3D displacements: [x1,y1,z1, x2,y2,z2, ...] - std_deviations: Numpy array of standard deviations per PCA mode. Shape: (modes,) - These are the square roots of eigenvalues - reference_image: ITK image providing the coordinate frame and intensity values - for registration. If None, must be set later before registration. - n_pca_modes: Number of PCA modes to use. Default: -1 (use all) - log_level: Logging level (logging.DEBUG, logging.INFO, logging.WARNING). - Default: logging.INFO - - Raises: - ValueError: If eigenvector dimensions don't match mesh points - """ - # Initialize base class with logging - super().__init__(class_name="RegisterModelToImagePCA", log_level=log_level) - - # Store model data - self.average_mesh: pv.UnstructuredGrid = average_mesh - self.eigenvectors: np.ndarray = eigenvectors - self.std_deviations: np.ndarray = std_deviations - self.reference_image = reference_image - - self.n_pca_modes: int = n_modes - if self.n_pca_modes == -1: - self.n_pca_modes = len(std_deviations) - - # Working transform (reused to avoid repeated memory allocation) - self._working_transform: itk.VersorRigid3DTransform = ( - itk.VersorRigid3DTransform[itk.D].New() - ) - self._transform_center_of_rotation: itk.Point = itk.Point[itk.D, 3]() - self._transform_center_of_rotation[0] = 0.0 - self._transform_center_of_rotation[1] = 0.0 - self._transform_center_of_rotation[2] = 0.0 - - # Registration results - Stage 1 (coarse rigid) - self.stage1_rigid_transform: Optional[itk.VersorRigid3DTransform] = None - - # Registration results - Stage 2 (refined rigid + PCA) - self.rigid_transform: Optional[itk.VersorRigid3DTransform] = None - self.pca_coefficients: Optional[np.ndarray] = None - self.registered_mesh: Optional[pv.UnstructuredGrid] = None - self.final_intensity: float = 0.0 - - # Transform utilities - self.transform_tools = TransformTools() - - # Image interpolator (created when needed) - self._interpolator: Optional[itk.LinearInterpolateImageFunction] = None - - # KDTree for efficient nearest neighbor search (built lazily) - self._kdtree: Optional[KDTree] = None - - # Distance threshold for transform_point method (in mm) - self._transform_point_distance_threshold: float = 10.0 # Default: 10mm - - # Store mean and deformed points for computing displacements - self._average_mesh_points = self.average_mesh.points - self._average_mesh_points_deformed: Optional[np.ndarray] = None - - self._metric_call_count: int = 0 - - # Pre-convert mean shape points to ITK format - self.point_subsample_step = point_subsample_step - self._average_mesh_points_itk: Optional[list[itk.Point]] = None - self._create_itk_points() - - @classmethod - def from_slicersalt( - cls, - average_mesh: pv.UnstructuredGrid, - json_filename: str, - group_key: str = 'All', - reference_image: itk.Image = None, - n_modes: int = -1, - log_level: int | str = logging.INFO, - point_subsample_step: int = 4, - ) -> Self: - """Read PCA model data from SlicerSALT format JSON file. - - This method reads PCA statistical shape model data from a JSON file - created by SlicerSALT, including the mean mesh, eigenvalues, and - eigenvector components. - - The method expects: - 1. A JSON file (e.g., 'pca.json') containing eigenvalues and components - 2. A corresponding VTK mesh file (e.g., 'pca_All_mean.vtk') in the same - directory, where the filename follows the pattern: - pca_{group_key}_mean.vtk - - Args: - json_filename: Path to the SlicerSALT PCA JSON file - group_key: Key for the PCA group to extract from JSON. - Default: 'All' - - Returns: - Tuple containing: - - average_mesh: PyVista mesh with mean shape - - eigenvalues: Numpy array of PCA eigenvalues - - eigenvectors: Numpy array of PCA eigenvector components - Shape: (modes, n_points*3) - - Raises: - FileNotFoundError: If JSON or VTK mesh file not found - KeyError: If group_key not found in JSON - ValueError: If data format is invalid - - Example: - >>> mesh, eigenvalues, eigenvectors = ( - ... RegisterModelToImagePCA.pca_read_slicersalt( - ... 'path/to/pca.json', - ... group_key='All' - ... ) - ... ) - >>> std_deviations = np.sqrt(eigenvalues) - >>> registrar = RegisterModelToImagePCA( - ... average_mesh=mesh, - ... eigenvectors=eigenvectors, - ... std_deviations=std_deviations, - ... reference_image=patient_image - ... ) - """ - # Create a logger for the classmethod since superclassclasss hasn'tt - # been initialized yet. - logger = logging.getLogger("PhysioMotion4D") - - json_path = Path(json_filename) - - # Check if JSON file exists - if not json_path.exists(): - self.log_error(f"PCA JSON file not found: {json_filename}") - raise FileNotFoundError(f"PCA JSON file not found: {json_filename}") - - logger.info("Loading PCA data from SlicerSALT format...") - logger.info(f" JSON file: {json_path}") - logger.info(f" Group key: {group_key}") - - # Load PCA data from JSON - logger.info("Reading JSON file...") - with open(json_path, 'r', encoding='utf-8') as f: - pca_data = json.load(f) - - # Extract PCA group data - if group_key not in pca_data: - available_keys = list(pca_data.keys()) - raise KeyError( - f"Group key '{group_key}' not found in JSON. " - f"Available keys: {available_keys}" - ) - - group_data = pca_data[group_key] - - # Extract eigenvalues - if 'eigenvalues' not in group_data: - raise ValueError( - f"'eigenvalues' field not found in group '{group_key}' data" - ) - eigenvalues = np.array(group_data['eigenvalues']) - logger.info(" Loaded %d eigenvalues", len(eigenvalues)) - - std_deviations = np.sqrt(eigenvalues) - - # Extract eigenvector components - if 'components' not in group_data: - raise ValueError( - f"'components' field not found in group '{group_key}' data" - ) - eigenvectors = np.array(group_data['components'], dtype=np.float64) - logger.info(f" Loaded eigenvectors with shape {eigenvectors.shape}") - - expected_eigenvector_size = average_mesh.n_points * 3 - actual_eigenvector_size = eigenvectors.shape[1] - - if actual_eigenvector_size != expected_eigenvector_size: - raise ValueError( - f"Eigenvector dimension mismatch: " - f"Expected {expected_eigenvector_size} (3 × {average_mesh.n_points} mesh points), " - f"got {actual_eigenvector_size}" - ) - - logger.info(" ✓ Data validation successful!") - logger.info("SlicerSALT PCA data loaded successfully!") - - return cls( - average_mesh=average_mesh, - eigenvectors=eigenvectors, - std_deviations=std_deviations, - reference_image=reference_image, - n_modes=n_modes, - log_level=log_level, - ) - - def _create_itk_points(self) -> None: - """Pre-convert mean shape points to ITK Point format for efficiency. - - This method creates ITK Point objects once at initialization, avoiding - repeated conversions during optimization iterations. - """ - self.log_info("Converting mean shape points to ITK format...") - - self._average_mesh_points_itk = [] - for i in range(len(self._average_mesh_points)): - itk_point = itk.Point[itk.D, 3]() - itk_point[0] = float(self._average_mesh_points[i, 0]) - itk_point[1] = float(self._average_mesh_points[i, 1]) - itk_point[2] = float(self._average_mesh_points[i, 2]) - self._average_mesh_points_itk.append(itk_point) - - self.log_info( - f" Converted {len(self._average_mesh_points_itk)} points to ITK format" - ) - - def set_reference_image(self, reference_image: itk.Image) -> None: - """Set the reference image for registration. - - Args: - reference_image: ITK image providing coordinate frame and intensity data - """ - self.reference_image = reference_image - # Clear interpolator to force recreation with new image - self._interpolator = None - - def set_average_mesh(self, average_mesh: pv.UnstructuredGrid) -> None: - """Set the average mesh for registration. - - Args: - average_mesh: PyVista mesh containing the mean 3D shape model - (unstructured grid or polydata) - """ - self.average_mesh = average_mesh - - self._kdtree = None - self._average_mesh_points = self.average_mesh.points - self._average_mesh_points_itk = None - self._average_mesh_points_deformed = None - - self._create_itk_points() - self.log_info(" ✓ Average mesh set successfully!") - - def set_transform_point_distance_threshold(self, distance_mm: float) -> None: - """Set the distance threshold for transform_point method. - - Args: - distance_mm: Distance threshold in millimeters. Points within this - distance will be used for weighted averaging when transforming - arbitrary points. Default: 10.0 mm - """ - if distance_mm <= 0: - raise ValueError("Distance threshold must be positive") - self._transform_point_distance_threshold = distance_mm - - def _evaluate_intensity_metric( - self, - pca_deformation: Optional[np.ndarray] = None, - transform_params: Optional[np.ndarray] = None, - ) -> float: - """Evaluate the optimization metric (mean intensity) at model points. - - This is the objective function to be MAXIMIZED during optimization. - Higher values indicate better alignment with bright regions. - - Args: - pca_deformation: Nx3 numpy array of PCA deformation vectors to add to points. - If None, no deformation is applied. - transform_params: 6-element array of rigid transform parameters. - If None, no rigid transformation is applied. - - Returns: - Mean intensity value across all points - """ - # Create interpolator if not already cached (inline creation) - if self._interpolator is None: - if self.reference_image is None: - self.log_error("Reference image is not set") - raise ValueError( - "Reference image must be set before creating interpolator" - ) - - ImageType = type(self.reference_image) - self._interpolator = itk.LinearInterpolateImageFunction[ - ImageType, itk.D - ].New() - self._interpolator.SetInputImage(self.reference_image) - self.log_debug(" Interpolator created") - - # Update working transform if parameters provided - if transform_params is not None: - itk_params = itk.OptimizerParameters[itk.D](6) - for i in range(6): - itk_params[i] = transform_params[i] - self._working_transform.SetParameters(itk_params) - - # Sample intensities at each point - n_valid_points = 0 - n_invalid_points = 0 - total_intensity = 0.0 - center = np.zeros(3) - image_size = self.reference_image.GetBufferedRegion().GetSize() - for i, base_point in enumerate(self._average_mesh_points_itk): - if i % self.point_subsample_step != 0: - continue - - # Start with base point - point = itk.Point[itk.D, 3]() - point[0] = base_point[0] - point[1] = base_point[1] - point[2] = base_point[2] - - # Add PCA deformation if provided - if pca_deformation is not None: - point[0] += pca_deformation[i, 0] - point[1] += pca_deformation[i, 1] - point[2] += pca_deformation[i, 2] - - # Apply rigid transform if parameters provided - if transform_params is not None: - point = self._working_transform.TransformPoint(point) - - # Check if point is inside image bounds - coord_index = self.reference_image.TransformPhysicalPointToContinuousIndex( - point - ) - if ( - 0 <= coord_index[0] < image_size[0] - and 0 <= coord_index[1] < image_size[1] - and 0 <= coord_index[2] < image_size[2] - ): - intensity = self._interpolator.EvaluateAtContinuousIndex(coord_index) - total_intensity += intensity - center[0] += point[0] - center[1] += point[1] - center[2] += point[2] - n_valid_points += 1 - else: - # Point is outside image bounds, skip - n_invalid_points += 1 - continue - - # Compute mean intensity - if n_valid_points > 0: - mean_intensity = total_intensity / n_valid_points - center /= n_valid_points - else: - mean_intensity = 0.0 - self.log_warning(" No valid points found") - - if n_invalid_points > 0: - self.log_warning(" %d points are outside image bounds", n_invalid_points) - self.log_warning(" Parameters: %s", transform_params) - self.log_warning(" Center: %s", center) - self.log_warning(" Mean intensity: %f", mean_intensity) - - if self.log_level <= logging.DEBUG or self._metric_call_count % 100 == 0: - self.log_info( - " Metric %d: %s -> %f", - (self._metric_call_count + 1), - center, - mean_intensity, - ) - self._metric_call_count += 1 - - return mean_intensity - - def _compute_pca_deformation( - self, pca_coefficients: np.ndarray, n_pca_modes: Optional[int] = None - ) -> np.ndarray: - """Compute PCA deformation vectors for all points. - - Deformation is computed as: - displacement = Σ(b_i * std_i * eigenvector_i) - - Args: - pca_coefficients: Array of PCA coefficients b_i (one per mode) - n_pca_modes: Number of PCA modes to use. Default: use all available modes - - Returns: - Nx3 array of deformation vectors (displacement from mean shape) - """ - if n_pca_modes is None: - n_pca_modes = len(pca_coefficients) - - if n_pca_modes > len(pca_coefficients): - raise ValueError( - f"Number of PCA modes to use ({n_pca_modes}) exceeds available modes ({self.n_pca_modes})" - ) - - # Initialize deformation to zero - deformation = np.zeros((self.average_mesh.n_points, 3), dtype=np.float64) - - # Add contribution from each PCA mode - for i in range(n_pca_modes): - # Get eigenvector for this mode (flattened: [x1,y1,z1, x2,y2,z2, ...]) - eigenvector_flat = self.eigenvectors[i, :] - - # Reshape to (N, 3) - eigenvector_3d = eigenvector_flat.reshape(-1, 3) - - # Add weighted deformation: b_i * std_i * eigenvector_i - deformation += pca_coefficients[i] * self.std_deviations[i] * eigenvector_3d - - return deformation - - def _rigid_objective_function(self, params: np.ndarray) -> float: - """Objective function for coarse rigid alignment optimization (Stage 1). - - This function is MINIMIZED by the optimizer, so we return negative mean intensity. - - Args: - params: 6-element array of transform parameters - - First 3: Versor rotation (rotation vector) - - Last 3: Translation - - Returns: - Negative mean intensity (to be minimized) - """ - # Evaluate intensity metric (no PCA deformation in Stage 1) - mean_intensity = self._evaluate_intensity_metric( - pca_deformation=None, transform_params=params - ) - - # Return negative (optimizer minimizes, we want to maximize intensity) - return -mean_intensity - - def optimize_rigid_alignment( - self, - initial_transform: itk.VersorRigid3DTransform, - method: str = 'Nelder-Mead', - max_iterations: int = 500, - ) -> tuple[itk.VersorRigid3DTransform, float]: - """Optimize coarse rigid alignment (Stage 1) to maximize mean intensity. - - This method optimizes 6 parameters (versor rotation + translation) - to align the mean shape model with bright regions in the image. - - Args: - initial_transform: Initial ITK VersorRigid3DTransform for starting point - method: Optimization method for scipy.optimize.minimize. - Default: 'Nelder-Mead' - max_iterations: Maximum number of optimization iterations. - Default: 500 - - Returns: - Tuple of (transform, mean_intensity): - - transform: Optimized ITK VersorRigid3DTransform - - mean_intensity: Final mean intensity metric value - - Raises: - ValueError: If reference image is not set - """ - if self.reference_image is None: - raise ValueError("Reference image must be set before optimization") - - self.log_section("Stage 1: Coarse Rigid Alignment Optimization", width=60) - - # Get initial parameters from transform - itk_params = initial_transform.GetParameters() - initial_params = np.array([itk_params[i] for i in range(len(itk_params))]) - - self.log_info(f"Initial parameters: {initial_params}") - self.log_info(f"Optimization method: {method}") - self.log_info(f"Max iterations: {max_iterations}") - - # Run optimization - self.log_info("Running optimization...") - if self.log_level <= logging.INFO: - disp = True - else: - disp = False - - result_rigid = minimize( - self._rigid_objective_function, - initial_params, - method=method, - options={'maxiter': max_iterations, 'disp': disp}, - ) - self.log_info(f"Optimization result: {result_rigid.x} -> {result_rigid.fun}") - - # Create optimized transform - optimized_transform = itk.VersorRigid3DTransform[itk.D].New() - opt_itk_params = itk.OptimizerParameters[itk.D](6) - for i in range(6): - opt_itk_params[i] = result_rigid.x[i] - optimized_transform.SetParameters(opt_itk_params) - - final_mean_intensity = -result_rigid.fun # Convert back from negative - - self.log_info("Stage 1 optimization completed!") - self.log_info(f"Final parameters: {result_rigid.x}") - self.log_info(f"Final mean intensity: {final_mean_intensity:.2f}") - - # Store Stage 1 result - self.stage1_rigid_transform = optimized_transform - - return optimized_transform, final_mean_intensity - - def _joint_objective_function(self, params: np.ndarray, n_pca_modes: int) -> float: - """Objective function for joint rigid + PCA optimization (Stage 2). - - This function is MINIMIZED by the optimizer, so we return negative mean intensity. - The first 6 parameters are rigid transformation (versor + translation), - followed by n_pca_modes PCA coefficients. - - Args: - params: (6 + n_pca_modes)-element array [v1, v2, v3, tx, ty, tz, b1, b2, ..., bn] - - v1, v2, v3: Versor rotation parameters - - tx, ty, tz: Translation in physical units - - b1, ..., bn: PCA coefficients (in units of std deviations) - n_pca_modes: Number of PCA modes being optimized - - Returns: - Negative mean intensity (to be minimized) - """ - # Extract rigid parameters - rigid_params = params[:6] - - # Extract PCA coefficients and compute deformation - pca_coefficients = params[6:] - pca_deformation = self._compute_pca_deformation( - pca_coefficients, n_pca_modes=n_pca_modes - ) - - # Evaluate intensity metric - mean_intensity = self._evaluate_intensity_metric( - pca_deformation=pca_deformation, - transform_params=rigid_params, - ) - - # Return negative (optimizer minimizes, we want to maximize intensity) - return -mean_intensity - - def optimize_joint_rigid_and_pca( - self, - initial_transform: itk.VersorRigid3DTransform, - n_pca_modes: int = -1, - method: str = 'L-BFGS-B', - pca_coefficient_bounds: float = 3.0, - rigid_refinement_bounds: Optional[dict[str, float]] = None, - max_iterations: int = 50, - ) -> tuple[itk.VersorRigid3DTransform, np.ndarray, float]: - """Optimize joint rigid parameters + PCA coefficients (Stage 2). - - This method simultaneously optimizes rigid transformation refinements - and PCA mode coefficients to deform the model to better match bright - regions in the image. The rigid parameters from Stage 1 are used as - the initial guess, and bounds constrain them to small refinements. - - Args: - n_pca_modes: Number of PCA modes to use in optimization. Using fewer - modes provides smoother deformations. Default: 10 - method: Optimization method for scipy.optimize.minimize. - Default: 'L-BFGS-B' (supports bounds) - initial_transform: Initial ITK VersorRigid3DTransform for starting point - pca_coefficient_bounds: Bound on PCA coefficients in units of std deviations. - Default: 3.0 (±3 std deviations per mode) - rigid_refinement_bounds: Dictionary specifying bounds on rigid parameter - refinements from Stage 1 initial values: - - 'versor': Max change in versor parameters (default: 0.2) - - 'translation_mm': Max translation change in mm (default: 20mm) - If None, uses defaults. - max_iterations: Maximum number of optimization iterations. - Default: 50 - - Returns: - Tuple of (transform, pca_coefficients, mean_intensity): - - transform: Final optimized ITK VersorRigid3DTransform - - pca_coefficients: Optimized PCA coefficients - - mean_intensity: Final mean intensity metric value - - Raises: - ValueError: If Stage 1 rigid alignment has not been performed - """ - self.log_section("Stage 2: Joint Rigid + PCA Deformable Registration", width=60) - - if n_pca_modes == -1: - n_pca_modes = len(self.eigenvectors) - if n_pca_modes > len(self.eigenvectors): - raise ValueError( - f"Number of PCA modes to use ({n_pca_modes}) exceeds available modes ({len(self.std_deviations)})" - ) - self.n_pca_modes = n_pca_modes - - # Set default rigid refinement bounds if not provided - if rigid_refinement_bounds is None: - rigid_refinement_bounds = { - 'versor': 0.2, # Max change in versor parameters - 'translation_mm': 20.0, # ±10 mm - } - - versor_bound = rigid_refinement_bounds['versor'] - translation_bound_mm = rigid_refinement_bounds['translation_mm'] - - self.log_info(f"Number of PCA modes: {n_pca_modes}") - self.log_info( - f"PCA coefficient bounds: ±{pca_coefficient_bounds} std deviations" - ) - self.log_info(f"Rigid versor refinement bounds: ±{versor_bound}") - self.log_info( - f"Rigid translation refinement bounds: ±{translation_bound_mm} mm" - ) - self.log_info(f"Optimization method: {method}") - self.log_info(f"Max iterations: {max_iterations}") - - # Get Stage 1 rigid parameters - itk_params = initial_transform.GetParameters() - initial_rigid_params = np.array([itk_params[i] for i in range(6)]) - - # Set initial parameters: Start from Stage 1 rigid + zero PCA coefficients - initial_params = np.concatenate( - [ - initial_rigid_params, - np.zeros(n_pca_modes), - ] # Start with mean shape (no deformation) - ) - - # Set bounds: constrained rigid refinement + PCA coefficient bounds - bounds = [] - - # Versor rotation bounds (first 3 parameters - constrained around Stage 1 values) - for v_rigid in initial_rigid_params[:3]: - bounds.append((v_rigid - versor_bound, v_rigid + versor_bound)) - - # Rigid translation bounds (last 3 rigid parameters - constrained around Stage 1 values) - for trans_rigid in initial_rigid_params[3:6]: - bounds.append( - ( - trans_rigid - translation_bound_mm, - trans_rigid + translation_bound_mm, - ) - ) - - # PCA coefficient bounds (±3 std deviations typically) - for _ in range(n_pca_modes): - bounds.append((-pca_coefficient_bounds, pca_coefficient_bounds)) - - # Run optimization - self.log_info("Running joint optimization...") - result_joint = minimize( - lambda params: self._joint_objective_function(params, n_pca_modes), - initial_params, - method=method, - bounds=bounds, - options={'maxiter': max_iterations, 'disp': False}, - ) - - # Create optimized transform - optimized_rigid_params = result_joint.x[:6] - optimized_transform = itk.VersorRigid3DTransform[itk.D].New() - opt_itk_params = itk.OptimizerParameters[itk.D](6) - for i in range(6): - opt_itk_params[i] = optimized_rigid_params[i] - optimized_transform.SetParameters(opt_itk_params) - - optimized_pca_coefficients = result_joint.x[6:] - - final_mean_intensity = -result_joint.fun # Convert back from negative - - # Compute changes from Stage 1 - param_change = optimized_rigid_params - initial_rigid_params - - self.log_info("Stage 2 optimization completed!") - self.log_info(f"Final rigid parameters: {optimized_rigid_params}") - self.log_info("Rigid refinement from initial parameters:") - self.log_info(f" Versor change: {param_change[:3]}") - self.log_info(f" Translation change (mm): {param_change[3:6]}") - self.log_info(f"Optimized PCA coefficients: {optimized_pca_coefficients}") - self.log_info(f"Final mean intensity: {final_mean_intensity:.2f}") - - # Store final results - self.rigid_transform = optimized_transform - self.pca_coefficients = optimized_pca_coefficients - self.final_intensity = final_mean_intensity - - return optimized_transform, optimized_pca_coefficients, final_mean_intensity - - def transform_mesh(self) -> pv.UnstructuredGrid: - """Create the final registered mesh by applying rigid + PCA transformations. - - This method combines the rigid transformation and PCA deformation to - create the final registered mesh with all point data and cell data - preserved from the original average mesh. - - Returns: - Final registered and deformed mesh as PyVista UnstructuredGrid - - Raises: - ValueError: If registration has not been performed - """ - if self.rigid_transform is None or self.pca_coefficients is None: - self.log_error("Must complete registration before creating registered mesh") - raise ValueError( - "Must complete registration before creating registered mesh" - ) - - self.log_info("Creating final registered mesh...") - - # Compute PCA deformation - pca_deformation = self._compute_pca_deformation( - self.pca_coefficients, - n_pca_modes=self.n_pca_modes, - ) - - # Apply deformation and rigid transform to each point - final_points = np.zeros((self.average_mesh.n_points, 3), dtype=np.float64) - - n_points = self.average_mesh.n_points - progress_interval = max(1, n_points // 10) # Report progress every 10% - - for i in range(n_points): - # Report progress - if i % progress_interval == 0 or i == n_points - 1: - self.log_progress(i + 1, n_points, prefix="Transforming points") - - # Start with mean shape point - point = itk.Point[itk.D, 3]() - point[0] = self._average_mesh_points[i][0] - point[1] = self._average_mesh_points[i][1] - point[2] = self._average_mesh_points[i][2] - - # Add PCA deformation - point[0] += pca_deformation[i, 0] - point[1] += pca_deformation[i, 1] - point[2] += pca_deformation[i, 2] - - # Apply rigid transform - transformed_point = self.rigid_transform.TransformPoint(point) - - # Store result - final_points[i, 0] = transformed_point[0] - final_points[i, 1] = transformed_point[1] - final_points[i, 2] = transformed_point[2] - - # Create new mesh with transformed points - self.registered_mesh = self.average_mesh.copy(deep=True) - self.registered_mesh.points = final_points.copy() - - # Store deformed points for transform_point method - self._average_mesh_points_deformed = final_points.copy() - - # Build KDTree from mean points for efficient nearest neighbor search - self._kdtree = KDTree(self._average_mesh_points) - - self.log_info( - f"Registered mesh created with {self.registered_mesh.n_points} points" - ) - - return self.registered_mesh - - def transform_point( - self, point: itk.Point, distance_threshold: Optional[float] = None - ) -> itk.Point: - """Transform an arbitrary point using distance-weighted interpolation. - - Finds all mesh points within a specified distance threshold and applies - a distance-weighted average of their displacements to transform the input - point. - - Args: - point: ITK point to transform (itk.Point[itk.D, 3]) - distance_threshold: Distance threshold in millimeters. If None, uses - the value set by set_transform_point_distance_threshold() - (default: 10.0 mm) - - Returns: - Transformed ITK point - - Raises: - ValueError: If registration has not been completed yet, or if no - points are found within the distance threshold - - Example: - >>> p = itk.Point[itk.D, 3]() - >>> p[0], p[1], p[2] = 10.0, 20.0, 30.0 - >>> # Use default distance threshold - >>> transformed_p = registrar.transform_point(p) - >>> # Or specify a custom distance threshold - >>> transformed_p = registrar.transform_point(p, distance_threshold=15.0) - """ - if ( - self._kdtree is None - or self._average_mesh_points is None - or self._average_mesh_points_deformed is None - ): - self.log_error( - "Must complete registration and create registered mesh before " - "calling transform_point(). Call transform_mesh() first." - ) - raise ValueError( - "Must complete registration and create registered mesh before " - "calling transform_point(). Call transform_mesh() first." - ) - - # Use provided distance threshold or default - if distance_threshold is None: - distance_threshold = self._transform_point_distance_threshold - else: - self._transform_point_distance_threshold = distance_threshold - - # Convert ITK point to numpy array for distance calculations - point_coords = np.array([float(point[0]), float(point[1]), float(point[2])]) - - # Query KDTree for all points within distance threshold - nn_indices = self._kdtree.query_ball_point(point_coords, r=distance_threshold) - - # Check if any points were found - if len(nn_indices) == 0: - raise ValueError( - f"No mesh points found within distance threshold of {distance_threshold} mm. " - f"Consider increasing the threshold using set_transform_point_distance_threshold() " - f"or passing a larger distance_threshold parameter." - ) - - # Compute distances for these points - nn_distances = np.linalg.norm( - self._average_mesh_points[nn_indices] - point_coords, axis=1 - ) - - # Compute distance weights (inverse distance weighting) - # Add small epsilon to avoid division by zero - epsilon = 1e-6 - weights = 1.0 / (nn_distances + epsilon) - weights = weights / weights.sum() # Normalize to sum to 1 - - # Compute weighted average displacement - weighted_displacement = np.zeros(3) - for i, idx in enumerate(nn_indices): - # Displacement is: deformed_point - mean_point - displacement = ( - self._average_mesh_points_deformed[idx] - self._average_mesh_points[idx] - ) - weighted_displacement += weights[i] * displacement - - # Apply displacement to input point - transformed_point = itk.Point[itk.D, 3]() - transformed_point[0] = point[0] + weighted_displacement[0] - transformed_point[1] = point[1] + weighted_displacement[1] - transformed_point[2] = point[2] + weighted_displacement[2] - - return transformed_point - - def register( - self, - initial_transform: itk.VersorRigid3DTransform = None, - n_pca_modes: int = -1, - stage1_max_iterations: int = 500, - stage2_max_iterations: int = 50, - pca_coefficient_bounds: float = 3.0, - rigid_refinement_bounds: Optional[dict[str, float]] = None, - ) -> dict: - """Execute the complete two-stage registration workflow. - - This method runs both Stage 1 (coarse rigid alignment) and Stage 2 - (joint rigid refinement + PCA deformable registration) in sequence. - - Args: - initial_transform: Initial ITK VersorRigid3DTransform for starting point - n_pca_modes: Number of PCA modes to use. Default: 10 - stage1_max_iterations: Max iterations for Stage 1 rigid. Default: 500 - stage2_max_iterations: Max iterations for Stage 2 joint. Default: 50 - pca_coefficient_bounds: PCA coefficient bounds (±std devs). Default: 3.0 - rigid_refinement_bounds: Dictionary with 'versor' and 'translation_mm' - bounds for Stage 2 rigid refinement. - Default: {'versor': 0.2, 'translation_mm': 20} - - Returns: - Dictionary containing: - - 'registered_mesh': Final registered PyVista mesh - - 'stage1_transform': Stage 1 ITK VersorRigid3DTransform - - 'rigid_transform': Final ITK VersorRigid3DTransform - - 'pca_coefficients': Optimized PCA coefficients - - 'final_intensity': Final mean intensity metric value - - Raises: - ValueError: If reference image is not set - - Example: - >>> initial_tfm = itk.VersorRigid3DTransform[itk.D].New() - >>> initial_tfm.SetIdentity() - >>> result = registrar.register( - ... initial_transform=initial_tfm, - ... n_pca_modes=10 - ... ) - >>> result['registered_mesh'].save("registered_heart.vtk") - """ - if self.reference_image is None: - raise ValueError("Reference image must be set before registration") - - if initial_transform is None: - initial_transform = itk.VersorRigid3DTransform[itk.D].New() - initial_transform.SetIdentity() - - if n_pca_modes == -1: - n_pca_modes = self.n_pca_modes - - self.log_section("PCA-BASED MODEL-TO-IMAGE REGISTRATION", width=70) - self.log_info(f"Number of points: {self.average_mesh.n_points}") - self.log_info(f"Modes to use: {n_pca_modes}") - - # Stage 1: Coarse rigid alignment - stage1_rigid_transform, stage1_intensity = self.optimize_rigid_alignment( - initial_transform=initial_transform, max_iterations=stage1_max_iterations - ) - - # Stage 2: Joint rigid + PCA optimization - final_rigid_transform, final_pca_coefficients, final_intensity = ( - self.optimize_joint_rigid_and_pca( - initial_transform=stage1_rigid_transform, - n_pca_modes=n_pca_modes, - max_iterations=stage2_max_iterations, - pca_coefficient_bounds=pca_coefficient_bounds, - rigid_refinement_bounds=rigid_refinement_bounds, - ) - ) - - # Create final registered mesh - final_registered_mesh = self.transform_mesh() - - self.log_section("REGISTRATION COMPLETE", width=70) - self.log_info(f"Stage 1 intensity (coarse rigid): {stage1_intensity:.2f}") - self.log_info(f"Stage 2 intensity (rigid+PCA): {final_intensity:.2f}") - intensity_improvement = final_intensity - stage1_intensity - self.log_info(f"Overall intensity improvement: {intensity_improvement:.2f}") - - # Return results as dictionary - return { - 'registered_mesh': final_registered_mesh, - 'pre_phi_FM': final_rigid_transform, - 'pca_coefficients_FM': final_pca_coefficients, - 'intensity': final_intensity, - } - - -# Example usage -if __name__ == "__main__": - """Example demonstrating PCA-based model-to-image registration.""" - - # This example shows how to use the RegisterModelToImagePCA class - # to register a statistical shape model to a patient-specific CT image. - - import json - - import itk - import numpy as np - import pyvista as pv - - # ========================================================================= - # Setup: Load patient image - # ========================================================================= - print("Loading patient image...") - patient_image = itk.imread("patient_cardiac_ct.nrrd") - print(f" Image size: {patient_image.GetBufferedRegion().GetSize()}") - print(f" Image spacing: {patient_image.GetSpacing()}") - - # ========================================================================= - # Load PCA model data - # ========================================================================= - print("\nLoading PCA model data...") - - # Load average mesh - average_mesh = pv.read("pca_All_mean.vtk") - print(f" Loaded mesh with {average_mesh.n_points} points") - - # Load PCA data from JSON - with open("pca.json", 'r') as f: - pca_data = json.load(f) - group_data = pca_data['All'] - - # Extract eigenvalues and convert to standard deviations - eigenvalues = np.array(group_data['eigenvalues']) - std_deviations = np.sqrt(eigenvalues) - print(f" Loaded {len(std_deviations)} eigenvalues") - - # Extract eigenvector components - eigenvectors = np.array(group_data['components'], dtype=np.float64) - print(f" Loaded eigenvectors with shape {eigenvectors.shape}") - - # ========================================================================= - # Initialize registration with loaded data - # ========================================================================= - print("\nInitializing PCA-based registration...") - registrar = RegisterModelToImagePCA( - average_mesh=average_mesh, - eigenvectors=eigenvectors, - std_deviations=std_deviations, - reference_image=patient_image, - ) - - # ========================================================================= - # Create initial transform (user-provided) - # ========================================================================= - print("\nCreating initial transform...") - initial_transform = itk.VersorRigid3DTransform[itk.D].New() - initial_transform.SetIdentity() - - # User can set initial rotation and translation here - # For example: - # params = itk.OptimizerParameters[itk.D](6) - # params[0] = 0.1 # Versor component 1 - # params[1] = 0.0 # Versor component 2 - # params[2] = 0.0 # Versor component 3 - # params[3] = 10.0 # Translation X - # params[4] = 0.0 # Translation Y - # params[5] = 0.0 # Translation Z - # initial_transform.SetParameters(params) - - # ========================================================================= - # Run complete registration (Stage 1: coarse rigid, Stage 2: joint rigid+PCA) - # ========================================================================= - print("\nRunning complete registration...") - result = registrar.register( - initial_transform=initial_transform, - n_pca_modes=10, # Use first 10 PCA modes - stage1_max_iterations=500, # Max iterations for Stage 1 (coarse rigid) - stage2_max_iterations=50, # Max iterations for Stage 2 (joint) - pca_coefficient_bounds=3.0, # Limit PCA coeffs to ±3 std deviations - rigid_refinement_bounds={ # Small refinements to rigid in Stage 2 - 'versor': 0.2, # ±0.2 in versor parameters - 'translation_mm': 20.0, # ±20 mm - }, - ) - - # ========================================================================= - # Access and save results - # ========================================================================= - print("\nSaving results...") - - # Save registered mesh - registered_mesh = result['registered_mesh'] - registered_mesh.save("registered_heart_pca.vtk") - print(f" Saved registered mesh: registered_heart_pca.vtk") - - # Print Stage 1 rigid transformation - stage1_itk_params = result['stage1_transform'].GetParameters() - stage1_params = np.array([stage1_itk_params[i] for i in range(6)]) - print("\nStage 1 (coarse rigid) transformation parameters:") - print(f" {stage1_params}") - - # Print final rigid transformation - final_itk_params = result['final_rigid_transform'].GetParameters() - final_params = np.array([final_itk_params[i] for i in range(6)]) - print("\nFinal rigid transformation parameters (after Stage 2 refinement):") - print(f" {final_params}") - - # Print PCA coefficients - print("\nPCA coefficients (in units of std deviations):") - print(f" {result['final_pca_coefficients']}") - - # Print final metric - print(f"\nFinal mean intensity: {result['final_intensity']:.2f}") - - print("\nRegistration complete!") diff --git a/src/physiomotion4d/register_model_to_model_icp.py b/src/physiomotion4d/register_model_to_model_icp.py deleted file mode 100644 index a8c985d..0000000 --- a/src/physiomotion4d/register_model_to_model_icp.py +++ /dev/null @@ -1,294 +0,0 @@ -"""ICP-based model-to-model registration for anatomical models. - -This module provides the RegisterModelToModelICP class for aligning anatomical -models using Iterative Closest Point (ICP) algorithm. The workflow includes: -1. Initial centroid alignment -2. Rigid or affine ICP alignment - -The registration is particularly useful for initial rough alignment of generic -models to patient-specific anatomical data. - -Key Features: - - Centroid-based initial alignment - - VTK ICP algorithm with rigid or affine transformation modes - - Three-stage affine pipeline: centroid → rigid ICP → affine ICP - - Support for PyVista meshes - - Automatic transform composition - -Example: - >>> import pyvista as pv - >>> from physiomotion4d import RegisterModelToModelICP - >>> - >>> # Load meshes - >>> moving_mesh = pv.read("generic_model.vtu") - >>> fixed_mesh = pv.read("patient_surface.stl") - >>> - >>> # Run affine registration - >>> registrar = RegisterModelToModelICP( - ... moving_mesh=moving_mesh, - ... fixed_mesh=fixed_mesh - ... ) - >>> result = registrar.register(mode='affine') - >>> - >>> # Access results - >>> aligned_mesh = result['moving_mesh'] - >>> phi_MF = result['phi_MF'] # Moving to fixed transform -""" - -import itk -import numpy as np -import pyvista as pv -import vtk - -from physiomotion4d.transform_tools import TransformTools - - -class RegisterModelToModelICP: - """Register anatomical models using Iterative Closest Point (ICP) algorithm. - - This class provides ICP-based alignment of 3D surface meshes with support for - both rigid and affine transformation modes. The registration pipeline uses - centroid alignment for initialization followed by VTK's ICP algorithm. - - **Registration Pipelines:** - - **Rigid mode**: Centroid alignment → Rigid ICP - - **Affine mode**: Centroid alignment → Rigid ICP → Affine ICP - - **Transform Convention:** - - phi_MF: Backward transform (moving → fixed space) - - phi_FM: Forward transform (fixed → moving space) - - Attributes: - moving_mesh (pv.PolyData): Surface mesh to be aligned - fixed_mesh (pv.PolyData): Target surface mesh - transform_tools (TransformTools): Transform utility instance - phi_MF (itk.AffineTransform): Optimized backward transform - phi_FM (itk.AffineTransform): Optimized forward transform - registered_mesh (pv.PolyData): Aligned moving mesh - - Example: - >>> # Initialize with meshes - >>> registrar = RegisterModelToModelICP( - ... moving_mesh=model_surface, - ... fixed_mesh=patient_surface - ... ) - >>> - >>> # Run rigid registration - >>> result = registrar.register(mode='rigid', max_iterations=2000) - >>> - >>> # Or run affine registration - >>> result = registrar.register(mode='affine', max_iterations=2000) - >>> - >>> # Get aligned mesh and transforms - >>> aligned_mesh = result['moving_mesh'] - >>> phi_MF = result['phi_MF'] - """ - - def __init__( - self, - moving_mesh: pv.PolyData, - fixed_mesh: pv.PolyData, - ): - """Initialize ICP-based model registration. - - Args: - moving_mesh: PyVista surface mesh to be aligned to fixed mesh - fixed_mesh: PyVista target surface mesh - - Note: - The moving_mesh is typically extracted from a VTU model using - mesh.extract_surface() before passing to this class. - """ - self.moving_mesh = moving_mesh - self.fixed_mesh = fixed_mesh - - # Transform utilities - self.transform_tools = TransformTools() - - # Registration results - self.phi_MF: itk.AffineTransform = None # Backward (moving→fixed) - self.phi_FM: itk.AffineTransform = None # Forward (fixed→moving) - self.registered_mesh: pv.PolyData = None - - def register(self, mode: str = 'affine', max_iterations: int = 2000) -> dict: - """Perform ICP alignment of moving mesh to fixed mesh. - - This method executes alignment with either rigid or affine transformations: - - **Rigid mode:** - 1. Centroid alignment: Translate moving mesh to align mass centers - 2. Rigid ICP: Refine with rigid-body transformation (rotation + translation) - - **Affine mode:** - 1. Centroid alignment: Translate moving mesh to align mass centers - 2. Rigid ICP: Refine with rigid-body transformation - 3. Affine ICP: Further refine with affine transformation (includes scaling/shearing) - - Args: - mode: Registration mode, either 'rigid' or 'affine'. Default: 'affine' - max_iterations: Maximum number of ICP iterations per stage. Default: 2000 - - Returns: - Dictionary containing: - - 'moving_mesh': Aligned moving mesh (PyVista PolyData) - - 'phi_MF': Backward transform (moving→fixed, ITK AffineTransform) - - 'phi_FM': Forward transform (fixed→moving, ITK AffineTransform) - - Raises: - ValueError: If mode is not 'rigid' or 'affine' - - Example: - >>> # Rigid registration - >>> result = registrar.register(mode='rigid', max_iterations=5000) - >>> - >>> # Affine registration - >>> result = registrar.register(mode='affine', max_iterations=2000) - """ - if mode not in ['rigid', 'affine']: - raise ValueError(f"Invalid mode '{mode}'. Must be 'rigid' or 'affine'.") - - print( - f"Performing {mode.upper()} ICP alignment of moving mesh to fixed mesh..." - ) - - # Step 1: Centroid alignment (common to both modes) - self.registered_mesh = self.moving_mesh.copy(deep=True) - - moving_centroid = np.array(self.registered_mesh.center) - print(f" Moving mesh centroid: {moving_centroid}") - fixed_centroid = np.array(self.fixed_mesh.center) - print(f" Fixed mesh centroid: {fixed_centroid}") - translation = fixed_centroid - moving_centroid - print(f" Step 1: Translating by {translation} to align centroids...") - - # Create ITK affine transform with translation - phi_ICP = itk.AffineTransform[itk.D, 3].New() - phi_ICP.SetIdentity() - phi_ICP.SetOffset(translation) - - # Apply centroid alignment to mesh - self.registered_mesh = self.transform_tools.transform_pvcontour( - self.registered_mesh, - phi_ICP, - with_deformation_magnitude=False, - ) - - print(f" Center after Step 1: {self.registered_mesh.center}") - - # Step 2: Rigid ICP (common to both modes) - print(f" Step 2: Performing rigid ICP (max iterations: {max_iterations})...") - icp_rigid = vtk.vtkIterativeClosestPointTransform() - icp_rigid.SetSource(self.registered_mesh) - icp_rigid.SetTarget(self.fixed_mesh) - icp_rigid.GetLandmarkTransform().SetModeToRigidBody() # Rigid mode - icp_rigid.SetMaximumNumberOfIterations(max_iterations) - icp_rigid.Update() - - # Convert VTK transform to ITK and compose with centroid transform - rigid_transform = self.transform_tools.convert_vtk_matrix_to_itk_transform( - icp_rigid.GetMatrix() - ) - phi_ICP.Compose(rigid_transform) - - # Apply rigid ICP transform to mesh - self.registered_mesh = self.transform_tools.transform_pvcontour( - self.registered_mesh, - rigid_transform, - with_deformation_magnitude=False, - ) - - print(f" Center after Step 2: {self.registered_mesh.center}") - - # Step 3: Affine ICP (only if affine mode) - if mode == 'affine': - print( - f" Step 3: Performing affine ICP (max iterations: {max_iterations})..." - ) - icp_affine = vtk.vtkIterativeClosestPointTransform() - icp_affine.SetSource(self.registered_mesh) - icp_affine.SetTarget(self.fixed_mesh) - icp_affine.GetLandmarkTransform().SetModeToAffine() # Affine mode - icp_affine.SetMaximumNumberOfIterations(max_iterations) - icp_affine.Update() - - # Convert VTK transform to ITK and compose - affine_transform = self.transform_tools.convert_vtk_matrix_to_itk_transform( - icp_affine.GetMatrix() - ) - phi_ICP.Compose(affine_transform) - - # Apply affine ICP transform to mesh - self.registered_mesh = self.transform_tools.transform_pvcontour( - self.registered_mesh, - affine_transform, - with_deformation_magnitude=False, - ) - - print(f" Center after Step 3: {self.registered_mesh.center}") - - # Compute inverse transform - self.phi_MF = phi_ICP.GetInverseTransform() - self.phi_FM = phi_ICP - - print(f" {mode.upper()} ICP registration complete!") - - # Return results as dictionary - return { - 'moving_mesh': self.registered_mesh, - 'phi_MF': self.phi_MF, - 'phi_FM': self.phi_FM, - } - - -# Example usage -if __name__ == "__main__": - """Example demonstrating ICP-based model-to-model registration.""" - - import pyvista as pv - - # ========================================================================= - # Setup: Load meshes - # ========================================================================= - print("Loading meshes...") - moving_mesh = pv.read("generic_model.vtu").extract_surface() - fixed_mesh = pv.read("patient_surface.stl") - print(f" Moving mesh: {moving_mesh.n_points} points") - print(f" Fixed mesh: {fixed_mesh.n_points} points") - - # ========================================================================= - # Example 1: Rigid ICP registration - # ========================================================================= - print("\n" + "=" * 60) - print("Example 1: Rigid ICP Registration") - print("=" * 60) - - registrar_rigid = RegisterModelToModelICP( - moving_mesh=moving_mesh, - fixed_mesh=fixed_mesh, - ) - - result_rigid = registrar_rigid.register(mode='rigid', max_iterations=2000) - - # Save rigid result - result_rigid['moving_mesh'].save("aligned_mesh_rigid_icp.vtk") - print(f"\n Saved: aligned_mesh_rigid_icp.vtk") - - # ========================================================================= - # Example 2: Affine ICP registration - # ========================================================================= - print("\n" + "=" * 60) - print("Example 2: Affine ICP Registration") - print("=" * 60) - - registrar_affine = RegisterModelToModelICP( - moving_mesh=moving_mesh, - fixed_mesh=fixed_mesh, - ) - - result_affine = registrar_affine.register(mode='affine', max_iterations=2000) - - # Save affine result - result_affine['moving_mesh'].save("aligned_mesh_affine_icp.vtk") - print(f"\n Saved: aligned_mesh_affine_icp.vtk") - - print("\nICP registration examples complete!") diff --git a/src/physiomotion4d/register_model_to_model_masks.py b/src/physiomotion4d/register_model_to_model_masks.py deleted file mode 100644 index e8d8671..0000000 --- a/src/physiomotion4d/register_model_to_model_masks.py +++ /dev/null @@ -1,443 +0,0 @@ -"""Mask-based model-to-model registration for anatomical models. - -This module provides the RegisterModelToModelMasks class for aligning anatomical -models using mask-based deformable registration. The workflow includes: -1. Generate binary masks from moving and fixed meshes -2. Generate ROI masks with dilation -4. Progressive registration stages: - - rigid: ANTs rigid registration - - affine: ANTs rigid → affine registration - - deformable: ANTs rigid → affine → deformable (SyN) registration -5. Optional ICON refinement at end - -The registration is particularly useful for aligning anatomical models where -shape differences require deformable transformations beyond rigid/affine ICP. - -Key Features: - - Automatic mask generation from PyVista meshes - - Multi-stage ANTs registration (rigid/affine/deformable) - - Optional ICON deep learning refinement - - Automatic transform composition - - Support for PyVista meshes - -Example: - >>> import itk - >>> import pyvista as pv - >>> from physiomotion4d import RegisterModelToModelMasks - >>> - >>> # Load meshes and reference image - >>> moving_mesh = pv.read("generic_model.vtu").extract_surface() - >>> fixed_mesh = pv.read("patient_surface.stl") - >>> reference_image = itk.imread("patient_ct.nii.gz") - >>> - >>> # Run deformable registration with ICON refinement - >>> registrar = RegisterModelToModelMasks( - ... moving_mesh=moving_mesh, - ... fixed_mesh=fixed_mesh, - ... reference_image=reference_image, - ... roi_dilation_mm=20, - ... ) - >>> result = registrar.register( - ... mode='deformable', - ... use_icon=True, - ... icon_iterations=50 - ... ) - >>> - >>> # Access results - >>> aligned_mesh = result['moving_mesh'] - >>> phi_MF = result['phi_MF'] # Moving to fixed transform -""" - -import itk -import numpy as np -import pyvista as pv -from itk import TubeTK as ttk - -from physiomotion4d.contour_tools import ContourTools -from physiomotion4d.register_images_ants import RegisterImagesANTs -from physiomotion4d.register_images_icon import RegisterImagesICON -from physiomotion4d.transform_tools import TransformTools - - -class RegisterModelToModelMasks: - """Register anatomical models using mask-based deformable registration. - - This class provides mask-based alignment of 3D surface meshes with support for - rigid, affine, and deformable transformation modes. The registration pipeline - generates masks from meshes, applies optional dilation, and uses ANTs for - progressive multi-stage registration with optional ICON refinement. - - **Registration Pipelines:** - - **Rigid mode**: ANTs rigid registration - - **Affine mode**: ANTs rigid → affine registration - - **Deformable mode**: ANTs rigid → affine → deformable (SyN) registration - - **Optional**: ICON deep learning refinement after any mode - - **Transform Convention:** - - phi_MF: Backward transform (moving → fixed space) - - phi_FM: Forward transform (fixed → moving space) - - Attributes: - moving_mesh (pv.PolyData): Surface mesh to be aligned - fixed_mesh (pv.PolyData): Target surface mesh - reference_image (itk.Image): Reference image for coordinate frame - roi_dilation_mm (float): Dilation amount in mm for ROI mask - transform_tools (TransformTools): Transform utility instance - contour_tools (ContourTools): Mesh utility instance - registrar_ants (RegisterImagesANTs): ANTs registration instance - registrar_icon (RegisterImagesICON): ICON registration instance - phi_MF (itk.CompositeTransform): Optimized backward transform - phi_FM (itk.CompositeTransform): Optimized forward transform - registered_mesh (pv.PolyData): Aligned moving mesh - - Example: - >>> # Initialize with meshes and reference image - >>> registrar = RegisterModelToModelMasks( - ... moving_mesh=model_surface, - ... fixed_mesh=patient_surface, - ... reference_image=patient_ct, - ... roi_dilation_mm=20, - ... ) - >>> - >>> # Run rigid registration - >>> result = registrar.register(mode='rigid') - >>> - >>> # Or run affine registration - >>> result = registrar.register(mode='affine') - >>> - >>> # Or run deformable with ICON refinement - >>> result = registrar.register( - ... mode='deformable', - ... use_icon=True, - ... icon_iterations=100 - ... ) - >>> - >>> # Get aligned mesh and transforms - >>> aligned_mesh = result['moving_mesh'] - >>> phi_MF = result['phi_MF'] - """ - - def __init__( - self, - moving_mesh: pv.PolyData, - fixed_mesh: pv.PolyData, - reference_image: itk.Image, - roi_dilation_mm: float = 10, - ): - """Initialize mask-based model registration. - - Args: - moving_mesh: PyVista surface mesh to be aligned to fixed mesh - fixed_mesh: PyVista target surface mesh - reference_image: ITK image providing coordinate frame (origin, spacing, direction) - for mask generation. Typically the patient CT/MRI image. - roi_dilation_mm: Dilation amount in millimeters for ROI mask generation. - Default: 20mm - - Note: - The moving_mesh and fixed_mesh are typically extracted from VTU models - using mesh.extract_surface() before passing to this class. - """ - self.moving_mesh = moving_mesh - self.fixed_mesh = fixed_mesh - self.reference_image = reference_image - self.roi_dilation_mm = roi_dilation_mm - - # Utilities - self.transform_tools = TransformTools() - self.contour_tools = ContourTools() - - # Registration instances - self.registrar_ants = RegisterImagesANTs() - self.registrar_icon = RegisterImagesICON() - self.registrar_icon.set_modality('ct') - self.registrar_icon.set_multi_modality(True) # For mask-based registration - - # Generated masks (will be created during registration) - self.fixed_mask_image: itk.Image = None - self.fixed_mask_roi_image: itk.Image = None - self.moving_mask_image: itk.Image = None - self.moving_mask_roi_image: itk.Image = None - - # Registration results - self.phi_MF: itk.CompositeTransform = None # Backward (moving→fixed) - self.phi_FM: itk.CompositeTransform = None # Forward (fixed→moving) - self.registered_mesh: pv.PolyData = None - - def _create_masks_from_meshes(self): - """Generate binary mask images from moving and fixed meshes. - - Creates: - - fixed_mask_image: Binary mask from fixed mesh - - fixed_mask_roi_image: Dilated ROI mask from fixed mesh - - moving_mask_image: Binary mask from moving mesh - - moving_mask_roi_image: Dilated ROI mask from moving mesh - - Uses self.reference_image for coordinate frame (origin, spacing, direction). - """ - print("Generating binary masks from meshes...") - - # Create fixed mask - self.fixed_mask_image = ( - self.contour_tools.create_contour_distance_map_from_mesh( - self.fixed_mesh, - self.reference_image, - max_distance=100.0, - invert_distance_map=True, - ) - ) - - # Create fixed ROI mask with dilation - print(f" Dilating fixed mask by {self.roi_dilation_mm}mm for ROI...") - mask = self.contour_tools.create_mask_from_mesh( - self.fixed_mesh, self.reference_image - ) - imMath = ttk.ImageMath.New(mask) - dilation_voxels = int( - self.roi_dilation_mm / self.reference_image.GetSpacing()[0] - ) - imMath.Dilate(dilation_voxels, 1, 0) - self.fixed_mask_roi_image = imMath.GetOutput() - - # Create moving mask - self.moving_mask_image = ( - self.contour_tools.create_contour_distance_map_from_mesh( - self.moving_mesh, - self.reference_image, - max_distance=100.0, - invert_distance_map=True, - ) - ) - - # Create moving ROI mask with dilation - print(f" Dilating moving mask by {self.roi_dilation_mm}mm for ROI...") - mask = self.contour_tools.create_mask_from_mesh( - self.moving_mesh, self.reference_image - ) - imMath = ttk.ImageMath.New(self.moving_mask_image) - imMath.Dilate(dilation_voxels, 1, 0) - self.moving_mask_roi_image = imMath.GetOutputUChar() - - print(" Mask generation complete.") - - def register( - self, - mode: str = 'affine', - use_icon: bool = False, - icon_iterations: int = 50, - ) -> dict: - """Perform mask-based registration of moving mesh to fixed mesh. - - This method executes progressive multi-stage registration: - - **Rigid mode:** - 1. Generate masks from meshes - 3. ANTs rigid registration - - **Affine mode:** - 1. Generate masks from meshes - 3. ANTs affine registration (includes rigid stage) - - **Deformable mode:** - 1. Generate masks from meshes - 3. ANTs SyN deformable registration (includes rigid + affine + deformable stages) - - **Optional ICON refinement** (all modes): - 4. ICON deep learning registration for fine-tuning - - Args: - mode: Registration mode - 'rigid', 'affine', or 'deformable'. Default: 'affine' - use_icon: Whether to apply ICON registration refinement after ANTs. Default: False - icon_iterations: Number of ICON optimization iterations if use_icon=True. Default: 50 - - Returns: - Dictionary containing: - - 'moving_mesh': Aligned moving mesh (PyVista PolyData) - - 'phi_MF': Backward transform (moving→fixed, ITK CompositeTransform) - - 'phi_FM': Forward transform (fixed→moving, ITK CompositeTransform) - - Raises: - ValueError: If mode is not 'rigid', 'affine', or 'deformable' - - Example: - >>> # Rigid registration - >>> result = registrar.register(mode='rigid') - >>> - >>> # Affine registration - >>> result = registrar.register(mode='affine') - >>> - >>> # Deformable registration with ICON refinement - >>> result = registrar.register(mode='deformable', use_icon=True, icon_iterations=100) - """ - if mode not in ['rigid', 'affine', 'deformable']: - raise ValueError( - f"Invalid mode '{mode}'. Must be 'rigid', 'affine', or 'deformable'." - ) - - print(f"Performing {mode.upper()} mask-based registration...") - - # Step 1: Generate masks from meshes - self._create_masks_from_meshes() - - # Step 3: ANTs registration with appropriate type_of_transform - if mode == 'rigid': - transform_type = "Rigid" - elif mode == 'affine': - transform_type = "Affine" # Includes rigid stage - else: # deformable - transform_type = "SyN" # Includes rigid + affine + deformable stages - - print(f" Performing ANTs {mode} registration (type: {transform_type})...") - - self.registrar_ants.set_fixed_image(self.fixed_mask_image) - self.registrar_ants.set_fixed_image_mask(self.fixed_mask_roi_image) - - result_ants = self.registrar_ants.register( - moving_image=self.moving_mask_image, - moving_image_mask=self.moving_mask_roi_image, - ) - phi_FM_ants = result_ants["phi_FM"] - phi_MF_ants = result_ants["phi_MF"] - - # Initialize composite transforms - self.phi_MF = phi_MF_ants - self.phi_FM = phi_FM_ants - - # Step 4: Optional ICON refinement - if use_icon: - print( - f" Performing ICON refinement registration ({icon_iterations} iterations)..." - ) - - # Transform masks with ANTs result for ICON input - moving_mask_ants_transformed = self.transform_tools.transform_image( - self.moving_mask_image, - phi_FM_ants, - self.reference_image, - interpolation_method="linear", - ) - - # Configure ICON - self.registrar_icon.set_number_of_iterations(icon_iterations) - self.registrar_icon.set_fixed_image(self.fixed_mask_image) - self.registrar_icon.set_fixed_image_mask(self.fixed_mask_roi_image) - - # ICON registration - result_icon = self.registrar_icon.register( - moving_image=moving_mask_ants_transformed, - moving_image_mask=self.moving_mask_roi_image, - ) - phi_FM_icon = result_icon["phi_FM"] - phi_MF_icon = result_icon["phi_MF"] - - # Compose ANTs and ICON transforms - composed_phi_MF = itk.CompositeTransform[itk.D, 3].New() - composed_phi_MF.AddTransform(phi_MF_ants) - composed_phi_MF.AddTransform(phi_MF_icon) - - composed_phi_FM = itk.CompositeTransform[itk.D, 3].New() - composed_phi_FM.AddTransform(phi_FM_icon) - composed_phi_FM.AddTransform(phi_FM_ants) - - self.phi_MF = composed_phi_MF - self.phi_FM = composed_phi_FM - - # Apply final transform to moving mesh - print(" Transforming moving mesh...") - self.registered_mesh = self.transform_tools.transform_pvcontour( - self.moving_mesh, - self.phi_MF, - with_deformation_magnitude=True, - ) - - print(f" {mode.upper()} mask-based registration complete!") - - # Return results as dictionary - return { - 'moving_mesh': self.registered_mesh, - 'phi_MF': self.phi_MF, - 'phi_FM': self.phi_FM, - } - - -# Example usage -if __name__ == "__main__": - """Example demonstrating mask-based model-to-model registration.""" - - import itk - import pyvista as pv - - # ========================================================================= - # Setup: Load meshes and reference image - # ========================================================================= - print("Loading meshes and reference image...") - moving_mesh = pv.read("generic_model.vtu").extract_surface() - fixed_mesh = pv.read("patient_surface.stl") - reference_image = itk.imread("patient_ct.nii.gz") - print(f" Moving mesh: {moving_mesh.n_points} points") - print(f" Fixed mesh: {fixed_mesh.n_points} points") - print(f" Reference image: {reference_image.GetLargestPossibleRegion().GetSize()}") - - # ========================================================================= - # Example 1: Rigid registration - # ========================================================================= - print("\n" + "=" * 60) - print("Example 1: Rigid Mask-based Registration") - print("=" * 60) - - registrar_rigid = RegisterModelToModelMasks( - moving_mesh=moving_mesh, - fixed_mesh=fixed_mesh, - reference_image=reference_image, - roi_dilation_mm=20, - ) - - result_rigid = registrar_rigid.register(mode='rigid') - - # Save rigid result - result_rigid['moving_mesh'].save("aligned_mesh_rigid_masks.vtk") - print(f"\n Saved: aligned_mesh_rigid_masks.vtk") - - # ========================================================================= - # Example 2: Affine registration - # ========================================================================= - print("\n" + "=" * 60) - print("Example 2: Affine Mask-based Registration") - print("=" * 60) - - registrar_affine = RegisterModelToModelMasks( - moving_mesh=moving_mesh, - fixed_mesh=fixed_mesh, - reference_image=reference_image, - roi_dilation_mm=20, - ) - - result_affine = registrar_affine.register(mode='affine') - - # Save affine result - result_affine['moving_mesh'].save("aligned_mesh_affine_masks.vtk") - print(f"\n Saved: aligned_mesh_affine_masks.vtk") - - # ========================================================================= - # Example 3: Deformable registration with ICON refinement - # ========================================================================= - print("\n" + "=" * 60) - print("Example 3: Deformable Mask-based Registration with ICON") - print("=" * 60) - - registrar_deformable = RegisterModelToModelMasks( - moving_mesh=moving_mesh, - fixed_mesh=fixed_mesh, - reference_image=reference_image, - roi_dilation_mm=20, - ) - - result_deformable = registrar_deformable.register( - mode='deformable', use_icon=True, icon_iterations=50 - ) - - # Save deformable result - result_deformable['moving_mesh'].save("aligned_mesh_deformable_masks.vtk") - print(f"\n Saved: aligned_mesh_deformable_masks.vtk") - - print("\nMask-based registration examples complete!") diff --git a/src/physiomotion4d/register_models_distance_maps.py b/src/physiomotion4d/register_models_distance_maps.py new file mode 100644 index 0000000..7353fd2 --- /dev/null +++ b/src/physiomotion4d/register_models_distance_maps.py @@ -0,0 +1,380 @@ +"""Mask-based model-to-model registration for anatomical models. + +This module provides the RegisterModelsDistanceMaps class for aligning anatomical +models using mask-based deformable registration. The workflow includes: +1. Generate binary masks from moving and fixed models +2. Generate ROI masks with dilation +4. Progressive registration stages: + - rigid: ANTs rigid registration + - affine: ANTs rigid → affine registration + - deformable: ANTs rigid → affine → deformable (SyN) registration +5. Optional ICON refinement at end + +The registration is particularly useful for aligning anatomical models where +shape differences require deformable transformations beyond rigid/affine ICP. + +Key Features: + - Automatic mask generation from PyVista models + - Multi-stage ANTs registration (rigid/affine/deformable) + - Optional ICON deep learning refinement + - Automatic transform composition + - Support for PyVista models + +Example: + >>> import itk + >>> import pyvista as pv + >>> from physiomotion4d import RegisterModelsDistanceMaps + >>> + >>> # Load models and reference image + >>> moving_model = pv.read("generic_model.vtu").extract_surface() + >>> fixed_model = pv.read("patient_surface.stl") + >>> reference_image = itk.imread("patient_ct.nii.gz") + >>> + >>> # Run deformable registration with ICON refinement + >>> registrar = RegisterModelsDistanceMaps( + ... moving_model=moving_model, + ... fixed_model=fixed_model, + ... reference_image=reference_image, + ... roi_dilation_mm=20, + ... ) + >>> result = registrar.register( + ... mode='deformable', + ... use_icon=True, + ... icon_iterations=50 + ... ) + >>> + >>> # Access results + >>> aligned_model = result['registered_model'] + >>> forward_transform = result['forward_transform'] # Moving to fixed transform +""" + +import logging + +import itk +import pyvista as pv +from itk import TubeTK as ttk + +from physiomotion4d.contour_tools import ContourTools +from physiomotion4d.physiomotion4d_base import PhysioMotion4DBase +from physiomotion4d.register_images_ants import RegisterImagesANTs +from physiomotion4d.register_images_icon import RegisterImagesICON +from physiomotion4d.transform_tools import TransformTools + + +class RegisterModelsDistanceMaps(PhysioMotion4DBase): + """Register anatomical models using mask-based deformable registration. + + This class provides mask-based alignment of 3D surface models with support for + rigid, affine, and deformable transformation modes. The registration pipeline + generates masks from models, applies optional dilation, and uses ANTs for + progressive multi-stage registration with optional ICON refinement. + + **Registration Pipelines:** + - **None mode**: No ANTs registration + - **Rigid mode**: ANTs rigid registration + - **Affine mode**: ANTs rigid → affine registration + - **Deformable mode**: ANTs rigid → affine → deformable (SyN) registration + - **Optional**: ICON deep learning refinement after any mode + + **Transform Convention:** + - forward_transform: Moving → fixed space transformation + - inverse_transform: Fixed → moving space transformation + + Attributes: + moving_model (pv.PolyData): Surface model to be aligned + fixed_model (pv.PolyData): Target surface model + reference_image (itk.Image): Reference image for coordinate frame + roi_dilation_mm (float): Dilation amount in mm for ROI mask + transform_tools (TransformTools): Transform utility instance + contour_tools (ContourTools): Model utility instance + registrar_ants (RegisterImagesANTs): ANTs registration instance + registrar_icon (RegisterImagesICON): ICON registration instance + forward_transform (itk.CompositeTransform): Optimized moving→fixed transform + inverse_transform (itk.CompositeTransform): Optimized fixed→moving transform + registered_model (pv.PolyData): Aligned moving model + + Example: + >>> # Initialize with models and reference image + >>> registrar = RegisterModelsDistanceMaps( + ... moving_model=model_surface, + ... fixed_model=patient_surface, + ... reference_image=patient_ct, + ... roi_dilation_mm=20, + ... ) + >>> + >>> # Run rigid registration + >>> result = registrar.register(mode='rigid') + >>> + >>> # Or run affine registration + >>> result = registrar.register(mode='affine') + >>> + >>> # Or run deformable with ICON refinement + >>> result = registrar.register( + ... mode='deformable', + ... use_ants=False, + ... use_icon=True, + ... icon_iterations=50 + ... ) + >>> + >>> # Get aligned model and transforms + >>> aligned_model = result['registered_model'] + >>> forward_transform = result['forward_transform'] + """ + + def __init__( + self, + moving_model: pv.PolyData, + fixed_model: pv.PolyData, + reference_image: itk.Image, + roi_dilation_mm: float = 10, + log_level: int | str = logging.INFO, + ): + """Initialize mask-based model registration. + + Args: + moving_model: PyVista surface model to be aligned to fixed model + fixed_model: PyVista target surface model + reference_image: ITK image providing coordinate frame (origin, spacing, direction) + for mask generation. Typically the patient CT/MRI image. + roi_dilation_mm: Dilation amount in millimeters for ROI mask generation. + Default: 20mm + log_level: Logging level (default: logging.INFO) + + Note: + The moving_model and fixed_model are typically extracted from VTU models + using model.extract_surface() before passing to this class. + """ + super().__init__(class_name=self.__class__.__name__, log_level=log_level) + + self.moving_model = moving_model + self.fixed_model = fixed_model + self.reference_image = reference_image + self.roi_dilation_mm = roi_dilation_mm + + # Utilities + self.transform_tools = TransformTools() + self.contour_tools = ContourTools() + + # Registration instances + self.registrar_ants = RegisterImagesANTs(log_level=log_level) + self.registrar_icon = RegisterImagesICON(log_level=log_level) + self.registrar_icon.set_modality('ct') + self.registrar_icon.set_multi_modality(False) + + # Generated masks (will be created during registration) + self.fixed_mask_image: itk.Image = None + self.fixed_mask_roi_image: itk.Image = None + self.moving_mask_image: itk.Image = None + self.moving_mask_roi_image: itk.Image = None + + # Registration results + self.forward_transform: itk.CompositeTransform = None # Moving→fixed + self.inverse_transform: itk.CompositeTransform = None # Fixed→moving + self.registered_model: pv.PolyData = None + + def _create_masks_from_models(self): + """Generate binary mask images from moving and fixed models. + + Creates: + - fixed_mask_image: Binary mask from fixed model + - fixed_mask_roi_image: Dilated ROI mask from fixed model + - moving_mask_image: Binary mask from moving model + - moving_mask_roi_image: Dilated ROI mask from moving model + + Uses self.reference_image for coordinate frame (origin, spacing, direction). + """ + self.log_info("Generating binary masks from models...") + + # Create fixed mask + self.fixed_mask_image = self.contour_tools.create_distance_map( + self.fixed_model, + self.reference_image, + squared_distance=True, + invert_distance_map=True, + ) + + # Create fixed ROI mask with dilation + self.log_info("Dilating fixed mask by %.1fmm for ROI...", self.roi_dilation_mm) + mask = self.contour_tools.create_mask_from_mesh( + self.fixed_model, self.reference_image + ) + imMath = ttk.ImageMath.New(mask) + dilation_voxels = int( + self.roi_dilation_mm / self.reference_image.GetSpacing()[0] + ) + imMath.Dilate(dilation_voxels, 1, 0) + self.fixed_mask_roi_image = imMath.GetOutput() + + # Create moving mask + self.moving_mask_image = self.contour_tools.create_distance_map( + self.moving_model, + self.reference_image, + squared_distance=True, + invert_distance_map=True, + ) + + # Create moving ROI mask with dilation + self.log_info("Dilating moving mask by %.1fmm for ROI...", self.roi_dilation_mm) + mask = self.contour_tools.create_mask_from_mesh( + self.moving_model, self.reference_image + ) + imMath = ttk.ImageMath.New(self.moving_mask_image) + imMath.Dilate(dilation_voxels, 1, 0) + self.moving_mask_roi_image = imMath.GetOutputUChar() + + self.log_info("Mask generation complete") + + def register( + self, + transform_type: str = 'Deformable', + use_icon: bool = False, + icon_iterations: int = 50, + ) -> dict: + """Perform mask-based registration of moving model to fixed model. + + This method executes progressive multi-stage registration: + + **None transform type:** + 1. No ANTs registration + + **Rigid transform type:** + 1. ANTs rigid registration + + **Affine transform type:** + 1. ANTs affine registration (includes rigid stage) + + **Deformable transform type:** + 1. ANTs SyN deformable registration (includes rigid + affine + deformable stages) + + **Optional ICON refinement** (all transform type): + 1. ICON deep learning registration for fine-tuning + + Args: + transform_type: Registration transform type - 'None', 'Rigid', 'Affine', or 'Deformable'. Default: 'Deformable' + use_icon: Whether to apply ICON registration refinement after ANTs. Default: False + icon_iterations: Number of ICON optimization iterations if use_icon=True. Default: 50 + + Returns: + Dictionary containing: + - 'moving_model': Aligned moving model (PyVista PolyData) + - 'forward_transform': Moving→fixed transform (ITK CompositeTransform) + - 'inverse_transform': Fixed→moving transform (ITK CompositeTransform) + + Raises: + ValueError: If transform_type is not 'None', 'Rigid', 'Affine', or 'Deformable' + + Example: + >>> # Rigid registration + >>> result = registrar.register(transform_type='Rigid') + >>> + >>> # Affine registration + >>> result = registrar.register(transform_type='Affine') + >>> + >>> # Deformable registration with ICON refinement + >>> result = registrar.register(transform_type='Deformable', use_icon=True, icon_iterations=100) + """ + if transform_type not in ['None', 'Rigid', 'Affine', 'Deformable']: + raise ValueError( + f"Invalid transform type '{transform_type}'. Must be 'None', 'Rigid', 'Affine', or 'Deformable'." + ) + + self.log_section("%s Mask-based Registration", transform_type.upper()) + + # Step 1: Generate masks from models + self._create_masks_from_models() + + self.log_info( + "Performing ANTs %s registration...", + transform_type, + ) + + inverse_transform_ants = None + forward_transform_ants = None + if transform_type != 'None': + self.registrar_ants.set_fixed_image(self.fixed_mask_image) + self.registrar_ants.set_fixed_mask(self.fixed_mask_roi_image) + + self.registrar_ants.set_transform_type(transform_type) + + result_ants = self.registrar_ants.register( + moving_image=self.moving_mask_image, + moving_mask=self.moving_mask_roi_image, + ) + inverse_transform_ants = result_ants["inverse_transform"] + forward_transform_ants = result_ants["forward_transform"] + else: + identity_transform = itk.AffineTransform[itk.D, 3].New() + identity_transform.SetIdentity() + inverse_transform_ants = identity_transform + forward_transform_ants = identity_transform + + # Initialize composite transforms + self.forward_transform = forward_transform_ants + self.inverse_transform = inverse_transform_ants + + # Optional ICON refinement + if use_icon: + self.log_info( + "Performing ICON refinement registration (%d iterations)...", + icon_iterations, + ) + + # Transform masks with ANTs result for ICON input + moving_mask_ants_transformed = self.transform_tools.transform_image( + self.moving_mask_image, + forward_transform_ants, + self.reference_image, + interpolation_method="linear", + ) + + # Configure ICON + self.registrar_icon.set_number_of_iterations(icon_iterations) + self.registrar_icon.set_fixed_image(self.fixed_mask_image) + self.registrar_icon.set_fixed_mask(self.fixed_mask_roi_image) + + # ICON registration + result_icon = self.registrar_icon.register( + moving_image=moving_mask_ants_transformed, + moving_mask=self.moving_mask_roi_image, + ) + inverse_transform_icon = result_icon["inverse_transform"] + forward_transform_icon = result_icon["forward_transform"] + + # Compose ANTs and ICON transforms + composed_forward = ( + self.transform_tools.combine_displacement_field_transforms( + forward_transform_ants, + forward_transform_icon, + reference_image=self.reference_image, + mode="compose", + ) + ) + + composed_inverse = ( + self.transform_tools.combine_displacement_field_transforms( + inverse_transform_icon, + inverse_transform_ants, + reference_image=self.reference_image, + mode="compose", + ) + ) + + self.forward_transform = composed_forward + self.inverse_transform = composed_inverse + + # Apply final transform to moving model + self.log_info("Transforming moving model...") + self.registered_model = self.transform_tools.transform_pvcontour( + self.moving_model, + self.inverse_transform, + with_deformation_magnitude=True, + ) + + self.log_info("%s mask-based registration complete!", transform_type.upper()) + + # Return results as dictionary + return { + 'forward_transform': self.forward_transform, + 'inverse_transform': self.inverse_transform, + 'registered_model': self.registered_model, + } diff --git a/src/physiomotion4d/register_models_icp.py b/src/physiomotion4d/register_models_icp.py new file mode 100644 index 0000000..281d33d --- /dev/null +++ b/src/physiomotion4d/register_models_icp.py @@ -0,0 +1,250 @@ +"""ICP-based model-to-model registration for anatomical models. + +This module provides the RegisterModelsICP class for aligning anatomical +models using Iterative Closest Point (ICP) algorithm. The workflow includes: +1. Initial centroid alignment +2. Rigid or affine ICP alignment + +The registration is particularly useful for initial rough alignment of generic +models to patient-specific anatomical data. + +Key Features: + - Centroid-based initial alignment + - VTK ICP algorithm with rigid or affine transformation modes + - Three-stage affine pipeline: centroid → rigid ICP → affine ICP + - Support for PyVista models + - Automatic transform composition + +Example: + >>> import pyvista as pv + >>> from physiomotion4d import RegisterModelsICP + >>> + >>> # Load models + >>> moving_model = pv.read("generic_model.vtu") + >>> fixed_model = pv.read("patient_surface.stl") + >>> + >>> # Run affine registration + >>> registrar = RegisterModelsICP( + ... moving_model=moving_model, + ... fixed_model=fixed_model + ... ) + >>> result = registrar.register(mode='affine') + >>> + >>> # Access results + >>> aligned_model = result['registered_model'] + >>> forward_point_transform = result['forward_point_transform'] # Moving to fixed transform +""" + +import logging + +import itk +import numpy as np +import pyvista as pv +import vtk + +from physiomotion4d.physiomotion4d_base import PhysioMotion4DBase +from physiomotion4d.transform_tools import TransformTools + + +class RegisterModelsICP(PhysioMotion4DBase): + """Register anatomical models using Iterative Closest Point (ICP) algorithm. + + This class provides ICP-based alignment of 3D surface models with support for + both rigid and affine transformation modes. The registration pipeline uses + centroid alignment for initialization followed by VTK's ICP algorithm. + + **Registration Pipelines:** + - **Rigid mode**: Centroid alignment → Rigid ICP + - **Affine mode**: Centroid alignment → Rigid ICP → Affine ICP + + **Transform Convention:** + - forward_point_transform: moving → fixed space transformation + (This is the inverse of the transform used to wrap the moving image to the fixed image) + - inverse_point_transform: moving → fixed space transformation + + Attributes: + moving_model (pv.PolyData): Surface model to be aligned + fixed_model (pv.PolyData): Target surface model + transform_tools (TransformTools): Transform utility instance + forward_point_transform (itk.AffineTransform): Optimized moving→fixed transform + inverse_point_transform (itk.AffineTransform): Optimized fixed→moving transform + registered_model (pv.PolyData): Aligned moving model + + Example: + >>> # Initialize with model + >>> registrar = RegisterModelsICP( + ... moving_model=model_surface, + ... fixed_model=patient_surface + ... ) + >>> + >>> # Run rigid registration + >>> result = registrar.register(mode='rigid', max_iterations=2000) + >>> + >>> # Or run affine registration + >>> result = registrar.register(mode='affine', max_iterations=2000) + >>> + >>> # Get aligned model and transforms + >>> aligned_model = result['registered_model'] + >>> forward_point_transform = result['forward_point_transform'] + """ + + def __init__( + self, + moving_model: pv.PolyData, + fixed_model: pv.PolyData, + log_level: int | str = logging.INFO, + ): + """Initialize ICP-based model registration. + + Args: + moving_model: PyVista surface model to be aligned to fixed model + fixed_model: PyVista target surface model + log_level: Logging level (default: logging.INFO) + + Note: + The moving_model is typically extracted from a VTU model using + model.extract_surface() before passing to this class. + """ + super().__init__(class_name=self.__class__.__name__, log_level=log_level) + + self.moving_model = moving_model + self.fixed_model = fixed_model + + # Transform utilities + self.transform_tools = TransformTools() + + # Registration results + self.forward_point_transform: itk.AffineTransform = None + self.inverse_point_transform: itk.AffineTransform = None + self.registered_model: pv.PolyData = None + + def register(self, mode: str = 'affine', max_iterations: int = 2000) -> dict: + """Perform ICP alignment of moving model to fixed model. + + This method executes alignment with either rigid or affine transformations: + + **Rigid mode:** + 1. Centroid alignment: Translate moving model to align mass centers + 2. Rigid ICP: Refine with rigid-body transformation (rotation + translation) + + **Affine mode:** + 1. Centroid alignment: Translate moving model to align mass centers + 2. Rigid ICP: Refine with rigid-body transformation + 3. Affine ICP: Further refine with affine transformation (includes scaling/shearing) + + Args: + mode: Registration mode, either 'rigid' or 'affine'. Default: 'affine' + max_iterations: Maximum number of ICP iterations per stage. Default: 2000 + + Returns: + Dictionary containing: + - 'registered_model': Aligned moving model (PyVista PolyData) + - 'forward_point_transform': Moving→fixed transform (ITK AffineTransform) + - 'inverse_point_transform': Fixed→moving transform (ITK AffineTransform) + + Raises: + ValueError: If mode is not 'rigid' or 'affine' + + Example: + >>> # Rigid registration + >>> result = registrar.register(mode='rigid', max_iterations=5000) + >>> + >>> # Affine registration + >>> result = registrar.register(mode='affine', max_iterations=2000) + """ + if mode not in ['rigid', 'affine']: + raise ValueError(f"Invalid mode '{mode}'. Must be 'rigid' or 'affine'.") + + self.log_section("%s ICP Alignment", mode.upper()) + + # Step 1: Centroid alignment (common to both modes) + self.registered_model = self.moving_model.copy(deep=True) + + moving_centroid = np.array(self.registered_model.center) + self.log_debug("Moving model centroid: %s", moving_centroid) + fixed_centroid = np.array(self.fixed_model.center) + self.log_debug("Fixed model centroid: %s", fixed_centroid) + translation = fixed_centroid - moving_centroid + self.log_info("Step 1: Translating by %s to align centroids...", translation) + + # Create ITK affine transform with translation + forward_point_transform = itk.AffineTransform[itk.D, 3].New() + forward_point_transform.SetIdentity() + forward_point_transform.SetOffset(translation) + + # Apply centroid alignment to model + self.registered_model = self.transform_tools.transform_pvcontour( + self.registered_model, + forward_point_transform, + with_deformation_magnitude=False, + ) + + self.log_debug("Center after Step 1: %s", self.registered_model.center) + + # Step 2: Rigid ICP (common to both modes) + self.log_info( + "Step 2: Performing rigid ICP (max iterations: %d)...", max_iterations + ) + icp_rigid = vtk.vtkIterativeClosestPointTransform() + icp_rigid.SetSource(self.registered_model) + icp_rigid.SetTarget(self.fixed_model) + icp_rigid.GetLandmarkTransform().SetModeToRigidBody() # Rigid mode + icp_rigid.SetMaximumNumberOfIterations(max_iterations) + icp_rigid.Update() + + # Convert VTK transform to ITK and compose with centroid transform + rigid_transform = self.transform_tools.convert_vtk_matrix_to_itk_transform( + icp_rigid.GetMatrix() + ) + forward_point_transform.Compose(rigid_transform) + + # Apply rigid ICP transform to model + self.registered_model = self.transform_tools.transform_pvcontour( + self.registered_model, + rigid_transform, + with_deformation_magnitude=False, + ) + + self.log_debug("Center after Step 2: %s", self.registered_model.center) + + # Step 3: Affine ICP (only if affine mode) + if mode == 'affine': + self.log_info( + "Step 3: Performing affine ICP (max iterations: %d)...", max_iterations + ) + icp_affine = vtk.vtkIterativeClosestPointTransform() + icp_affine.SetSource(self.registered_model) + icp_affine.SetTarget(self.fixed_model) + icp_affine.GetLandmarkTransform().SetModeToAffine() # Affine mode + icp_affine.SetMaximumNumberOfIterations(max_iterations) + icp_affine.Update() + + # Convert VTK transform to ITK and compose + affine_transform = self.transform_tools.convert_vtk_matrix_to_itk_transform( + icp_affine.GetMatrix() + ) + forward_point_transform.Compose(affine_transform) + + # Apply affine ICP transform to model + self.registered_model = self.transform_tools.transform_pvcontour( + self.registered_model, + affine_transform, + with_deformation_magnitude=False, + ) + + self.log_debug("Center after Step 3: %s", self.registered_model.center) + + # Compute inverse transform + # Ths forward transform for ICP is consistent with the transform convention + # used with images-to-images registration. + self.forward_point_transform = forward_point_transform + self.inverse_point_transform = forward_point_transform.GetInverseTransform() + + self.log_info("%s ICP registration complete!", mode.upper()) + + # Return results as dictionary + return { + 'registered_model': self.registered_model, + 'forward_point_transform': self.forward_point_transform, + 'inverse_point_transform': self.inverse_point_transform, + } diff --git a/src/physiomotion4d/register_models_icp_itk.py b/src/physiomotion4d/register_models_icp_itk.py new file mode 100644 index 0000000..54ac781 --- /dev/null +++ b/src/physiomotion4d/register_models_icp_itk.py @@ -0,0 +1,381 @@ +import logging +from typing import Optional + +import itk +import numpy as np +import pyvista as pv +from scipy.optimize import minimize + +from physiomotion4d.contour_tools import ContourTools +from physiomotion4d.physiomotion4d_base import PhysioMotion4DBase +from physiomotion4d.transform_tools import TransformTools + + +class RegisterModelsICPITK(PhysioMotion4DBase): + """Register shape models using model to distance map minimization. + + **Optimization Objective:** + Minimize the mean distance of the distance map sampled at model points using + ITK's LinearInterpolateImageFunction. This aligns the model with bright + regions in target image. + + Attributes: + fixed_model (pv.PolyData) + moving_model (pv.PolyData) + reference_image (itk.Image): Patient image providing coordinate frame and distance data + transform_type: Rigid or Affine + forward_point_transform (itk.ComposeScaleSkewVersor3DTransform): Optimized transformation + registered_model (pv.PolyData): Final registered model + + Note: + The fixed_model and moving_model are typically extracted from VTU models + using model.extract_surface() before passing to this class. + """ + + def __init__( + self, + fixed_model: pv.PolyData, + reference_image: Optional[itk.Image] = None, + point_subsample_step: int = 4, + log_level: int | str = logging.INFO, + ): + # Initialize base class with logging + super().__init__(class_name="RegisterModelsICPITK", log_level=log_level) + + # Store model data + self.fixed_model: pv.Polydata = fixed_model + self.reference_image = reference_image + + self.moving_model: Optional[pv.Polydata] = None + + # Working transform (reused to avoid repeated memory allocation) + self._working_transform: itk.ComposeScaleSkewVersor3DTransform[itk.D] = ( + itk.ComposeScaleSkewVersor3DTransform[itk.D].New() + ) + + self.transform_type: str = "Affine" + + # outputs + self.forward_point_transform: ( + Optional[itk.ComposeScaleSkewVersor3DTransform] | None + ) = None + self.registered_model: Optional[pv.PolyData] = None + self.final_mean_distance = 0 + + # Transform utilities + self._contour_tools = ContourTools() + self._transform_tools = TransformTools() + + # Image interpolator (created when needed) + self.fixed_distance_map: itk.Image | None = None + self._interpolator: Optional[itk.LinearInterpolateImageFunction] = None + self._max_distance: float = 0.0 + + self._metric_call_count: int = 0 + + # Pre-convert mean shape points to ITK format + self.point_subsample_step = point_subsample_step + self._moving_model_points_itk: Optional[list[itk.Point]] = None + + def _create_itk_points(self) -> None: + """Pre-convert mean shape points to ITK Point format for efficiency. + + This method creates ITK Point objects once at initialization, avoiding + repeated conversions during optimization iterations. + """ + self.log_info("Converting mean shape points to ITK format...") + + self._moving_model_points_itk = [] + for point in self.moving_model.points: + itk_point = itk.Point[itk.D, 3]() + itk_point[0] = float(point[0]) + itk_point[1] = float(point[1]) + itk_point[2] = float(point[2]) + self._moving_model_points_itk.append(itk_point) + + self.log_info( + f" Converted {len(self._moving_model_points_itk)} points to ITK format" + ) + + def set_reference_image(self, reference_image: itk.Image) -> None: + """Set the reference image for registration. + + Args: + reference_image: ITK image providing coordinate frame and distance data + """ + self.reference_image = reference_image + # Clear interpolator to force recreation with new image + self._interpolator = None + self.fixed_distance_map = None + + def set_fixed_model(self, fixed_model: pv.PolyData) -> None: + """Set the average model for registration. + + Args: + fixed_model: PyVista model containing the mean 3D shape model + (unstructured grid or polydata) + """ + self.fixed_model = fixed_model + self.fixed_distance_map = None + self._interpolator = None + + self.log_info(" ✓ Fixed model set successfully!") + + def _evaluate_distance_metric( + self, + transform_params: np.ndarray, + ) -> float: + """Evaluate the optimization metric (mean distance) at model points. + + This is the objective function to be minimized during optimization. + Higher values indicate better alignment with bright regions. + + Args: + pca_deformation: Nx3 numpy array of PCA deformation vectors to add to points. + If None, no deformation is applied. + transform_params: 12-element array of affine transform parameters. + If None, no affine transformation is applied. + + Returns: + Mean distance value across all points + """ + if self._interpolator is None: + if self.fixed_distance_map is None: + self.fixed_distance_map = self._contour_tools.create_distance_map( + self.fixed_model, + self.reference_image, + ) + self.log_debug(" Distance map created") + ImageType = type(self.fixed_distance_map) + self._interpolator = itk.LinearInterpolateImageFunction[ + ImageType, itk.D + ].New() + self._interpolator.SetInputImage(self.fixed_distance_map) + fixed_distance_map_array = itk.GetArrayFromImage(self.fixed_distance_map) + self._max_distance = fixed_distance_map_array.max() + + self.log_debug(" Interpolator created") + + if self._moving_model_points_itk is None: + self._create_itk_points() + + # Update working transform if parameters provided + if self.transform_type == "Rigid": + itk_params = itk.OptimizerParameters[itk.D](12) + for i in range(6): + itk_params[i] = transform_params[i] + for i in range(6, 9): + itk_params[i] = 1 + for i in range(9, 12): + itk_params[i] = 0 + self._working_transform.SetParameters(itk_params) + else: + itk_params = itk.OptimizerParameters[itk.D](12) + for i in range(12): + itk_params[i] = transform_params[i] + self._working_transform.SetParameters(itk_params) + + # Sample intensities at each point + n_valid_points = 0 + n_invalid_points = 0 + total_distance = 0.0 + center = np.zeros(3) + point = itk.Point[itk.D, 3]() + image_size = self.reference_image.GetBufferedRegion().GetSize() + for i, base_point in enumerate(self._moving_model_points_itk): + if i % self.point_subsample_step != 0: + continue + + point[0] = base_point[0] + point[1] = base_point[1] + point[2] = base_point[2] + + point = self._working_transform.TransformPoint(point) + + # Check if point is inside image bounds + coord_index = self.reference_image.TransformPhysicalPointToContinuousIndex( + point + ) + if ( + 0 <= coord_index[0] < image_size[0] + and 0 <= coord_index[1] < image_size[1] + and 0 <= coord_index[2] < image_size[2] + ): + center[0] += point[0] + center[1] += point[1] + center[2] += point[2] + distance = self._interpolator.EvaluateAtContinuousIndex(coord_index) + total_distance += distance + n_valid_points += 1 + else: + self.log_warning(" Point %d is outside image bounds (%s)", i, point) + return self._max_distance + + if n_valid_points > n_invalid_points: + mean_distance = total_distance / n_valid_points + center /= n_valid_points + else: + mean_distance = 0.0 + self.log_warning(" *** No valid points found") + + if n_invalid_points > 0: + self.log_warning(" %d points are outside image bounds", n_invalid_points) + self.log_warning(" Parameters: %s", transform_params) + if n_valid_points > n_invalid_points: + self.log_warning(" Center: %s", center) + self.log_warning(" Mean distance: %f", mean_distance) + + if self.log_level <= logging.DEBUG or self._metric_call_count % 100 == 0: + self.log_info( + " Metric %d: %s -> %f", + (self._metric_call_count + 1), + center, + mean_distance, + ) + self._metric_call_count += 1 + + return mean_distance + + def register( + self, + moving_model: pv.PolyData, + initial_transform: Optional[itk.MatrixOffsetTransformBase] = None, + transform_type: str = 'Affine', # or 'Rigid' + method: str = 'L-BFGS-B', # or 'Nelder-Mead' + scale_bound: float = 0.20, + skew_bound: float = 0.03, + versor_bound: float = 0.15, + translation_bound: float = 15, + max_iterations: int = 500, + ) -> dict: + """Optimize affine alignment to minimize mean distance. + + to align the mean shape model with bright regions in the image. + + Args: + initial_transform: Initial ITK ComposeScaleSkewVersor3DTransform for starting point + method: Optimization method for scipy.optimize.minimize. + Default: 'Nelder-Mead' + max_iterations: Maximum number of optimization iterations. + Default: 500 + + Returns: + Tuple of (transform, mean_distance): + - transform: Optimized ITK ComposeScaleSkewVersor3DTransform + - mean_distance: Final mean distance metric value + + Raises: + ValueError: If reference image is not set + """ + if self.reference_image is None: + raise ValueError("Reference image must be set before optimization") + + self.log_section("Affine Alignment Optimization", width=60) + + self.moving_model = moving_model + + self.transform_type = transform_type + + # Get initial parameters from transform + initial_params = None + if initial_transform is not None: + self.log_info("Using initial transform...") + self._working_transform.SetIdentity() + self._working_transform.SetMatrix(initial_transform.GetMatrix()) + self._working_transform.SetOffset(initial_transform.GetOffset()) + self._working_transform.SetCenter(initial_transform.GetCenter()) + else: + self.log_info( + "No initial transform provided, performing centroid alignment..." + ) + moving_centroid = np.array(self.moving_model.center) + self.log_debug("Moving model centroid: %s", moving_centroid) + fixed_centroid = np.array(self.fixed_model.center) + self.log_debug("Fixed model centroid: %s", fixed_centroid) + translation = fixed_centroid - moving_centroid + self._working_transform.SetIdentity() + self._working_transform.SetOffset(translation) + self._working_transform.SetCenter(moving_centroid) + + if self.transform_type == "Rigid": + initial_params = [ + self._working_transform.GetParameters()[i] for i in range(6) + ] + elif self.transform_type == "Affine": + initial_params = [ + self._working_transform.GetParameters()[i] for i in range(12) + ] + else: + self.log_error("Invalid transform type: %s", self.transform_type) + raise ValueError(f"Invalid transform type: {self.transform_type}") + self.log_info("Initial parameters: %s", initial_params) + + bounds = [] + # Scale, Skew, Versor rotation bounds + for v_affine in initial_params[:3]: + bounds.append((v_affine - versor_bound, v_affine + versor_bound)) + for trans_affine in initial_params[3:6]: + bounds.append( + ( + trans_affine - translation_bound, + trans_affine + translation_bound, + ) + ) + if self.transform_type == "Affine": + for s_affine in initial_params[6:9]: + bounds.append((s_affine - scale_bound, s_affine + scale_bound)) + for k_affine in initial_params[9:12]: + bounds.append((k_affine - skew_bound, k_affine + skew_bound)) + + # Run optimization + self.log_info("Running optimization...") + if self.log_level <= logging.INFO: + disp = True + else: + disp = False + + result_affine = minimize( + self._evaluate_distance_metric, + initial_params, + method=method, + bounds=bounds, + options={'maxiter': max_iterations, 'disp': disp}, + ) + self.log_info( + "Optimization result: %s -> %f", result_affine.x, result_affine.fun + ) + + # Create optimized transform + self.forward_point_transform = itk.ComposeScaleSkewVersor3DTransform[ + itk.D + ].New() + opt_itk_params = itk.OptimizerParameters[itk.D](12) + if self.transform_type == "Rigid": + for i in range(6): + opt_itk_params[i] = result_affine.x[i] + for i in range(6, 9): + opt_itk_params[i] = 1 + for i in range(9, 12): + opt_itk_params[i] = 0 + elif self.transform_type == "Affine": + for i in range(12): + opt_itk_params[i] = result_affine.x[i] + self.forward_point_transform.SetParameters(opt_itk_params) + + self.final_mean_distance = result_affine.fun + + self.registered_model = self._transform_tools.transform_pvcontour( + self.moving_model, + self.forward_point_transform, + with_deformation_magnitude=False, + ) + + self.log_info("Optimization completed!") + self.log_info(f"Final parameters: {result_affine.x}") + self.log_info(f"Final mean distance: {self.final_mean_distance:.2f}") + + return { + 'registered_model': self.registered_model, + 'forward_point_transform': self.forward_point_transform, + 'mean_distance': self.final_mean_distance, + } diff --git a/src/physiomotion4d/register_models_pca.py b/src/physiomotion4d/register_models_pca.py new file mode 100644 index 0000000..093b448 --- /dev/null +++ b/src/physiomotion4d/register_models_pca.py @@ -0,0 +1,818 @@ +import json +import logging +from pathlib import Path +from typing import Optional + +try: + from typing import Self +except ImportError: + from typing_extensions import Self + +import itk +import numpy as np +import pyvista as pv +from scipy.optimize import minimize + +from physiomotion4d.contour_tools import ContourTools +from physiomotion4d.physiomotion4d_base import PhysioMotion4DBase +from physiomotion4d.transform_tools import TransformTools + + +class RegisterModelsPCA(PhysioMotion4DBase): + """Register PCA-based shape models to medical images using mean distance optimization. + + This class implements a registration pipeline for fitting statistical + shape models to patient-specific medical images: + + **PCA Deformable Registration** + - Optimizes PCA coefficients + - Model equation: P = mean + Σ(b_i * std_i * pca_eigenvector_i) + - Maximizes mean distance at deformed model points P + + **Optimization Objective:** + Maximize the mean distance of the image sampled at model points using + ITK's LinearInterpolateImageFunction. This aligns the model with bright + regions in contrast-enhanced images (e.g., blood pool in cardiac CT). + + Attributes: + pca_template_model (pv.UnstructuredGrid): Mean shape model + pca_eigenvectors (np.ndarray): PCA eigenvectors/components (modes × n_points*3) + pca_std_deviations (np.ndarray): Standard deviations per mode (modes,) + fixed_distance_map (itk.Image): Patient image providing distance data + n_points (int): Number of points in the model + pca_number_of_modes (int): Number of PCA modes available + pca_coefficients (np.ndarray): Optimized PCA coefficients + registered_model (pv.UnstructuredGrid): Final registered and deformed model + post_pca_transform (itk.Transform): Transform to apply after PCA registration + forward_point_transform (itk.DisplacementFieldTransform): Forward displacement field transform + (Does not include the post-PCA transform) + inverse_point_transform (itk.DisplacementFieldTransform): Inverse displacement field transform + (Does not include the post-PCA transform) + + Example: + >>> # Load PCA model data + >>> pca_template_model = pv.read("pca_All_mean.vtk") + >>> with open("pca.json", 'r') as f: + ... pca_data = json.load(f) + >>> pca_group_data = pca_data['All'] + >>> pca_std_deviations = np.sqrt(np.array(pca_group_data['eigenvalues'])) + >>> pca_eigenvectors = np.array(pca_group_data['components']) + >>> + >>> # Initialize registrar with loaded data + >>> registrar = RegisterModelsPCA( + ... pca_template_model=pca_template_model, + ... pca_eigenvectors=pca_eigenvectors, + ... pca_std_deviations=pca_std_deviations, + ... ) + >>> + >>> # Run full registration pipeline + >>> result = registrar.register( + ... pca_number_of_modes=10 + ... ) + >>> + >>> # Save registered model + >>> result['registered_model'].save("registered_heart.vtk") + >>> + >>> # Print optimization results + >>> print(f"Final mean distance: {result['mean_distance']:.2f}") + >>> print(f"PCA coefficients: {result['pca_coefficients']}") + """ + + def __init__( + self, + pca_template_model: pv.UnstructuredGrid, + pca_eigenvectors: np.ndarray, + pca_std_deviations: np.ndarray, + pca_number_of_modes: int = 0, + pca_template_model_point_subsample: int = 4, + post_pca_transform: Optional[itk.Transform] = None, + fixed_distance_map: Optional[itk.Image] = None, + fixed_model: Optional[pv.UnstructuredGrid] = None, + reference_image: Optional[itk.Image] = None, + log_level: int | str = logging.INFO, + ): + """Initialize the PCA-based model-to-image registration. + + Args: + pca_template_model: PyVista model containing the mean 3D shape model + (unstructured grid or polydata) + pca_eigenvectors: Numpy array of PCA eigenvectors/components. Shape: (modes, n_points*3) + Each row is a flattened eigenmode with 3D displacements: [x1,y1,z1, x2,y2,z2, ...] + pca_std_deviations: Numpy array of standard deviations per PCA mode. Shape: (modes,) + These are the square roots of pca_eigenvalues + pca_number_of_modes: Number of PCA modes to use. Default: -1 (use all) + pca_template_model_point_subsample: Step size for subsampling model points. Default: 4 + post_pca_transform: Optional ITK transform to apply after PCA registration. + Default: None + fixed_distance_map: ITK image providing the distance map. + Default: None + fixed_model: PyVista model used to compute the distance map, if one isn't provided. + reference_image: ITK image providing coordinate frame for computing the distance map. + log_level: Logging level (logging.DEBUG, logging.INFO, logging.WARNING). + Default: logging.INFO + + Raises: + ValueError: If pca_eigenvector dimensions don't match model points + """ + # Initialize base class with logging + super().__init__(class_name="RegisterModelsPCA", log_level=log_level) + + # Store model data + self.pca_template_model: pv.UnstructuredGrid = pca_template_model + self.pca_eigenvectors: np.ndarray = pca_eigenvectors + self.pca_std_deviations: np.ndarray = pca_std_deviations + + self.post_pca_transform = post_pca_transform + + self._contour_tools = ContourTools() + + self.fixed_distance_map = fixed_distance_map + if ( + self.fixed_distance_map is None + and fixed_model is not None + and reference_image is not None + ): + self.fixed_model = fixed_model + self.fixed_distance_map = self._contour_tools.create_distance_map( + fixed_model, + reference_image, + squared_distance=True, + ) + elif self.fixed_distance_map is not None and ( + fixed_model is not None or reference_image is not None + ): + self.log_warning( + "Fixed model and reference image will be ignored because a distance map is provided." + ) + elif self.fixed_distance_map is None and ( + fixed_model is None or reference_image is None + ): + self.log_error( + "Fixed model and reference image must be provided if no distance map is provided." + ) + raise ValueError( + "Fixed model and reference image must be provided if no distance map is provided." + ) + + self.pca_number_of_modes: int = pca_number_of_modes + if self.pca_number_of_modes <= 0: + self.pca_number_of_modes = len(pca_std_deviations) + + self.pca_template_model_point_subsample = pca_template_model_point_subsample + + # outputs + self.registered_model_pca_coefficients: np.ndarray | None = None + self.registered_model: pv.UnstructuredGrid | None = None + self.registered_model_mean_distance: float = 0.0 + self.forward_point_transform: itk.DisplacementFieldTransform | None = None + self.inverse_point_transform: itk.DisplacementFieldTransform | None = None + + self._template_model_pca_deformation_field_image: itk.Image | None = None + self._deformation_field_interpolator_x = None + self._deformation_field_interpolator_y = None + self._deformation_field_interpolator_z = None + + # Image interpolator (created when needed) + self._interpolator: Optional[itk.LinearInterpolateImageFunction] = None + self._max_distance: float = 0.0 + + self._metric_call_count: int = 0 + + # Pre-convert mean shape points to ITK format + self._pca_template_model_points_itk: Optional[list[itk.Point]] = None + self._create_itk_points() + + @classmethod + def from_slicersalt( + cls, + pca_template_model: pv.UnstructuredGrid, + pca_json_filename: str, + pca_group_key: str = 'All', + pca_number_of_modes: int = 0, + pca_template_model_point_subsample: int = 4, + post_pca_transform: Optional[itk.Transform] = None, + fixed_distance_map: Optional[itk.Image] = None, + fixed_model: Optional[pv.UnstructuredGrid] = None, + reference_image: Optional[itk.Image] = None, + log_level: int | str = logging.INFO, + ) -> Self: + """Read PCA model data from SlicerSALT format JSON file. + + This method reads PCA statistical shape model data from a JSON file + created by SlicerSALT, including the mean model, pca_eigenvalues, and + pca_eigenvector components. + + The method expects: + - A JSON file (e.g., 'pca.json') containing eigenvalues and components + + Args: + pca_json_filename: Path to the SlicerSALT PCA JSON file + pca_group_key: Key for the PCA group to extract from JSON. Default: 'All' + pca_number_of_modes: Number of PCA modes to use. Default: 0 (use all) + pca_template_model_point_subsample: Step size for subsampling model points. Default: 4 + post_pca_transform: Optional ITK transform to apply after PCA registration. + Default: None + fixed_distance_map: ITK image providing the distance values + for registration. If None, must be set later before registration. + log_level: Logging level (logging.DEBUG, logging.INFO, logging.WARNING). + Default: logging.INFO + + Returns: + RegisterModelsPCA instance + + Raises: + FileNotFoundError: If JSON or VTK model file not found + KeyError: If pca_group_key not found in JSON + ValueError: If data format is invalid + + Example: + >>> registrar = RegisterModelsPCA.from_slicersalt( + ... pca_template_model=pca_template_model, + ... pca_json_filename='path/to/pca.json', + ... pca_group_key='All', + ... fixed_model=fixed_model, + ... reference_image=reference_image + ... ) + """ + # Create a logger for the classmethod since superclassclasss hasn'tt + # been initialized yet. + logger = logging.getLogger("PhysioMotion4D") + + json_path = Path(pca_json_filename) + + # Check if JSON file exists + if not json_path.exists(): + self.log_error(f"PCA JSON file not found: {pca_json_filename}") + raise FileNotFoundError(f"PCA JSON file not found: {pca_json_filename}") + + logger.info("Loading PCA data from SlicerSALT format...") + logger.info(f" JSON file: {json_path}") + logger.info(f" Group key: {pca_group_key}") + + # Load PCA data from JSON + logger.info("Reading JSON file...") + with open(json_path, 'r', encoding='utf-8') as f: + pca_data = json.load(f) + + # Extract PCA group data + if pca_group_key not in pca_data: + available_keys = list(pca_data.keys()) + raise KeyError( + f"Group key '{pca_group_key}' not found in JSON. " + f"Available keys: {available_keys}" + ) + + pca_group_data = pca_data[pca_group_key] + + # Extract data_projection_std + if 'data_projection_std' not in pca_group_data: + raise ValueError( + f"'data_projection_std' field not found in group '{pca_group_key}' data" + ) + pca_std_deviations = np.array(pca_group_data['data_projection_std']) + logger.info(" Loaded %d standard deviations", len(pca_std_deviations)) + + # Extract pca_eigenvector components + if 'components' not in pca_group_data: + raise ValueError( + f"'components' field not found in group '{pca_group_key}' data" + ) + pca_eigenvectors = np.array(pca_group_data['components'], dtype=np.float64) + logger.info(f" Loaded pca_eigenvectors with shape {pca_eigenvectors.shape}") + + expected_pca_eigenvector_size = pca_template_model.n_points * 3 + actual_pca_eigenvector_size = pca_eigenvectors.shape[1] + if actual_pca_eigenvector_size != expected_pca_eigenvector_size: + raise ValueError( + f"pca_Eigenvector dimension mismatch: " + f"Expected {expected_pca_eigenvector_size} (3 × {pca_template_model.n_points} model points), " + f"got {actual_pca_eigenvector_size}" + ) + + logger.info(" ✓ Data validation successful!") + logger.info("SlicerSALT PCA data loaded successfully!") + + return cls( + pca_template_model=pca_template_model, + pca_eigenvectors=pca_eigenvectors, + pca_std_deviations=pca_std_deviations, + pca_number_of_modes=pca_number_of_modes, + pca_template_model_point_subsample=pca_template_model_point_subsample, + post_pca_transform=post_pca_transform, + fixed_distance_map=fixed_distance_map, + fixed_model=fixed_model, + reference_image=reference_image, + log_level=log_level, + ) + + def _create_itk_points(self) -> None: + """Pre-convert mean shape points to ITK Point format for efficiency. + + This method creates ITK Point objects once at initialization, avoiding + repeated conversions during optimization iterations. + """ + self.log_info("Converting mean shape points to ITK format...") + + self._pca_template_model_points_itk = [] + itk_point = itk.Point[itk.D, 3]() + for point in self.pca_template_model.points: + itk_point[0] = float(point[0]) + itk_point[1] = float(point[1]) + itk_point[2] = float(point[2]) + self._pca_template_model_points_itk.append(itk_point) + + self.log_info( + f" Converted {len(self._pca_template_model_points_itk)} points to ITK format" + ) + + def set_fixed_model( + self, fixed_model: pv.UnstructuredGrid, reference_image: itk.Image + ) -> None: + """Set the fixed model for registration. + + If this is set, the fixed distance map will be set to None. + + Args: + fixed_model: PyVista model used to compute the distance map, if one isn't provided. + reference_image: ITK image providing coordinate frame for computing the distance map. + """ + self.fixed_distance_map = self._contour_tools.create_distance_map( + fixed_model, + reference_image, + squared_distance=True, + ) + self._interpolator = None + + def set_fixed_distance_map(self, fixed_distance_map: itk.Image) -> None: + """Set the reference image for registration. + + If this is set, the fixed model will be set to None. + + Args: + fixed_distance_map: ITK image providing distance data + """ + self.fixed_distance_map = fixed_distance_map + self._interpolator = None + + def set_pca_template_model(self, pca_template_model: pv.UnstructuredGrid) -> None: + """Set the average model for registration. + + Args: + pca_template_model: PyVista model containing the mean 3D shape model + (unstructured grid or polydata) + """ + self.pca_template_model = pca_template_model + + self._pca_template_model_points_itk = None + + self._create_itk_points() + self.log_info(" ✓ Average model set successfully!") + + def _mean_distance_metric( + self, + params: np.ndarray, + ) -> float: + """Evaluate the optimization metric (mean intensity) at model points. + + This is the objective function to be MAXIMIZED during optimization. + Higher values indicate better alignment with bright regions. + + Args: + pca_deformation: Nx3 numpy array of PCA deformation vectors to add to points. + If None, no deformation is applied. + + Returns: + Mean distance value across all points + """ + pca_deformation = self._compute_pca_deformation(params) + + # Create interpolator if not already cached (inline creation) + if self._interpolator is None: + if self.fixed_distance_map is None: + self.log_error("Distance map is not set.") + raise ValueError("Distance map must be set before registering.") + ImageType = type(self.fixed_distance_map) + self._interpolator = itk.LinearInterpolateImageFunction[ + ImageType, itk.D + ].New() + self._interpolator.SetInputImage(self.fixed_distance_map) + fixed_distance_map_array = itk.GetArrayFromImage(self.fixed_distance_map) + self._max_distance = fixed_distance_map_array.max() + self.log_debug("Interpolator created") + self.log_debug(" Max distance = %s", self._max_distance) + + self.log_debug("Evaluating params = %s", params) + self.log_debug(" Max displacement = %s", pca_deformation.max(axis=0)) + + # Sample distance at each point + n_valid_points = 0 + total_distance = 0.0 + center = np.zeros(3) + point = itk.Point[itk.D, 3]() + image_size = self.fixed_distance_map.GetBufferedRegion().GetSize() + for i, base_point in enumerate(self._pca_template_model_points_itk): + if i % self.pca_template_model_point_subsample != 0: + continue + + # Start with base point + point[0] = base_point[0] + point[1] = base_point[1] + point[2] = base_point[2] + + # Add PCA deformation if provided + point[0] += pca_deformation[i, 0] + point[1] += pca_deformation[i, 1] + point[2] += pca_deformation[i, 2] + + if self.post_pca_transform is not None: + point = self.post_pca_transform.TransformPoint(point) + + # Check if point is inside image bounds + + coord_index = ( + self.fixed_distance_map.TransformPhysicalPointToContinuousIndex(point) + ) + if ( + 0 <= coord_index[0] < image_size[0] + and 0 <= coord_index[1] < image_size[1] + and 0 <= coord_index[2] < image_size[2] + ): + center[0] += point[0] + center[1] += point[1] + center[2] += point[2] + distance = self._interpolator.EvaluateAtContinuousIndex(coord_index) + total_distance += distance + n_valid_points += 1 + else: + self.log_warning(" Point %d is outside image bounds (%s)", i, point) + return self._max_distance + + # Compute mean distance + mean_distance = total_distance / n_valid_points + center /= n_valid_points + + if self.log_level <= logging.DEBUG or self._metric_call_count % 100 == 0: + self.log_info( + " Metric %d: %s -> %f", + (self._metric_call_count + 1), + center, + mean_distance, + ) + self.log_info( + " Params %s", + params, + ) + self._metric_call_count += 1 + + return mean_distance + + def _compute_pca_deformation(self, pca_coefficients: np.ndarray) -> np.ndarray: + """Compute PCA deformation vectors for all points. + + Deformation is computed as: + displacement = Σ(b_i * std_i * pca_eigenvector_i) + + Args: + pca_coefficients: Array of PCA coefficients b_i (one per mode) + pca_number_of_modes: Number of PCA modes to use. Default: use all available modes + + Returns: + Nx3 array of deformation vectors (displacement from mean shape) + """ + # Initialize deformation to zero + deformation = np.zeros((self.pca_template_model.n_points, 3), dtype=np.float64) + + # Add contribution from each PCA mode + for i in range(self.pca_number_of_modes): + pca_eigenvector_flat = self.pca_eigenvectors[i, :] + + # Reshape to (N, 3) + pca_eigenvector_3d = pca_eigenvector_flat.reshape(-1, 3) + + # Add weighted deformation: b_i * std_i * pca_eigenvector_i + deformation += ( + pca_coefficients[i] * self.pca_std_deviations[i] * pca_eigenvector_3d + ) + + return deformation + + def _optimize_pca_coefficients( + self, + pca_number_of_modes: int = 0, + pca_coefficient_bounds: float = 3.0, + method: str = 'L-BFGS-B', + max_iterations: int = 50, + ) -> tuple[np.ndarray, float]: + """Optimize PCA coefficients + + This method optimizes PCA mode coefficients to deform the model to better match + low values in the distance map. + + Args: + pca_number_of_modes: Number of PCA modes to use in optimization. Using fewer + modes provides smoother deformations. Default: 10 + pca_coefficient_bounds: Bound on PCA coefficients in units of std deviations. + Default: 3.0 (±3 std deviations per mode) + method: Optimization method for scipy.optimize.minimize. + Default: 'L-BFGS-B' (supports bounds) + max_iterations: Maximum number of optimization iterations. + Default: 50 + + Returns: + Tuple of (pca_coefficients, mean_distance): + - pca_coefficients: Optimized PCA coefficients + - mean_distance: Final mean distance metric value + + Raises: + ValueError: If number of PCA modes to use exceeds available modes + """ + if pca_number_of_modes <= 0: + pca_number_of_modes = len(self.pca_eigenvectors) + if pca_number_of_modes > len(self.pca_eigenvectors): + raise ValueError( + f"Number of PCA modes to use ({pca_number_of_modes}) exceeds available modes ({len(self.pca_std_deviations)})" + ) + self.pca_number_of_modes = pca_number_of_modes + + self.log_info(f"Number of PCA modes: {pca_number_of_modes}") + self.log_info( + f"PCA coefficient bounds: ±{pca_coefficient_bounds} std deviations" + ) + self.log_info(f"Optimization method: {method}") + self.log_info(f"Max iterations: {max_iterations}") + + bounds = [] + for _ in range(pca_number_of_modes): + bounds.append((-pca_coefficient_bounds, pca_coefficient_bounds)) + + disp = self.log_level <= logging.INFO + + self.log_info("Running optimization...") + result_pca = minimize( + lambda params: self._mean_distance_metric(params), + np.zeros(self.pca_number_of_modes), + method=method, + bounds=bounds, + options={'maxiter': max_iterations, 'disp': disp}, + ) + + optimized_pca_coefficients = result_pca.x + optimized_mean_distance = result_pca.fun + + self.log_info("Optimization completed!") + self.log_info(f"Optimized PCA coefficients: {optimized_pca_coefficients}") + self.log_info(f"Final mean intensity: {optimized_mean_distance:.2f}") + + return optimized_pca_coefficients, optimized_mean_distance + + def transform_template_model(self) -> pv.UnstructuredGrid: + """Create the final registered model by applying PCA deformation. + + Returns: + Final registered and deformed model as PyVista UnstructuredGrid + + Raises: + ValueError: If registration has not been performed + """ + if self.registered_model_pca_coefficients is None: + self.log_error("PCA coefficients are not set.") + raise ValueError( + "PCA coefficients must be set before creating registered model" + ) + + self.log_info("Creating final registered model...") + + # Compute PCA deformation + if self.register_model_pca_deformation is None: + self.register_model_pca_deformation = self._compute_pca_deformation( + self.registered_model_pca_coefficients, + ) + + # Apply deformation and affine transform to each point + final_points = np.zeros((self.pca_template_model.n_points, 3), dtype=np.float64) + + n_points = self.pca_template_model.n_points + progress_interval = max(1, n_points // 10) # Report progress every 10% + + point = itk.Point[itk.D, 3]() + for i in range(n_points): + # Report progress + if i % progress_interval == 0 or i == n_points - 1: + self.log_progress(i + 1, n_points, prefix="Transforming points") + + # Start with mean shape point + point[0] = float(self.pca_template_model.points[i][0]) + point[1] = float(self.pca_template_model.points[i][1]) + point[2] = float(self.pca_template_model.points[i][2]) + + # Add PCA deformation + point[0] += self.register_model_pca_deformation[i, 0] + point[1] += self.register_model_pca_deformation[i, 1] + point[2] += self.register_model_pca_deformation[i, 2] + + if self.post_pca_transform is not None: + point = self.post_pca_transform.TransformPoint(point) + + # Store result + final_points[i, 0] = point[0] + final_points[i, 1] = point[1] + final_points[i, 2] = point[2] + + # Create new model with transformed points + self.registered_model = self.pca_template_model.copy(deep=True) + self.registered_model.points = final_points.copy() + + self.log_info( + f"Registered model created with {self.registered_model.n_points} points" + ) + + return self.registered_model + + def transform_point( + self, + point: itk.Point, + include_post_pca_transform: bool = True, + ) -> itk.Point: + """Transform an arbitrary point using nearest neighbor interpolation. + + Args: + point: ITK point to transform (itk.Point[itk.D, 3]) + + Returns: + Transformed ITK point + + Raises: + ValueError: If registration has not been completed yet + + Example: + >>> p = itk.Point[itk.D, 3]() + >>> p[0], p[1], p[2] = 10.0, 20.0, 30.0 + >>> transformed_p = registrar.transform_point(p) + """ + + if self._deformation_field_interpolator_x is None: + field_array = itk.GetArrayFromImage( + self._template_model_pca_deformation_field_image + ) + field_x_image = itk.GetImageFromArray(field_array[:, :, :, 0]) + field_x_image.CopyInformation( + self._template_model_pca_deformation_field_image + ) + self._deformation_field_interpolator_x = itk.LinearInterpolateImageFunction[ + itk.Image[itk.D, 3], itk.D + ].New() + self._deformation_field_interpolator_x.SetInputImage(field_x_image) + + field_y_image = itk.GetImageFromArray(field_array[:, :, :, 1]) + field_y_image.CopyInformation( + self._template_model_pca_deformation_field_image + ) + self._deformation_field_interpolator_y = itk.LinearInterpolateImageFunction[ + itk.Image[itk.D, 3], itk.D + ].New() + self._deformation_field_interpolator_y.SetInputImage(field_y_image) + + field_z_image = itk.GetImageFromArray(field_array[:, :, :, 2]) + field_z_image.CopyInformation( + self._template_model_pca_deformation_field_image + ) + self._deformation_field_interpolator_z = itk.LinearInterpolateImageFunction[ + itk.Image[itk.D, 3], itk.D + ].New() + self._deformation_field_interpolator_z.SetInputImage(field_z_image) + + cindx = self._template_model_pca_deformation_field_image.TransformPhysicalPointToContinuousIndex( + point + ) + size = ( + self._template_model_pca_deformation_field_image.GetLargestPossibleRegion().GetSize() + ) + if ( + cindx[0] < 0 + or cindx[0] >= size[0] + or cindx[1] < 0 + or cindx[1] >= size[1] + or cindx[2] < 0 + or cindx[2] >= size[2] + ): + self.log_error("Point is outside deformation field bounds") + return point + + deformation_x = ( + self._deformation_field_interpolator_x.EvaluateAtContinuousIndex(cindx) + ) + deformation_y = ( + self._deformation_field_interpolator_y.EvaluateAtContinuousIndex(cindx) + ) + deformation_z = ( + self._deformation_field_interpolator_z.EvaluateAtContinuousIndex(cindx) + ) + + transformed_point = itk.Point[itk.D, 3]() + transformed_point[0] = float(point[0] + deformation_x) + transformed_point[1] = float(point[1] + deformation_y) + transformed_point[2] = float(point[2] + deformation_z) + + if include_post_pca_transform: + transformed_point = self.post_pca_transform.TransformPoint( + transformed_point + ) + + return transformed_point + + def compute_pca_transforms(self, reference_image: itk.Image) -> dict: + """Compute PCA transforms. + + Returns: + Dictionary containing: + - 'forward_point_transform': Forward displacement field transform + - 'inverse_point_transform': Inverse displacement field transform + """ + self._template_model_pca_deformation_field_image = ( + self._contour_tools.create_deformation_field( + np.array(self.pca_template_model.points), + self.register_model_pca_deformation, + reference_image=reference_image, + blur_sigma=2.5, + ptype=itk.D, + ) + ) + + self.forward_point_transform = itk.DisplacementFieldTransform[itk.D, 3].New() + self.forward_point_transform.SetDisplacementField( + self._template_model_pca_deformation_field_image + ) + + transform_tools = TransformTools() + self.inverse_point_transform = ( + transform_tools.invert_displacement_field_transform( + self.forward_point_transform + ) + ) + return { + 'forward_point_transform': self.forward_point_transform, + 'inverse_point_transform': self.inverse_point_transform, + } + + def register( + self, + pca_number_of_modes: int = 0, + pca_coefficient_bounds: float = 3.0, + method: str = 'L-BFGS-B', + max_iterations: int = 50, + ) -> dict: + """Optimize PCA coefficients to deform the model to better match + low values in the distance map. + + Args: + pca_number_of_modes: Number of PCA modes to use. Default: 0 (use all available modes) + pca_coefficient_bounds: PCA coefficient bounds (±std devs). Default: 3.0 + method: Optimization method for scipy.optimize.minimize. + Default: 'L-BFGS-B' (supports bounds) + max_iterations: Maximum number of optimization iterations. + Default: 50 + + Returns: + Dictionary containing: + - 'registered_model': Final registered PyVista model + - 'pca_coefficients': Optimized PCA coefficients + - 'mean_distance': Final mean distance metric value + + Raises: + ValueError: If reference image is not set + + Example: + >>> result = registrar.register( + ... pca_number_of_modes=10 + ... ) + >>> result['registered_model'].save("registered_heart.vtk") + """ + if self.fixed_distance_map is None: + raise ValueError("Reference image must be set before registration") + + if pca_number_of_modes <= 0: + pca_number_of_modes = self.pca_number_of_modes + + self.log_section("PCA-BASED MODEL-TO-IMAGE REGISTRATION", width=70) + self.log_info(f"Number of points: {self.pca_template_model.n_points}") + self.log_info(f"Modes to use: {pca_number_of_modes}") + + self.registered_model_pca_coefficients, self.registered_model_mean_distance = ( + self._optimize_pca_coefficients( + pca_number_of_modes=pca_number_of_modes, + pca_coefficient_bounds=pca_coefficient_bounds, + method=method, + max_iterations=max_iterations, + ) + ) + + # Create final registered model + self.register_model_pca_deformation = None + self.registered_model = self.transform_template_model() + + # Return results as dictionary + return { + 'registered_model': self.registered_model, + 'pca_coefficients': self.registered_model_pca_coefficients, + 'mean_distance': self.registered_model_mean_distance, + } diff --git a/src/physiomotion4d/register_time_series_images.py b/src/physiomotion4d/register_time_series_images.py index 03dae08..c3cfede 100644 --- a/src/physiomotion4d/register_time_series_images.py +++ b/src/physiomotion4d/register_time_series_images.py @@ -9,6 +9,8 @@ CT where sequential frames need to be registered to a common frame. """ +import logging + import itk from physiomotion4d.register_images_ants import RegisterImagesANTs @@ -25,9 +27,9 @@ class RegisterTimeSeriesImages(RegisterImagesBase): ANTs and ICON registration methods and can propagate information from prior registrations to initialize subsequent ones. - The registration proceeds in two passes from a starting index: - 1. Forward pass: from starting_index to the end of the series - 2. Backward pass: from starting_index-1 to the beginning + The registration proceeds in two passes from a reference frame: + 1. Forward pass: from reference_frame to the end of the series + 2. Backward pass: from reference_frame-1 to the beginning This bidirectional approach helps maintain temporal coherence in the registration results. @@ -54,32 +56,33 @@ class RegisterTimeSeriesImages(RegisterImagesBase): >>> # Register all time points to fixed image >>> result = registrar.register_time_series( ... moving_images=time_series_images, - ... starting_index=5, # Start from middle of cardiac cycle - ... register_start_to_fixed_image=True, - ... portion_of_prior_transform_to_init_next_transform=0.5 + ... reference_frame=5, # Start from middle of cardiac cycle + ... register_reference=True, + ... prior_weight=0.5 ... ) >>> - >>> phi_MF_list = result["phi_MF_list"] - >>> phi_FM_list = result["phi_FM_list"] + >>> forward_tfms = result["forward_transforms"] # Moving → Fixed + >>> inverse_tfms = result["inverse_transforms"] # Fixed → Moving >>> losses = result["losses"] """ - def __init__(self, registration_method='ants'): + def __init__(self, registration_method='ants', log_level: int | str = logging.INFO): """Initialize the time series image registration class. Args: registration_method (str): Registration method to use. Options: 'ants' or 'icon'. Default: 'ants' + log_level: Logging level (default: logging.INFO) Raises: ValueError: If registration_method is not 'ants' or 'icon' """ - super().__init__() + super().__init__(log_level=log_level) self.registration_method = registration_method.lower() - self.registrar_ants = RegisterImagesANTs() - self.registrar_icon = RegisterImagesICON() + self.registrar_ants = RegisterImagesANTs(log_level=log_level) + self.registrar_icon = RegisterImagesICON(log_level=log_level) if self.registration_method == 'ants': self.number_of_iterations = [40, 20, 10] elif self.registration_method == 'icon': @@ -146,98 +149,96 @@ def set_fixed_image(self, fixed_image): """ self.fixed_image = fixed_image - def set_fixed_image_mask(self, fixed_image_mask): + def set_fixed_mask(self, fixed_mask): """Set a binary mask for the fixed image region of interest. This passes through to the underlying registration method. Args: - fixed_image_mask (itk.Image): Binary mask defining ROI + fixed_mask (itk.Image): Binary mask defining ROI """ - self.fixed_image_mask = fixed_image_mask + self.fixed_mask = fixed_mask def register_time_series( self, moving_images, - moving_images_masks=None, - starting_index=0, - register_start_to_fixed_image=True, - portion_of_prior_transform_to_init_next_transform=0.0, - images_are_labelmaps=False, + moving_masks=None, + reference_frame=0, + register_reference=True, + prior_weight=0.0, ): """Register a time series of images to the fixed image. This method registers an ordered sequence of images to a common fixed - frame. Registration proceeds bidirectionally from a starting index: + frame. Registration proceeds bidirectionally from a reference frame: forward to the end and backward to the beginning. - For each image after the starting image, the method can optionally use + For each image after the reference image, the method can optionally use the transform from the previous image to initialize the registration, which can improve convergence and temporal coherence. Args: moving_images (list[itk.Image]): List of 3D images to register - moving_images_masks (list[itk.Image], optional): List of binary masks, + moving_masks (list[itk.Image], optional): List of binary masks, one for each moving image. If None, no masks are used. If provided, must have the same length as moving_images. Default: None - starting_index (int, optional): Index of the first image to register. + reference_frame (int, optional): Index of the reference image to register first. Registration proceeds forward from this index to the end, then backward from this index to the beginning. Default: 0 - register_start_to_fixed_image (bool, optional): If True, register the - starting image to the fixed image. If False, use identity transform - for the starting image. Default: True - portion_of_prior_transform_to_init_next_transform (float, optional): + register_reference (bool, optional): If True, register the + reference image to the fixed image. If False, use identity transform + for the reference image. Default: True + prior_weight (float, optional): Weight (0.0 to 1.0) for using the prior image's transform to initialize the next registration. 0.0 means no prior information is used (each registration starts from identity). Higher values provide more temporal smoothness but may propagate errors. Default: 0.0 - images_are_labelmaps (bool, optional): If True, treat images as label maps - and use appropriate label-based registration. Default: False Returns: - dict: Dictionary containing: - - "phi_MF_list" (list[itk.Transform]): Transforms from moving to fixed - space for each image in moving_images - - "phi_FM_list" (list[itk.Transform]): Transforms from fixed to moving - space for each image in moving_images + dict: Dictionary containing results: + - "forward_transforms" (list[itk.Transform]): Transforms from moving to fixed + space for each image (warps moving → fixed) + - "inverse_transforms" (list[itk.Transform]): Transforms from fixed to moving + space for each image (warps fixed → moving) - "losses" (list[float]): Registration loss value for each image Raises: ValueError: If fixed_image is not set - ValueError: If starting_index is out of range - ValueError: If portion_of_prior_transform_to_init_next_transform not in [0, 1] - ValueError: If moving_images_masks length doesn't match moving_images length + ValueError: If reference_frame is out of range + ValueError: If prior_weight not in [0, 1] + ValueError: If moving_masks length doesn't match moving_images length Note: The method compares registration with identity initialization versus prior transform initialization and selects the result with lower loss. This helps prevent error propagation in the temporal sequence. - The fixed image mask can be set using set_fixed_image_mask() before + The fixed image mask can be set using set_fixed_mask() before calling this method. Example: >>> registrar = RegisterTimeSeriesImages(registration_method='ants') >>> registrar.set_fixed_image(fixed_image) - >>> registrar.set_fixed_image_mask(fixed_mask) # Optional + >>> registrar.set_fixed_mask(fixed_mask) # Optional >>> registrar.set_number_of_iterations([30, 15, 5]) >>> + >>> # Use new intuitive parameter names >>> result = registrar.register_time_series( ... moving_images=image_list, - ... moving_images_masks=mask_list, # Optional - ... starting_index=5, - ... register_start_to_fixed_image=True, - ... portion_of_prior_transform_to_init_next_transform=0.0 + ... moving_masks=mask_list, # Optional + ... reference_frame=5, + ... register_reference=True, + ... prior_weight=0.5 ... ) >>> - >>> # Access results - >>> for i, (phi_MF, loss) in enumerate(zip( - ... result["phi_MF_list"], result["losses"] + >>> # Access results using new intuitive names + >>> for i, (forward_tfm, loss) in enumerate(zip( + ... result["forward_transforms"], result["losses"] ... )): - ... # Apply transform to image i + ... # Apply forward transform to align moving image i to fixed ... registered = transform_tools.transform_image( - ... moving_images[i], phi_MF, fixed_image + ... moving_images[i], forward_tfm, fixed_image ... ) """ if self.fixed_image is None: @@ -248,46 +249,44 @@ def register_time_series( self.registrar_ants.set_modality(self.modality) self.registrar_ants.set_mask_dilation(self.mask_dilation_mm) self.registrar_ants.set_number_of_iterations(self.number_of_iterations) - self.registrar_ants.set_fixed_image_mask(self.fixed_image_mask) + self.registrar_ants.set_fixed_mask(self.fixed_mask) elif self.registration_method == 'icon': self.registrar_icon.set_fixed_image(self.fixed_image) self.registrar_icon.set_modality(self.modality) self.registrar_icon.set_mask_dilation(self.mask_dilation_mm) self.registrar_icon.set_number_of_iterations(self.number_of_iterations) - self.registrar_icon.set_fixed_image_mask(self.fixed_image_mask) + self.registrar_icon.set_fixed_mask(self.fixed_mask) elif self.registration_method == 'ants_icon': self.registrar_ants.set_fixed_image(self.fixed_image) self.registrar_ants.set_modality(self.modality) self.registrar_ants.set_mask_dilation(self.mask_dilation_mm) self.registrar_ants.set_number_of_iterations(self.number_of_iterations[0]) - self.registrar_ants.set_fixed_image_mask(self.fixed_image_mask) + self.registrar_ants.set_fixed_mask(self.fixed_mask) self.registrar_icon.set_fixed_image(self.fixed_image) self.registrar_icon.set_modality(self.modality) self.registrar_icon.set_mask_dilation(self.mask_dilation_mm) self.registrar_icon.set_number_of_iterations(self.number_of_iterations[1]) - self.registrar_icon.set_fixed_image_mask(self.fixed_image_mask) + self.registrar_icon.set_fixed_mask(self.fixed_mask) num_images = len(moving_images) - if starting_index < 0 or starting_index >= num_images: + if reference_frame < 0 or reference_frame >= num_images: raise ValueError( - f"starting_index {starting_index} out of range [0, {num_images-1}]" + f"reference_frame {reference_frame} out of range [0, {num_images-1}]" ) - if not 0.0 <= portion_of_prior_transform_to_init_next_transform <= 1.0: - raise ValueError( - "portion_of_prior_transform_to_init_next_transform must be in [0.0, 1.0]" - ) + if not 0.0 <= prior_weight <= 1.0: + raise ValueError("prior_weight must be in [0.0, 1.0]") - if moving_images_masks is not None and len(moving_images_masks) != num_images: + if moving_masks is not None and len(moving_masks) != num_images: raise ValueError( - f"moving_images_masks length ({len(moving_images_masks)}) must match " + f"moving_masks length ({len(moving_masks)}) must match " f"moving_images length ({num_images})" ) # Initialize result lists - phi_MF_list = [None] * num_images - phi_FM_list = [None] * num_images + forward_transforms = [None] * num_images + inverse_transforms = [None] * num_images losses = [0.0] * num_images # Create identity transform for fixed image @@ -298,207 +297,192 @@ def register_time_series( ) ) - # Register the starting image - if register_start_to_fixed_image: - starting_mask = ( - moving_images_masks[starting_index] - if moving_images_masks is not None - else None + # Register the reference frame image + if register_reference: + reference_mask = ( + moving_masks[reference_frame] if moving_masks is not None else None ) if self.registration_method == 'ants': result = self.registrar_ants.register( - moving_images[starting_index], - moving_image_mask=starting_mask, - images_are_labelmaps=images_are_labelmaps, + moving_images[reference_frame], + moving_mask=reference_mask, ) elif self.registration_method == 'icon': result = self.registrar_icon.register( - moving_images[starting_index], - moving_image_mask=starting_mask, - images_are_labelmaps=images_are_labelmaps, + moving_images[reference_frame], + moving_mask=reference_mask, ) elif self.registration_method == 'ants_icon': result = self.registrar_ants.register( - moving_images[starting_index], - moving_image_mask=starting_mask, - images_are_labelmaps=images_are_labelmaps, + moving_images[reference_frame], + moving_mask=reference_mask, ) - phi_MF_ants = result["phi_MF"] + forward_ants = result["forward_transform"] result = self.registrar_icon.register( - moving_images[starting_index], - moving_image_mask=starting_mask, - initial_phi_MF=phi_MF_ants, - images_are_labelmaps=images_are_labelmaps, + moving_images[reference_frame], + moving_mask=reference_mask, + initial_forward_transform=forward_ants, ) else: raise ValueError( f"Invalid registration method: {self.registration_method}" ) - phi_MF = result["phi_MF"] - phi_FM = result["phi_FM"] + forward_transform = result["forward_transform"] + inverse_transform = result["inverse_transform"] loss = result["loss"] else: - # Use identity transform for starting image - phi_MF = identity_tfm - phi_FM = identity_tfm + # Use identity transform for reference frame + forward_transform = identity_tfm + inverse_transform = identity_tfm loss = 0.0 - phi_MF_list[starting_index] = phi_MF - phi_FM_list[starting_index] = phi_FM - losses[starting_index] = loss + forward_transforms[reference_frame] = forward_transform + inverse_transforms[reference_frame] = inverse_transform + losses[reference_frame] = loss - # Compute prior transform for starting image if needed - prior_phi_MF_ref = None - if portion_of_prior_transform_to_init_next_transform > 0.0: - prior_phi_MF_ref = ( + # Compute prior transform for reference frame if needed + prior_forward_ref = None + if prior_weight > 0.0: + prior_forward_ref = ( self.transform_tools.combine_displacement_field_transforms( identity_tfm, - phi_MF, + forward_transform, self.fixed_image, tfm1_weight=1.0, - tfm2_weight=portion_of_prior_transform_to_init_next_transform, + tfm2_weight=prior_weight, tfm1_blur_sigma=0.0, tfm2_blur_sigma=0.5, mode="add", ) ) - # Register forward and backward from starting index + # Register forward and backward from reference frame for step, start_idx, end_idx in [ - (1, starting_index + 1, num_images), # Forward pass - (-1, starting_index - 1, -1), # Backward pass + (1, reference_frame + 1, num_images), # Forward pass + (-1, reference_frame - 1, -1), # Backward pass ]: - prior_phi_MF = prior_phi_MF_ref + prior_forward = prior_forward_ref for img_idx in range(start_idx, end_idx, step): moving_image = moving_images[img_idx] moving_mask = ( - moving_images_masks[img_idx] - if moving_images_masks is not None - else None + moving_masks[img_idx] if moving_masks is not None else None ) # Try registration with identity initialization if self.registration_method == 'ants': result_init_identity = self.registrar_ants.register( moving_image=moving_image, - moving_image_mask=moving_mask, - images_are_labelmaps=images_are_labelmaps, + moving_mask=moving_mask, ) elif self.registration_method == 'icon': result_init_identity = self.registrar_icon.register( moving_image=moving_image, - moving_image_mask=moving_mask, - images_are_labelmaps=images_are_labelmaps, + moving_mask=moving_mask, ) elif self.registration_method == 'ants_icon': result_init_identity = self.registrar_ants.register( moving_image=moving_image, - moving_image_mask=moving_mask, - images_are_labelmaps=images_are_labelmaps, + moving_mask=moving_mask, ) - phi_MF_ants = result_init_identity["phi_MF"] + forward_ants = result_init_identity["forward_transform"] result_init_identity = self.registrar_icon.register( moving_image=moving_image, - moving_image_mask=moving_mask, - initial_phi_MF=phi_MF_ants, - images_are_labelmaps=images_are_labelmaps, + moving_mask=moving_mask, + initial_forward_transform=forward_ants, ) else: raise ValueError( f"Invalid registration method: {self.registration_method}" ) - phi_MF_init_identity = result_init_identity["phi_MF"] - phi_FM_init_identity = result_init_identity["phi_FM"] + forward_init_identity = result_init_identity["forward_transform"] + inverse_init_identity = result_init_identity["inverse_transform"] loss_init_identity = result_init_identity["loss"] # Select best result based on prior usage - if portion_of_prior_transform_to_init_next_transform > 0.0: + if prior_weight > 0.0: # Try with prior transform initialization if self.registration_method == 'ants': result_init_prior = self.registrar_ants.register( moving_image=moving_image, - moving_image_mask=moving_mask, - initial_phi_MF=prior_phi_MF, - images_are_labelmaps=images_are_labelmaps, + moving_mask=moving_mask, + initial_forward_transform=prior_forward, ) elif self.registration_method == 'icon': result_init_prior = self.registrar_icon.register( moving_image=moving_image, - moving_image_mask=moving_mask, - initial_phi_MF=prior_phi_MF, - images_are_labelmaps=images_are_labelmaps, + moving_mask=moving_mask, + initial_forward_transform=prior_forward, ) elif self.registration_method == 'ants_icon': result_init_prior = self.registrar_ants.register( moving_image=moving_image, - moving_image_mask=moving_mask, - initial_phi_MF=prior_phi_MF, - images_are_labelmaps=images_are_labelmaps, + moving_mask=moving_mask, + initial_forward_transform=prior_forward, ) - phi_MF_ants = result_init_prior["phi_MF"] + forward_ants = result_init_prior["forward_transform"] result_init_prior = self.registrar_icon.register( moving_image=moving_image, - moving_image_mask=moving_mask, - initial_phi_MF=phi_MF_ants, - images_are_labelmaps=images_are_labelmaps, + moving_mask=moving_mask, + initial_forward_transform=forward_ants, ) else: raise ValueError( f"Invalid registration method: {self.registration_method}" ) - phi_MF_init_prior = result_init_prior["phi_MF"] - phi_FM_init_prior = result_init_prior["phi_FM"] + forward_init_prior = result_init_prior["forward_transform"] + inverse_init_prior = result_init_prior["inverse_transform"] loss_init_prior = result_init_prior["loss"] # Select result with lower loss if loss_init_identity < loss_init_prior: # Identity initialization was better - prior_phi_MF = identity_tfm - phi_MF = phi_MF_init_identity - phi_FM = phi_FM_init_identity + prior_forward = identity_tfm + forward_transform = forward_init_identity + inverse_transform = inverse_init_identity loss = loss_init_identity else: # Prior initialization was better - phi_MF = phi_MF_init_prior - phi_FM = phi_FM_init_prior + forward_transform = forward_init_prior + inverse_transform = inverse_init_prior loss = loss_init_prior # Update prior for next iteration - prior_phi_MF = self.transform_tools.combine_displacement_field_transforms( - identity_tfm, - phi_MF, - self.fixed_image, - tfm1_weight=1.0, - tfm2_weight=portion_of_prior_transform_to_init_next_transform, - tfm1_blur_sigma=0.0, - tfm2_blur_sigma=self.smooth_prior_transform_sigma, - mode="add", + prior_forward = ( + self.transform_tools.combine_displacement_field_transforms( + identity_tfm, + forward_transform, + self.fixed_image, + tfm1_weight=1.0, + tfm2_weight=prior_weight, + tfm1_blur_sigma=0.0, + tfm2_blur_sigma=self.smooth_prior_transform_sigma, + mode="add", + ) ) else: # No prior usage, just use identity result - phi_MF = phi_MF_init_identity - phi_FM = phi_FM_init_identity + forward_transform = forward_init_identity + inverse_transform = inverse_init_identity loss = loss_init_identity # Store results - phi_MF_list[img_idx] = phi_MF - phi_FM_list[img_idx] = phi_FM + forward_transforms[img_idx] = forward_transform + inverse_transforms[img_idx] = inverse_transform losses[img_idx] = loss return { - "phi_MF_list": phi_MF_list, - "phi_FM_list": phi_FM_list, + "forward_transforms": forward_transforms, # List of transforms: moving → fixed + "inverse_transforms": inverse_transforms, # List of transforms: fixed → moving "losses": losses, } def registration_method( self, moving_image, - moving_image_mask=None, + moving_mask=None, moving_image_pre=None, - images_are_labelmaps=False, - initial_phi_MF=None, + initial_forward_transform=None, ): """Registration method required by RegisterImagesBase. @@ -507,43 +491,38 @@ def registration_method( Args: moving_image (itk.Image): Image to register - moving_image_mask (itk.Image, optional): Binary mask + moving_mask (itk.Image, optional): Binary mask moving_image_pre (itk.Image, optional): Preprocessed image - images_are_labelmaps (bool, optional): Whether to use label-based registration - initial_phi_MF (itk.Transform, optional): Initial transform + initial_forward_transform (itk.Transform, optional): Initial transform Returns: - dict: Registration result with phi_FM, phi_MF, and loss + dict: Registration result with forward_transform, inverse_transform, and loss """ if self.registration_method == 'ants': return self.registrar_ants.registration_method( moving_image=moving_image, - moving_image_mask=moving_image_mask, + moving_mask=moving_mask, moving_image_pre=moving_image_pre, - images_are_labelmaps=images_are_labelmaps, - initial_phi_MF=initial_phi_MF, + initial_forward_transform=initial_forward_transform, ) elif self.registration_method == 'icon': return self.registrar_icon.registration_method( moving_image=moving_image, - moving_image_mask=moving_image_mask, + moving_mask=moving_mask, moving_image_pre=moving_image_pre, - images_are_labelmaps=images_are_labelmaps, - initial_phi_MF=initial_phi_MF, + initial_forward_transform=initial_forward_transform, ) elif self.registration_method == 'ants_icon': - phi_MF_ants = self.registrar_ants.registration_method( + forward_ants = self.registrar_ants.registration_method( moving_image=moving_image, - moving_image_mask=moving_image_mask, + moving_mask=moving_mask, moving_image_pre=moving_image_pre, - images_are_labelmaps=images_are_labelmaps, - )["phi_MF"] + )["forward_transform"] return self.registrar_icon.registration_method( moving_image=moving_image, - moving_image_mask=moving_image_mask, + moving_mask=moving_mask, moving_image_pre=moving_image_pre, - images_are_labelmaps=images_are_labelmaps, - initial_phi_MF=phi_MF_ants, + initial_forward_transform=forward_ants, ) else: raise ValueError(f"Invalid registration method: {self.registration_method}") diff --git a/src/physiomotion4d/segment_chest_base.py b/src/physiomotion4d/segment_chest_base.py index 2b76da8..c03c7ae 100644 --- a/src/physiomotion4d/segment_chest_base.py +++ b/src/physiomotion4d/segment_chest_base.py @@ -5,12 +5,16 @@ preprocessing, postprocessing, and anatomical structure organization tasks. """ +import logging + import itk import numpy as np from itk import TubeTK as tube +from physiomotion4d.physiomotion4d_base import PhysioMotion4DBase + -class SegmentChestBase: +class SegmentChestBase(PhysioMotion4DBase): """Base class for chest segmentation that provides common functionality for segmenting chest CT images. @@ -37,13 +41,18 @@ class SegmentChestBase: other_mask_ids (dict): Dictionary of remaining structure IDs """ - def __init__(self): + def __init__(self, log_level: int | str = logging.INFO): """Initialize the SegmentChestBase class. Sets up default parameters for image preprocessing and anatomical structure ID mappings. Subclasses should call this constructor and then override the mask ID dictionaries with their specific mappings. + + Args: + log_level: Logging level (default: logging.INFO) """ + super().__init__(class_name=self.__class__.__name__, log_level=log_level) + self.target_spacing = 0 self.rescale_intensity_range = False @@ -138,7 +147,7 @@ def preprocess_input( # Check the input image assert len(input_image.GetSpacing()) == 3, "The input image must be 3D" - resale_image = False + rescale_image = False results_image = None if self.target_spacing > 0.0: if ( @@ -146,26 +155,26 @@ def preprocess_input( or input_image.GetSpacing()[1] != self.target_spacing or input_image.GetSpacing()[2] != self.target_spacing ): - resale_image = True + rescale_image = True else: isotropy = ( (input_image.GetSpacing()[1] / input_image.GetSpacing()[0]) + (input_image.GetSpacing()[2] / input_image.GetSpacing()[0]) ) / 2 if isotropy < 0.9 or isotropy > 1.1: - resale_image = True + rescale_image = True self.target_spacing = ( input_image.GetSpacing()[0] + input_image.GetSpacing()[1] + input_image.GetSpacing()[2] ) / 3 - print( - " Resampling to", self.target_spacing, "isotropic spacing." + self.log_info( + "Resampling to %.3f isotropic spacing", self.target_spacing ) - if resale_image: - print("WARNING: The input image should have isotropic spacing.") - print(" The input image has spacing:", input_image.GetSpacing()) - print(" Resampling to isotropic:", self.target_spacing) + if rescale_image: + self.log_warning("The input image should have isotropic spacing") + self.log_info("Input image has spacing: %s", str(input_image.GetSpacing())) + self.log_info("Resampling to isotropic: %.3f", self.target_spacing) interpolator = itk.LinearInterpolateImageFunction.New(input_image) results_image = itk.ResampleImageFilter( input_image, @@ -204,7 +213,7 @@ def preprocess_input( minv = results_image_arr.min() maxv = results_image_arr.max() if self.rescale_intensity_range: - print("Rescaling intensity range...") + self.log_info("Rescaling intensity range...") if ( self.input_intensity_scale_range is None or self.output_intensity_scale_range is None diff --git a/src/physiomotion4d/segment_chest_ensemble.py b/src/physiomotion4d/segment_chest_ensemble.py index 76969a9..d16e506 100644 --- a/src/physiomotion4d/segment_chest_ensemble.py +++ b/src/physiomotion4d/segment_chest_ensemble.py @@ -8,6 +8,7 @@ # -v /tmp/data:/home/aylward/tmp/data nvcr.io/nim/nvidia/vista3d:latest import argparse +import logging import itk import numpy as np @@ -23,9 +24,13 @@ class SegmentChestEnsemble(SegmentChestBase): segmentation method using VISTA3D. """ - def __init__(self): - """Initialize the vista3d class.""" - super().__init__() + def __init__(self, log_level: int | str = logging.INFO): + """Initialize the vista3d class. + + Args: + log_level: Logging level (default: logging.INFO) + """ + super().__init__(log_level=log_level) self.target_spacing = 0.0 @@ -316,15 +321,15 @@ def ensemble_segmentation( itk.image: The combined segmentation result. """ - print("Running ensemble segmentation: combining results") + self.log_info("Running ensemble segmentation: combining results") labelmap_vista_arr = itk.GetArrayFromImage(labelmap_vista) labelmap_totseg_arr = itk.GetArrayFromImage(labelmap_totseg) - print("Segs done", flush=True) + self.log_info("Segmentations loaded") results_arr = np.zeros_like(labelmap_vista_arr) - print("Setting interpolators") + self.log_info("Setting interpolators") labelmap_vista_interp = itk.LabelImageGaussianInterpolateImageFunction.New( labelmap_vista ) @@ -332,11 +337,15 @@ def ensemble_segmentation( labelmap_totseg ) - print("Iterating through labelmaps", flush=True) + self.log_info("Iterating through labelmaps") lastidx0 = -1 + total_slices = labelmap_vista_arr.shape[0] for idx in np.ndindex(labelmap_vista_arr.shape): if idx[0] != lastidx0: - print("Processing slice", idx[0], flush=True) + if idx[0] % 10 == 0 or idx[0] == total_slices - 1: + self.log_progress( + idx[0] + 1, total_slices, prefix="Processing slices" + ) lastidx0 = idx[0] # Skip if both are zero vista_label = labelmap_vista_arr[idx] diff --git a/src/physiomotion4d/segment_chest_total_segmentator.py b/src/physiomotion4d/segment_chest_total_segmentator.py index 6b0e551..e1b04d0 100644 --- a/src/physiomotion4d/segment_chest_total_segmentator.py +++ b/src/physiomotion4d/segment_chest_total_segmentator.py @@ -7,6 +7,7 @@ """ import argparse +import logging import os import tempfile @@ -54,14 +55,17 @@ class SegmentChestTotalSegmentator(SegmentChestBase): >>> heart_mask = result["heart"] """ - def __init__(self): + def __init__(self, log_level: int | str = logging.INFO): """Initialize the TotalSegmentator-based chest segmentation. Sets up the TotalSegmentator-specific anatomical structure ID mappings and processing parameters. The target spacing is set to 1.5mm which provides a good balance between accuracy and processing speed. + + Args: + log_level: Logging level (default: logging.INFO) """ - super().__init__() + super().__init__(log_level=log_level) self.target_spacing = 1.5 @@ -232,10 +236,10 @@ def segmentation_method(self, preprocessed_image: itk.image) -> itk.image: # For higher performance, you can use fast=True, which uses a # faster but less accurate model. - output_nib_image1 = totalsegmentator(nib_image, task="total", device="cuda") + output_nib_image1 = totalsegmentator(nib_image, task="total", device="gpu") labelmap_arr1 = output_nib_image1.get_fdata().astype(np.uint8) - output_nib_image2 = totalsegmentator(nib_image, task="body", device="cuda") + output_nib_image2 = totalsegmentator(nib_image, task="body", device="gpu") labelmap_arr2 = output_nib_image2.get_fdata().astype(np.uint8) # The data from nibabel is in RAS orientation with xyz axis order. diff --git a/src/physiomotion4d/segment_chest_vista_3d.py b/src/physiomotion4d/segment_chest_vista_3d.py index f23829e..669f0ce 100644 --- a/src/physiomotion4d/segment_chest_vista_3d.py +++ b/src/physiomotion4d/segment_chest_vista_3d.py @@ -17,12 +17,12 @@ # -v /tmp/data:/home/aylward/tmp/data nvcr.io/nim/nvidia/vista3d:latest import argparse +import logging import os import shutil import tempfile import itk -import numpy as np import torch from huggingface_hub import snapshot_download @@ -65,7 +65,7 @@ class SegmentChestVista3D(SegmentChestBase): >>> result = segmenter.segment(ct_image) """ - def __init__(self): + def __init__(self, log_level: int | str = logging.INFO): """Initialize the VISTA-3D based chest segmentation. Sets up the VISTA-3D model including downloading weights from Hugging Face, @@ -75,11 +75,14 @@ def __init__(self): The initialization automatically downloads the VISTA-3D model weights from the MONAI/VISTA3D-HF repository on Hugging Face if not already present. + Args: + log_level: Logging level (default: logging.INFO) + Raises: RuntimeError: If CUDA is not available for GPU acceleration ConnectionError: If model weights cannot be downloaded """ - super().__init__() + super().__init__(log_level=log_level) self.target_spacing = 0.0 self.resale_intensity_range = False diff --git a/src/physiomotion4d/segment_chest_vista_3d_nim.py b/src/physiomotion4d/segment_chest_vista_3d_nim.py index 9a22b6d..47fea7a 100644 --- a/src/physiomotion4d/segment_chest_vista_3d_nim.py +++ b/src/physiomotion4d/segment_chest_vista_3d_nim.py @@ -9,12 +9,12 @@ import argparse import io +import logging import os import tempfile import zipfile import itk -import numpy as np import requests from physiomotion4d.segment_chest_vista_3d import SegmentChestVista3D @@ -26,9 +26,13 @@ class SegmentChestVista3DNIM(SegmentChestVista3D): segmentation method using VISTA3D. """ - def __init__(self): - """Initialize the vista3d class.""" - super().__init__() + def __init__(self, log_level: int | str = logging.INFO): + """Initialize the vista3d class. + + Args: + log_level: Logging level (default: logging.INFO) + """ + super().__init__(log_level=log_level) self.invoke_url = "http://localhost:8000/v1/vista3d/inference" self.wsl_docker_tmp_file = ( @@ -66,7 +70,7 @@ def segmentation_method(self, preprocessed_image: itk.image) -> itk.image: z.extractall(temp_dir) file_list = os.listdir(temp_dir) for filename in file_list: - print(filename) + self.log_debug("Found file: %s", filename) filepath = os.path.join(temp_dir, filename) if os.path.isfile(filepath) and filename.endswith(".nii.gz"): # SUCCESS: Return the results diff --git a/src/physiomotion4d/transform_tools.py b/src/physiomotion4d/transform_tools.py index 23469df..8830eed 100644 --- a/src/physiomotion4d/transform_tools.py +++ b/src/physiomotion4d/transform_tools.py @@ -11,16 +11,20 @@ are used to track anatomical motion over time. """ +import logging + +import cupy as cp import itk import numpy as np import pyvista as pv import SimpleITK as sitk from pxr import Gf, Usd, UsdGeom -from .image_tools import ImageTools +from physiomotion4d.image_tools import ImageTools +from physiomotion4d.physiomotion4d_base import PhysioMotion4DBase -class TransformTools: +class TransformTools(PhysioMotion4DBase): """ Utilities for transforming and manipulating ITK transforms. @@ -53,13 +57,13 @@ class TransformTools: >>> field = transform_tools.generate_field(transform, reference_image) """ - def __init__(self): + def __init__(self, log_level: int | str = logging.INFO): """Initialize the TransformTools class. - No parameters are required for initialization as all methods - operate on provided transforms and images. + Args: + log_level: Logging level (default: logging.INFO) """ - pass + super().__init__(class_name=self.__class__.__name__, log_level=log_level) def combine_displacement_field_transforms( self, @@ -309,13 +313,18 @@ def transform_pvcontour( f"Expected single transform or list with one transform, got list with {len(tfm)} transforms" ) - new_pnts = [tfm.TransformPoint(p) for p in pnts] + pnts = np.array(pnts) + new_pnts = [ + tfm.TransformPoint((float(p[0]), float(p[1]), float(p[2]))) for p in pnts + ] new_contour.points = new_pnts + new_pnts = cp.array(new_pnts) + pnts = cp.array(pnts) if with_deformation_magnitude: - new_contour.point_data["DeformationMagnitude"] = np.linalg.norm( + new_contour.point_data["DeformationMagnitude"] = cp.linalg.norm( new_pnts - pnts, axis=1 - ) + ).get() return new_contour @@ -557,15 +566,10 @@ def combine_transforms_with_masks( combined_field_arr = sum_fields_arr / denom - # Create displacement field by duplicating and updating - # This preserves the exact image type - duplicator = itk.ImageDuplicator.New(field1) - duplicator.Update() - combined_field = duplicator.GetOutput() - # Copy array data to ITK image - combined_field_view = itk.array_view_from_image(combined_field) - combined_field_view[:] = combined_field_arr + combined_field = ImageTools().convert_array_to_image_of_vectors( + combined_field_arr, field1, itk.F + ) # Correct spatial folding iteratively for _ in range(max_iter): @@ -602,6 +606,11 @@ def compute_jacobian_determinant_from_field(self, field: itk.Image) -> itk.Image ... deformation_field ... ) """ + if "VF" not in str(type(field)): + field_arr = itk.array_from_image(field) + field = ImageTools().convert_array_to_image_of_vectors( + field_arr, field, itk.F + ) jac_filter = itk.DisplacementFieldJacobianDeterminantFilter.New(field) jac_filter.SetUseImageSpacing(True) jac_filter.Update() @@ -865,10 +874,10 @@ def convert_itk_transform_to_usd_visualization( # Save the stage stage.Save() - print(f"Created USD visualization: {output_filename}") - print(f" Type: {visualization_type}") - print(f" Points: {np.prod(subsampled_size)}") - print(f" Subsample factor: {subsample_factor}") + self.log_info("Created USD visualization: %s", output_filename) + self.log_info(" Type: %s", visualization_type) + self.log_info(" Points: %d", np.prod(subsampled_size)) + self.log_info(" Subsample factor: %d", subsample_factor) return output_filename @@ -921,7 +930,7 @@ def _create_arrow_visualization( ) arrow_count += 1 - print(f" Created {arrow_count} arrows") + self.log_info(" Created %d arrows", arrow_count) def _create_arrow_prim( self, stage, prim_path, position, displacement, magnitude, arrow_scale @@ -1050,7 +1059,7 @@ def _create_flowline_visualization( self._create_curve_prim(stage, curve_path, streamline_points) flowline_count += 1 - print(f" Created {flowline_count} flowlines") + self.log_info(" Created %d flowlines", flowline_count) def _trace_streamline( self, @@ -1138,4 +1147,3 @@ def _create_curve_prim(self, stage, prim_path, points): curve.GetDisplayColorAttr().Set([Gf.Vec3f(0.0, 1.0, 1.0)]) return curve - return curve diff --git a/src/physiomotion4d/usd_anatomy_tools.py b/src/physiomotion4d/usd_anatomy_tools.py index de5c891..923907a 100644 --- a/src/physiomotion4d/usd_anatomy_tools.py +++ b/src/physiomotion4d/usd_anatomy_tools.py @@ -4,18 +4,28 @@ """ import argparse +import logging import os from pxr import Sdf, Usd, UsdGeom, UsdShade +from physiomotion4d.physiomotion4d_base import PhysioMotion4DBase -class USDAnatomyTools: + +class USDAnatomyTools(PhysioMotion4DBase): """ This class is used to enhance the appearance of anatomy meshes in a USD file. """ - def __init__(self, stage): + def __init__(self, stage, log_level: int | str = logging.INFO): + """Initialize USDAnatomyTools. + + Args: + stage: USD stage to work with + log_level: Logging level (default: logging.INFO) + """ + super().__init__(class_name=self.__class__.__name__, log_level=log_level) self.stage = stage self.heart_params = { @@ -306,9 +316,9 @@ def enhance_meshes(self, segmentator): if transform_prim and not mesh_prim: current_prim_path = str(prim.GetPrimPath()) root_prim_path = "/".join(current_prim_path.split("/")[:-1]) - print(f"Root prim path: {root_prim_path}") + self.log_debug("Root prim path: %s", root_prim_path) anatomy_prim_path = "/".join([root_prim_path, "Anatomy"]) - print(f" Anatomy prim path: {anatomy_prim_path}") + self.log_debug(" Anatomy prim path: %s", anatomy_prim_path) if not self.stage.GetPrimAtPath(anatomy_prim_path): UsdGeom.Xform.Define(self.stage, anatomy_prim_path) anatomy_prim_path = "/".join( @@ -326,8 +336,8 @@ def enhance_meshes(self, segmentator): ] ) anatomy_prim_path = Sdf.Path(anatomy_prim_path) - print(f" Current prim path: {current_prim_path}") - print(f" Anatomy prim path: {anatomy_prim_path}") + self.log_debug(" Current prim path: %s", current_prim_path) + self.log_debug(" Anatomy prim path: %s", anatomy_prim_path) editor.MovePrimAtPath( current_prim_path, anatomy_prim_path, diff --git a/src/physiomotion4d/usd_tools.py b/src/physiomotion4d/usd_tools.py index 1bcc72a..6065afd 100644 --- a/src/physiomotion4d/usd_tools.py +++ b/src/physiomotion4d/usd_tools.py @@ -10,13 +10,15 @@ anatomical structures need to be organized and visualized together. """ -import os +import logging import numpy as np -from pxr import Gf, Sdf, Usd, UsdGeom, UsdShade, UsdUtils +from pxr import Gf, Usd, UsdGeom, UsdShade +from physiomotion4d.physiomotion4d_base import PhysioMotion4DBase -class USDTools: + +class USDTools(PhysioMotion4DBase): """ Utilities for manipulating Universal Scene Description (USD) files. @@ -54,13 +56,13 @@ class USDTools: ... ) """ - def __init__(self): + def __init__(self, log_level: int | str = logging.INFO): """Initialize the USDTools class. - No parameters are required for initialization as all methods - operate on provided USD files and stages. + Args: + log_level: Logging level (default: logging.INFO) """ - pass + super().__init__(class_name=self.__class__.__name__, log_level=log_level) def get_subtree_bounding_box( self, prim: UsdGeom.Xform @@ -152,7 +154,7 @@ def save_usd_file_arrangement(self, new_stage_name: str, usd_file_names: list[st n_objects = len(usd_file_names) n_rows = int(np.floor(np.sqrt(n_objects))) n_cols = int(np.ceil(n_objects / n_rows)) - print(f"n_rows: {n_rows}, n_cols: {n_cols}") + self.log_info("Grid layout: %d rows x %d cols", n_rows, n_cols) x_spacing = 400.0 y_spacing = 400.0 x_offset = -x_spacing * (n_cols - 1) / 2 @@ -165,17 +167,17 @@ def save_usd_file_arrangement(self, new_stage_name: str, usd_file_names: list[st source_root = source_stage.GetPrimAtPath("/World") children = source_root.GetChildren() for child in children: - print(f"Copying {usd_file_name}:{child.GetPrimPath()}") + self.log_info("Copying %s:%s", usd_file_name, child.GetPrimPath()) new_stage.DefinePrim(child.GetPrimPath()).GetReferences().AddReference( assetPath=usd_file_name, primPath=child.GetPrimPath(), ) # Apply translation to t for grandchild in child.GetAllChildren(): - print(f" Bounding box of {grandchild.GetPrimPath()}") + self.log_debug(" Bounding box of %s", grandchild.GetPrimPath()) bbox_min, bbox_max = self.get_subtree_bounding_box(grandchild) bbox_center = (bbox_min + bbox_max) / 2 - print(f" Bounding box center: {bbox_center}") + self.log_debug(" Bounding box center: %s", bbox_center) xform = UsdGeom.Xformable(grandchild) if not xform.GetOrderedXformOps(): @@ -192,7 +194,9 @@ def save_usd_file_arrangement(self, new_stage_name: str, usd_file_names: list[st grid_y - bbox_center[1], -bbox_center[2], ) - print(f" Translating {grandchild.GetPrimPath()} to {translate}") + self.log_debug( + " Translating %s to %s", grandchild.GetPrimPath(), translate + ) xform_op.Set(translate, Usd.TimeCode.Default()) for prim in source_stage.Traverse(): @@ -206,8 +210,10 @@ def save_usd_file_arrangement(self, new_stage_name: str, usd_file_names: list[st and len(mesh_material) > 0 else str(mesh_material.GetPath()) ) - print( - f" Mesh {prim.GetPrimPath()} has material {material_path}" + self.log_debug( + " Mesh %s has material %s", + prim.GetPrimPath(), + material_path, ) new_prim = new_stage.GetPrimAtPath(prim.GetPrimPath()) material = UsdShade.Material.Get(new_stage, material_path) @@ -215,11 +221,12 @@ def save_usd_file_arrangement(self, new_stage_name: str, usd_file_names: list[st binding_api = UsdShade.MaterialBindingAPI.Apply(new_prim) binding_api.Bind(material) else: - print( - f" Cannot bind. No new prim found for {prim.GetPrimPath()}" + self.log_warning( + " Cannot bind. No new prim found for %s", + prim.GetPrimPath(), ) - print("Exporting stage...") + self.log_info("Exporting stage...") new_stage.Export(new_stage_name) def merge_usd_files(self, output_filename: str, input_filenames_list: list[str]): @@ -290,7 +297,7 @@ def merge_usd_files(self, output_filename: str, input_filenames_list: list[str]) # Copy all root prims from input for prim in input_stage.GetPseudoRoot().GetAllChildren(): new_path = "/" + prim.GetName() - print(f"Copying {prim.GetPrimPath()} to {new_path}") + self.log_info("Copying %s to %s", prim.GetPrimPath(), new_path) # Recursively copy prim hierarchy with all attributes and time samples def _copy_prim(src_prim, target_path): @@ -356,23 +363,29 @@ def _copy_prim(src_prim, target_path): and len(mesh_material) > 0 else str(mesh_material.GetPath()) ) - print( - f" Binding material {material_path} to {prim.GetPrimPath()}" + self.log_debug( + " Binding material %s to %s", + material_path, + prim.GetPrimPath(), ) # Get corresponding mesh prim and material in target stage new_prim = stage.GetPrimAtPath(prim.GetPrimPath()) material = UsdShade.Material.Get(stage, material_path) if new_prim is not None and new_prim.IsValid(): if material and material.GetPrim().IsValid(): - binding_api = UsdShade.MaterialBindingAPI.Apply(new_prim) + binding_api = UsdShade.MaterialBindingAPI.Apply( + new_prim + ) binding_api.Bind(material) else: - print( - f" Warning: Material not found at {material_path} in target stage" + self.log_warning( + " Material not found at %s in target stage", + material_path, ) else: - print( - f" Warning: Cannot bind material. No mesh prim found at {prim.GetPrimPath()}" + self.log_warning( + " Cannot bind material. No mesh prim found at %s", + prim.GetPrimPath(), ) # Set stage time range metadata for animation playback @@ -383,8 +396,14 @@ def _copy_prim(src_prim, target_path): stage.SetTimeCodesPerSecond(time_codes_per_second) if frames_per_second is not None: stage.SetFramesPerSecond(frames_per_second) - print(f"\nSet stage time range: {global_start_time} to {global_end_time}") - print(f"Time codes per second: {time_codes_per_second}, Frames per second: {frames_per_second}") + self.log_info( + "Set stage time range: %.1f to %.1f", global_start_time, global_end_time + ) + self.log_info( + "Time codes per second: %s, Frames per second: %s", + time_codes_per_second, + frames_per_second, + ) # Save with USDA format # stage.GetRootLayer().Export(output_path, args=['--usdFormat', 'usda']) @@ -445,8 +464,9 @@ def merge_usd_files_flattened( frames_per_second = None # Add references to all input files - for input_path in input_filenames_list: - print(f"Referencing {input_path}") + num_files = len(input_filenames_list) + for idx, input_path in enumerate(input_filenames_list): + self.log_progress(idx + 1, num_files, prefix="Referencing files") input_stage = Usd.Stage.Open(input_path, Usd.Stage.LoadAll) # Track time range from this input file @@ -463,12 +483,13 @@ def merge_usd_files_flattened( # Reference each top-level prim from the input file for prim in input_stage.GetPseudoRoot().GetAllChildren(): new_path = "/" + prim.GetName() - print(f" Adding reference: {prim.GetPrimPath()} -> {new_path}") + self.log_debug( + " Adding reference: %s -> %s", prim.GetPrimPath(), new_path + ) # Create prim and add reference to source file temp_stage.DefinePrim(new_path).GetReferences().AddReference( - assetPath=input_path, - primPath=prim.GetPrimPath() + assetPath=input_path, primPath=prim.GetPrimPath() ) # Set time range metadata on temporary stage before flattening @@ -479,12 +500,18 @@ def merge_usd_files_flattened( temp_stage.SetTimeCodesPerSecond(time_codes_per_second) if frames_per_second is not None: temp_stage.SetFramesPerSecond(frames_per_second) - print(f"Time range: {global_start_time} to {global_end_time}") - print(f"Time codes per second: {time_codes_per_second}, Frames per second: {frames_per_second}") + self.log_info( + "Time range: %.1f to %.1f", global_start_time, global_end_time + ) + self.log_info( + "Time codes per second: %s, Frames per second: %s", + time_codes_per_second, + frames_per_second, + ) # Flatten the composed stage into a single layer # This resolves all references and bakes everything into one file - print("Flattening composed stage...") + self.log_info("Flattening composed stage...") flattened_layer = temp_stage.Flatten() # Create output stage from flattened layer @@ -497,11 +524,13 @@ def merge_usd_files_flattened( output_stage.SetEndTimeCode(global_end_time) if time_codes_per_second is not None: output_stage.SetTimeCodesPerSecond(time_codes_per_second) - print(f"Set output TimeCodesPerSecond: {time_codes_per_second}") + self.log_info( + "Set output TimeCodesPerSecond: %s", time_codes_per_second + ) if frames_per_second is not None: output_stage.SetFramesPerSecond(frames_per_second) - print(f"Set output FramesPerSecond: {frames_per_second}") + self.log_info("Set output FramesPerSecond: %s", frames_per_second) # Export the flattened layer with corrected metadata - print(f"Exporting to {output_filename}") + self.log_info("Exporting to %s", output_filename) output_stage.Export(output_filename) diff --git a/tests/GITHUB_WORKFLOWS.md b/tests/GITHUB_WORKFLOWS.md index a7f9f03..9964857 100644 --- a/tests/GITHUB_WORKFLOWS.md +++ b/tests/GITHUB_WORKFLOWS.md @@ -2,29 +2,52 @@ This document describes the GitHub Actions workflows configured for running tests on pull requests. -## Workflow Configuration +## Overview -File: `.github/workflows/tests.yml` +The project has two main workflow files: + +1. **`.github/workflows/ci.yml`** (Main CI Workflow) + - Runs on every PR and push to main/master/develop + - Unit tests (cross-platform: Ubuntu + Windows) + - Integration tests with external data + - GPU tests (self-hosted runners) + - Code quality checks + +2. **`.github/workflows/test-slow.yml`** (Scheduled Slow Tests) + - Runs nightly at 2 AM UTC + - Slow/GPU-intensive tests (registration, segmentation) + - Self-hosted GPU runners only + +3. **`.github/workflows/docs.yml`** (Documentation) + - Builds Sphinx documentation + - Deploys to GitHub Pages + +## Main CI Workflow + +File: `.github/workflows/ci.yml` + +This comprehensive CI workflow combines unit tests, integration tests, GPU tests, and code quality checks. ## Workflow Jobs -### 1. Unit Tests (`test` job) -**Trigger**: All pull requests and pushes to main -**Platforms**: Ubuntu, Windows, macOS +### 1. Unit Tests (`unit-tests` job) +**Trigger**: All pull requests and pushes to main, master, and develop branches +**Platforms**: Ubuntu, Windows **Python versions**: 3.10, 3.11, 3.12 **What runs**: - Tests that don't require external data - Tests that aren't marked as slow - Coverage reporting to Codecov +- Cross-platform validation **Command**: ```bash -pytest tests/ -v -m "not requires_data and not slow" --cov=src/physiomotion4d --cov-report=xml +pytest tests/ -v -m "not slow and not requires_data" --cov=physiomotion4d --cov-report=xml ``` **Tests included**: -- Basic unit tests +- Basic unit tests (e.g., ImageTools conversions) - Fast integration tests with mocked data **Tests excluded**: @@ -33,10 +56,11 @@ pytest tests/ -v -m "not requires_data and not slow" --cov=src/physiomotion4d -- --- -### 2. Integration Tests with Data (`test-with-data` job) +### 2. Integration Tests with Data (`integration-tests` job) **Trigger**: Pull requests only **Platform**: Ubuntu only **Python version**: 3.10 +**Dependencies**: Requires `unit-tests` to pass first **What runs**: Tests that require downloading external data, executed in sequence with caching. @@ -97,13 +121,76 @@ All data-dependent test steps use `continue-on-error: true` to prevent CI failur Coverage is still uploaded even if tests fail. +### 3. GPU Tests (`gpu-tests` job) +**Trigger**: All pull requests and pushes to main, master, and develop branches +**Platform**: Self-hosted Linux runners with GPU +**Python versions**: 3.10, 3.11 +**Dependencies**: Requires `unit-tests` to pass first + +**What runs**: +- All non-slow tests with GPU support +- PyTorch with CUDA 12.6 +- GPU-accelerated deep learning tests +- Requires self-hosted runners with NVIDIA GPUs + +**Command**: +```bash +pytest tests/ -v -m "not slow" --cov=physiomotion4d --cov-report=xml +``` + +**Environment**: +- `CUDA_VISIBLE_DEVICES: 0` +- Self-hosted runners with NVIDIA GPU +- PyTorch with CUDA support + +**Tests included**: +- GPU-accelerated model inference (when available) +- Tests that benefit from GPU but don't require hours of compute +- Fast integration tests on GPU hardware + +**Note**: This job continues even on error (`continue-on-error: true`) since self-hosted GPU runners may not always be available. + +### 4. Code Quality Checks (`code-quality` job) +**Trigger**: All pull requests and pushes to main, master, and develop branches +**Platform**: Ubuntu only +**Python version**: 3.10 + +**What runs**: +- Code formatting checks (Black) +- Import sorting checks (isort) +- Linting (Ruff, Flake8) +- Style enforcement + +**Tools**: +1. **Black**: Code formatting + ```bash + black --check src/ tests/ + ``` + +2. **isort**: Import sorting + ```bash + isort --check-only src/ tests/ + ``` + +3. **Ruff**: Fast Python linter + ```bash + ruff check src/ tests/ + ``` + +4. **Flake8**: Additional style checks + ```bash + flake8 src/ tests/ + ``` + +**Note**: All checks use `continue-on-error: true` to avoid blocking PRs on style issues while still providing feedback. + --- ## Tests Excluded from CI -### Slow and GPU-Dependent Tests +### Slow Tests (Marked with `@pytest.mark.slow`) -These tests are **NOT** run in CI because they are slow or require CUDA-enabled GPUs: +These tests are **NOT** run in CI (even on GPU runners) because they require extended compute time: 1. **ANTs Registration Tests** (`test_register_images_ants.py`) - Requires: ANTsPy library @@ -136,11 +223,11 @@ These tests are **NOT** run in CI because they are slow or require CUDA-enabled - Why excluded: Requires GPU, model inference **Why excluded**: -- GitHub Actions runners don't have CUDA-enabled GPUs -- Model inference requires significant GPU memory and time +- These tests take 5-15 minutes each, even with GPU acceleration - Registration algorithms are computationally intensive -- Tests would timeout or fail without GPU/sufficient compute -- CI runtime would exceed reasonable limits +- Deep learning model inference requires significant GPU memory and time +- CI runtime would exceed reasonable limits (even on self-hosted GPU runners) +- Better suited for nightly/scheduled testing or local development **Local testing**: ```bash @@ -154,6 +241,14 @@ pytest tests/test_register_images_ants.py tests/test_register_images_icon.py -v pytest tests/test_segment_chest_total_segmentator.py tests/test_segment_chest_vista_3d.py -v -s ``` +**Scheduled slow tests**: +These tests run automatically on a schedule via `.github/workflows/test-slow.yml`: +- **Schedule**: Nightly at 2 AM UTC +- **Trigger**: Also available via manual workflow dispatch +- **Platform**: Self-hosted Linux GPU runners +- **Command**: `pytest tests/ -v -m "slow"` +- **Purpose**: Regular validation of computationally intensive tests without blocking PRs + --- ## Coverage Reporting @@ -168,6 +263,16 @@ pytest tests/test_segment_chest_total_segmentator.py tests/test_segment_chest_vi - **When**: After all data-dependent tests complete - **Upload**: Even if tests fail (`fail_ci_if_error: false`) +### GPU Tests Coverage +- **Flag**: `gpu-tests` +- **When**: Python 3.10 on self-hosted GPU runners +- **Upload**: After GPU tests complete + +### Slow Tests Coverage +- **Flag**: `slow-tests-gpu` +- **When**: Nightly scheduled runs on self-hosted GPU runners +- **Upload**: After slow tests complete + ### Viewing Coverage Coverage reports are uploaded to Codecov and can be viewed at: - Repository Codecov dashboard @@ -247,7 +352,7 @@ Currently not configured, but can be added for: - Consider parallel execution ### Workflow Files -- Main workflow: `.github/workflows/tests.yml` +- Main workflow: `.github/workflows/ci.yml` - Test configuration: `tests/conftest.py` - Test documentation: `tests/TEST_ORGANIZATION.md` diff --git a/tests/README.md b/tests/README.md index 629e950..0eb03fe 100644 --- a/tests/README.md +++ b/tests/README.md @@ -140,7 +140,7 @@ Tests automatically run on pull requests via GitHub Actions. The CI workflow: **Problem: Test timeout** - Global timeout: 900 seconds (15 minutes) - Registration tests need GPU for reasonable speed -- Override with: `pytest --timeout=1800` +- Override with: `@pytest.mark.timeout(1800)` decorator or use `-o timeout=1800` **Problem: Fixture naming errors** - ✅ **Fixed!** Use correct fixture names from `conftest.py` diff --git a/tests/TESTING_GUIDE.md b/tests/TESTING_GUIDE.md index d37a153..f20b0ef 100644 --- a/tests/TESTING_GUIDE.md +++ b/tests/TESTING_GUIDE.md @@ -176,7 +176,7 @@ tests/ **Solutions**: - Ensure GPU is available (much faster than CPU) - Run slow tests individually: `pytest tests/test_register_images_ants.py -v` -- Override timeout: `pytest tests/test_name.py --timeout=1800` +- Override timeout: Use `@pytest.mark.timeout(1800)` decorator in test or `-o timeout=1800` on command line ### ITK Size Indexing Errors @@ -351,7 +351,7 @@ rm -rf tests/data/Slicer-Heart-CT/ ## GitHub Actions Integration -Tests automatically run on pull requests via `.github/workflows/tests.yml`. The workflow: +Tests automatically run on pull requests via `.github/workflows/ci.yml`. The workflow: - ✅ Runs fast tests (USD, conversion, basic validation) - ❌ Skips slow tests (registration, segmentation) diff --git a/tests/conftest.py b/tests/conftest.py index 3168e34..d0593c8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -112,18 +112,29 @@ def converted_3d_images(download_truncal_valve_data, test_directories): @pytest.fixture(scope="session") def test_images(converted_3d_images): - """Load two time points from the converted 3D data for testing.""" + """Load time points from the converted 3D data for testing.""" data_dir = converted_3d_images - # Load two time points (slice_000 and slice_001) + # Load time points slice_000 = data_dir / "slice_000.mha" slice_001 = data_dir / "slice_001.mha" + slice_002 = data_dir / "slice_002.mha" + slice_003 = data_dir / "slice_003.mha" + slice_004 = data_dir / "slice_004.mha" + slice_005 = data_dir / "slice_005.mha" # Ensure the files exist - if not slice_000.exists() or not slice_001.exists(): + if not slice_000.exists() or not slice_001.exists() or not slice_002.exists(): pytest.skip("Converted 3D slice files not found. Run conversion test first.") - images = [itk.imread(str(slice_000)), itk.imread(str(slice_001))] + images = [ + itk.imread(str(slice_000)), + itk.imread(str(slice_001)), + itk.imread(str(slice_002)), + itk.imread(str(slice_003)), + itk.imread(str(slice_004)), + itk.imread(str(slice_005)), + ] for i, img in enumerate(images): resampler = ttk.ResampleImage.New(Input=img) @@ -131,7 +142,7 @@ def test_images(converted_3d_images): resampler.Update() images[i] = resampler.GetOutput() - print(f"\nLoaded 2 time points for testing") + print(f"\nLoaded {len(images)} time points for testing") return images @@ -247,21 +258,24 @@ def ants_registration_results(registrar_ants, test_images, test_directories): reg_output_dir = output_dir / "registration_ants" reg_output_dir.mkdir(exist_ok=True) - phi_FM_path = reg_output_dir / "ants_phi_FM_no_mask.hdf" - phi_MF_path = reg_output_dir / "ants_phi_MF_no_mask.hdf" + inverse_transform_path = reg_output_dir / "ants_inverse_transform_no_mask.hdf" + forward_transform_path = reg_output_dir / "ants_forward_transform_no_mask.hdf" - if phi_FM_path.exists() and phi_MF_path.exists(): + if inverse_transform_path.exists() and forward_transform_path.exists(): print("\nLoading existing ANTs registration results...") try: - phi_FM = itk.transformread(str(phi_FM_path)) - phi_MF = itk.transformread(str(phi_MF_path)) - return {"phi_FM": phi_FM, "phi_MF": phi_MF} + inverse_transform = itk.transformread(str(inverse_transform_path)) + forward_transform = itk.transformread(str(forward_transform_path)) + return { + "inverse_transform": inverse_transform, + "forward_transform": forward_transform, + } except (RuntimeError, Exception) as e: print(f"Error loading transforms: {e}") print("Regenerating registration results...") # Delete corrupt files - phi_FM_path.unlink(missing_ok=True) - phi_MF_path.unlink(missing_ok=True) + inverse_transform_path.unlink(missing_ok=True) + forward_transform_path.unlink(missing_ok=True) # Perform registration if files don't exist or loading failed print("\nPerforming ANTs registration...") @@ -271,12 +285,15 @@ def ants_registration_results(registrar_ants, test_images, test_directories): registrar_ants.set_fixed_image(fixed_image) result = registrar_ants.register(moving_image=moving_image) - phi_FM = result["phi_FM"] - phi_MF = result["phi_MF"] + inverse_transform = result["inverse_transform"] + forward_transform = result["forward_transform"] - itk.transformwrite(phi_FM, str(phi_FM_path), compression=True) - itk.transformwrite(phi_MF, str(phi_MF_path), compression=True) - return {"phi_FM": phi_FM, "phi_MF": phi_MF} + itk.transformwrite(inverse_transform, str(inverse_transform_path), compression=True) + itk.transformwrite(forward_transform, str(forward_transform_path), compression=True) + return { + "inverse_transform": inverse_transform, + "forward_transform": forward_transform, + } # ============================================================================ diff --git a/tests/test_contour_tools.py b/tests/test_contour_tools.py index 5cd3320..6db0f25 100644 --- a/tests/test_contour_tools.py +++ b/tests/test_contour_tools.py @@ -6,15 +6,11 @@ segmentation results to test contour extraction and manipulation. """ -from pathlib import Path - import itk import numpy as np import pytest import pyvista as pv -from physiomotion4d.contour_tools import ContourTools - @pytest.mark.requires_data @pytest.mark.slow @@ -132,9 +128,7 @@ def test_create_mask_from_mesh( # Create mask from the extracted mesh print("\nCreating mask from extracted heart contours...") reference_image = test_images[0] - recreated_mask = contour_tools.create_mask_from_mesh( - contours, reference_image, resample_to_reference=True - ) + recreated_mask = contour_tools.create_mask_from_mesh(contours, reference_image) # Verify recreated mask assert recreated_mask is not None, "Mask not created from mesh" diff --git a/tests/test_convert_vtk_4d_to_usd_polymesh.py b/tests/test_convert_vtk_4d_to_usd_polymesh.py index cf22c5c..4426fcd 100644 --- a/tests/test_convert_vtk_4d_to_usd_polymesh.py +++ b/tests/test_convert_vtk_4d_to_usd_polymesh.py @@ -9,8 +9,8 @@ from pathlib import Path import itk -import pyvista as pv import pytest +import pyvista as pv from pxr import Usd, UsdGeom from physiomotion4d.convert_vtk_4d_to_usd_polymesh import ConvertVTK4DToUSDPolyMesh @@ -26,62 +26,58 @@ def contour_meshes(self, contour_tools, segmentation_results, test_directories): """Extract or load contour meshes for USD conversion testing.""" output_dir = test_directories["output"] contour_output_dir = output_dir / "contour_tools" - + # Check if contour files exist heart_contour_000 = contour_output_dir / "heart_contours_slice000.vtp" heart_contour_001 = contour_output_dir / "heart_contours_slice001.vtp" - + if not heart_contour_000.exists() or not heart_contour_001.exists(): # Extract contours if they don't exist print("\nContour files not found, extracting them...") contour_output_dir.mkdir(parents=True, exist_ok=True) - + meshes = [] for i, result in enumerate(segmentation_results): heart_mask = result["heart"] contours = contour_tools.extract_contours(heart_mask) meshes.append(contours) - + # Save contours output_file = contour_output_dir / f"heart_contours_slice{i:03d}.vtp" contours.save(str(output_file)) - + return meshes else: # Load existing contours print("\nLoading existing contour files...") meshes = [ pv.read(str(contour_output_dir / "heart_contours_slice000.vtp")), - pv.read(str(contour_output_dir / "heart_contours_slice001.vtp")) + pv.read(str(contour_output_dir / "heart_contours_slice001.vtp")), ] return meshes def test_converter_initialization(self): """Test that ConvertVTK4DToUSDPolyMesh initializes correctly.""" converter = ConvertVTK4DToUSDPolyMesh( - data_basename="TestModel", - input_polydata=[], - mask_ids=None + data_basename="TestModel", input_polydata=[], mask_ids=None ) - + assert converter is not None, "Converter not initialized" assert converter.data_basename == "TestModel", "Data basename not set correctly" - + print("\nConverter initialized successfully") def test_supports_mesh_type(self, contour_meshes): """Test that converter correctly identifies supported mesh types.""" mesh = contour_meshes[0] - + converter = ConvertVTK4DToUSDPolyMesh( - data_basename="TestModel", - input_polydata=[mesh], - mask_ids=None + data_basename="TestModel", input_polydata=[mesh], mask_ids=None ) - + # PolyData should be supported assert converter.supports_mesh_type(mesh), "PolyData should be supported" - + print("\nMesh type support check passed") def test_convert_single_time_point(self, contour_meshes, test_directories): @@ -89,30 +85,28 @@ def test_convert_single_time_point(self, contour_meshes, test_directories): output_dir = test_directories["output"] usd_output_dir = output_dir / "usd_polymesh" usd_output_dir.mkdir(exist_ok=True) - + # Use first time point only mesh = contour_meshes[0] - + print("\nConverting single time point to USD...") print(f" Mesh: {mesh.n_points} points, {mesh.n_cells} cells") - + converter = ConvertVTK4DToUSDPolyMesh( - data_basename="HeartSingle", - input_polydata=[mesh], - mask_ids=None + data_basename="HeartSingle", input_polydata=[mesh], mask_ids=None ) - + output_file = usd_output_dir / "heart_single_time.usd" stage = converter.convert(str(output_file)) - + # Verify USD stage was created assert stage is not None, "USD stage not created" assert output_file.exists(), f"USD file not created: {output_file}" - + # Verify stage contents (actual path includes /World prefix) prim = stage.GetPrimAtPath("/World/HeartSingle") assert prim.IsValid(), f"Root prim not found at /World/HeartSingle" - + print(f"Single time point converted to USD") print(f" Output: {output_file}") print(f" File size: {output_file.stat().st_size / 1024:.2f} KB") @@ -122,70 +116,69 @@ def test_convert_multiple_time_points(self, contour_meshes, test_directories): output_dir = test_directories["output"] usd_output_dir = output_dir / "usd_polymesh" usd_output_dir.mkdir(exist_ok=True) - + print("\nConverting multiple time points to USD...") print(f" Time points: {len(contour_meshes)}") - + converter = ConvertVTK4DToUSDPolyMesh( - data_basename="HeartMulti", - input_polydata=contour_meshes, - mask_ids=None + data_basename="HeartMulti", input_polydata=contour_meshes, mask_ids=None ) - + output_file = usd_output_dir / "heart_multi_time.usd" stage = converter.convert(str(output_file)) - + # Verify USD stage assert stage is not None, "USD stage not created" assert output_file.exists(), f"USD file not created: {output_file}" - + # Verify time samples (actual path includes /World prefix) prim = stage.GetPrimAtPath("/World/HeartMulti") assert prim.IsValid(), f"Root prim not found at /World/HeartMulti" - + # Check that mesh exists (checking the Transform group) transform_path = "/World/HeartMulti/Transform_heart_multi_time" transform_prim = stage.GetPrimAtPath(transform_path) assert transform_prim.IsValid(), f"Transform not found at {transform_path}" - + print(f"Multiple time points converted to USD") print(f" Output: {output_file}") print(f" File size: {output_file.stat().st_size / 1024:.2f} KB") - def test_convert_with_deformation(self, contour_tools, segmentation_results, test_directories): + def test_convert_with_deformation( + self, contour_tools, segmentation_results, test_directories + ): """Test converting meshes with deformation magnitude.""" output_dir = test_directories["output"] usd_output_dir = output_dir / "usd_polymesh" usd_output_dir.mkdir(exist_ok=True) - + # Extract contours heart_mask = segmentation_results[0]["heart"] contours = contour_tools.extract_contours(heart_mask) - + # Add deformation magnitude (simulate with random values) import numpy as np + deformation = np.random.uniform(0, 5, contours.n_points) contours["DeformationMagnitude"] = deformation - + print("\nConverting mesh with deformation magnitude...") - + converter = ConvertVTK4DToUSDPolyMesh( - data_basename="HeartDeformation", - input_polydata=[contours], - mask_ids=None + data_basename="HeartDeformation", input_polydata=[contours], mask_ids=None ) - + output_file = usd_output_dir / "heart_with_deformation.usd" stage = converter.convert(str(output_file)) - + assert stage is not None, "USD stage not created" assert output_file.exists(), "USD file not created" - + # Check that transform was created (actual path includes /World prefix) transform_path = "/World/HeartDeformation/Transform_heart_with_deformation" transform_prim = stage.GetPrimAtPath(transform_path) assert transform_prim.IsValid(), f"Transform not found at {transform_path}" - + print(f"Mesh with deformation converted to USD") print(f" Output: {output_file}") @@ -194,35 +187,34 @@ def test_convert_with_colormap(self, contour_meshes, test_directories): output_dir = test_directories["output"] usd_output_dir = output_dir / "usd_polymesh" usd_output_dir.mkdir(exist_ok=True) - + # Add scalar field to mesh for colormapping import numpy as np + mesh = contour_meshes[0] scalar_values = np.random.uniform(0, 100, mesh.n_points) mesh["pressure"] = scalar_values - + print("\nConverting mesh with colormap...") - + converter = ConvertVTK4DToUSDPolyMesh( - data_basename="HeartColormap", - input_polydata=[mesh], - mask_ids=None + data_basename="HeartColormap", input_polydata=[mesh], mask_ids=None ) - + # Set colormap converter.set_colormap(color_by_array="pressure", colormap="plasma") - + output_file = usd_output_dir / "heart_with_colormap.usd" stage = converter.convert(str(output_file)) - + assert stage is not None, "USD stage not created" assert output_file.exists(), "USD file not created" - + # Verify transform was created (actual path includes /World prefix) transform_path = "/World/HeartColormap/Transform_heart_with_colormap" transform_prim = stage.GetPrimAtPath(transform_path) assert transform_prim.IsValid(), f"Transform not found at {transform_path}" - + print(f"Mesh with colormap converted to USD") print(f" Colormap: plasma") print(f" Output: {output_file}") @@ -232,36 +224,44 @@ def test_convert_unstructured_grid_to_surface(self, test_directories): output_dir = test_directories["output"] usd_output_dir = output_dir / "usd_polymesh" usd_output_dir.mkdir(exist_ok=True) - + # Create a simple UnstructuredGrid (cube) import numpy as np import pyvista as pv - - points = np.array([ - [0, 0, 0], [1, 0, 0], [1, 1, 0], [0, 1, 0], - [0, 0, 1], [1, 0, 1], [1, 1, 1], [0, 1, 1] - ]) + + points = np.array( + [ + [0, 0, 0], + [1, 0, 0], + [1, 1, 0], + [0, 1, 0], + [0, 0, 1], + [1, 0, 1], + [1, 1, 1], + [0, 1, 1], + ] + ).astype(np.float32) cells = [8, 0, 1, 2, 3, 4, 5, 6, 7] cell_types = [12] # VTK_HEXAHEDRON - + ugrid = pv.UnstructuredGrid(cells, cell_types, points) - + print("\nConverting UnstructuredGrid to USD...") print(f" Grid: {ugrid.n_points} points, {ugrid.n_cells} cells") - + converter = ConvertVTK4DToUSDPolyMesh( data_basename="CubeSurface", input_polydata=[ugrid], mask_ids=None, - convert_to_surface=True + convert_to_surface=True, ) - + output_file = usd_output_dir / "cube_surface.usd" stage = converter.convert(str(output_file)) - + assert stage is not None, "USD stage not created" assert output_file.exists(), "USD file not created" - + print(f"UnstructuredGrid converted to surface USD") print(f" Output: {output_file}") @@ -270,27 +270,31 @@ def test_usd_file_structure(self, contour_meshes, test_directories): output_dir = test_directories["output"] usd_output_dir = output_dir / "usd_polymesh" usd_output_dir.mkdir(exist_ok=True) - + converter = ConvertVTK4DToUSDPolyMesh( data_basename="HeartStructure", input_polydata=[contour_meshes[0]], - mask_ids=None + mask_ids=None, ) - + output_file = usd_output_dir / "heart_structure_test.usd" stage = converter.convert(str(output_file)) - + print("\nVerifying USD file structure...") - + # Check root prim (actual path includes /World prefix) root_prim = stage.GetPrimAtPath("/World/HeartStructure") assert root_prim.IsValid(), f"Root prim not found at /World/HeartStructure" assert UsdGeom.Xform(root_prim), "Root should be an Xform" - + # Check transform/mesh structure - transform_prim = stage.GetPrimAtPath("/World/HeartStructure/Transform_heart_structure_test") - assert transform_prim.IsValid(), f"Transform prim not found at /World/HeartStructure/Transform_heart_structure_test" - + transform_prim = stage.GetPrimAtPath( + "/World/HeartStructure/Transform_heart_structure_test" + ) + assert ( + transform_prim.IsValid() + ), f"Transform prim not found at /World/HeartStructure/Transform_heart_structure_test" + print(f"USD file structure verified") print(f" Root: {root_prim.GetPath()}") print(f" Transform: {transform_prim.GetPath()}") @@ -300,80 +304,81 @@ def test_time_varying_topology(self, contour_meshes, test_directories): output_dir = test_directories["output"] usd_output_dir = output_dir / "usd_polymesh" usd_output_dir.mkdir(exist_ok=True) - + # Modify one mesh to have different topology mesh1 = contour_meshes[0].copy() mesh2 = contour_meshes[1].copy() - + # Decimate second mesh to change topology mesh2 = mesh2.decimate(0.5) - + print("\nConverting meshes with varying topology...") print(f" Mesh 1: {mesh1.n_points} points, {mesh1.n_cells} cells") print(f" Mesh 2: {mesh2.n_points} points, {mesh2.n_cells} cells") - + converter = ConvertVTK4DToUSDPolyMesh( - data_basename="HeartVarying", - input_polydata=[mesh1, mesh2], - mask_ids=None + data_basename="HeartVarying", input_polydata=[mesh1, mesh2], mask_ids=None ) - + output_file = usd_output_dir / "heart_varying_topology.usd" stage = converter.convert(str(output_file)) - + assert stage is not None, "USD stage not created" assert output_file.exists(), "USD file not created" - + # Check for time-varying meshes (separate mesh prims) parent_path = "/HeartVarying/default" parent_prim = stage.GetPrimAtPath(parent_path) - + # Should have child meshes for each time step children = parent_prim.GetChildren() if parent_prim.IsValid() else [] - + print(f"Time-varying topology handled") print(f" Parent prim: {parent_path}") print(f" Child prims: {len(children)}") print(f" Output: {output_file}") - def test_batch_conversion(self, contour_tools, segmentation_results, test_directories): + def test_batch_conversion( + self, contour_tools, segmentation_results, test_directories + ): """Test converting multiple anatomy structures in batch.""" output_dir = test_directories["output"] usd_output_dir = output_dir / "usd_polymesh" usd_output_dir.mkdir(exist_ok=True) - + # Extract contours from multiple anatomies anatomy_groups = ["lung", "heart"] meshes_dict = {} - + for group in anatomy_groups: mask = segmentation_results[0][group] mask_arr = itk.array_from_image(mask) - + import numpy as np + if np.sum(mask_arr > 0) > 100: contours = contour_tools.extract_contours(mask) meshes_dict[group] = contours - + if len(meshes_dict) >= 2: print(f"\nConverting {len(meshes_dict)} anatomy structures...") - + # Convert each anatomy separately for anatomy, mesh in meshes_dict.items(): converter = ConvertVTK4DToUSDPolyMesh( data_basename=f"{anatomy.capitalize()}", input_polydata=[mesh], - mask_ids=None + mask_ids=None, ) - + output_file = usd_output_dir / f"{anatomy}_anatomy.usd" stage = converter.convert(str(output_file)) - + assert stage is not None, f"USD stage not created for {anatomy}" assert output_file.exists(), f"USD file not created for {anatomy}" - + print(f" {anatomy}: {output_file}") - + print(f"Batch conversion complete") else: pytest.skip("Not enough anatomies with sufficient voxels") @@ -381,4 +386,3 @@ def test_batch_conversion(self, contour_tools, segmentation_results, test_direct if __name__ == "__main__": pytest.main([__file__, "-v", "-s"]) - diff --git a/tests/test_image_tools.py b/tests/test_image_tools.py index bb9683f..f378c45 100644 --- a/tests/test_image_tools.py +++ b/tests/test_image_tools.py @@ -16,222 +16,286 @@ class TestImageTools: """Test suite for ImageTools conversions.""" - + @pytest.fixture def image_tools(self): """Create ImageTools instance.""" return ImageTools() - + def test_itk_to_sitk_scalar_image(self, image_tools): """Test conversion of scalar ITK image to SimpleITK.""" # Create a simple 3D scalar ITK image size = [10, 20, 30] spacing = [1.0, 2.0, 3.0] origin = [0.0, 0.0, 0.0] - + # Create ITK image with known values ImageType = itk.Image[itk.F, 3] itk_image = ImageType.New() - + region = itk.ImageRegion[3]() region.SetSize(size) itk_image.SetRegions(region) itk_image.SetSpacing(spacing) itk_image.SetOrigin(origin) itk_image.Allocate() - + # Fill with test pattern itk_image.FillBuffer(42.0) - + # Convert to SimpleITK sitk_image = image_tools.convert_itk_image_to_sitk(itk_image) - + # Verify metadata - assert sitk_image.GetSize() == tuple(reversed(size)) - assert sitk_image.GetSpacing() == tuple(reversed(spacing)) - assert sitk_image.GetOrigin() == tuple(reversed(origin)) - + assert sitk_image.GetSize() == tuple(size) + assert sitk_image.GetSpacing() == tuple(spacing) + assert sitk_image.GetOrigin() == tuple(origin) + # Verify data array_sitk = sitk.GetArrayFromImage(sitk_image) assert np.allclose(array_sitk, 42.0) - + print("✓ ITK to SimpleITK scalar conversion successful") - + def test_sitk_to_itk_scalar_image(self, image_tools): """Test conversion of scalar SimpleITK image to ITK.""" # Create a simple 3D scalar SimpleITK image size = [10, 20, 30] spacing = [1.0, 2.0, 3.0] origin = [0.0, 0.0, 0.0] - + # Create SimpleITK image sitk_image = sitk.Image(size, sitk.sitkFloat32) sitk_image.SetSpacing(spacing) sitk_image.SetOrigin(origin) - + # Fill with test pattern array = np.ones((size[2], size[1], size[0]), dtype=np.float32) * 99.0 sitk_image = sitk.GetImageFromArray(array) sitk_image.SetSpacing(spacing) sitk_image.SetOrigin(origin) - + # Convert to ITK itk_image = image_tools.convert_sitk_image_to_itk(sitk_image) - + # Verify metadata - assert itk.size(itk_image) == tuple(reversed(size)) + assert itk.size(itk_image) == tuple(size) assert itk.spacing(itk_image) == tuple(spacing) assert itk.origin(itk_image) == tuple(origin) - + # Verify data array_itk = itk.array_from_image(itk_image) assert np.allclose(array_itk, 99.0) - + print("✓ SimpleITK to ITK scalar conversion successful") - + def test_roundtrip_scalar_image(self, image_tools): """Test roundtrip conversion: ITK -> SimpleITK -> ITK.""" # Create ITK image size = [15, 25, 35] spacing = [0.5, 1.5, 2.5] origin = [10.0, 20.0, 30.0] - + ImageType = itk.Image[itk.F, 3] itk_image_original = ImageType.New() - + region = itk.ImageRegion[3]() region.SetSize(size) itk_image_original.SetRegions(region) itk_image_original.SetSpacing(spacing) itk_image_original.SetOrigin(origin) itk_image_original.Allocate() - + # Fill with test data array_original = np.random.rand(size[2], size[1], size[0]).astype(np.float32) itk.array_view_from_image(itk_image_original)[:] = array_original - + # Roundtrip conversion sitk_image = image_tools.convert_itk_image_to_sitk(itk_image_original) itk_image_final = image_tools.convert_sitk_image_to_itk(sitk_image) - + # Verify metadata preserved assert itk.size(itk_image_final) == tuple(size) assert np.allclose(itk.spacing(itk_image_final), spacing) assert np.allclose(itk.origin(itk_image_final), origin) - + # Verify data preserved array_final = itk.array_from_image(itk_image_final) assert np.allclose(array_original, array_final) - + print("✓ Roundtrip scalar conversion successful") - + def test_itk_to_sitk_vector_image(self, image_tools): """Test conversion of vector ITK image to SimpleITK.""" # Create a 3D vector ITK image (like a displacement field) size = [8, 12, 16] spacing = [1.0, 1.0, 1.0] origin = [0.0, 0.0, 0.0] - + # Create vector image with 3 components VectorImageType = itk.Image[itk.Vector[itk.F, 3], 3] itk_image = VectorImageType.New() - + region = itk.ImageRegion[3]() region.SetSize(size) itk_image.SetRegions(region) itk_image.SetSpacing(spacing) itk_image.SetOrigin(origin) itk_image.Allocate() - + # Fill with test vector data array = np.random.rand(size[2], size[1], size[0], 3).astype(np.float32) itk.array_view_from_image(itk_image)[:] = array - + # Convert to SimpleITK sitk_image = image_tools.convert_itk_image_to_sitk(itk_image) - + # Verify it's a vector image assert sitk_image.GetNumberOfComponentsPerPixel() == 3 - + # Verify metadata - assert sitk_image.GetSize() == tuple(reversed(size)) - assert sitk_image.GetSpacing() == tuple(reversed(spacing)) - + assert sitk_image.GetSize() == tuple(size) + assert sitk_image.GetSpacing() == tuple(spacing) + # Verify data array_sitk = sitk.GetArrayFromImage(sitk_image) assert np.allclose(array, array_sitk) - + print("✓ ITK to SimpleITK vector conversion successful") - + def test_sitk_to_itk_vector_image(self, image_tools): """Test conversion of vector SimpleITK image to ITK.""" # Create a 3D vector SimpleITK image size = [8, 12, 16] spacing = [1.0, 1.0, 1.0] origin = [0.0, 0.0, 0.0] - + # Create vector data array = np.random.rand(size[2], size[1], size[0], 3).astype(np.float32) - + # Create SimpleITK vector image sitk_image = sitk.GetImageFromArray(array, isVector=True) sitk_image.SetSpacing(spacing) sitk_image.SetOrigin(origin) - + # Convert to ITK itk_image = image_tools.convert_sitk_image_to_itk(sitk_image) - + # Verify it's a vector image assert itk_image.GetNumberOfComponentsPerPixel() == 3 - + # Verify metadata - assert itk.size(itk_image) == tuple(reversed(size)) + assert itk.size(itk_image) == tuple(size) assert itk.spacing(itk_image) == tuple(spacing) - + # Verify data array_itk = itk.array_from_image(itk_image) assert np.allclose(array, array_itk) - + print("✓ SimpleITK to ITK vector conversion successful") - + def test_roundtrip_vector_image(self, image_tools): """Test roundtrip conversion for vector images: ITK -> SimpleITK -> ITK.""" # Create ITK vector image size = [10, 15, 20] spacing = [0.8, 1.2, 1.6] origin = [5.0, 10.0, 15.0] - + VectorImageType = itk.Image[itk.Vector[itk.F, 3], 3] itk_image_original = VectorImageType.New() - + region = itk.ImageRegion[3]() region.SetSize(size) itk_image_original.SetRegions(region) itk_image_original.SetSpacing(spacing) itk_image_original.SetOrigin(origin) itk_image_original.Allocate() - + # Fill with test vector data array_original = np.random.rand(size[2], size[1], size[0], 3).astype(np.float32) itk.array_view_from_image(itk_image_original)[:] = array_original - + # Roundtrip conversion sitk_image = image_tools.convert_itk_image_to_sitk(itk_image_original) itk_image_final = image_tools.convert_sitk_image_to_itk(sitk_image) - + # Verify metadata preserved assert itk.size(itk_image_final) == tuple(size) assert np.allclose(itk.spacing(itk_image_final), spacing) assert np.allclose(itk.origin(itk_image_final), origin) assert itk_image_final.GetNumberOfComponentsPerPixel() == 3 - + # Verify data preserved array_final = itk.array_from_image(itk_image_final) assert np.allclose(array_original, array_final) - + print("✓ Roundtrip vector conversion successful") + @pytest.mark.requires_data + @pytest.mark.slow + def test_imwrite_imread_vd3( + self, image_tools, ants_registration_results, test_images, test_directories + ): + """Test reading and writing double precision vector images.""" + from physiomotion4d.transform_tools import TransformTools + + output_dir = test_directories["output"] + img_output_dir = output_dir / "image_tools" + img_output_dir.mkdir(exist_ok=True) + + fixed_image = test_images[0] + forward_transform = ants_registration_results["forward_transform"] + + print("\nTesting imwriteVD3 and imreadVD3...") + + # Generate a deformation field using TransformTools + transform_tools = TransformTools() + deformation_field = transform_tools.convert_transform_to_displacement_field( + forward_transform, fixed_image + ) + + # Verify it's double precision vector image + field_type = str(type(deformation_field)) + print(f" Original field type: {field_type}") + assert "VD" in field_type, "Expected double precision vector image" + + # Get original data for comparison + original_arr = itk.array_from_image(deformation_field) + + # Write using imwriteVD3 + output_path = str(img_output_dir / "test_vector_field_vd3.mha") + image_tools.imwriteVD3(deformation_field, output_path, compression=True) + + print(f" Wrote to: {output_path}") + + # Read back using imreadVD3 + field_read = image_tools.imreadVD3(output_path) + + # Verify read field + assert field_read is not None, "Read field is None" + assert itk.size(field_read) == itk.size(deformation_field), "Size mismatch" + + # Verify it's double precision + read_type = str(type(field_read)) + print(f" Read field type: {read_type}") + assert "VD" in read_type, "Expected double precision vector image after reading" + + # Compare data + read_arr = itk.array_from_image(field_read) + assert read_arr.shape == original_arr.shape, "Array shape mismatch" + + # Check numerical accuracy (should be very close, small float precision loss) + max_diff = np.max(np.abs(read_arr - original_arr)) + mean_diff = np.mean(np.abs(read_arr - original_arr)) + + print(f"✓ Vector field I/O test complete") + print(f" Max difference: {max_diff:.6e}") + print(f" Mean difference: {mean_diff:.6e}") + + # Differences should be very small (float precision conversion) + assert max_diff < 1e-5, f"Max difference too large: {max_diff}" + assert mean_diff < 1e-6, f"Mean difference too large: {mean_diff}" + if __name__ == "__main__": pytest.main([__file__, "-v", "-s"]) - diff --git a/tests/test_register_images_ants.py b/tests/test_register_images_ants.py index a491f6a..782fd08 100644 --- a/tests/test_register_images_ants.py +++ b/tests/test_register_images_ants.py @@ -26,9 +26,7 @@ def test_registrar_initialization(self, registrar_ants): """Test that RegisterImagesANTs initializes correctly.""" assert registrar_ants is not None, "Registrar not initialized" assert hasattr(registrar_ants, 'fixed_image'), "Missing fixed_image attribute" - assert hasattr( - registrar_ants, 'fixed_image_mask' - ), "Missing fixed_image_mask attribute" + assert hasattr(registrar_ants, 'fixed_mask'), "Missing fixed_mask attribute" print("\n✓ ANTs registrar initialized successfully") @@ -75,26 +73,30 @@ def test_register_without_mask(self, registrar_ants, test_images, test_directori # Verify result is a dictionary assert isinstance(result, dict), "Result should be a dictionary" - assert "phi_FM" in result, "Missing phi_FM in result" - assert "phi_MF" in result, "Missing phi_MF in result" + assert "inverse_transform" in result, "Missing inverse_transform in result" + assert "forward_transform" in result, "Missing forward_transform in result" - phi_FM = result["phi_FM"] - phi_MF = result["phi_MF"] + inverse_transform = result["inverse_transform"] + forward_transform = result["forward_transform"] # Verify transforms are valid - assert phi_FM is not None, "phi_FM is None" - assert phi_MF is not None, "phi_MF is None" + assert inverse_transform is not None, "inverse_transform is None" + assert forward_transform is not None, "forward_transform is None" print("✓ Registration complete without mask") - print(f" phi_FM type: {type(phi_FM).__name__}") - print(f" phi_MF type: {type(phi_MF).__name__}") + print(f" inverse_transform type: {type(inverse_transform).__name__}") + print(f" forward_transform type: {type(forward_transform).__name__}") # Save transforms itk.transformwrite( - [phi_FM], str(reg_output_dir / "ants_phi_FM_no_mask.hdf"), compression=True + [inverse_transform], + str(reg_output_dir / "ants_inverse_transform_no_mask.hdf"), + compression=True, ) itk.transformwrite( - [phi_MF], str(reg_output_dir / "ants_phi_MF_no_mask.hdf"), compression=True + [forward_transform], + str(reg_output_dir / "ants_forward_transform_no_mask.hdf"), + compression=True, ) print(f" Saved transforms to: {reg_output_dir}") @@ -154,35 +156,35 @@ def test_register_with_mask(self, registrar_ants, test_images, test_directories) # Set up registration with masks registrar_ants.set_modality('ct') registrar_ants.set_fixed_image(fixed_image) - registrar_ants.set_fixed_image_mask(fixed_mask) + registrar_ants.set_fixed_mask(fixed_mask) # Register result = registrar_ants.register( - moving_image=moving_image, moving_image_mask=moving_mask + moving_image=moving_image, moving_mask=moving_mask ) # Verify result assert isinstance(result, dict), "Result should be a dictionary" - assert "phi_FM" in result, "Missing phi_FM in result" - assert "phi_MF" in result, "Missing phi_MF in result" + assert "inverse_transform" in result, "Missing inverse_transform in result" + assert "forward_transform" in result, "Missing forward_transform in result" - phi_FM = result["phi_FM"] - phi_MF = result["phi_MF"] + inverse_transform = result["inverse_transform"] + forward_transform = result["forward_transform"] - assert phi_FM is not None, "phi_FM is None" - assert phi_MF is not None, "phi_MF is None" + assert inverse_transform is not None, "inverse_transform is None" + assert forward_transform is not None, "forward_transform is None" print("✓ Registration complete with masks") # Save transforms itk.transformwrite( - [phi_FM], - str(reg_output_dir / "ants_phi_FM_with_mask.hdf"), + [inverse_transform], + str(reg_output_dir / "ants_inverse_transform_with_mask.hdf"), compression=True, ) itk.transformwrite( - [phi_MF], - str(reg_output_dir / "ants_phi_MF_with_mask.hdf"), + [forward_transform], + str(reg_output_dir / "ants_forward_transform_with_mask.hdf"), compression=True, ) @@ -200,14 +202,14 @@ def test_transform_application(self, registrar_ants, test_images, test_directori registrar_ants.set_fixed_image(fixed_image) result = registrar_ants.register(moving_image=moving_image) - phi_MF = result["phi_MF"] + forward_transform = result["forward_transform"] print("\nApplying transform to moving image...") # Apply transform transform_tools = TransformTools() registered_image = transform_tools.transform_image( - moving_image, phi_MF, fixed_image, interpolation_method="linear" + moving_image, forward_transform, fixed_image, interpolation_method="linear" ) # Verify registered image @@ -263,27 +265,23 @@ def test_registration_with_initial_transform( moving_image = test_images[1] # Create initial translation transform - initial_tfm_FM = itk.TranslationTransform[itk.D, 3].New() - initial_tfm_FM.SetOffset([5.0, 5.0, 5.0]) - - initial_tfm_MF = itk.TranslationTransform[itk.D, 3].New() - initial_tfm_MF.SetOffset([-5.0, -5.0, -5.0]) + initial_tfm_forward = itk.TranslationTransform[itk.D, 3].New() + initial_tfm_forward.SetOffset([-5.0, -5.0, -5.0]) print("\nRegistering with initial transform...") - print(" Initial offset: [5.0, 5.0, 5.0]") + print(" Initial offset: [-5.0, -5.0, -5.0]") registrar_ants.set_modality('ct') registrar_ants.set_fixed_image(fixed_image) result = registrar_ants.register( moving_image=moving_image, - initial_phi_FM=initial_tfm_FM, - initial_phi_MF=initial_tfm_MF, + initial_forward_transform=initial_tfm_forward, ) assert isinstance(result, dict), "Result should be a dictionary" - assert result["phi_FM"] is not None, "phi_FM is None" - assert result["phi_MF"] is not None, "phi_MF is None" + assert result["inverse_transform"] is not None, "inverse_transform is None" + assert result["forward_transform"] is not None, "forward_transform is None" print("✓ Registration with initial transform complete") @@ -304,8 +302,12 @@ def test_multiple_registrations(self, registrar_ants, test_images): results.append(result) assert isinstance(result, dict), f"Result {i+1} should be a dictionary" - assert "phi_FM" in result, f"Missing phi_FM in result {i+1}" - assert "phi_MF" in result, f"Missing phi_MF in result {i+1}" + assert ( + "inverse_transform" in result + ), f"Missing inverse_transform in result {i+1}" + assert ( + "forward_transform" in result + ), f"Missing forward_transform in result {i+1}" print(f"✓ Multiple registrations complete: {len(results)} runs") @@ -318,22 +320,22 @@ def test_transform_types(self, registrar_ants, test_images): registrar_ants.set_fixed_image(fixed_image) result = registrar_ants.register(moving_image=moving_image) - phi_FM = result["phi_FM"] - phi_MF = result["phi_MF"] + inverse_transform = result["inverse_transform"] + forward_transform = result["forward_transform"] print("\nVerifying transform types...") # Check that transforms are CompositeTransform (ANTs returns composite) assert isinstance( - phi_FM, itk.CompositeTransform - ), f"phi_FM should be CompositeTransform, got {type(phi_FM)}" + inverse_transform, itk.CompositeTransform + ), f"inverse_transform should be CompositeTransform, got {type(inverse_transform)}" assert isinstance( - phi_MF, itk.CompositeTransform - ), f"phi_MF should be CompositeTransform, got {type(phi_MF)}" + forward_transform, itk.CompositeTransform + ), f"forward_transform should be CompositeTransform, got {type(forward_transform)}" print("✓ Transform types verified") - print(f" phi_FM: {type(phi_FM).__name__}") - print(f" phi_MF: {type(phi_MF).__name__}") + print(f" inverse_transform: {type(inverse_transform).__name__}") + print(f" forward_transform: {type(forward_transform).__name__}") def test_image_conversion_cycle_scalar(self, registrar_ants, test_images): """Test round-trip conversion: ITK image -> ANTs -> ITK for scalar images.""" @@ -519,27 +521,36 @@ def test_transform_conversion_cycle_affine(self, registrar_ants, test_images): print(f" Original center: {[center[i] for i in range(3)]}") print(f" Original translation: {[translation[i] for i in range(3)]}") - # Convert ITK -> ANTs - ants_tfm = registrar_ants.itk_transform_to_ants_transform( - affine_tfm, reference_image - ) - assert ants_tfm is not None, "ANTs transform is None" - print(f" ANTs transform type: {ants_tfm.transform_type}") - - # Convert back ANTs -> ITK via displacement field - # (ANTs stores as displacement field, so we convert back through that) - + # Convert ITK -> ANTs file with tempfile.TemporaryDirectory() as tmpdir: - # Save ANTs transform to file temp_tfm_file = os.path.join(tmpdir, "temp_transform.mat") - ants.write_transform(ants_tfm, temp_tfm_file) - - # Read back as displacement field - recovered_tfm = ( - registrar_ants._antsfile_to_itk_displacement_field_transform( - temp_tfm_file, reference_image - ) + transform_files = registrar_ants.itk_transform_to_antsfile( + affine_tfm, reference_image, temp_tfm_file ) + assert len(transform_files) == 1, "Should return one transform file" + # Note: The returned filename may have extension added/modified + assert os.path.exists( + transform_files[0] + ), f"Transform file not found: {transform_files[0]}" + print(f" ANTs transform written to: {transform_files[0]}") + + # Read the transform to verify (use the actual written file) + ants_tfm = ants.read_transform(transform_files[0]) + print(f" ANTs transform type: {ants_tfm.transform_type}") + + # Convert back ANTs -> ITK + # Affine transforms are stored as affine in ANTs, so read back as affine + if ants_tfm.transform_type == "AffineTransform": + recovered_tfm = registrar_ants._antsfile_to_itk_affine_transform( + transform_files[0] + ) + else: + # For displacement field transforms + recovered_tfm = ( + registrar_ants._antsfile_to_itk_displacement_field_transform( + transform_files[0], reference_image + ) + ) assert recovered_tfm is not None, "Recovered transform is None" print(f" Recovered transform type: {type(recovered_tfm).__name__}") @@ -587,18 +598,25 @@ def test_transform_conversion_cycle_displacement_field( print("\nTesting displacement field transform conversion cycle...") - # Create a simple displacement field - VectorImageType = itk.Image[itk.Vector[itk.F, 3], 3] - disp_field = VectorImageType.New() - disp_field.CopyInformation(reference_image) - disp_field.SetRegions(reference_image.GetLargestPossibleRegion()) - disp_field.Allocate() + # Create a simple displacement field with double precision + # Use ImageTools to create the correct type + from physiomotion4d.image_tools import ImageTools + + image_tools = ImageTools() - # Fill with a simple displacement pattern - disp_array = itk.array_from_image(disp_field) - # Create a smooth displacement field (small random displacements) - for i in range(3): - disp_array[..., i] = np.random.randn(*disp_array.shape[:-1]) * 0.5 + # Create displacement array (small random displacements) + ref_size = itk.size(reference_image) + disp_array = ( + np.random.randn( + int(ref_size[2]), int(ref_size[1]), int(ref_size[0]), 3 + ).astype(np.float64) + * 0.5 + ) + + # Convert to ITK image with correct vector type + disp_field = image_tools.convert_array_to_image_of_vectors( + disp_array, reference_image, itk.D + ) # Create displacement field transform disp_tfm = itk.DisplacementFieldTransform[itk.D, 3].New() @@ -609,22 +627,24 @@ def test_transform_conversion_cycle_displacement_field( f" Max displacement magnitude: {np.max(np.linalg.norm(disp_array, axis=-1)):.3f}" ) - # Convert ITK -> ANTs - ants_tfm = registrar_ants.itk_transform_to_ants_transform( - disp_tfm, reference_image - ) - assert ants_tfm is not None, "ANTs transform is None" - print(" ANTs transform created successfully") - - # Convert back ANTs -> ITK - + # Convert ITK -> ANTs file with tempfile.TemporaryDirectory() as tmpdir: temp_tfm_file = os.path.join(tmpdir, "temp_disp_transform.mat") - ants.write_transform(ants_tfm, temp_tfm_file) + transform_files = registrar_ants.itk_transform_to_antsfile( + disp_tfm, reference_image, temp_tfm_file + ) + assert len(transform_files) == 1, "Should return one transform file" + # Note: The returned filename may have extension added/modified + assert os.path.exists( + transform_files[0] + ), f"Transform file not found: {transform_files[0]}" + print(f" ANTs transform written successfully to: {transform_files[0]}") + + # Convert back ANTs -> ITK recovered_tfm = ( registrar_ants._antsfile_to_itk_displacement_field_transform( - temp_tfm_file, reference_image + transform_files[0], reference_image ) ) @@ -689,11 +709,18 @@ def test_transform_conversion_with_composite(self, registrar_ants, test_images): f" Composite transform with {composite_tfm.GetNumberOfTransforms()} transforms" ) - # Convert to ANTs - ants_tfm = registrar_ants.itk_transform_to_ants_transform( - composite_tfm, reference_image - ) - assert ants_tfm is not None, "ANTs transform is None" + # Convert to ANTs file + with tempfile.TemporaryDirectory() as tmpdir: + temp_tfm_file = os.path.join(tmpdir, "temp_composite_transform.mat") + transform_files = registrar_ants.itk_transform_to_antsfile( + composite_tfm, reference_image, temp_tfm_file + ) + assert len(transform_files) == 1, "Should return one transform file" + # Note: The returned filename may have extension added/modified + assert os.path.exists( + transform_files[0] + ), f"Transform file not found: {transform_files[0]}" + print(f" ANTs transform written to: {transform_files[0]}") # Test on sample points test_points = [ diff --git a/tests/test_register_images_icon.py b/tests/test_register_images_icon.py index 1708ca4..97b7689 100644 --- a/tests/test_register_images_icon.py +++ b/tests/test_register_images_icon.py @@ -26,16 +26,14 @@ def test_registrar_initialization(self, registrar_icon): """Test that RegisterImagesICON initializes correctly.""" assert registrar_icon is not None, "Registrar not initialized" assert hasattr(registrar_icon, 'fixed_image'), "Missing fixed_image attribute" + assert hasattr(registrar_icon, 'fixed_mask'), "Missing fixed_mask attribute" assert hasattr( - registrar_icon, 'fixed_image_mask' - ), "Missing fixed_image_mask attribute" - assert hasattr( - registrar_icon, 'num_iterations' - ), "Missing num_iterations attribute" + registrar_icon, 'number_of_iterations' + ), "Missing number_of_iterations attribute" assert hasattr(registrar_icon, 'net'), "Missing net attribute (ICON network)" print("\nICON registrar initialized successfully") - print(f" Default iterations: {registrar_icon.num_iterations}") + print(f" Default iterations: {registrar_icon.number_of_iterations}") def test_set_modality(self, registrar_icon): """Test setting imaging modality.""" @@ -50,10 +48,12 @@ def test_set_modality(self, registrar_icon): def test_set_number_of_iterations(self, registrar_icon): """Test setting number of iterations.""" registrar_icon.set_number_of_iterations(10) - assert registrar_icon.num_iterations == 10, "Number of iterations not set" + assert registrar_icon.number_of_iterations == 10, "Number of iterations not set" registrar_icon.set_number_of_iterations(5) - assert registrar_icon.num_iterations == 5, "Number of iterations update failed" + assert ( + registrar_icon.number_of_iterations == 5 + ), "Number of iterations update failed" print("\nNumber of iterations setting works correctly") @@ -115,26 +115,30 @@ def test_register_without_mask(self, registrar_icon, test_images, test_directori # Verify result is a dictionary assert isinstance(result, dict), "Result should be a dictionary" - assert "phi_FM" in result, "Missing phi_FM in result" - assert "phi_MF" in result, "Missing phi_MF in result" + assert "inverse_transform" in result, "Missing inverse_transform in result" + assert "forward_transform" in result, "Missing forward_transform in result" - phi_FM = result["phi_FM"] - phi_MF = result["phi_MF"] + inverse_transform = result["inverse_transform"] + forward_transform = result["forward_transform"] # Verify transforms are valid - assert phi_FM is not None, "phi_FM is None" - assert phi_MF is not None, "phi_MF is None" + assert inverse_transform is not None, "inverse_transform is None" + assert forward_transform is not None, "forward_transform is None" print(f"ICON registration complete without mask") - print(f" phi_FM type: {type(phi_FM).__name__}") - print(f" phi_MF type: {type(phi_MF).__name__}") + print(f" inverse_transform type: {type(inverse_transform).__name__}") + print(f" forward_transform type: {type(forward_transform).__name__}") # Save transforms itk.transformwrite( - [phi_FM], str(reg_output_dir / "icon_phi_FM_no_mask.hdf"), compression=True + [inverse_transform], + str(reg_output_dir / "icon_inverse_transform_no_mask.hdf"), + compression=True, ) itk.transformwrite( - [phi_MF], str(reg_output_dir / "icon_phi_MF_no_mask.hdf"), compression=True + [forward_transform], + str(reg_output_dir / "icon_forward_transform_no_mask.hdf"), + compression=True, ) print(f" Saved transforms to: {reg_output_dir}") @@ -194,36 +198,36 @@ def test_register_with_mask(self, registrar_icon, test_images, test_directories) # Set up registration with masks registrar_icon.set_modality('ct') registrar_icon.set_fixed_image(fixed_image) - registrar_icon.set_fixed_image_mask(fixed_mask) + registrar_icon.set_fixed_mask(fixed_mask) registrar_icon.set_number_of_iterations(2) # Register result = registrar_icon.register( - moving_image=moving_image, moving_image_mask=moving_mask + moving_image=moving_image, moving_mask=moving_mask ) # Verify result assert isinstance(result, dict), "Result should be a dictionary" - assert "phi_FM" in result, "Missing phi_FM in result" - assert "phi_MF" in result, "Missing phi_MF in result" + assert "inverse_transform" in result, "Missing inverse_transform in result" + assert "forward_transform" in result, "Missing forward_transform in result" - phi_FM = result["phi_FM"] - phi_MF = result["phi_MF"] + inverse_transform = result["inverse_transform"] + forward_transform = result["forward_transform"] - assert phi_FM is not None, "phi_FM is None" - assert phi_MF is not None, "phi_MF is None" + assert inverse_transform is not None, "inverse_transform is None" + assert forward_transform is not None, "forward_transform is None" print(f"ICON registration complete with masks") # Save transforms itk.transformwrite( - [phi_FM], - str(reg_output_dir / "icon_phi_FM_with_mask.hdf"), + [inverse_transform], + str(reg_output_dir / "icon_inverse_transform_with_mask.hdf"), compression=True, ) itk.transformwrite( - [phi_MF], - str(reg_output_dir / "icon_phi_MF_with_mask.hdf"), + [forward_transform], + str(reg_output_dir / "icon_forward_transform_with_mask.hdf"), compression=True, ) @@ -242,14 +246,14 @@ def test_transform_application(self, registrar_icon, test_images, test_directori registrar_icon.set_number_of_iterations(2) result = registrar_icon.register(moving_image=moving_image) - phi_MF = result["phi_MF"] + forward_transform = result["forward_transform"] print("\nApplying ICON transform to moving image...") # Apply transform transform_tools = TransformTools() registered_image = transform_tools.transform_image( - moving_image, phi_MF, fixed_image, interpolation_method="linear" + moving_image, forward_transform, fixed_image, interpolation_method="linear" ) # Verify registered image @@ -288,8 +292,8 @@ def test_inverse_consistency(self, registrar_icon, test_images): registrar_icon.set_number_of_iterations(2) result = registrar_icon.register(moving_image=moving_image) - phi_FM = result["phi_FM"] - phi_MF = result["phi_MF"] + inverse_transform = result["inverse_transform"] + forward_transform = result["forward_transform"] # Test point transformation test_point = itk.Point[itk.D, 3]() @@ -298,8 +302,8 @@ def test_inverse_consistency(self, registrar_icon, test_images): test_point[2] = float(itk.size(fixed_image)[2] / 2) # Forward then backward - transformed_point = phi_MF.TransformPoint(test_point) - back_transformed_point = phi_FM.TransformPoint(transformed_point) + transformed_point = forward_transform.TransformPoint(test_point) + back_transformed_point = inverse_transform.TransformPoint(transformed_point) # Calculate error error = np.sqrt( @@ -348,11 +352,11 @@ def test_registration_with_initial_transform( moving_image = test_images[1] # Create initial translation transform - initial_tfm_FM = itk.TranslationTransform[itk.D, 3].New() - initial_tfm_FM.SetOffset([5.0, 5.0, 5.0]) + initial_tfm_inverse = itk.TranslationTransform[itk.D, 3].New() + initial_tfm_inverse.SetOffset([5.0, 5.0, 5.0]) - initial_tfm_MF = itk.TranslationTransform[itk.D, 3].New() - initial_tfm_MF.SetOffset([-5.0, -5.0, -5.0]) + initial_tfm_forward = itk.TranslationTransform[itk.D, 3].New() + initial_tfm_forward.SetOffset([-5.0, -5.0, -5.0]) print("\nRegistering with initial transform...") print(f" Initial offset: [5.0, 5.0, 5.0]") @@ -363,13 +367,12 @@ def test_registration_with_initial_transform( result = registrar_icon.register( moving_image=moving_image, - initial_phi_FM=initial_tfm_FM, - initial_phi_MF=initial_tfm_MF, + initial_forward_transform=initial_tfm_forward, ) assert isinstance(result, dict), "Result should be a dictionary" - assert result["phi_FM"] is not None, "phi_FM is None" - assert result["phi_MF"] is not None, "phi_MF is None" + assert result["inverse_transform"] is not None, "inverse_transform is None" + assert result["forward_transform"] is not None, "forward_transform is None" print(f"Registration with initial transform complete") @@ -383,34 +386,34 @@ def test_transform_types(self, registrar_icon, test_images): registrar_icon.set_number_of_iterations(2) result = registrar_icon.register(moving_image=moving_image) - phi_FM = result["phi_FM"] - phi_MF = result["phi_MF"] + inverse_transform = result["inverse_transform"] + forward_transform = result["forward_transform"] print("\nVerifying ICON transform types...") # ICON returns transforms (either DisplacementFieldTransform or CompositeTransform wrapping it) # The important thing is that they are valid ITK transforms - assert phi_FM is not None, "phi_FM is None" - assert phi_MF is not None, "phi_MF is None" + assert inverse_transform is not None, "inverse_transform is None" + assert forward_transform is not None, "forward_transform is None" # Check if it's either a DisplacementFieldTransform or CompositeTransform - valid_fm = isinstance( - phi_FM, (itk.DisplacementFieldTransform, itk.CompositeTransform) + valid_inverse = isinstance( + inverse_transform, (itk.DisplacementFieldTransform, itk.CompositeTransform) ) - valid_mf = isinstance( - phi_MF, (itk.DisplacementFieldTransform, itk.CompositeTransform) + valid_forward = isinstance( + forward_transform, (itk.DisplacementFieldTransform, itk.CompositeTransform) ) assert ( - valid_fm - ), f"phi_FM should be DisplacementFieldTransform or CompositeTransform, got {type(phi_FM)}" + valid_inverse + ), f"inverse_transform should be DisplacementFieldTransform or CompositeTransform, got {type(inverse_transform)}" assert ( - valid_mf - ), f"phi_MF should be DisplacementFieldTransform or CompositeTransform, got {type(phi_MF)}" + valid_forward + ), f"forward_transform should be DisplacementFieldTransform or CompositeTransform, got {type(forward_transform)}" print(f"Transform types verified") - print(f" phi_FM: {type(phi_FM).__name__}") - print(f" phi_MF: {type(phi_MF).__name__}") + print(f" inverse_transform: {type(inverse_transform).__name__}") + print(f" forward_transform: {type(forward_transform).__name__}") def test_different_iteration_counts(self, registrar_icon, test_images): """Test ICON with different iteration counts.""" @@ -432,8 +435,8 @@ def test_different_iteration_counts(self, registrar_icon, test_images): results.append(result) assert isinstance(result, dict), "Result should be a dictionary" - assert "phi_FM" in result, "Missing phi_FM" - assert "phi_MF" in result, "Missing phi_MF" + assert "inverse_transform" in result, "Missing inverse_transform" + assert "forward_transform" in result, "Missing forward_transform" print(f"Tested {len(iteration_counts)} different iteration counts") diff --git a/tests/test_register_time_series_images.py b/tests/test_register_time_series_images.py index b0aa15f..1bdbab5 100644 --- a/tests/test_register_time_series_images.py +++ b/tests/test_register_time_series_images.py @@ -24,7 +24,12 @@ def test_registrar_initialization_ants(self): registrar = RegisterTimeSeriesImages(registration_method='ants') assert registrar is not None, "Registrar not initialized" assert registrar.registration_method == 'ants', "Method not set correctly" - assert registrar.registrar is not None, "Internal registrar not created" + assert ( + registrar.registrar_ants is not None + ), "Internal ANTs registrar not created" + assert ( + registrar.registrar_icon is not None + ), "Internal ICON registrar not created" print("\n✓ Time series registrar initialized with ANTs") @@ -33,7 +38,12 @@ def test_registrar_initialization_icon(self): registrar = RegisterTimeSeriesImages(registration_method='icon') assert registrar is not None, "Registrar not initialized" assert registrar.registration_method == 'icon', "Method not set correctly" - assert registrar.registrar is not None, "Internal registrar not created" + assert ( + registrar.registrar_ants is not None + ), "Internal ANTs registrar not created" + assert ( + registrar.registrar_icon is not None + ), "Internal ICON registrar not created" print("\n✓ Time series registrar initialized with ICON") @@ -49,9 +59,6 @@ def test_set_modality(self): registrar = RegisterTimeSeriesImages(registration_method='ants') registrar.set_modality('ct') assert registrar.modality == 'ct', "Modality not set correctly" - assert ( - registrar.registrar.modality == 'ct' - ), "Modality not passed to internal registrar" print("\n✓ Modality setting works correctly") @@ -62,9 +69,6 @@ def test_set_fixed_image(self, test_images): registrar.set_fixed_image(fixed_image) assert registrar.fixed_image is not None, "Fixed image not set" - assert ( - registrar.registrar.fixed_image is not None - ), "Fixed image not passed through" print("\n✓ Fixed image set successfully") print(f" Image size: {itk.size(registrar.fixed_image)}") @@ -102,39 +106,45 @@ def test_register_time_series_basic(self, test_images, test_directories): result = registrar.register_time_series( moving_images=moving_images, - starting_index=0, - register_start_to_reference=True, - portion_of_prior_transform_to_init_next_transform=0.0, + reference_frame=0, + register_reference=True, + prior_weight=0.0, ) # Verify result structure assert isinstance(result, dict), "Result should be a dictionary" - assert "phi_MF_list" in result, "Missing phi_MF_list in result" - assert "phi_FM_list" in result, "Missing phi_FM_list in result" + assert "forward_transforms" in result, "Missing forward_transforms in result" + assert "inverse_transforms" in result, "Missing inverse_transforms in result" assert "losses" in result, "Missing losses in result" - phi_MF_list = result["phi_MF_list"] - phi_FM_list = result["phi_FM_list"] + forward_transforms = result["forward_transforms"] + inverse_transforms = result["inverse_transforms"] losses = result["losses"] # Verify list lengths - assert len(phi_MF_list) == len(moving_images), "phi_MF_list length mismatch" - assert len(phi_FM_list) == len(moving_images), "phi_FM_list length mismatch" + assert len(forward_transforms) == len( + moving_images + ), "forward_transforms length mismatch" + assert len(inverse_transforms) == len( + moving_images + ), "inverse_transforms length mismatch" assert len(losses) == len(moving_images), "losses length mismatch" # Verify all transforms are valid - for i, (phi_MF, phi_FM) in enumerate(zip(phi_MF_list, phi_FM_list)): - assert phi_MF is not None, f"phi_MF[{i}] is None" - assert phi_FM is not None, f"phi_FM[{i}] is None" + for i, (forward_transform, inverse_transform) in enumerate( + zip(forward_transforms, inverse_transforms, strict=False) + ): + assert forward_transform is not None, f"forward_transform[{i}] is None" + assert inverse_transform is not None, f"inverse_transform[{i}] is None" print("✓ Time series registration complete") - print(f" Transforms generated: {len(phi_MF_list)}") + print(f" Transforms generated: {len(forward_transforms)}") print(f" Average loss: {np.mean(losses):.6f}") # Save first transform for verification itk.transformwrite( - [phi_MF_list[0]], - str(reg_output_dir / "time_series_phi_MF_0.hdf"), + [forward_transforms[0]], + str(reg_output_dir / "time_series_forward_transform_0.hdf"), compression=True, ) print(f" Saved sample transform to: {reg_output_dir}") @@ -150,7 +160,7 @@ def test_register_time_series_with_prior(self, test_images, test_directories): print("\nRegistering time series (with prior)...") print(f" Number of moving images: {len(moving_images)}") - print(f" Using prior transform weight: 0.5") + print(" Using prior transform weight: 0.5") registrar = RegisterTimeSeriesImages(registration_method='ants') registrar.set_modality('ct') @@ -159,17 +169,17 @@ def test_register_time_series_with_prior(self, test_images, test_directories): result = registrar.register_time_series( moving_images=moving_images, - starting_index=1, # Start from middle - register_start_to_reference=True, - portion_of_prior_transform_to_init_next_transform=0.5, + reference_frame=1, # Start from middle + register_reference=True, + prior_weight=0.5, ) - phi_MF_list = result["phi_MF_list"] + forward_transforms = result["forward_transforms"] losses = result["losses"] # Verify all transforms generated - for i, phi_MF in enumerate(phi_MF_list): - assert phi_MF is not None, f"phi_MF[{i}] is None" + for i, forward_transform in enumerate(forward_transforms): + assert forward_transform is not None, f"forward_transform[{i}] is None" print("✓ Time series registration with prior complete") print(f" Losses: {[f'{loss:.6f}' for loss in losses]}") @@ -188,9 +198,9 @@ def test_register_time_series_identity_start(self, test_images): result = registrar.register_time_series( moving_images=moving_images, - starting_index=0, - register_start_to_reference=False, # Use identity - portion_of_prior_transform_to_init_next_transform=0.0, + reference_frame=0, + register_reference=False, # Use identity + prior_weight=0.0, ) # Starting image should have very low/zero loss @@ -217,14 +227,14 @@ def test_register_time_series_different_starting_indices(self, test_images): print(f" Starting index: {starting_index}") result = registrar.register_time_series( moving_images=moving_images, - starting_index=starting_index, - register_start_to_reference=True, - portion_of_prior_transform_to_init_next_transform=0.0, + reference_frame=starting_index, + register_reference=True, + prior_weight=0.0, ) - assert len(result["phi_MF_list"]) == len( + assert len(result["forward_transforms"]) == len( moving_images - ), f"Wrong number of transforms for starting_index={starting_index}" + ), f"Wrong number of transforms for reference_frame={starting_index}" print("✓ Different starting indices work correctly") @@ -247,15 +257,15 @@ def test_register_time_series_error_invalid_starting_index(self, test_images): moving_images = test_images[1:4] # Test negative index - with pytest.raises(ValueError, match="starting_index.*out of range"): + with pytest.raises(ValueError, match="reference_frame.*out of range"): registrar.register_time_series( - moving_images=moving_images, starting_index=-1 + moving_images=moving_images, reference_frame=-1 ) # Test index too large - with pytest.raises(ValueError, match="starting_index.*out of range"): + with pytest.raises(ValueError, match="reference_frame.*out of range"): registrar.register_time_series( - moving_images=moving_images, starting_index=10 + moving_images=moving_images, reference_frame=10 ) print("\n✓ Invalid starting index correctly rejected") @@ -271,14 +281,14 @@ def test_register_time_series_error_invalid_prior_portion(self, test_images): with pytest.raises(ValueError, match="must be in"): registrar.register_time_series( moving_images=moving_images, - portion_of_prior_transform_to_init_next_transform=-0.1, + prior_weight=-0.1, ) # Test value > 1 with pytest.raises(ValueError, match="must be in"): registrar.register_time_series( moving_images=moving_images, - portion_of_prior_transform_to_init_next_transform=1.5, + prior_weight=1.5, ) print("\n✓ Invalid prior portion correctly rejected") @@ -301,17 +311,20 @@ def test_transform_application_time_series(self, test_images, test_directories): result = registrar.register_time_series( moving_images=moving_images, - starting_index=0, - register_start_to_reference=True, - portion_of_prior_transform_to_init_next_transform=0.0, + reference_frame=0, + register_reference=True, + prior_weight=0.0, ) - phi_MF_list = result["phi_MF_list"] + forward_transforms = result["forward_transforms"] # Apply transform to first moving image transform_tools = TransformTools() registered_image = transform_tools.transform_image( - moving_images[0], phi_MF_list[0], fixed_image, interpolation_method="linear" + moving_images[0], + forward_transforms[0], + fixed_image, + interpolation_method="linear", ) assert registered_image is not None, "Registered image is None" @@ -341,21 +354,19 @@ def test_register_time_series_icon(self, test_images): result = registrar.register_time_series( moving_images=moving_images, - starting_index=0, - register_start_to_reference=True, - portion_of_prior_transform_to_init_next_transform=0.0, + reference_frame=0, + register_reference=True, + prior_weight=0.0, ) - assert len(result["phi_MF_list"]) == len(moving_images) - assert len(result["phi_FM_list"]) == len(moving_images) + assert len(result["forward_transforms"]) == len(moving_images) + assert len(result["inverse_transforms"]) == len(moving_images) assert len(result["losses"]) == len(moving_images) print("✓ ICON time series registration complete") def test_register_time_series_with_mask(self, test_images, test_directories): """Test time series registration with fixed image mask.""" - output_dir = test_directories["output"] - fixed_image = test_images[0] moving_images = test_images[1:3] @@ -383,17 +394,17 @@ def test_register_time_series_with_mask(self, test_images, test_directories): registrar = RegisterTimeSeriesImages(registration_method='ants') registrar.set_modality('ct') registrar.set_fixed_image(fixed_image) - registrar.set_fixed_image_mask(fixed_mask) + registrar.set_fixed_mask(fixed_mask) registrar.set_number_of_iterations([20, 10, 2]) result = registrar.register_time_series( moving_images=moving_images, - starting_index=0, - register_start_to_reference=True, - portion_of_prior_transform_to_init_next_transform=0.0, + reference_frame=0, + register_reference=True, + prior_weight=0.0, ) - assert len(result["phi_MF_list"]) == len(moving_images) + assert len(result["forward_transforms"]) == len(moving_images) print("✓ Masked time series registration complete") @@ -404,7 +415,7 @@ def test_bidirectional_registration(self, test_images): print("\nTesting bidirectional registration...") print(f" Total images: {len(moving_images)}") - print(f" Starting from middle (index 2)") + print(" Starting from middle (index 2)") registrar = RegisterTimeSeriesImages(registration_method='ants') registrar.set_modality('ct') @@ -413,19 +424,19 @@ def test_bidirectional_registration(self, test_images): result = registrar.register_time_series( moving_images=moving_images, - starting_index=2, # Middle image - register_start_to_reference=True, - portion_of_prior_transform_to_init_next_transform=0.0, + reference_frame=2, # Middle image + register_reference=True, + prior_weight=0.0, ) - phi_MF_list = result["phi_MF_list"] + forward_transforms = result["forward_transforms"] # All transforms should be generated - for i, phi_MF in enumerate(phi_MF_list): - assert phi_MF is not None, f"Transform {i} is None" + for i, forward_transform in enumerate(forward_transforms): + assert forward_transform is not None, f"Transform {i} is None" print("✓ Bidirectional registration successful") - print(f" All {len(phi_MF_list)} transforms generated") + print(f" All {len(forward_transforms)} transforms generated") if __name__ == "__main__": diff --git a/tests/test_segment_chest_total_segmentator.py b/tests/test_segment_chest_total_segmentator.py index 4a83db6..00f19b4 100644 --- a/tests/test_segment_chest_total_segmentator.py +++ b/tests/test_segment_chest_total_segmentator.py @@ -6,14 +6,10 @@ functionality on two time points from the converted 3D data. """ -from pathlib import Path - import itk import numpy as np import pytest -from physiomotion4d.segment_chest_total_segmentator import SegmentChestTotalSegmentator - @pytest.mark.requires_data @pytest.mark.slow @@ -23,110 +19,155 @@ class TestSegmentChestTotalSegmentator: def test_segmenter_initialization(self, segmenter_total_segmentator): """Test that SegmentChestTotalSegmentator initializes correctly.""" assert segmenter_total_segmentator is not None, "Segmenter not initialized" - assert segmenter_total_segmentator.target_spacing == 1.5, "Target spacing not set correctly" - + assert ( + segmenter_total_segmentator.target_spacing == 1.5 + ), "Target spacing not set correctly" + # Check that anatomical structure ID mappings are defined - assert len(segmenter_total_segmentator.heart_mask_ids) > 0, "Heart mask IDs not defined" - assert len(segmenter_total_segmentator.major_vessels_mask_ids) > 0, "Major vessels mask IDs not defined" - assert len(segmenter_total_segmentator.lung_mask_ids) > 0, "Lung mask IDs not defined" - assert len(segmenter_total_segmentator.bone_mask_ids) > 0, "Bone mask IDs not defined" - assert len(segmenter_total_segmentator.soft_tissue_mask_ids) > 0, "Soft tissue mask IDs not defined" - + assert ( + len(segmenter_total_segmentator.heart_mask_ids) > 0 + ), "Heart mask IDs not defined" + assert ( + len(segmenter_total_segmentator.major_vessels_mask_ids) > 0 + ), "Major vessels mask IDs not defined" + assert ( + len(segmenter_total_segmentator.lung_mask_ids) > 0 + ), "Lung mask IDs not defined" + assert ( + len(segmenter_total_segmentator.bone_mask_ids) > 0 + ), "Bone mask IDs not defined" + assert ( + len(segmenter_total_segmentator.soft_tissue_mask_ids) > 0 + ), "Soft tissue mask IDs not defined" + print("\n✓ Segmenter initialized with correct parameters") print(f" Heart structures: {len(segmenter_total_segmentator.heart_mask_ids)}") - print(f" Major vessels: {len(segmenter_total_segmentator.major_vessels_mask_ids)}") + print( + f" Major vessels: {len(segmenter_total_segmentator.major_vessels_mask_ids)}" + ) print(f" Lung structures: {len(segmenter_total_segmentator.lung_mask_ids)}") print(f" Bone structures: {len(segmenter_total_segmentator.bone_mask_ids)}") - print(f" Soft tissue structures: {len(segmenter_total_segmentator.soft_tissue_mask_ids)}") + print( + f" Soft tissue structures: {len(segmenter_total_segmentator.soft_tissue_mask_ids)}" + ) - def test_segment_single_image(self, segmenter_total_segmentator, test_images, test_directories): + def test_segment_single_image( + self, segmenter_total_segmentator, test_images, test_directories + ): """Test segmentation on a single time point.""" output_dir = test_directories["output"] - + # Test on first time point only input_image = test_images[0] - + print(f"\nSegmenting time point 0...") print(f" Input image size: {itk.size(input_image)}") - + # Run segmentation - result = segmenter_total_segmentator.segment(input_image, contrast_enhanced_study=False) - + result = segmenter_total_segmentator.segment( + input_image, contrast_enhanced_study=False + ) + # Verify result is a dictionary with expected keys assert isinstance(result, dict), "Result should be a dictionary" - expected_keys = ["labelmap", "lung", "heart", "major_vessels", "bone", - "soft_tissue", "other", "contrast"] + expected_keys = [ + "labelmap", + "lung", + "heart", + "major_vessels", + "bone", + "soft_tissue", + "other", + "contrast", + ] for key in expected_keys: assert key in result, f"Missing key '{key}' in result" assert result[key] is not None, f"Result['{key}'] is None" - + # Verify labelmap properties labelmap = result["labelmap"] assert itk.size(labelmap) == itk.size(input_image), "Labelmap size mismatch" - + # Check that labels are present labelmap_arr = itk.array_from_image(labelmap) unique_labels = np.unique(labelmap_arr) assert len(unique_labels) > 1, "Labelmap should contain multiple labels" - + print(f"✓ Segmentation complete for time point 0") print(f" Labelmap size: {itk.size(labelmap)}") print(f" Unique labels: {len(unique_labels)}") - + # Save results seg_output_dir = output_dir / "segmentation_total_segmentator" seg_output_dir.mkdir(exist_ok=True) - - itk.imwrite(labelmap, str(seg_output_dir / "slice_000_labelmap.mha"), compression=True) + + itk.imwrite( + labelmap, str(seg_output_dir / "slice_000_labelmap.mha"), compression=True + ) print(f" Saved labelmap to: {seg_output_dir / 'slice_000_labelmap.mha'}") - def test_segment_multiple_images(self, segmenter_total_segmentator, test_images, test_directories): + def test_segment_multiple_images( + self, segmenter_total_segmentator, test_images, test_directories + ): """Test segmentation on two time points.""" output_dir = test_directories["output"] seg_output_dir = output_dir / "segmentation_total_segmentator" seg_output_dir.mkdir(exist_ok=True) - + results = [] - for i, input_image in enumerate(test_images): + for i, input_image in enumerate(test_images[0:2]): print(f"\nSegmenting time point {i}...") - - result = segmenter_total_segmentator.segment(input_image, contrast_enhanced_study=False) + + result = segmenter_total_segmentator.segment( + input_image, contrast_enhanced_study=False + ) results.append(result) - + # Save labelmap for each time point labelmap = result["labelmap"] output_file = seg_output_dir / f"slice_{i:03d}_labelmap.mha" itk.imwrite(labelmap, str(output_file), compression=True) - + print(f"✓ Time point {i} complete") print(f" Saved to: {output_file}") - + assert len(results) == 2, "Expected 2 segmentation results" print(f"\n✓ Successfully segmented {len(results)} time points") def test_anatomy_group_masks(self, segmenter_total_segmentator, test_images): """Test that anatomy group masks are created correctly.""" input_image = test_images[0] - + # Run segmentation - result = segmenter_total_segmentator.segment(input_image, contrast_enhanced_study=False) - + result = segmenter_total_segmentator.segment( + input_image, contrast_enhanced_study=False + ) + # Check each anatomy group mask - anatomy_groups = ["lung", "heart", "major_vessels", "bone", "soft_tissue", "other"] - + anatomy_groups = [ + "lung", + "heart", + "major_vessels", + "bone", + "soft_tissue", + "other", + ] + for group in anatomy_groups: mask = result[group] assert mask is not None, f"{group} mask is None" - + # Check that mask is binary mask_arr = itk.array_from_image(mask) unique_values = np.unique(mask_arr) assert len(unique_values) <= 2, f"{group} mask should be binary" assert 0 in unique_values, f"{group} mask should contain background" - + # Check that mask has same size as input - assert itk.size(mask) == itk.size(input_image), f"{group} mask size mismatch" - + assert itk.size(mask) == itk.size( + input_image + ), f"{group} mask size mismatch" + print("\n✓ All anatomy group masks created correctly") for group in anatomy_groups: mask_arr = itk.array_from_image(result[group]) @@ -136,21 +177,25 @@ def test_anatomy_group_masks(self, segmenter_total_segmentator, test_images): def test_contrast_detection(self, segmenter_total_segmentator, test_images): """Test contrast detection functionality.""" input_image = test_images[0] - + # Test without contrast - result_no_contrast = segmenter_total_segmentator.segment(input_image, contrast_enhanced_study=False) + result_no_contrast = segmenter_total_segmentator.segment( + input_image, contrast_enhanced_study=False + ) contrast_mask_no = result_no_contrast["contrast"] - + # Test with contrast flag - result_with_contrast = segmenter_total_segmentator.segment(input_image, contrast_enhanced_study=True) + result_with_contrast = segmenter_total_segmentator.segment( + input_image, contrast_enhanced_study=True + ) contrast_mask_yes = result_with_contrast["contrast"] - + # Both should return valid masks assert contrast_mask_no is not None, "Contrast mask (no flag) is None" assert contrast_mask_yes is not None, "Contrast mask (with flag) is None" - + print("\n✓ Contrast detection tested") - + contrast_arr_no = itk.array_from_image(contrast_mask_no) contrast_arr_yes = itk.array_from_image(contrast_mask_yes) print(f" Without contrast flag: {np.sum(contrast_arr_no > 0)} voxels") @@ -159,41 +204,48 @@ def test_contrast_detection(self, segmenter_total_segmentator, test_images): def test_preprocessing(self, segmenter_total_segmentator, test_images): """Test preprocessing functionality.""" input_image = test_images[0] - + # Get original properties original_spacing = itk.spacing(input_image) - + # Preprocessing is done internally by segment(), not exposed as public method # Just verify that segment() works (which includes preprocessing) - result = segmenter_total_segmentator.segment(input_image, contrast_enhanced_study=False) - + result = segmenter_total_segmentator.segment( + input_image, contrast_enhanced_study=False + ) + # Check that segmentation was successful (which means preprocessing worked) assert result is not None, "Segmentation result is None" assert "labelmap" in result, "Labelmap not in result" - + print("\n✓ Preprocessing tested (via successful segmentation)") print(f" Original image spacing: {original_spacing}") def test_postprocessing(self, segmenter_total_segmentator, test_images): """Test postprocessing functionality.""" input_image = test_images[0] - + # Run full segmentation to get labelmap - result = segmenter_total_segmentator.segment(input_image, contrast_enhanced_study=False) + result = segmenter_total_segmentator.segment( + input_image, contrast_enhanced_study=False + ) labelmap = result["labelmap"] - + # Postprocessing is part of segment(), verify output is properly sized - assert itk.size(labelmap) == itk.size(input_image), "Postprocessing failed: size mismatch" - + assert itk.size(labelmap) == itk.size( + input_image + ), "Postprocessing failed: size mismatch" + # Check that labelmap has been resampled to original spacing original_spacing = itk.spacing(input_image) labelmap_spacing = itk.spacing(labelmap) - + # Spacing should match (within floating point tolerance) for i in range(3): - assert abs(labelmap_spacing[i] - original_spacing[i]) < 0.01, \ - f"Spacing mismatch at dimension {i}" - + assert ( + abs(labelmap_spacing[i] - original_spacing[i]) < 0.01 + ), f"Spacing mismatch at dimension {i}" + print("\n✓ Postprocessing tested") print(f" Original spacing: {original_spacing}") print(f" Labelmap spacing: {labelmap_spacing}") @@ -201,4 +253,3 @@ def test_postprocessing(self, segmenter_total_segmentator, test_images): if __name__ == "__main__": pytest.main([__file__, "-v", "-s"]) - diff --git a/tests/test_segment_chest_vista_3d.py b/tests/test_segment_chest_vista_3d.py index b55c7d3..aed8331 100644 --- a/tests/test_segment_chest_vista_3d.py +++ b/tests/test_segment_chest_vista_3d.py @@ -24,94 +24,118 @@ def test_segmenter_initialization(self, segmenter_vista_3d): """Test that SegmentChestVista3D initializes correctly.""" assert segmenter_vista_3d is not None, "Segmenter not initialized" assert segmenter_vista_3d.device is not None, "CUDA device not initialized" - + # Check that anatomical structure ID mappings are defined assert len(segmenter_vista_3d.heart_mask_ids) > 0, "Heart mask IDs not defined" - assert len(segmenter_vista_3d.major_vessels_mask_ids) > 0, "Major vessels mask IDs not defined" + assert ( + len(segmenter_vista_3d.major_vessels_mask_ids) > 0 + ), "Major vessels mask IDs not defined" assert len(segmenter_vista_3d.lung_mask_ids) > 0, "Lung mask IDs not defined" assert len(segmenter_vista_3d.bone_mask_ids) > 0, "Bone mask IDs not defined" - assert len(segmenter_vista_3d.soft_tissue_mask_ids) > 0, "Soft tissue mask IDs not defined" - + assert ( + len(segmenter_vista_3d.soft_tissue_mask_ids) > 0 + ), "Soft tissue mask IDs not defined" + # Check VISTA-3D specific attributes assert segmenter_vista_3d.bundle_path is not None, "Bundle path not set" - assert segmenter_vista_3d.label_prompt is None, "Label prompt should be None initially" - + assert ( + segmenter_vista_3d.label_prompt is None + ), "Label prompt should be None initially" + print("\n✓ Segmenter initialized with correct parameters") print(f" Heart structures: {len(segmenter_vista_3d.heart_mask_ids)}") print(f" Major vessels: {len(segmenter_vista_3d.major_vessels_mask_ids)}") print(f" Lung structures: {len(segmenter_vista_3d.lung_mask_ids)}") print(f" Bone structures: {len(segmenter_vista_3d.bone_mask_ids)}") - print(f" Soft tissue structures: {len(segmenter_vista_3d.soft_tissue_mask_ids)}") + print( + f" Soft tissue structures: {len(segmenter_vista_3d.soft_tissue_mask_ids)}" + ) print(f" Bundle path: {segmenter_vista_3d.bundle_path}") - def test_segment_single_image(self, segmenter_vista_3d, test_images, test_directories): + def test_segment_single_image( + self, segmenter_vista_3d, test_images, test_directories + ): """Test automatic segmentation on a single time point.""" output_dir = test_directories["output"] - + # Ensure we're in automatic segmentation mode segmenter_vista_3d.set_whole_image_segmentation() - + # Test on first time point only input_image = test_images[0] - + print(f"\nSegmenting time point 0 (automatic mode)...") print(f" Input image size: {itk.size(input_image)}") - + # Run segmentation result = segmenter_vista_3d.segment(input_image, contrast_enhanced_study=False) - + # Verify result is a dictionary with expected keys assert isinstance(result, dict), "Result should be a dictionary" - expected_keys = ["labelmap", "lung", "heart", "major_vessels", "bone", - "soft_tissue", "other", "contrast"] + expected_keys = [ + "labelmap", + "lung", + "heart", + "major_vessels", + "bone", + "soft_tissue", + "other", + "contrast", + ] for key in expected_keys: assert key in result, f"Missing key '{key}' in result" assert result[key] is not None, f"Result['{key}'] is None" - + # Verify labelmap properties labelmap = result["labelmap"] assert itk.size(labelmap) == itk.size(input_image), "Labelmap size mismatch" - + # Check that labels are present labelmap_arr = itk.array_from_image(labelmap) unique_labels = np.unique(labelmap_arr) assert len(unique_labels) > 1, "Labelmap should contain multiple labels" - + print(f"✓ Segmentation complete for time point 0") print(f" Labelmap size: {itk.size(labelmap)}") print(f" Unique labels: {len(unique_labels)}") - + # Save results seg_output_dir = output_dir / "segmentation_vista3d" seg_output_dir.mkdir(exist_ok=True) - - itk.imwrite(labelmap, str(seg_output_dir / "slice_000_labelmap.mha"), compression=True) + + itk.imwrite( + labelmap, str(seg_output_dir / "slice_000_labelmap.mha"), compression=True + ) print(f" Saved labelmap to: {seg_output_dir / 'slice_000_labelmap.mha'}") - def test_segment_multiple_images(self, segmenter_vista_3d, test_images, test_directories): + def test_segment_multiple_images( + self, segmenter_vista_3d, test_images, test_directories + ): """Test automatic segmentation on two time points.""" output_dir = test_directories["output"] seg_output_dir = output_dir / "segmentation_vista3d" seg_output_dir.mkdir(exist_ok=True) - + # Ensure automatic segmentation mode segmenter_vista_3d.set_whole_image_segmentation() - + results = [] - for i, input_image in enumerate(test_images): + for i, input_image in enumerate(test_images[0:2]): print(f"\nSegmenting time point {i}...") - - result = segmenter_vista_3d.segment(input_image, contrast_enhanced_study=False) + + result = segmenter_vista_3d.segment( + input_image, contrast_enhanced_study=False + ) results.append(result) - + # Save labelmap for each time point labelmap = result["labelmap"] output_file = seg_output_dir / f"slice_{i:03d}_labelmap.mha" itk.imwrite(labelmap, str(output_file), compression=True) - + print(f"✓ Time point {i} complete") print(f" Saved to: {output_file}") - + assert len(results) == 2, "Expected 2 segmentation results" print(f"\n✓ Successfully segmented {len(results)} time points") @@ -119,63 +143,74 @@ def test_anatomy_group_masks(self, segmenter_vista_3d, test_images): """Test that anatomy group masks are created correctly.""" segmenter_vista_3d.set_whole_image_segmentation() input_image = test_images[0] - + # Run segmentation result = segmenter_vista_3d.segment(input_image, contrast_enhanced_study=False) - + # Check each anatomy group mask - anatomy_groups = ["lung", "heart", "major_vessels", "bone", "soft_tissue", "other"] - + anatomy_groups = [ + "lung", + "heart", + "major_vessels", + "bone", + "soft_tissue", + "other", + ] + for group in anatomy_groups: mask = result[group] assert mask is not None, f"{group} mask is None" - + # Check that mask is binary mask_arr = itk.array_from_image(mask) unique_values = np.unique(mask_arr) assert len(unique_values) <= 2, f"{group} mask should be binary" assert 0 in unique_values, f"{group} mask should contain background" - + # Check that mask has same size as input - assert itk.size(mask) == itk.size(input_image), f"{group} mask size mismatch" - + assert itk.size(mask) == itk.size( + input_image + ), f"{group} mask size mismatch" + print("\n✓ All anatomy group masks created correctly") for group in anatomy_groups: mask_arr = itk.array_from_image(result[group]) num_voxels = np.sum(mask_arr > 0) print(f" {group}: {num_voxels} voxels") - def test_label_prompt_segmentation(self, segmenter_vista_3d, test_images, test_directories): + def test_label_prompt_segmentation( + self, segmenter_vista_3d, test_images, test_directories + ): """Test segmentation with specific label prompts.""" output_dir = test_directories["output"] seg_output_dir = output_dir / "segmentation_vista3d" seg_output_dir.mkdir(exist_ok=True) - + input_image = test_images[0] - + # Test with heart and aorta labels only heart_aorta_labels = [115, 6] # Heart and aorta segmenter_vista_3d.set_label_prompt(heart_aorta_labels) - + print(f"\nSegmenting with label prompts: {heart_aorta_labels}") result = segmenter_vista_3d.segment(input_image, contrast_enhanced_study=False) - + # Verify result assert isinstance(result, dict), "Result should be a dictionary" labelmap = result["labelmap"] - + # Check that only prompted labels are present (plus background and soft tissue fill) labelmap_arr = itk.array_from_image(labelmap) unique_labels = np.unique(labelmap_arr) - + print(f"✓ Label prompt segmentation complete") print(f" Unique labels: {unique_labels}") - + # Save result output_file = seg_output_dir / "slice_000_label_prompt.mha" itk.imwrite(labelmap, str(output_file), compression=True) print(f" Saved to: {output_file}") - + # Reset to whole image segmentation segmenter_vista_3d.set_whole_image_segmentation() @@ -183,21 +218,25 @@ def test_contrast_detection(self, segmenter_vista_3d, test_images): """Test contrast detection functionality.""" segmenter_vista_3d.set_whole_image_segmentation() input_image = test_images[0] - + # Test without contrast - result_no_contrast = segmenter_vista_3d.segment(input_image, contrast_enhanced_study=False) + result_no_contrast = segmenter_vista_3d.segment( + input_image, contrast_enhanced_study=False + ) contrast_mask_no = result_no_contrast["contrast"] - + # Test with contrast flag - result_with_contrast = segmenter_vista_3d.segment(input_image, contrast_enhanced_study=True) + result_with_contrast = segmenter_vista_3d.segment( + input_image, contrast_enhanced_study=True + ) contrast_mask_yes = result_with_contrast["contrast"] - + # Both should return valid masks assert contrast_mask_no is not None, "Contrast mask (no flag) is None" assert contrast_mask_yes is not None, "Contrast mask (with flag) is None" - + print("\n✓ Contrast detection tested") - + contrast_arr_no = itk.array_from_image(contrast_mask_no) contrast_arr_yes = itk.array_from_image(contrast_mask_yes) print(f" Without contrast flag: {np.sum(contrast_arr_no > 0)} voxels") @@ -207,18 +246,18 @@ def test_preprocessing(self, segmenter_vista_3d, test_images): """Test preprocessing functionality.""" segmenter_vista_3d.set_whole_image_segmentation() input_image = test_images[0] - + # Get original properties original_spacing = itk.spacing(input_image) - + # Preprocessing is done internally by segment(), not exposed as public method # Just verify that segment() works (which includes preprocessing) result = segmenter_vista_3d.segment(input_image, contrast_enhanced_study=False) - + # Check that segmentation was successful (which means preprocessing worked) assert result is not None, "Segmentation result is None" assert "labelmap" in result, "Labelmap not in result" - + print("\n✓ Preprocessing tested (via successful segmentation)") print(f" Original image spacing: {original_spacing}") @@ -226,23 +265,26 @@ def test_postprocessing(self, segmenter_vista_3d, test_images): """Test postprocessing functionality.""" segmenter_vista_3d.set_whole_image_segmentation() input_image = test_images[0] - + # Run full segmentation to get labelmap result = segmenter_vista_3d.segment(input_image, contrast_enhanced_study=False) labelmap = result["labelmap"] - + # Postprocessing is part of segment(), verify output is properly sized - assert itk.size(labelmap) == itk.size(input_image), "Postprocessing failed: size mismatch" - + assert itk.size(labelmap) == itk.size( + input_image + ), "Postprocessing failed: size mismatch" + # Check that labelmap has been resampled to original spacing original_spacing = itk.spacing(input_image) labelmap_spacing = itk.spacing(labelmap) - + # Spacing should match (within floating point tolerance) for i in range(3): - assert abs(labelmap_spacing[i] - original_spacing[i]) < 0.01, \ - f"Spacing mismatch at dimension {i}" - + assert ( + abs(labelmap_spacing[i] - original_spacing[i]) < 0.01 + ), f"Spacing mismatch at dimension {i}" + print("\n✓ Postprocessing tested") print(f" Original spacing: {original_spacing}") print(f" Labelmap spacing: {labelmap_spacing}") @@ -250,19 +292,25 @@ def test_postprocessing(self, segmenter_vista_3d, test_images): def test_set_and_reset_prompts(self, segmenter_vista_3d): """Test setting and resetting label prompt mode.""" # Initially should be in automatic mode - assert segmenter_vista_3d.label_prompt is None, "Label prompt should be None initially" - + assert ( + segmenter_vista_3d.label_prompt is None + ), "Label prompt should be None initially" + # Set label prompt segmenter_vista_3d.set_label_prompt([115, 6]) - assert segmenter_vista_3d.label_prompt == [115, 6], "Label prompt not set correctly" - + assert segmenter_vista_3d.label_prompt == [ + 115, + 6, + ], "Label prompt not set correctly" + # Reset to whole image segmenter_vista_3d.set_whole_image_segmentation() - assert segmenter_vista_3d.label_prompt is None, "Label prompt should be None after reset" - + assert ( + segmenter_vista_3d.label_prompt is None + ), "Label prompt should be None after reset" + print("\n✓ Prompt setting and resetting works correctly") if __name__ == "__main__": pytest.main([__file__, "-v", "-s"]) - diff --git a/tests/test_transform_tools.py b/tests/test_transform_tools.py index 4b32cf8..3cb9bc7 100644 --- a/tests/test_transform_tools.py +++ b/tests/test_transform_tools.py @@ -14,6 +14,7 @@ import pyvista as pv import vtk +from physiomotion4d.image_tools import ImageTools from physiomotion4d.transform_tools import TransformTools @@ -44,12 +45,12 @@ def test_transform_image_linear( moving_image = test_images[1] fixed_image = test_images[0] - phi_MF = ants_registration_results["phi_MF"] + forward_transform = ants_registration_results["forward_transform"] print("\nTransforming image with linear interpolation...") transformed_image = transform_tools.transform_image( - moving_image, phi_MF, fixed_image, interpolation_method="linear" + moving_image, forward_transform, fixed_image, interpolation_method="linear" ) # Verify result @@ -80,12 +81,12 @@ def test_transform_image_nearest( moving_image = test_images[1] fixed_image = test_images[0] - phi_MF = ants_registration_results["phi_MF"] + forward_transform = ants_registration_results["forward_transform"] print("\nTransforming image with nearest neighbor interpolation...") transformed_image = transform_tools.transform_image( - moving_image, phi_MF, fixed_image, interpolation_method="nearest" + moving_image, forward_transform, fixed_image, interpolation_method="nearest" ) assert transformed_image is not None, "Transformed image is None" @@ -110,12 +111,12 @@ def test_transform_image_sinc( moving_image = test_images[1] fixed_image = test_images[0] - phi_MF = ants_registration_results["phi_MF"] + forward_transform = ants_registration_results["forward_transform"] print("\nTransforming image with sinc interpolation...") transformed_image = transform_tools.transform_image( - moving_image, phi_MF, fixed_image, interpolation_method="sinc" + moving_image, forward_transform, fixed_image, interpolation_method="sinc" ) assert transformed_image is not None, "Transformed image is None" @@ -136,13 +137,16 @@ def test_transform_image_invalid_method( """Test that invalid interpolation method raises error.""" moving_image = test_images[1] fixed_image = test_images[0] - phi_MF = ants_registration_results["phi_MF"] + forward_transform = ants_registration_results["forward_transform"] print("\nTesting invalid interpolation method...") with pytest.raises(ValueError): transform_tools.transform_image( - moving_image, phi_MF, fixed_image, interpolation_method="invalid" + moving_image, + forward_transform, + fixed_image, + interpolation_method="invalid", ) print(f"✓ Invalid method correctly raises ValueError") @@ -151,13 +155,13 @@ def test_transform_pvcontour_without_deformation( self, transform_tools, test_contour, ants_registration_results ): """Test transforming PyVista contour without deformation magnitude.""" - phi_MF = ants_registration_results["phi_MF"] + forward_transform = ants_registration_results["forward_transform"] print("\nTransforming contour without deformation magnitude...") print(f" Original contour points: {test_contour.n_points}") transformed_contour = transform_tools.transform_pvcontour( - test_contour, phi_MF, with_deformation_magnitude=False + test_contour, forward_transform, with_deformation_magnitude=False ) # Verify result @@ -187,12 +191,12 @@ def test_transform_pvcontour_with_deformation( tfm_output_dir = output_dir / "transform_tools" tfm_output_dir.mkdir(exist_ok=True) - phi_MF = ants_registration_results["phi_MF"] + forward_transform = ants_registration_results["forward_transform"] print("\nTransforming contour with deformation magnitude...") transformed_contour = transform_tools.transform_pvcontour( - test_contour, phi_MF, with_deformation_magnitude=True + test_contour, forward_transform, with_deformation_magnitude=True ) # Verify result @@ -227,12 +231,12 @@ def test_convert_transform_to_displacement_field( tfm_output_dir.mkdir(exist_ok=True) fixed_image = test_images[0] - phi_MF = ants_registration_results["phi_MF"] + forward_transform = ants_registration_results["forward_transform"] print("\nConverting transform to deformation field...") deformation_field = transform_tools.convert_transform_to_displacement_field( - phi_MF, fixed_image + forward_transform, fixed_image ) # Verify deformation field @@ -247,8 +251,9 @@ def test_convert_transform_to_displacement_field( print(f" Field size: {itk.size(deformation_field)}") print(f" Field shape: {field_arr.shape}") - # Save deformation field - itk.imwrite( + # Save deformation field using imwriteVD3 (for double precision vector images) + image_tools = ImageTools() + image_tools.imwriteVD3( deformation_field, str(tfm_output_dir / "deformation_field.mha"), compression=True, @@ -293,13 +298,13 @@ def test_compute_jacobian_determinant_from_field( tfm_output_dir.mkdir(exist_ok=True) fixed_image = test_images[0] - phi_MF = ants_registration_results["phi_MF"] + forward_transform = ants_registration_results["forward_transform"] # First convert transform to field print("\nComputing Jacobian determinant from deformation field...") deformation_field = transform_tools.convert_transform_to_displacement_field( - phi_MF, fixed_image + forward_transform, fixed_image ) jacobian_det = transform_tools.compute_jacobian_determinant_from_field( @@ -336,13 +341,13 @@ def test_detect_folding_in_field( ): """Test detecting spatial folding in deformation field.""" fixed_image = test_images[0] - phi_MF = ants_registration_results["phi_MF"] + forward_transform = ants_registration_results["forward_transform"] # Convert transform to field print("\nDetecting folding in deformation field...") deformation_field = transform_tools.convert_transform_to_displacement_field( - phi_MF, fixed_image + forward_transform, fixed_image ) # Compute jacobian determinant from field @@ -362,7 +367,7 @@ def test_interpolate_transforms( self, transform_tools, ants_registration_results, test_images ): """Test temporal interpolation between transforms.""" - phi_MF = ants_registration_results["phi_MF"] + forward_transform = ants_registration_results["forward_transform"] # Create an identity transform as second transform identity_tfm = itk.AffineTransform[itk.D, 3].New() @@ -374,7 +379,7 @@ def test_interpolate_transforms( # Interpolate at midpoint (portion=0.5) interpolated_tfm = transform_tools.combine_displacement_field_transforms( - phi_MF, + forward_transform, identity_tfm, fixed_image, tfm1_weight=0.5, @@ -396,7 +401,7 @@ def test_combine_displacement_field_transforms( self, transform_tools, ants_registration_results, test_images ): """Test composing two transforms with various weights.""" - phi_MF = ants_registration_results["phi_MF"] + forward_transform = ants_registration_results["forward_transform"] fixed_image = test_images[0] # Create an identity transform as second transform @@ -408,7 +413,7 @@ def test_combine_displacement_field_transforms( # Test 1: Equal weights (should be similar to interpolation at 0.5) print(" Test 1: Equal weights (0.5, 0.5)") composed_tfm1 = transform_tools.combine_displacement_field_transforms( - phi_MF, + forward_transform, identity_tfm, fixed_image, tfm1_weight=0.5, @@ -425,7 +430,7 @@ def test_combine_displacement_field_transforms( # Test 2: First transform only (weight 1.0, 0.0) print(" Test 2: First transform only (1.0, 0.0)") composed_tfm2 = transform_tools.combine_displacement_field_transforms( - phi_MF, + forward_transform, identity_tfm, fixed_image, tfm1_weight=1.0, @@ -438,7 +443,7 @@ def test_combine_displacement_field_transforms( # Test 3: Second transform only (weight 0.0, 1.0) print(" Test 3: Second transform only (0.0, 1.0)") composed_tfm3 = transform_tools.combine_displacement_field_transforms( - phi_MF, + forward_transform, identity_tfm, fixed_image, tfm1_weight=0.0, @@ -451,7 +456,7 @@ def test_combine_displacement_field_transforms( # Test 4: Custom weights print(" Test 4: Custom weights (0.75, 0.25)") composed_tfm4 = transform_tools.combine_displacement_field_transforms( - phi_MF, + forward_transform, identity_tfm, fixed_image, tfm1_weight=0.75, @@ -464,7 +469,7 @@ def test_combine_displacement_field_transforms( # Test 5: With blur sigma print(" Test 5: With blur sigma (1.0, 1.0)") composed_tfm5 = transform_tools.combine_displacement_field_transforms( - phi_MF, + forward_transform, identity_tfm, fixed_image, tfm1_weight=0.5, @@ -500,21 +505,21 @@ def test_combine_displacement_field_transforms( print(f" Field magnitude (0.0, 1.0): {mag3:.3f} mm") print(f" Difference between (1.0,0.0) and (0.0,1.0): {diff_2_3:.3f} mm") - # The difference should be non-zero since phi_MF is not identity + # The difference should be non-zero since forward_transform is not identity assert diff_2_3 > 0, "Different weights should produce different results" def test_smooth_transform( self, transform_tools, ants_registration_results, test_images ): """Test smoothing a transform.""" - phi_MF = ants_registration_results["phi_MF"] + forward_transform = ants_registration_results["forward_transform"] fixed_image = test_images[0] print("\nSmoothing transform...") # Smooth the transform smoothed_tfm = transform_tools.smooth_transform( - phi_MF, sigma=2.0, reference_image=fixed_image + forward_transform, sigma=2.0, reference_image=fixed_image ) # Verify result @@ -530,7 +535,7 @@ def test_combine_transforms_with_masks( self, transform_tools, ants_registration_results, test_images ): """Test combining transforms with spatial masks.""" - phi_MF = ants_registration_results["phi_MF"] + forward_transform = ants_registration_results["forward_transform"] fixed_image = test_images[0] # Create identity transform @@ -559,7 +564,7 @@ def test_combine_transforms_with_masks( # Combine transforms combined_tfm = transform_tools.combine_transforms_with_masks( - phi_MF, identity_tfm, mask1, mask2, fixed_image + forward_transform, identity_tfm, mask1, mask2, fixed_image ) # Verify result @@ -576,18 +581,18 @@ def test_multiple_transform_applications( """Test applying multiple transforms in sequence.""" moving_image = test_images[1] fixed_image = test_images[0] - phi_MF = ants_registration_results["phi_MF"] + forward_transform = ants_registration_results["forward_transform"] print("\nApplying transforms multiple times...") # Apply transform once result1 = transform_tools.transform_image( - moving_image, phi_MF, fixed_image, interpolation_method="linear" + moving_image, forward_transform, fixed_image, interpolation_method="linear" ) # Apply transform again (should work even though it's already transformed) result2 = transform_tools.transform_image( - result1, phi_MF, fixed_image, interpolation_method="linear" + result1, forward_transform, fixed_image, interpolation_method="linear" ) assert result1 is not None, "First transform result is None"