From 416846e3f03ffbc5acc740f04734a9cf02908cfb Mon Sep 17 00:00:00 2001 From: Stephen Aylward Date: Tue, 16 Dec 2025 10:10:06 -0500 Subject: [PATCH 01/13] ENH: Use logging in project classes. BUG: Fixed tests. Still failing tests: test_register_images_icon test_time_series_images test_transform_tools --- .github/workflows/README.md | 194 ++++++++++++ .github/workflows/test-slow.yml | 69 +++++ .github/workflows/test.yml | 278 ++++++++++++++++++ .gitignore | 4 +- pyproject.toml | 8 +- src/physiomotion4d/contour_tools.py | 18 +- src/physiomotion4d/convert_nrrd_4d_to_3d.py | 13 +- src/physiomotion4d/convert_vtk_4d_to_usd.py | 18 +- .../convert_vtk_4d_to_usd_base.py | 33 ++- .../convert_vtk_4d_to_usd_polymesh.py | 16 +- .../convert_vtk_4d_to_usd_tetmesh.py | 16 +- .../heart_gated_ct_to_usd_workflow.py | 57 ++-- src/physiomotion4d/image_tools.py | 20 +- src/physiomotion4d/register_images_ants.py | 197 ++++++++++--- src/physiomotion4d/register_images_base.py | 12 +- src/physiomotion4d/register_images_icon.py | 8 +- .../register_model_to_model_icp.py | 33 ++- .../register_model_to_model_masks.py | 33 ++- .../register_time_series_images.py | 11 +- src/physiomotion4d/segment_chest_base.py | 25 +- src/physiomotion4d/segment_chest_ensemble.py | 23 +- .../segment_chest_total_segmentator.py | 12 +- src/physiomotion4d/segment_chest_vista_3d.py | 8 +- .../segment_chest_vista_3d_nim.py | 13 +- src/physiomotion4d/transform_tools.py | 83 +++++- src/physiomotion4d/usd_anatomy_tools.py | 22 +- src/physiomotion4d/usd_tools.py | 70 ++--- tests/test_image_tools.py | 113 ++++--- tests/test_register_images_ants.py | 121 ++++---- tests/test_register_images_icon.py | 11 +- tests/test_register_time_series_images.py | 46 +-- tests/test_transform_tools.py | 66 ++++- 32 files changed, 1295 insertions(+), 356 deletions(-) create mode 100644 .github/workflows/README.md create mode 100644 .github/workflows/test-slow.yml create mode 100644 .github/workflows/test.yml diff --git a/.github/workflows/README.md b/.github/workflows/README.md new file mode 100644 index 0000000..a781129 --- /dev/null +++ b/.github/workflows/README.md @@ -0,0 +1,194 @@ +# GitHub Actions Workflows + +This directory contains GitHub Actions workflows for automated testing and CI/CD. + +## Workflows + +### `test.yml` - Main Test Suite + +Runs on every push and pull request to main branches. Includes: + +- **test-cpu**: Unit tests on CPU across Python 3.10, 3.11, and 3.12 + - Uses PyTorch CPU version to avoid GPU dependencies + - Runs tests marked with `unit` and excludes GPU-requiring tests + - Generates coverage reports + +- **test-gpu**: Tests on self-hosted GPU runners (if available) + - Uses PyTorch with CUDA 12.6 support + - Runs all tests except those marked as slow + - Requires self-hosted runner with `[self-hosted, linux, gpu]` labels + +- **test-integration**: Integration tests on CPU + - Runs after CPU tests pass + - Tests marked with `integration` marker + +- **code-quality**: Static code analysis + - Black, isort, ruff, flake8 checks + - Does not fail the build (continue-on-error: true) + +### `test-slow.yml` - Long-Running Tests + +Runs nightly or on manual trigger. Includes: + +- **test-slow-gpu**: Slow tests requiring GPU + - Tests marked with `slow` marker + - Extended timeout (3600 seconds) + - Uses self-hosted GPU runners + +## Caching Strategy + +The workflows use multiple caching layers to speed up builds: + +1. **Python package cache** via `setup-python` action + - Caches pip packages based on `pyproject.toml` hash + +2. **Additional pip cache** via `actions/cache` + - Caches `~/.cache/pip` directory + - Separate keys for CPU, GPU, and integration tests + - Hierarchical restore keys for fallback + +## GPU Support + +### Self-Hosted Runners + +GPU tests require self-hosted runners with: +- Linux OS +- NVIDIA GPU with CUDA 12.6+ support +- Runner labels: `[self-hosted, linux, gpu]` + +### Setting Up Self-Hosted GPU Runners + +1. **Install GitHub Actions Runner**: + ```bash + # Download and configure runner from GitHub repository Settings > Actions > Runners + ``` + +2. **Install NVIDIA Drivers and CUDA**: + ```bash + # Install NVIDIA drivers + sudo apt-get install nvidia-driver-535 + + # Install CUDA toolkit 12.6 + wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.0-1_all.deb + sudo dpkg -i cuda-keyring_1.0-1_all.deb + sudo apt-get update + sudo apt-get install cuda-toolkit-12-6 + ``` + +3. **Configure Runner Labels**: + - Add labels: `self-hosted`, `linux`, `gpu` + - Verify GPU is accessible: `nvidia-smi` + +4. **Start the Runner**: + ```bash + ./run.sh + ``` + +### GitHub-Hosted Runners + +GitHub-hosted runners do **not** have GPU support. GPU tests will be skipped automatically if no self-hosted runners are available (`continue-on-error: true`). + +## Test Dependencies + +Test dependencies are installed from `pyproject.toml`: + +```bash +pip install -e ".[test]" +``` + +This installs: +- pytest >= 7.0.0 +- pytest-cov >= 4.0.0 +- pytest-xdist >= 3.0.0 +- pytest-timeout >= 2.0.0 +- coverage[toml] >= 7.0.0 + +## Test Markers + +Tests should be marked appropriately: + +```python +import pytest + +@pytest.mark.unit +def test_simple_function(): + """Fast unit test""" + pass + +@pytest.mark.integration +def test_full_pipeline(): + """Integration test""" + pass + +@pytest.mark.slow +def test_long_running(): + """Long-running test""" + pass + +@pytest.mark.requires_gpu +def test_gpu_function(): + """Test requiring GPU""" + if not torch.cuda.is_available(): + pytest.skip("GPU not available") + pass +``` + +## Running Tests Locally + +### CPU Tests +```bash +# Install dependencies +pip install -e ".[test]" + +# Run unit tests +pytest tests/ -m "unit and not requires_gpu" + +# Run with coverage +pytest tests/ -m "unit and not requires_gpu" --cov=physiomotion4d +``` + +### GPU Tests +```bash +# Install with CUDA support +pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126 +pip install -e ".[test]" + +# Run all tests (including GPU) +pytest tests/ -m "not slow" + +# Run slow tests +pytest tests/ -m "slow" +``` + +## Coverage Reports + +Coverage reports are: +1. Uploaded to Codecov (if configured) +2. Stored as artifacts for 7 days +3. Available as HTML reports in the `htmlcov/` directory + +## Troubleshooting + +### GPU Tests Not Running + +If GPU tests are not running: +1. Verify self-hosted runner is online: Settings > Actions > Runners +2. Check runner labels include `gpu` +3. Verify `nvidia-smi` works on the runner +4. Check workflow logs for runner assignment + +### Cache Not Working + +If builds are slow: +1. Check cache hit/miss in workflow logs +2. Verify `pyproject.toml` hasn't changed unexpectedly +3. Try clearing caches: Settings > Actions > Caches + +### Test Failures + +For test failures: +1. Check individual test logs in the workflow run +2. Run tests locally to reproduce +3. Use `pytest -v --tb=long` for detailed error traces +4. Check if tests are marked correctly (unit/integration/slow/requires_gpu) + diff --git a/.github/workflows/test-slow.yml b/.github/workflows/test-slow.yml new file mode 100644 index 0000000..e39787f --- /dev/null +++ b/.github/workflows/test-slow.yml @@ -0,0 +1,69 @@ +name: Slow Tests + +on: + schedule: + # Run nightly at 2 AM UTC + - cron: '0 2 * * *' + workflow_dispatch: + +jobs: + test-slow-gpu: + name: Slow Tests on GPU + runs-on: [self-hosted, linux, gpu] + continue-on-error: true + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Python 3.10 + uses: actions/setup-python@v5 + with: + python-version: "3.10" + + - name: Check GPU availability + run: | + nvidia-smi + python -c "import torch; print(f'PyTorch CUDA available: {torch.cuda.is_available()}')" || echo "PyTorch not installed yet" + + - name: Cache pip packages + uses: actions/cache@v4 + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pip-slow-gpu-${{ hashFiles('pyproject.toml') }} + restore-keys: | + ${{ runner.os }}-pip-slow-gpu- + ${{ runner.os }}-pip-gpu- + ${{ runner.os }}-pip- + + - name: Upgrade pip and build tools + run: | + python -m pip install --upgrade pip setuptools wheel + + - name: Install PyTorch (CUDA 12.6) + run: | + pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126 + + - name: Install package with test dependencies + run: | + pip install -e ".[test]" + + - name: Run slow tests + run: | + pytest tests/ -v \ + -m "slow" \ + --cov=physiomotion4d \ + --cov-report=xml \ + --cov-report=term \ + --timeout=3600 + env: + CUDA_VISIBLE_DEVICES: 0 + + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v4 + with: + file: ./coverage.xml + flags: slow-tests-gpu + name: codecov-slow-gpu + fail_ci_if_error: false + diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 0000000..bf24cfc --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,278 @@ +name: Tests + +on: + push: + branches: [ main, master, develop ] + pull_request: + branches: [ main, master, develop ] + workflow_dispatch: + +jobs: + test-cpu: + name: Test on CPU (Python ${{ matrix.python-version }}) + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest] + python-version: ["3.10", "3.11", "3.12"] + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + cache: 'pip' + cache-dependency-path: | + pyproject.toml + + - name: Cache pip packages + uses: actions/cache@v4 + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pip-${{ matrix.python-version }}-${{ hashFiles('pyproject.toml') }} + restore-keys: | + ${{ runner.os }}-pip-${{ matrix.python-version }}- + ${{ runner.os }}-pip- + + - name: Install system dependencies + run: | + sudo apt-get update + sudo apt-get install -y \ + libgl1-mesa-glx \ + libxrender1 \ + libgomp1 \ + libglib2.0-0 \ + libsm6 \ + libxext6 \ + libxrandr2 \ + libxi6 + + - name: Upgrade pip and build tools + run: | + python -m pip install --upgrade pip setuptools wheel + + - name: Install PyTorch (CPU) + run: | + pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu + + - name: Install package with test dependencies + run: | + pip install -e ".[test]" + + - name: List installed packages + run: | + pip list + + - name: Run unit tests + run: | + pytest tests/ -v \ + -m "unit and not requires_gpu" \ + --cov=physiomotion4d \ + --cov-report=xml \ + --cov-report=term \ + --cov-report=html \ + --timeout=900 + + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v4 + if: matrix.python-version == '3.10' + with: + file: ./coverage.xml + flags: unittests-cpu + name: codecov-cpu-py${{ matrix.python-version }} + fail_ci_if_error: false + + - name: Upload coverage artifacts + uses: actions/upload-artifact@v4 + if: matrix.python-version == '3.10' + with: + name: coverage-report-cpu + path: htmlcov/ + retention-days: 7 + + test-gpu: + name: Test on GPU (Python ${{ matrix.python-version }}) + runs-on: [self-hosted, linux, gpu] + strategy: + fail-fast: false + matrix: + python-version: ["3.10", "3.11"] + # Only run GPU tests if self-hosted runners are available + continue-on-error: true + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Check GPU availability + run: | + nvidia-smi || echo "No GPU available" + python -c "import torch; print(f'PyTorch CUDA available: {torch.cuda.is_available()}')" || echo "PyTorch not installed yet" + + - name: Cache pip packages + uses: actions/cache@v4 + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pip-gpu-${{ matrix.python-version }}-${{ hashFiles('pyproject.toml') }} + restore-keys: | + ${{ runner.os }}-pip-gpu-${{ matrix.python-version }}- + ${{ runner.os }}-pip-gpu- + ${{ runner.os }}-pip- + + - name: Upgrade pip and build tools + run: | + python -m pip install --upgrade pip setuptools wheel + + - name: Install PyTorch (CUDA 12.6) + run: | + pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126 + + - name: Install package with test dependencies + run: | + pip install -e ".[test]" + + - name: Verify GPU setup + run: | + python -c "import torch; print(f'PyTorch version: {torch.__version__}')" + python -c "import torch; print(f'CUDA available: {torch.cuda.is_available()}')" + python -c "import torch; print(f'CUDA version: {torch.version.cuda}')" || echo "CUDA not available" + python -c "import torch; print(f'Number of GPUs: {torch.cuda.device_count()}')" || echo "No GPUs" + + - name: List installed packages + run: | + pip list + + - name: Run GPU tests + run: | + pytest tests/ -v \ + -m "not slow" \ + --cov=physiomotion4d \ + --cov-report=xml \ + --cov-report=term \ + --cov-report=html \ + --timeout=900 + env: + CUDA_VISIBLE_DEVICES: 0 + + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v4 + if: matrix.python-version == '3.10' + with: + file: ./coverage.xml + flags: unittests-gpu + name: codecov-gpu-py${{ matrix.python-version }} + fail_ci_if_error: false + + - name: Upload coverage artifacts + uses: actions/upload-artifact@v4 + if: matrix.python-version == '3.10' + with: + name: coverage-report-gpu + path: htmlcov/ + retention-days: 7 + + test-integration: + name: Integration Tests (CPU) + runs-on: ubuntu-latest + needs: test-cpu + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Python 3.10 + uses: actions/setup-python@v5 + with: + python-version: "3.10" + cache: 'pip' + cache-dependency-path: pyproject.toml + + - name: Cache pip packages + uses: actions/cache@v4 + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pip-integration-${{ hashFiles('pyproject.toml') }} + restore-keys: | + ${{ runner.os }}-pip-integration- + ${{ runner.os }}-pip- + + - name: Install system dependencies + run: | + sudo apt-get update + sudo apt-get install -y \ + libgl1-mesa-glx \ + libxrender1 \ + libgomp1 \ + libglib2.0-0 \ + libsm6 \ + libxext6 \ + libxrandr2 \ + libxi6 + + - name: Upgrade pip and build tools + run: | + python -m pip install --upgrade pip setuptools wheel + + - name: Install PyTorch (CPU) + run: | + pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu + + - name: Install package with test dependencies + run: | + pip install -e ".[test]" + + - name: Run integration tests + run: | + pytest tests/ -v \ + -m "integration and not requires_gpu" \ + --timeout=1800 + continue-on-error: true + + code-quality: + name: Code Quality Checks + runs-on: ubuntu-latest + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Python 3.10 + uses: actions/setup-python@v5 + with: + python-version: "3.10" + cache: 'pip' + + - name: Install dev dependencies + run: | + python -m pip install --upgrade pip + pip install black isort ruff flake8 pylint mypy + + - name: Check code formatting with Black + run: | + black --check src/ tests/ || echo "Black check failed" + continue-on-error: true + + - name: Check import sorting with isort + run: | + isort --check-only src/ tests/ || echo "isort check failed" + continue-on-error: true + + - name: Lint with Ruff + run: | + ruff check src/ tests/ || echo "Ruff check failed" + continue-on-error: true + + - name: Lint with Flake8 + run: | + flake8 src/ tests/ || echo "Flake8 check failed" + continue-on-error: true + diff --git a/.gitignore b/.gitignore index 157bc21..6e23e4b 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,7 @@ .coverage coverage.xml *~ +*.swp venv* .venv build @@ -17,6 +18,7 @@ network_weights # Data files *.gz +*.mat *.mha *.mhd *.zip @@ -41,4 +43,4 @@ results* tests/data/ tests/results/ tests/baseline_staging/ -test_output.txt \ No newline at end of file +test_output.txt diff --git a/pyproject.toml b/pyproject.toml index 6388c14..9b0bfc1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -121,7 +121,8 @@ dev = [ "ruff>=0.1.0", ] docs = [ - "linkify>=1.4", + "linkify-it-py>=2.0.0", + "uc-micro-py>=1.0.1", "sphinx>=7.0.0", "sphinx-rtd-theme>=2.0.0", "sphinx-autodoc-typehints>=1.25.0", @@ -288,9 +289,10 @@ addopts = [ ] testpaths = ["tests"] markers = [ + "unit: marks tests as unit tests (fast, isolated)", + "integration: marks tests as integration tests (slower, multiple components)", "slow: marks tests as slow (deselect with '-m \"not slow\"')", - "integration: marks tests as integration tests", - "unit: marks tests as unit tests", + "requires_gpu: marks tests that require GPU/CUDA support", "requires_data: marks tests that require external data download" ] timeout = 900 diff --git a/src/physiomotion4d/contour_tools.py b/src/physiomotion4d/contour_tools.py index e49e6e2..20e8d3b 100644 --- a/src/physiomotion4d/contour_tools.py +++ b/src/physiomotion4d/contour_tools.py @@ -2,21 +2,29 @@ Tools for creating and manipulating contours. """ +import logging + import itk import numpy as np import pyvista as pv import trimesh +from physiomotion4d.physiomotion4d_base import PhysioMotion4DBase from physiomotion4d.transform_tools import TransformTools -class ContourTools: +class ContourTools(PhysioMotion4DBase): """ Tools for creating and manipulating contours. """ - def __init__(self): - pass + def __init__(self, log_level: int | str = logging.INFO): + """Initialize ContourTools. + + Args: + log_level: Logging level (default: logging.INFO) + """ + super().__init__(class_name=self.__class__.__name__, log_level=log_level) def extract_contours( self, @@ -87,7 +95,7 @@ def merge_meshes(self, meshes): pv.PolyData Merged mesh """ - print("Merging meshes...") + self.log_info("Merging meshes...") if hasattr(meshes[0], 'n_faces_strict'): meshes = [ trimesh.Trimesh( @@ -224,7 +232,7 @@ def create_contour_distance_map_from_mesh( edge_mask_image = edge_filter.GetOutput() # Compute signed distance map (positive inside, negative outside) - print(" Computing signed distance map...") + self.log_info("Computing signed distance map...") distance_filter = itk.SignedMaurerDistanceMapImageFilter.New( Input=edge_mask_image ) diff --git a/src/physiomotion4d/convert_nrrd_4d_to_3d.py b/src/physiomotion4d/convert_nrrd_4d_to_3d.py index 37f87ea..5c11f68 100644 --- a/src/physiomotion4d/convert_nrrd_4d_to_3d.py +++ b/src/physiomotion4d/convert_nrrd_4d_to_3d.py @@ -1,12 +1,21 @@ +import logging import os import itk import nrrd import numpy as np +from physiomotion4d.physiomotion4d_base import PhysioMotion4DBase -class ConvertNRRD4DTo3D: - def __init__(self): + +class ConvertNRRD4DTo3D(PhysioMotion4DBase): + def __init__(self, log_level: int | str = logging.INFO): + """Initialize the NRRD 4D to 3D converter. + + Args: + log_level: Logging level (default: logging.INFO) + """ + super().__init__(class_name=self.__class__.__name__, log_level=log_level) self.nrrd_4d = None self.img_3d = [] diff --git a/src/physiomotion4d/convert_vtk_4d_to_usd.py b/src/physiomotion4d/convert_vtk_4d_to_usd.py index 0cddc57..48cfe38 100644 --- a/src/physiomotion4d/convert_vtk_4d_to_usd.py +++ b/src/physiomotion4d/convert_vtk_4d_to_usd.py @@ -1,14 +1,17 @@ """Unified facade for VTK to USD conversion supporting both PolyData and UnstructuredGrid.""" +import logging + import pyvista as pv import vtk from pxr import Usd from .convert_vtk_4d_to_usd_polymesh import ConvertVTK4DToUSDPolyMesh from .convert_vtk_4d_to_usd_tetmesh import ConvertVTK4DToUSDTetMesh +from .physiomotion4d_base import PhysioMotion4DBase -class ConvertVTK4DToUSD: +class ConvertVTK4DToUSD(PhysioMotion4DBase): """ Unified converter supporting both PolyData and UnstructuredGrid. @@ -54,7 +57,7 @@ class ConvertVTK4DToUSD: >>> stage = converter.convert("output.usd") """ - def __init__(self, data_basename, input_polydata, mask_ids=None): + def __init__(self, data_basename, input_polydata, mask_ids=None, log_level: int | str = logging.INFO): """ Initialize converter and store parameters for later routing. @@ -66,7 +69,10 @@ def __init__(self, data_basename, input_polydata, mask_ids=None): mask_ids (dict or None): Optional mapping of label IDs to label names for organizing meshes by anatomical regions. Default: None + log_level: Logging level (default: logging.INFO) """ + super().__init__(class_name=self.__class__.__name__, log_level=log_level) + self.data_basename = data_basename self.input_polydata = input_polydata self.mask_ids = mask_ids @@ -166,9 +172,9 @@ def convert(self, output_usd_file, convert_to_surface=False) -> Usd.Stage: # Case 1: Only PolyData (or surface-converted UGrid) if has_polydata and not has_ugrid: - print("Routing to PolyMesh converter (surface meshes)") + self.log_info("Routing to PolyMesh converter (surface meshes)") converter = ConvertVTK4DToUSDPolyMesh( - self.data_basename, self.input_polydata, self.mask_ids + self.data_basename, self.input_polydata, self.mask_ids, log_level=self.log_level ) converter.set_colormap( self.color_by_array, self.colormap, self.intensity_range @@ -177,9 +183,9 @@ def convert(self, output_usd_file, convert_to_surface=False) -> Usd.Stage: # Case 2: Only UnstructuredGrid (tetmesh) elif has_ugrid and not has_polydata: - print("Routing to TetMesh converter (volumetric meshes)") + self.log_info("Routing to TetMesh converter (volumetric meshes)") converter = ConvertVTK4DToUSDTetMesh( - self.data_basename, self.input_polydata, self.mask_ids + self.data_basename, self.input_polydata, self.mask_ids, log_level=self.log_level ) converter.set_colormap( self.color_by_array, self.colormap, self.intensity_range diff --git a/src/physiomotion4d/convert_vtk_4d_to_usd_base.py b/src/physiomotion4d/convert_vtk_4d_to_usd_base.py index 62a75d8..a16cd51 100644 --- a/src/physiomotion4d/convert_vtk_4d_to_usd_base.py +++ b/src/physiomotion4d/convert_vtk_4d_to_usd_base.py @@ -1,5 +1,6 @@ """Abstract base class for converting 4D VTK data to animated USD meshes.""" +import logging import os from abc import ABC, abstractmethod @@ -9,6 +10,8 @@ import vtk from pxr import Gf, Sdf, Usd, UsdGeom +from physiomotion4d.physiomotion4d_base import PhysioMotion4DBase + # VTK Cell Type Constants VTK_TRIANGLE = 5 VTK_QUAD = 9 @@ -18,7 +21,7 @@ VTK_PYRAMID = 14 -class ConvertVTK4DToUSDBase(ABC): +class ConvertVTK4DToUSDBase(PhysioMotion4DBase, ABC): """ Abstract base class for VTK to USD conversion. @@ -27,7 +30,7 @@ class ConvertVTK4DToUSDBase(ABC): mesh-specific processing and USD creation methods. """ - def __init__(self, data_basename, input_polydata, mask_ids=None, convert_to_surface=False): + def __init__(self, data_basename, input_polydata, mask_ids=None, convert_to_surface=False, log_level: int | str = logging.INFO): """ Initialize VTK to USD converter. @@ -41,7 +44,10 @@ def __init__(self, data_basename, input_polydata, mask_ids=None, convert_to_surf convert_to_surface (bool): If True, convert UnstructuredGrid meshes to surface PolyData before processing. Only applicable for PolyMesh converter. Default: False + log_level: Logging level (default: logging.INFO) """ + super().__init__(class_name=self.__class__.__name__, log_level=log_level) + self.data_basename = data_basename self.input_polydata = input_polydata self.mask_ids = mask_ids @@ -299,7 +305,7 @@ def _extract_color_array(self, mesh): color_values.append(scalar_value) return np.array(color_values) else: - print(f"Warning: Array '{self.color_by_array}' not found in point data") + self.log_warning("Array '%s' not found in point data", self.color_by_array) return None def _check_topology_changes(self, mesh_time_data): @@ -360,9 +366,9 @@ def _check_topology_changes(self, mesh_time_data): topology_changes[label] = has_change if has_change: - print( - f"Detected topology changes for label '{label}' - " - f"will use time-varying mesh approach" + self.log_info( + "Detected topology changes for label '%s' - will use time-varying mesh approach", + label ) return topology_changes @@ -497,7 +503,7 @@ def convert(self, output_usd_file, convert_to_surface=None) -> Usd.Stage: UsdGeom.Xform.Define(self.stage, root_path) basename = os.path.basename(output_usd_file).split(".")[0] - print(f"Converting {basename}") + self.log_info("Converting %s", basename) root_path = f"{root_path}/Transform_{basename}" UsdGeom.Xform.Define(self.stage, root_path) @@ -506,12 +512,14 @@ def convert(self, output_usd_file, convert_to_surface=None) -> Usd.Stage: # Collect the label data from each time point polydata_time_data = {} - for fnum in range(len(self.input_polydata)): + num_timepoints = len(self.input_polydata) + for fnum in range(num_timepoints): polydata_time_data[fnum] = self._process_mesh_data(self.input_polydata[fnum]) - print("Processed time point", fnum) + if fnum % 10 == 0 or fnum == num_timepoints - 1: + self.log_progress(fnum + 1, num_timepoints, prefix="Processing time points") # Check for topology changes across time steps - print("\nChecking for topology changes...") + self.log_info("Checking for topology changes...") topology_changes = self._check_topology_changes(polydata_time_data) # Assign a unique color to each label @@ -525,11 +533,12 @@ def convert(self, output_usd_file, convert_to_surface=None) -> Usd.Stage: first_data = polydata_time_data[0] # Create a mesh prim for each label group + num_labels = len(first_data.items()) for idx, (label, data) in enumerate(first_data.items()): - print(f"Processing {label} {idx}") + self.log_info("Processing %s (%d/%d)", label, idx + 1, num_labels) # Create a transform for each mesh transform_path = f"{root_path}/Transform_{label}" - print(f"Transform path: {transform_path}") + self.log_debug("Transform path: %s", transform_path) UsdGeom.Xform.Define(self.stage, transform_path) # Determine if topology changes for this label diff --git a/src/physiomotion4d/convert_vtk_4d_to_usd_polymesh.py b/src/physiomotion4d/convert_vtk_4d_to_usd_polymesh.py index d09cddb..d351b8b 100644 --- a/src/physiomotion4d/convert_vtk_4d_to_usd_polymesh.py +++ b/src/physiomotion4d/convert_vtk_4d_to_usd_polymesh.py @@ -91,15 +91,15 @@ def _create_usd_mesh( has_topology_change: Whether topology varies over time """ if has_topology_change: - print( - f"Creating time-varying UsdGeomMesh for label: {label} " - f"(topology changes detected)" + self.log_info( + "Creating time-varying UsdGeomMesh for label: %s (topology changes detected)", + label ) self._create_usd_polymesh_varying( transform_path, label, mesh_time_data, label_colors ) else: - print(f"Creating UsdGeomMesh for label: {label}") + self.log_info("Creating UsdGeomMesh for label: %s", label) self._create_usd_polymesh( transform_path, label, mesh_time_data, label_colors ) @@ -351,8 +351,10 @@ def _create_usd_polymesh(self, transform_path, label, mesh_time_data, label_colo global_vmin = float('inf') global_vmax = float('-inf') + num_times = len(self.times) for time_idx, time_code in enumerate(self.times): - print(f"Processing time sample: {time_code} for label: {label}") + if time_idx % 10 == 0 or time_idx == num_times - 1: + self.log_progress(time_idx + 1, num_times, prefix=f"Processing time samples for {label}") time_data = mesh_time_data[time_idx][label] # Compute per-vertex normals for this timestep (REQUIRED for IndeX renderer) @@ -471,8 +473,10 @@ def _create_usd_polymesh_varying( UsdGeom.Xform.Define(self.stage, parent_path) # Create separate mesh for each time step + num_times = len(self.times) for time_idx, time_code in enumerate(self.times): - print(f"Creating UsdGeomMesh for label: {label} at time: {time_code}") + if time_idx % 10 == 0 or time_idx == num_times - 1: + self.log_progress(time_idx + 1, num_times, prefix=f"Creating meshes for {label}") # Skip if label doesn't exist at this timestep if label not in mesh_time_data[time_idx]: continue diff --git a/src/physiomotion4d/convert_vtk_4d_to_usd_tetmesh.py b/src/physiomotion4d/convert_vtk_4d_to_usd_tetmesh.py index 4b4cbaf..2eb2376 100644 --- a/src/physiomotion4d/convert_vtk_4d_to_usd_tetmesh.py +++ b/src/physiomotion4d/convert_vtk_4d_to_usd_tetmesh.py @@ -87,15 +87,15 @@ def _create_usd_mesh( if mesh_type == 'tetmesh': if has_topology_change: - print( - f"Creating time-varying UsdGeomTetMesh for label: {label} " - f"(topology changes detected)" + self.log_info( + "Creating time-varying UsdGeomTetMesh for label: %s (topology changes detected)", + label ) self._create_usd_tetmesh_varying( transform_path, label, mesh_time_data, label_colors ) else: - print(f"Creating UsdGeomTetMesh for label: {label}") + self.log_info("Creating UsdGeomTetMesh for label: %s", label) self._create_usd_tetmesh( transform_path, label, mesh_time_data, label_colors ) @@ -340,8 +340,10 @@ def _create_usd_tetmesh(self, transform_path, label, mesh_time_data, label_color extent_attr = tetmesh.CreateExtentAttr() time_samples = {} + num_times = len(self.times) for time_idx, time_code in enumerate(self.times): - print(f"Processing time step {time_code} for label: {label}") + if time_idx % 10 == 0 or time_idx == num_times - 1: + self.log_progress(time_idx + 1, num_times, prefix=f"Processing time steps for {label}") time_data = mesh_time_data[time_idx][label] # Compute per-vertex normals for surface faces (REQUIRED for IndeX renderer) @@ -406,8 +408,10 @@ def _create_usd_tetmesh_varying( UsdGeom.Xform.Define(self.stage, parent_path) # Create separate tetmesh for each time step + num_times = len(self.times) for time_idx, time_code in enumerate(self.times): - print(f"Creating UsdGeomTetMesh for label: {label} at time: {time_code}") + if time_idx % 10 == 0 or time_idx == num_times - 1: + self.log_progress(time_idx + 1, num_times, prefix=f"Creating tetmeshes for {label}") # Skip if label doesn't exist at this timestep if label not in mesh_time_data[time_idx]: continue diff --git a/src/physiomotion4d/heart_gated_ct_to_usd_workflow.py b/src/physiomotion4d/heart_gated_ct_to_usd_workflow.py index f1723f9..19b3e90 100644 --- a/src/physiomotion4d/heart_gated_ct_to_usd_workflow.py +++ b/src/physiomotion4d/heart_gated_ct_to_usd_workflow.py @@ -5,6 +5,7 @@ as demonstrated in the Heart-GatedCT experiment notebooks. """ +import logging import os from typing import List, Optional @@ -16,6 +17,7 @@ from physiomotion4d.contour_tools import ContourTools from physiomotion4d.convert_nrrd_4d_to_3d import ConvertNRRD4DTo3D from physiomotion4d.convert_vtk_4d_to_usd import ConvertVTK4DToUSD +from physiomotion4d.physiomotion4d_base import PhysioMotion4DBase from physiomotion4d.register_images_ants import RegisterImagesANTs from physiomotion4d.register_images_icon import RegisterImagesICON from physiomotion4d.segment_chest_total_segmentator import SegmentChestTotalSegmentator @@ -23,7 +25,7 @@ from physiomotion4d.usd_anatomy_tools import USDAnatomyTools -class HeartGatedCTToUSDWorkflow: +class HeartGatedCTToUSDWorkflow(PhysioMotion4DBase): """ Complete workflow for Heart-gated CT images to dynamic USD models. @@ -40,6 +42,7 @@ def __init__( reference_image_filename: Optional[str] = None, number_of_registration_iterations: Optional[int] = 1, registration_method: str = 'icon', + log_level: int | str = logging.INFO, ): """ Initialize the Heart-gated CT to USD workflow. @@ -53,7 +56,10 @@ def __init__( reference_image_filename (Optional[str]): Path to reference image file number_of_registration_iterations (Optional[int]): Number of registration iterations registration_method (str): Registration method to use: 'ants' or 'icon' (default: 'icon') + log_level: Logging level (default: logging.INFO) """ + super().__init__(class_name=self.__class__.__name__, log_level=log_level) + self.input_filenames = input_filenames self.contrast_enhanced = contrast_enhanced self.output_directory = output_directory @@ -73,14 +79,14 @@ def __init__( os.makedirs(output_directory, exist_ok=True) # Initialize processing components - self.converter = ConvertNRRD4DTo3D() - self.segmenter = SegmentChestTotalSegmentator() + self.converter = ConvertNRRD4DTo3D(log_level=log_level) + self.segmenter = SegmentChestTotalSegmentator(log_level=log_level) self.segmenter.contrast_threshold = 500 # Initialize registration method if self.registration_method == 'ants': - print(f"Initializing ANTs registration...") - self.registrar = RegisterImagesANTs() + self.log_info("Initializing ANTs registration...") + self.registrar = RegisterImagesANTs(log_level=log_level) self.registrar.set_modality('ct') self.registrar.set_transform_type('SyN') if ( @@ -95,8 +101,8 @@ def __init__( ) ) else: # icon (default) - print(f"Initializing ICON registration...") - self.registrar = RegisterImagesICON() + self.log_info("Initializing ICON registration...") + self.registrar = RegisterImagesICON(log_level=log_level) self.registrar.set_modality('ct') if ( number_of_registration_iterations is not None @@ -124,7 +130,7 @@ def process(self) -> str: Returns: str: Path to the final dynamic anatomy USD file """ - print("Starting Heart-gated CT processing pipeline...") + self.log_section("Heart-gated CT Processing Pipeline") # Load and convert data self._load_time_series() @@ -141,12 +147,12 @@ def process(self) -> str: # Create USD files self._create_usd_files() - print("Processing pipeline completed successfully") + self.log_info("Processing pipeline completed successfully") return f"{self.project_name}.dynamic_anatomy_painted.usd" def _load_time_series(self): """Load and convert 4D data to time series images.""" - print("Loading time series data...") + self.log_info("Loading time series data...") if len(self.input_filenames) == 1: self.converter.load_nrrd_4d(self.input_filenames[0]) @@ -156,7 +162,7 @@ def _load_time_series(self): ) ) else: - print(f"Loading {len(self.input_filenames)} 3D NRRD files") + self.log_info("Loading %d 3D NRRD files", len(self.input_filenames)) self.converter.load_nrrd_3d(self.input_filenames) self._num_time_points = self.converter.get_number_of_3d_images() @@ -178,14 +184,14 @@ def _load_time_series(self): compression=True, ) - print(f"Loaded {self._num_time_points} time points") + self.log_info("Loaded %d time points", self._num_time_points) def _segment_and_register_frames(self): """Segment each frame and register to reference image.""" - print("Segmenting and registering frames...") + self.log_info("Segmenting and registering frames...") # Segment reference image - print("Segmenting reference image...") + self.log_info("Segmenting reference image...") self._fixed_segmentation = self.segmenter.segment( self._fixed_image, contrast_enhanced_study=self.contrast_enhanced ) @@ -231,7 +237,7 @@ def _segment_and_register_frames(self): # Process each time point self._time_series_transforms = [] for i in range(self._num_time_points): - print(f"Processing frame {i+1}/{self._num_time_points}") + self.log_progress(i + 1, self._num_time_points, prefix="Processing frames") moving_image = self._time_series_images[i] @@ -309,7 +315,7 @@ def _segment_and_register_frames(self): def _generate_reference_contours(self): """Generate contour meshes from reference segmentation.""" - print("Generating reference contours...") + self.log_info("Generating reference contours...") ( labelmap_image, @@ -363,12 +369,12 @@ def _generate_reference_contours(self): def _transform_all_contours(self): """Transform contours for all time points using registration transforms.""" - print("Transforming contours for all time points...") + self.log_info("Transforming contours for all time points...") self._transformed_contours = {'all': [], 'dynamic': [], 'static': []} for i in range(self._num_time_points): - print(f"Transforming contours for frame {i+1}/{self._num_time_points}") + self.log_progress(i + 1, self._num_time_points, prefix="Transforming contours") frame_contours = {} for anatomy_type in ['all', 'dynamic', 'static']: @@ -387,25 +393,26 @@ def _transform_all_contours(self): def _create_usd_files(self): """Create painted USD files for all anatomy types.""" - print("Creating USD files...") + self.log_info("Creating USD files...") # Create USD for each anatomy type for anatomy_type in ['all', 'dynamic', 'static']: - print(f"Creating {anatomy_type} anatomy USD...") + self.log_info("Creating %s anatomy USD...", anatomy_type) # Convert VTK contours to USD converter = ConvertVTK4DToUSD( self.project_name, self._transformed_contours[anatomy_type], self.segmenter.all_mask_ids, - os.path.join( - self.output_directory, f"{self.project_name}.{anatomy_type}.usd" - ), + log_level=self.log_level + ) + usd_file = os.path.join( + self.output_directory, f"{self.project_name}.{anatomy_type}.usd" ) - stage = converter.convert() + stage = converter.convert(usd_file) # Paint the USD file - print(f"Painting {anatomy_type} anatomy USD...") + self.log_info("Painting %s anatomy USD...", anatomy_type) output_filename = os.path.join( self.output_directory, f"{self.project_name}.{anatomy_type}_painted.usd" ) diff --git a/src/physiomotion4d/image_tools.py b/src/physiomotion4d/image_tools.py index e4769fd..066129d 100644 --- a/src/physiomotion4d/image_tools.py +++ b/src/physiomotion4d/image_tools.py @@ -5,12 +5,16 @@ and performing image processing operations. """ +import logging + import itk import numpy as np import SimpleITK as sitk +from physiomotion4d.physiomotion4d_base import PhysioMotion4DBase + -class ImageTools: +class ImageTools(PhysioMotion4DBase): """ Utilities for medical image format conversions and processing. @@ -26,9 +30,13 @@ class ImageTools: >>> itk_image_back = tools.convert_sitk_image_to_itk(sitk_image) """ - def __init__(self): - """Initialize ImageTools.""" - pass + def __init__(self, log_level: int | str = logging.INFO): + """Initialize ImageTools. + + Args: + log_level: Logging level (default: logging.INFO) + """ + super().__init__(class_name=self.__class__.__name__, log_level=log_level) def convert_itk_image_to_sitk(self, itk_image: itk.Image) -> sitk.Image: """ @@ -78,8 +86,8 @@ def convert_itk_image_to_sitk(self, itk_image: itk.Image) -> sitk.Image: # Set metadata # Convert origin and spacing to tuples (reverse order for SimpleITK: x, y, z) - sitk_image.SetOrigin(tuple(reversed(origin))) - sitk_image.SetSpacing(tuple(reversed(spacing))) + sitk_image.SetOrigin(tuple(origin)) + sitk_image.SetSpacing(tuple(spacing)) # Direction matrix needs to be flattened and reversed appropriately # ITK and SimpleITK use the same direction convention, but we need to handle diff --git a/src/physiomotion4d/register_images_ants.py b/src/physiomotion4d/register_images_ants.py index 254717e..81affed 100644 --- a/src/physiomotion4d/register_images_ants.py +++ b/src/physiomotion4d/register_images_ants.py @@ -10,11 +10,12 @@ """ import argparse +import logging +import os import ants import itk import numpy as np -from itk import TubeTK as ttk from physiomotion4d.register_images_base import RegisterImagesBase from physiomotion4d.transform_tools import TransformTools @@ -63,13 +64,16 @@ class RegisterImagesANTs(RegisterImagesBase): >>> phi_FM = result["phi_FM"] """ - def __init__(self): + def __init__(self, log_level: int | str = logging.INFO): """Initialize the ANTs image registration class. Calls the parent RegisterImagesBase constructor to set up common parameters. Default ANTs registration parameters are set to work well for medical images. + + Args: + log_level: Logging level (default: logging.INFO) """ - super().__init__() + super().__init__(log_level=log_level) self.number_of_iterations = [40, 20, 10] @@ -247,26 +251,130 @@ def _antsfile_to_itk_displacement_field_transform( self._itk_to_ants_image(ref_image, dtype='float'), ) - disp_field_itk = self._ants_to_itk_image(disp_field_ants) + disp_field_itk_raw = self._ants_to_itk_image(disp_field_ants) + + # Convert to the correct Image[Vector[D, 3], 3] type for DisplacementFieldTransform + # Use ImageTools helper to convert array to vector image with correct type + from physiomotion4d.image_tools import ImageTools + image_tools = ImageTools() + + disp_array = itk.array_from_image(disp_field_itk_raw) + disp_field_itk = image_tools.convert_array_to_image_of_vectors( + disp_array, itk.D, ref_image + ) + + # Create displacement field transform disp_tfm = itk.DisplacementFieldTransform[itk.D, 3].New() disp_tfm.SetDisplacementField(disp_field_itk) return disp_tfm - def itk_transform_to_ants_transform( - self, itk_tfm: itk.Transform, reference_image: itk.Image + def itk_affine_transform_to_ants_transform(self, itk_tfm): + """Convert ITK affine/rigid transform to ANTs affine transform. + + Converts an ITK MatrixOffsetTransformBase-derived transform (such as + AffineTransform or Rigid3DTransform) to an ANTs affine transform object. + + The conversion extracts: + - 3x3 transformation matrix (converted to row-major order for ANTs) + - Translation vector + - Center of rotation (fixed parameters) + + Args: + itk_tfm (itk.Transform): ITK affine or rigid transform derived from + itkMatrixOffsetTransformBase (e.g., itk.AffineTransform[itk.D, 3], + itk.Rigid3DTransform, etc.) + + Returns: + ants.ANTsTransform: ANTs affine transform object + + Raises: + ValueError: If transform dimension is not 3D + + Example: + >>> # Create ITK affine transform + >>> affine_itk = itk.AffineTransform[itk.D, 3].New() + >>> affine_itk.SetIdentity() + >>> # Convert to ANTs + >>> affine_ants = registrar.itk_affine_transform_to_ants_transform(affine_itk) + >>> # Use in ANTs operations + >>> result = ants.apply_ants_transform(affine_ants, moving_image) + """ + # Get dimension of the transform + dimension = itk_tfm.GetInputSpaceDimension() + if dimension != 3: + raise ValueError( + f"Only 3D transforms are supported, got dimension: {dimension}" + ) + + # Check if transform has a matrix (Translation transforms don't) + if hasattr(itk_tfm, 'GetMatrix'): + # Extract matrix (ITK matrix is row-major) + matrix_itk = np.asarray(itk_tfm.GetMatrix()).reshape(3, 3) + else: + # For transforms without matrix (e.g., TranslationTransform), use identity matrix + matrix_itk = np.eye(3, dtype=np.float64) + + # Extract translation and center based on transform type: + # - MatrixOffsetTransformBase (Affine, Rigid): use GetTranslation() WITH GetCenter() + # - TranslationTransform: use GetOffset() WITHOUT GetCenter() + if hasattr(itk_tfm, 'GetTranslation'): + # MatrixOffsetTransformBase-derived transforms + translation_itk = np.asarray(itk_tfm.GetTranslation()) + center_itk = np.asarray(itk_tfm.GetCenter()) + elif hasattr(itk_tfm, 'GetOffset'): + # TranslationTransform - use GetOffset() WITHOUT GetCenter() + translation_itk = np.asarray(itk_tfm.GetOffset()) + center_itk = np.zeros(3, dtype=np.float64) # No center for translation + else: + # Fallback for unknown transform types + translation_itk = np.zeros(3, dtype=np.float64) + center_itk = np.zeros(3, dtype=np.float64) + + # ANTs affine transform parameters structure: + # For 3D: 12 parameters + # parameters[0:9]: 3x3 matrix in row-major order + # parameters[9:12]: translation vector + # fixed_parameters[0:3]: center of rotation + + # Flatten matrix to row-major order for ANTs + params = np.zeros(12, dtype=np.float64) + params[0:9] = matrix_itk.flatten() # Already row-major + params[9:12] = translation_itk + + # Ensure fixed_params is also float64 + fixed_params = center_itk.astype(np.float64) + + # Create ANTs affine transform + # Note: dimension must be integer 3, not float + ants_tfm = ants.new_ants_transform( + precision='double', + transform_type='AffineTransform', + dimension=int(3), + parameters=params.tolist(), # Convert to list to ensure proper type + fixed_parameters=fixed_params.tolist(), + ) + + return ants_tfm + + def itk_transform_to_antsfile( + self, + itk_tfm: itk.Transform, + reference_image: itk.Image, + output_filename: str, ): - """Convert ITK transform to ANTsPy transform object. + """Convert ITK transform to ANTs transform file. This method converts any ITK transform (Affine, Rigid, DisplacementField, etc.) - to an ANTsPy transform object that can be used as initial_transform in + to an ANTs transform file that can be used as initial_transform in ants.registration() or ants.label_image_registration(). The conversion process: 1. Uses TransformTools to convert the ITK transform to a displacement field 2. Converts the displacement field image from ITK to ANTs format 3. Creates an ANTsPy transform object from the displacement field + 4. Writes the ANTs transform to a file Args: itk_tfm (itk.Transform): Input ITK transform to convert. Can be any @@ -274,42 +382,56 @@ def itk_transform_to_ants_transform( CompositeTransform, etc.) reference_image (itk.Image): Reference image that defines the spatial domain for the displacement field (spacing, size, origin, direction) + output_filename (str): Path where the ANTs transform file will be written. + Typically should have .mat extension for ANTs transforms. Returns: - ants.core.ANTsTransform: ANTsPy transform object suitable for use - as initial_transform parameter in ANTs registration functions + list[str]: List containing the path to the written ANTs transform file Example: - >>> # Convert ITK affine transform to ANTs + >>> # Convert ITK affine transform to ANTs file >>> affine_itk = itk.AffineTransform[itk.D, 3].New() >>> affine_itk.SetIdentity() - >>> affine_ants = registrar.itk_transform_to_ants_transform( - ... affine_itk, reference_image + >>> transform_files = registrar.itk_transform_to_antsfile( + ... affine_itk, reference_image, "initial_transform.mat" ... ) >>> >>> # Use in registration >>> result = ants.registration( ... fixed=fixed_ants, ... moving=moving_ants, - ... initial_transform=affine_ants + ... initial_transform=transform_files ... ) """ - # Use TransformTools to convert any ITK transform to displacement field - transform_tools = TransformTools() - disp_field_itk = transform_tools.convert_transform_to_displacement_field( - tfm=itk_tfm, - reference_image=reference_image, - np_component_type=np.float64, - use_reference_image_as_mask=False, - ) + if isinstance(itk_tfm, itk.DisplacementFieldTransform) or isinstance( + itk_tfm, itk.CompositeTransform + ): + transform_tools = TransformTools() + disp_field_itk = transform_tools.convert_transform_to_displacement_field( + tfm=itk_tfm, + reference_image=reference_image, + np_component_type=np.float32, # Use float32 for compatibility with ANTs + use_reference_image_as_mask=False, + ) - # Convert ITK displacement field to ANTs image format - disp_field_ants = self._itk_to_ants_image(disp_field_itk, dtype='double') + if "nii.gz" not in output_filename: + output_filename = os.path.splitext(output_filename)[0] + '.nii.gz' - # Create ANTs transform object from displacement field - ants_tfm = ants.transform_from_displacement_field(disp_field_ants) + # Write displacement field directly as nifti (ANTs can read this) + itk.imwrite(disp_field_itk, output_filename, compression=True) + self.log_info("Wrote ANTs displacement field to: %s", output_filename) - return ants_tfm + return [output_filename] + else: + ants_tfm = self.itk_affine_transform_to_ants_transform(itk_tfm) + if ".mat" not in output_filename: + output_filename = os.path.splitext(output_filename)[0] + '.mat' + + # Write transform to file + ants.write_transform(ants_tfm, output_filename) + self.log_info("Wrote ANTs transform to: %s", output_filename) + + return [output_filename] def _antsfiles_to_itk_transforms( self, @@ -432,17 +554,19 @@ def registration_method( # Convert initial ITK transform to ANTs format if provided initial_transform = "identity" if initial_phi_MF is not None: - print("Converting initial ITK transform to ANTs format...") - initial_transform = self.itk_transform_to_ants_transform( - itk_tfm=initial_phi_MF, reference_image=self.fixed_image + self.log_info("Converting initial ITK transform to ANTs format...") + initial_transform = self.itk_transform_to_antsfile( + itk_tfm=initial_phi_MF, + reference_image=self.fixed_image, + output_filename="initial_transform_temp.mat", ) - print("✓ Initial transform converted successfully") + self.log_info("Initial transform converted successfully") if images_are_labelmaps: registration_result = ants.label_image_registration( fixed_label_images=self._itk_to_ants_image(self.fixed_image), moving_label_images=self._itk_to_ants_image(self.moving_image), - initial_transform=initial_transform, + initial_transform=[initial_transform], verbose=True, ) else: @@ -452,7 +576,7 @@ def registration_method( mask=self._itk_to_ants_image(self.fixed_image_mask), moving=self._itk_to_ants_image(self.moving_image_pre), moving_mask=self._itk_to_ants_image(self.moving_image_mask), - initial_transform=initial_transform, + initial_transform=[initial_transform], type_of_transform="antsRegistrationSyNQuick[so]", use_histogram_matching=False, mask_all_stages=True, @@ -463,7 +587,7 @@ def registration_method( registration_result = ants.registration( fixed=self._itk_to_ants_image(self.fixed_image_pre), moving=self._itk_to_ants_image(self.moving_image_pre), - initial_transform=initial_transform, + initial_transform=[initial_transform], type_of_transform="antsRegistrationSyNQuick[so]", use_histogram_matching=False, verbose=True, @@ -485,7 +609,7 @@ def registration_method( # Important: ANTs does NOT include the initial_transform in the output transforms # We need to manually compose them if initial_phi_MF is not None: - print("Composing initial transform with registration result...") + self.log_info("Composing initial transform with registration result...") # For phi_MF (Moving -> Fixed): Apply initial_phi_MF first, then registration # Transform order: point -> initial_phi_MF -> phi_MF_reg @@ -517,7 +641,7 @@ def registration_method( ) phi_FM.AddTransform(initial_phi_FM) - print("✓ Transforms composed successfully") + self.log_info("Transforms composed successfully") else: # No initial transform, use registration results directly phi_MF = phi_MF_reg @@ -584,3 +708,4 @@ def parse_args(): # Convert back to ITK and save moving_image_reg = registrar._ants_to_itk_image(moving_image_reg_ants) itk.imwrite(moving_image_reg, args.output_image, compression=True) + itk.imwrite(moving_image_reg, args.output_image, compression=True) diff --git a/src/physiomotion4d/register_images_base.py b/src/physiomotion4d/register_images_base.py index dc4a9fc..0cd83e3 100644 --- a/src/physiomotion4d/register_images_base.py +++ b/src/physiomotion4d/register_images_base.py @@ -15,14 +15,17 @@ the register() method with their specific algorithm (e.g., Icon, ANTs, etc.). """ +import logging + import itk import numpy as np from itk import TubeTK as ttk +from physiomotion4d.physiomotion4d_base import PhysioMotion4DBase from physiomotion4d.transform_tools import TransformTools -class RegisterImagesBase: +class RegisterImagesBase(PhysioMotion4DBase): """Base class for deformable image registration algorithms. This class provides a common interface and shared functionality for @@ -64,13 +67,18 @@ class and implement the register() method. >>> phi_MF = result["phi_MF"] """ - def __init__(self): + def __init__(self, log_level: int | str = logging.INFO): """Initialize the base image registration class. Sets up the common registration parameters with default values. Algorithm-specific components (like neural networks or optimization objects) should be initialized in the concrete implementation to avoid unnecessary resource allocation. + + Args: + log_level: Logging level (default: logging.INFO) """ + super().__init__(class_name=self.__class__.__name__, log_level=log_level) + self.net = None self.modality = 'ct' diff --git a/src/physiomotion4d/register_images_icon.py b/src/physiomotion4d/register_images_icon.py index 487645d..c5e7810 100644 --- a/src/physiomotion4d/register_images_icon.py +++ b/src/physiomotion4d/register_images_icon.py @@ -10,6 +10,7 @@ """ import argparse +import logging import icon_registration as icon import icon_registration.itk_wrapper @@ -57,14 +58,17 @@ class RegisterImagesICON(RegisterImagesBase): >>> phi_FM = result["phi_FM"] """ - def __init__(self): + def __init__(self, log_level: int | str = logging.INFO): """Initialize the ICON image registration class. Calls the parent RegisterImagesBase constructor to set up common parameters. The ICON deep learning network is initialized lazily on first use to avoid unnecessary GPU memory allocation. + + Args: + log_level: Logging level (default: logging.INFO) """ - super().__init__() + super().__init__(log_level=log_level) self.net = None self.use_multi_modality = False diff --git a/src/physiomotion4d/register_model_to_model_icp.py b/src/physiomotion4d/register_model_to_model_icp.py index a8c985d..0aef36c 100644 --- a/src/physiomotion4d/register_model_to_model_icp.py +++ b/src/physiomotion4d/register_model_to_model_icp.py @@ -35,15 +35,18 @@ >>> phi_MF = result['phi_MF'] # Moving to fixed transform """ +import logging + import itk import numpy as np import pyvista as pv import vtk +from physiomotion4d.physiomotion4d_base import PhysioMotion4DBase from physiomotion4d.transform_tools import TransformTools -class RegisterModelToModelICP: +class RegisterModelToModelICP(PhysioMotion4DBase): """Register anatomical models using Iterative Closest Point (ICP) algorithm. This class provides ICP-based alignment of 3D surface meshes with support for @@ -88,17 +91,21 @@ def __init__( self, moving_mesh: pv.PolyData, fixed_mesh: pv.PolyData, + log_level: int | str = logging.INFO, ): """Initialize ICP-based model registration. Args: moving_mesh: PyVista surface mesh to be aligned to fixed mesh fixed_mesh: PyVista target surface mesh + log_level: Logging level (default: logging.INFO) Note: The moving_mesh is typically extracted from a VTU model using mesh.extract_surface() before passing to this class. """ + super().__init__(class_name=self.__class__.__name__, log_level=log_level) + self.moving_mesh = moving_mesh self.fixed_mesh = fixed_mesh @@ -147,19 +154,17 @@ def register(self, mode: str = 'affine', max_iterations: int = 2000) -> dict: if mode not in ['rigid', 'affine']: raise ValueError(f"Invalid mode '{mode}'. Must be 'rigid' or 'affine'.") - print( - f"Performing {mode.upper()} ICP alignment of moving mesh to fixed mesh..." - ) + self.log_section("%s ICP Alignment", mode.upper()) # Step 1: Centroid alignment (common to both modes) self.registered_mesh = self.moving_mesh.copy(deep=True) moving_centroid = np.array(self.registered_mesh.center) - print(f" Moving mesh centroid: {moving_centroid}") + self.log_debug("Moving mesh centroid: %s", moving_centroid) fixed_centroid = np.array(self.fixed_mesh.center) - print(f" Fixed mesh centroid: {fixed_centroid}") + self.log_debug("Fixed mesh centroid: %s", fixed_centroid) translation = fixed_centroid - moving_centroid - print(f" Step 1: Translating by {translation} to align centroids...") + self.log_info("Step 1: Translating by %s to align centroids...", translation) # Create ITK affine transform with translation phi_ICP = itk.AffineTransform[itk.D, 3].New() @@ -173,10 +178,10 @@ def register(self, mode: str = 'affine', max_iterations: int = 2000) -> dict: with_deformation_magnitude=False, ) - print(f" Center after Step 1: {self.registered_mesh.center}") + self.log_debug("Center after Step 1: %s", self.registered_mesh.center) # Step 2: Rigid ICP (common to both modes) - print(f" Step 2: Performing rigid ICP (max iterations: {max_iterations})...") + self.log_info("Step 2: Performing rigid ICP (max iterations: %d)...", max_iterations) icp_rigid = vtk.vtkIterativeClosestPointTransform() icp_rigid.SetSource(self.registered_mesh) icp_rigid.SetTarget(self.fixed_mesh) @@ -197,13 +202,11 @@ def register(self, mode: str = 'affine', max_iterations: int = 2000) -> dict: with_deformation_magnitude=False, ) - print(f" Center after Step 2: {self.registered_mesh.center}") + self.log_debug("Center after Step 2: %s", self.registered_mesh.center) # Step 3: Affine ICP (only if affine mode) if mode == 'affine': - print( - f" Step 3: Performing affine ICP (max iterations: {max_iterations})..." - ) + self.log_info("Step 3: Performing affine ICP (max iterations: %d)...", max_iterations) icp_affine = vtk.vtkIterativeClosestPointTransform() icp_affine.SetSource(self.registered_mesh) icp_affine.SetTarget(self.fixed_mesh) @@ -224,13 +227,13 @@ def register(self, mode: str = 'affine', max_iterations: int = 2000) -> dict: with_deformation_magnitude=False, ) - print(f" Center after Step 3: {self.registered_mesh.center}") + self.log_debug("Center after Step 3: %s", self.registered_mesh.center) # Compute inverse transform self.phi_MF = phi_ICP.GetInverseTransform() self.phi_FM = phi_ICP - print(f" {mode.upper()} ICP registration complete!") + self.log_info("%s ICP registration complete!", mode.upper()) # Return results as dictionary return { diff --git a/src/physiomotion4d/register_model_to_model_masks.py b/src/physiomotion4d/register_model_to_model_masks.py index e8d8671..197246a 100644 --- a/src/physiomotion4d/register_model_to_model_masks.py +++ b/src/physiomotion4d/register_model_to_model_masks.py @@ -48,18 +48,21 @@ >>> phi_MF = result['phi_MF'] # Moving to fixed transform """ +import logging + import itk import numpy as np import pyvista as pv from itk import TubeTK as ttk from physiomotion4d.contour_tools import ContourTools +from physiomotion4d.physiomotion4d_base import PhysioMotion4DBase from physiomotion4d.register_images_ants import RegisterImagesANTs from physiomotion4d.register_images_icon import RegisterImagesICON from physiomotion4d.transform_tools import TransformTools -class RegisterModelToModelMasks: +class RegisterModelToModelMasks(PhysioMotion4DBase): """Register anatomical models using mask-based deformable registration. This class provides mask-based alignment of 3D surface meshes with support for @@ -123,6 +126,7 @@ def __init__( fixed_mesh: pv.PolyData, reference_image: itk.Image, roi_dilation_mm: float = 10, + log_level: int | str = logging.INFO, ): """Initialize mask-based model registration. @@ -133,11 +137,14 @@ def __init__( for mask generation. Typically the patient CT/MRI image. roi_dilation_mm: Dilation amount in millimeters for ROI mask generation. Default: 20mm + log_level: Logging level (default: logging.INFO) Note: The moving_mesh and fixed_mesh are typically extracted from VTU models using mesh.extract_surface() before passing to this class. """ + super().__init__(class_name=self.__class__.__name__, log_level=log_level) + self.moving_mesh = moving_mesh self.fixed_mesh = fixed_mesh self.reference_image = reference_image @@ -148,8 +155,8 @@ def __init__( self.contour_tools = ContourTools() # Registration instances - self.registrar_ants = RegisterImagesANTs() - self.registrar_icon = RegisterImagesICON() + self.registrar_ants = RegisterImagesANTs(log_level=log_level) + self.registrar_icon = RegisterImagesICON(log_level=log_level) self.registrar_icon.set_modality('ct') self.registrar_icon.set_multi_modality(True) # For mask-based registration @@ -175,7 +182,7 @@ def _create_masks_from_meshes(self): Uses self.reference_image for coordinate frame (origin, spacing, direction). """ - print("Generating binary masks from meshes...") + self.log_info("Generating binary masks from meshes...") # Create fixed mask self.fixed_mask_image = ( @@ -188,7 +195,7 @@ def _create_masks_from_meshes(self): ) # Create fixed ROI mask with dilation - print(f" Dilating fixed mask by {self.roi_dilation_mm}mm for ROI...") + self.log_info("Dilating fixed mask by %.1fmm for ROI...", self.roi_dilation_mm) mask = self.contour_tools.create_mask_from_mesh( self.fixed_mesh, self.reference_image ) @@ -210,7 +217,7 @@ def _create_masks_from_meshes(self): ) # Create moving ROI mask with dilation - print(f" Dilating moving mask by {self.roi_dilation_mm}mm for ROI...") + self.log_info("Dilating moving mask by %.1fmm for ROI...", self.roi_dilation_mm) mask = self.contour_tools.create_mask_from_mesh( self.moving_mesh, self.reference_image ) @@ -218,7 +225,7 @@ def _create_masks_from_meshes(self): imMath.Dilate(dilation_voxels, 1, 0) self.moving_mask_roi_image = imMath.GetOutputUChar() - print(" Mask generation complete.") + self.log_info("Mask generation complete") def register( self, @@ -274,7 +281,7 @@ def register( f"Invalid mode '{mode}'. Must be 'rigid', 'affine', or 'deformable'." ) - print(f"Performing {mode.upper()} mask-based registration...") + self.log_section("%s Mask-based Registration", mode.upper()) # Step 1: Generate masks from meshes self._create_masks_from_meshes() @@ -287,7 +294,7 @@ def register( else: # deformable transform_type = "SyN" # Includes rigid + affine + deformable stages - print(f" Performing ANTs {mode} registration (type: {transform_type})...") + self.log_info("Performing ANTs %s registration (type: %s)...", mode, transform_type) self.registrar_ants.set_fixed_image(self.fixed_mask_image) self.registrar_ants.set_fixed_image_mask(self.fixed_mask_roi_image) @@ -305,9 +312,7 @@ def register( # Step 4: Optional ICON refinement if use_icon: - print( - f" Performing ICON refinement registration ({icon_iterations} iterations)..." - ) + self.log_info("Performing ICON refinement registration (%d iterations)...", icon_iterations) # Transform masks with ANTs result for ICON input moving_mask_ants_transformed = self.transform_tools.transform_image( @@ -343,14 +348,14 @@ def register( self.phi_FM = composed_phi_FM # Apply final transform to moving mesh - print(" Transforming moving mesh...") + self.log_info("Transforming moving mesh...") self.registered_mesh = self.transform_tools.transform_pvcontour( self.moving_mesh, self.phi_MF, with_deformation_magnitude=True, ) - print(f" {mode.upper()} mask-based registration complete!") + self.log_info("%s mask-based registration complete!", mode.upper()) # Return results as dictionary return { diff --git a/src/physiomotion4d/register_time_series_images.py b/src/physiomotion4d/register_time_series_images.py index 03dae08..e253c29 100644 --- a/src/physiomotion4d/register_time_series_images.py +++ b/src/physiomotion4d/register_time_series_images.py @@ -9,6 +9,8 @@ CT where sequential frames need to be registered to a common frame. """ +import logging + import itk from physiomotion4d.register_images_ants import RegisterImagesANTs @@ -64,22 +66,23 @@ class RegisterTimeSeriesImages(RegisterImagesBase): >>> losses = result["losses"] """ - def __init__(self, registration_method='ants'): + def __init__(self, registration_method='ants', log_level: int | str = logging.INFO): """Initialize the time series image registration class. Args: registration_method (str): Registration method to use. Options: 'ants' or 'icon'. Default: 'ants' + log_level: Logging level (default: logging.INFO) Raises: ValueError: If registration_method is not 'ants' or 'icon' """ - super().__init__() + super().__init__(log_level=log_level) self.registration_method = registration_method.lower() - self.registrar_ants = RegisterImagesANTs() - self.registrar_icon = RegisterImagesICON() + self.registrar_ants = RegisterImagesANTs(log_level=log_level) + self.registrar_icon = RegisterImagesICON(log_level=log_level) if self.registration_method == 'ants': self.number_of_iterations = [40, 20, 10] elif self.registration_method == 'icon': diff --git a/src/physiomotion4d/segment_chest_base.py b/src/physiomotion4d/segment_chest_base.py index 2b76da8..24cb734 100644 --- a/src/physiomotion4d/segment_chest_base.py +++ b/src/physiomotion4d/segment_chest_base.py @@ -5,12 +5,16 @@ preprocessing, postprocessing, and anatomical structure organization tasks. """ +import logging + import itk import numpy as np from itk import TubeTK as tube +from physiomotion4d.physiomotion4d_base import PhysioMotion4DBase + -class SegmentChestBase: +class SegmentChestBase(PhysioMotion4DBase): """Base class for chest segmentation that provides common functionality for segmenting chest CT images. @@ -37,13 +41,18 @@ class SegmentChestBase: other_mask_ids (dict): Dictionary of remaining structure IDs """ - def __init__(self): + def __init__(self, log_level: int | str = logging.INFO): """Initialize the SegmentChestBase class. Sets up default parameters for image preprocessing and anatomical structure ID mappings. Subclasses should call this constructor and then override the mask ID dictionaries with their specific mappings. + + Args: + log_level: Logging level (default: logging.INFO) """ + super().__init__(class_name=self.__class__.__name__, log_level=log_level) + self.target_spacing = 0 self.rescale_intensity_range = False @@ -159,13 +168,13 @@ def preprocess_input( + input_image.GetSpacing()[1] + input_image.GetSpacing()[2] ) / 3 - print( - " Resampling to", self.target_spacing, "isotropic spacing." + self.log_info( + "Resampling to %.3f isotropic spacing", self.target_spacing ) if resale_image: - print("WARNING: The input image should have isotropic spacing.") - print(" The input image has spacing:", input_image.GetSpacing()) - print(" Resampling to isotropic:", self.target_spacing) + self.log_warning("The input image should have isotropic spacing") + self.log_info("Input image has spacing: %s", str(input_image.GetSpacing())) + self.log_info("Resampling to isotropic: %.3f", self.target_spacing) interpolator = itk.LinearInterpolateImageFunction.New(input_image) results_image = itk.ResampleImageFilter( input_image, @@ -204,7 +213,7 @@ def preprocess_input( minv = results_image_arr.min() maxv = results_image_arr.max() if self.rescale_intensity_range: - print("Rescaling intensity range...") + self.log_info("Rescaling intensity range...") if ( self.input_intensity_scale_range is None or self.output_intensity_scale_range is None diff --git a/src/physiomotion4d/segment_chest_ensemble.py b/src/physiomotion4d/segment_chest_ensemble.py index 76969a9..5a919f9 100644 --- a/src/physiomotion4d/segment_chest_ensemble.py +++ b/src/physiomotion4d/segment_chest_ensemble.py @@ -8,6 +8,7 @@ # -v /tmp/data:/home/aylward/tmp/data nvcr.io/nim/nvidia/vista3d:latest import argparse +import logging import itk import numpy as np @@ -23,9 +24,13 @@ class SegmentChestEnsemble(SegmentChestBase): segmentation method using VISTA3D. """ - def __init__(self): - """Initialize the vista3d class.""" - super().__init__() + def __init__(self, log_level: int | str = logging.INFO): + """Initialize the vista3d class. + + Args: + log_level: Logging level (default: logging.INFO) + """ + super().__init__(log_level=log_level) self.target_spacing = 0.0 @@ -316,15 +321,15 @@ def ensemble_segmentation( itk.image: The combined segmentation result. """ - print("Running ensemble segmentation: combining results") + self.log_info("Running ensemble segmentation: combining results") labelmap_vista_arr = itk.GetArrayFromImage(labelmap_vista) labelmap_totseg_arr = itk.GetArrayFromImage(labelmap_totseg) - print("Segs done", flush=True) + self.log_info("Segmentations loaded") results_arr = np.zeros_like(labelmap_vista_arr) - print("Setting interpolators") + self.log_info("Setting interpolators") labelmap_vista_interp = itk.LabelImageGaussianInterpolateImageFunction.New( labelmap_vista ) @@ -332,11 +337,13 @@ def ensemble_segmentation( labelmap_totseg ) - print("Iterating through labelmaps", flush=True) + self.log_info("Iterating through labelmaps") lastidx0 = -1 + total_slices = labelmap_vista_arr.shape[0] for idx in np.ndindex(labelmap_vista_arr.shape): if idx[0] != lastidx0: - print("Processing slice", idx[0], flush=True) + if idx[0] % 10 == 0 or idx[0] == total_slices - 1: + self.log_progress(idx[0] + 1, total_slices, prefix="Processing slices") lastidx0 = idx[0] # Skip if both are zero vista_label = labelmap_vista_arr[idx] diff --git a/src/physiomotion4d/segment_chest_total_segmentator.py b/src/physiomotion4d/segment_chest_total_segmentator.py index 6b0e551..e1b04d0 100644 --- a/src/physiomotion4d/segment_chest_total_segmentator.py +++ b/src/physiomotion4d/segment_chest_total_segmentator.py @@ -7,6 +7,7 @@ """ import argparse +import logging import os import tempfile @@ -54,14 +55,17 @@ class SegmentChestTotalSegmentator(SegmentChestBase): >>> heart_mask = result["heart"] """ - def __init__(self): + def __init__(self, log_level: int | str = logging.INFO): """Initialize the TotalSegmentator-based chest segmentation. Sets up the TotalSegmentator-specific anatomical structure ID mappings and processing parameters. The target spacing is set to 1.5mm which provides a good balance between accuracy and processing speed. + + Args: + log_level: Logging level (default: logging.INFO) """ - super().__init__() + super().__init__(log_level=log_level) self.target_spacing = 1.5 @@ -232,10 +236,10 @@ def segmentation_method(self, preprocessed_image: itk.image) -> itk.image: # For higher performance, you can use fast=True, which uses a # faster but less accurate model. - output_nib_image1 = totalsegmentator(nib_image, task="total", device="cuda") + output_nib_image1 = totalsegmentator(nib_image, task="total", device="gpu") labelmap_arr1 = output_nib_image1.get_fdata().astype(np.uint8) - output_nib_image2 = totalsegmentator(nib_image, task="body", device="cuda") + output_nib_image2 = totalsegmentator(nib_image, task="body", device="gpu") labelmap_arr2 = output_nib_image2.get_fdata().astype(np.uint8) # The data from nibabel is in RAS orientation with xyz axis order. diff --git a/src/physiomotion4d/segment_chest_vista_3d.py b/src/physiomotion4d/segment_chest_vista_3d.py index f23829e..e5b37d2 100644 --- a/src/physiomotion4d/segment_chest_vista_3d.py +++ b/src/physiomotion4d/segment_chest_vista_3d.py @@ -17,6 +17,7 @@ # -v /tmp/data:/home/aylward/tmp/data nvcr.io/nim/nvidia/vista3d:latest import argparse +import logging import os import shutil import tempfile @@ -65,7 +66,7 @@ class SegmentChestVista3D(SegmentChestBase): >>> result = segmenter.segment(ct_image) """ - def __init__(self): + def __init__(self, log_level: int | str = logging.INFO): """Initialize the VISTA-3D based chest segmentation. Sets up the VISTA-3D model including downloading weights from Hugging Face, @@ -75,11 +76,14 @@ def __init__(self): The initialization automatically downloads the VISTA-3D model weights from the MONAI/VISTA3D-HF repository on Hugging Face if not already present. + Args: + log_level: Logging level (default: logging.INFO) + Raises: RuntimeError: If CUDA is not available for GPU acceleration ConnectionError: If model weights cannot be downloaded """ - super().__init__() + super().__init__(log_level=log_level) self.target_spacing = 0.0 self.resale_intensity_range = False diff --git a/src/physiomotion4d/segment_chest_vista_3d_nim.py b/src/physiomotion4d/segment_chest_vista_3d_nim.py index 9a22b6d..8103de0 100644 --- a/src/physiomotion4d/segment_chest_vista_3d_nim.py +++ b/src/physiomotion4d/segment_chest_vista_3d_nim.py @@ -9,6 +9,7 @@ import argparse import io +import logging import os import tempfile import zipfile @@ -26,9 +27,13 @@ class SegmentChestVista3DNIM(SegmentChestVista3D): segmentation method using VISTA3D. """ - def __init__(self): - """Initialize the vista3d class.""" - super().__init__() + def __init__(self, log_level: int | str = logging.INFO): + """Initialize the vista3d class. + + Args: + log_level: Logging level (default: logging.INFO) + """ + super().__init__(log_level=log_level) self.invoke_url = "http://localhost:8000/v1/vista3d/inference" self.wsl_docker_tmp_file = ( @@ -66,7 +71,7 @@ def segmentation_method(self, preprocessed_image: itk.image) -> itk.image: z.extractall(temp_dir) file_list = os.listdir(temp_dir) for filename in file_list: - print(filename) + self.log_debug("Found file: %s", filename) filepath = os.path.join(temp_dir, filename) if os.path.isfile(filepath) and filename.endswith(".nii.gz"): # SUCCESS: Return the results diff --git a/src/physiomotion4d/transform_tools.py b/src/physiomotion4d/transform_tools.py index 23469df..d703dd3 100644 --- a/src/physiomotion4d/transform_tools.py +++ b/src/physiomotion4d/transform_tools.py @@ -11,6 +11,8 @@ are used to track anatomical motion over time. """ +import logging + import itk import numpy as np import pyvista as pv @@ -18,9 +20,10 @@ from pxr import Gf, Usd, UsdGeom from .image_tools import ImageTools +from .physiomotion4d_base import PhysioMotion4DBase -class TransformTools: +class TransformTools(PhysioMotion4DBase): """ Utilities for transforming and manipulating ITK transforms. @@ -53,13 +56,71 @@ class TransformTools: >>> field = transform_tools.generate_field(transform, reference_image) """ - def __init__(self): + def __init__(self, log_level: int | str = logging.INFO): """Initialize the TransformTools class. - No parameters are required for initialization as all methods - operate on provided transforms and images. + Args: + log_level: Logging level (default: logging.INFO) + """ + super().__init__(class_name=self.__class__.__name__, log_level=log_level) + + def imreadVD3(self, filename: str) -> itk.Image: + """Read an ITK vector image with double precision vectors. + + ITK's imread is not wrapped for itk.Image[itk.Vector[itk.D,3],3], + so this method reads as itk.Image[itk.Vector[itk.F,3],3] and converts + to double precision. + + Args: + filename (str): Path to the image file to read + + Returns: + itk.Image[itk.Vector[itk.D,3],3]: Vector image with double precision + + Example: + >>> transform_tools = TransformTools() + >>> displacement_field = transform_tools.imreadVD3("deformation.mha") + """ + # Read as float precision vector image + image_float = itk.imread(filename, itk.Image[itk.Vector[itk.F, 3], 3]) + + # Convert to double precision + caster = itk.CastImageFilter[ + itk.Image[itk.Vector[itk.F, 3], 3], + itk.Image[itk.Vector[itk.D, 3], 3] + ].New() + caster.SetInput(image_float) + caster.Update() + + return caster.GetOutput() + + def imwriteVD3(self, image: itk.Image, filename: str, compression: bool = True): + """Write an ITK vector image with double precision vectors. + + ITK's imwrite is not wrapped for itk.Image[itk.Vector[itk.D,3],3], + so this method converts to itk.Image[itk.Vector[itk.F,3],3] and writes. + + Args: + image (itk.Image[itk.Vector[itk.D,3],3]): Vector image to write + filename (str): Path to the output file + compression (bool): Whether to use compression (default: True) + + Example: + >>> transform_tools = TransformTools() + >>> transform_tools.imwriteVD3(displacement_field, "deformation.mha") """ - pass + # Convert to float precision for writing + caster = itk.CastImageFilter[ + itk.Image[itk.Vector[itk.D, 3], 3], + itk.Image[itk.Vector[itk.F, 3], 3] + ].New() + caster.SetInput(image) + caster.Update() + + image_float = caster.GetOutput() + + # Write the float image + itk.imwrite(image_float, filename, compression=compression) def combine_displacement_field_transforms( self, @@ -865,10 +926,10 @@ def convert_itk_transform_to_usd_visualization( # Save the stage stage.Save() - print(f"Created USD visualization: {output_filename}") - print(f" Type: {visualization_type}") - print(f" Points: {np.prod(subsampled_size)}") - print(f" Subsample factor: {subsample_factor}") + self.log_info("Created USD visualization: %s", output_filename) + self.log_info(" Type: %s", visualization_type) + self.log_info(" Points: %d", np.prod(subsampled_size)) + self.log_info(" Subsample factor: %d", subsample_factor) return output_filename @@ -921,7 +982,7 @@ def _create_arrow_visualization( ) arrow_count += 1 - print(f" Created {arrow_count} arrows") + self.log_info(" Created %d arrows", arrow_count) def _create_arrow_prim( self, stage, prim_path, position, displacement, magnitude, arrow_scale @@ -1050,7 +1111,7 @@ def _create_flowline_visualization( self._create_curve_prim(stage, curve_path, streamline_points) flowline_count += 1 - print(f" Created {flowline_count} flowlines") + self.log_info(" Created %d flowlines", flowline_count) def _trace_streamline( self, diff --git a/src/physiomotion4d/usd_anatomy_tools.py b/src/physiomotion4d/usd_anatomy_tools.py index de5c891..923907a 100644 --- a/src/physiomotion4d/usd_anatomy_tools.py +++ b/src/physiomotion4d/usd_anatomy_tools.py @@ -4,18 +4,28 @@ """ import argparse +import logging import os from pxr import Sdf, Usd, UsdGeom, UsdShade +from physiomotion4d.physiomotion4d_base import PhysioMotion4DBase -class USDAnatomyTools: + +class USDAnatomyTools(PhysioMotion4DBase): """ This class is used to enhance the appearance of anatomy meshes in a USD file. """ - def __init__(self, stage): + def __init__(self, stage, log_level: int | str = logging.INFO): + """Initialize USDAnatomyTools. + + Args: + stage: USD stage to work with + log_level: Logging level (default: logging.INFO) + """ + super().__init__(class_name=self.__class__.__name__, log_level=log_level) self.stage = stage self.heart_params = { @@ -306,9 +316,9 @@ def enhance_meshes(self, segmentator): if transform_prim and not mesh_prim: current_prim_path = str(prim.GetPrimPath()) root_prim_path = "/".join(current_prim_path.split("/")[:-1]) - print(f"Root prim path: {root_prim_path}") + self.log_debug("Root prim path: %s", root_prim_path) anatomy_prim_path = "/".join([root_prim_path, "Anatomy"]) - print(f" Anatomy prim path: {anatomy_prim_path}") + self.log_debug(" Anatomy prim path: %s", anatomy_prim_path) if not self.stage.GetPrimAtPath(anatomy_prim_path): UsdGeom.Xform.Define(self.stage, anatomy_prim_path) anatomy_prim_path = "/".join( @@ -326,8 +336,8 @@ def enhance_meshes(self, segmentator): ] ) anatomy_prim_path = Sdf.Path(anatomy_prim_path) - print(f" Current prim path: {current_prim_path}") - print(f" Anatomy prim path: {anatomy_prim_path}") + self.log_debug(" Current prim path: %s", current_prim_path) + self.log_debug(" Anatomy prim path: %s", anatomy_prim_path) editor.MovePrimAtPath( current_prim_path, anatomy_prim_path, diff --git a/src/physiomotion4d/usd_tools.py b/src/physiomotion4d/usd_tools.py index 1bcc72a..3ddbcb1 100644 --- a/src/physiomotion4d/usd_tools.py +++ b/src/physiomotion4d/usd_tools.py @@ -10,13 +10,16 @@ anatomical structures need to be organized and visualized together. """ +import logging import os import numpy as np from pxr import Gf, Sdf, Usd, UsdGeom, UsdShade, UsdUtils +from physiomotion4d.physiomotion4d_base import PhysioMotion4DBase -class USDTools: + +class USDTools(PhysioMotion4DBase): """ Utilities for manipulating Universal Scene Description (USD) files. @@ -54,13 +57,13 @@ class USDTools: ... ) """ - def __init__(self): + def __init__(self, log_level: int | str = logging.INFO): """Initialize the USDTools class. - No parameters are required for initialization as all methods - operate on provided USD files and stages. + Args: + log_level: Logging level (default: logging.INFO) """ - pass + super().__init__(class_name=self.__class__.__name__, log_level=log_level) def get_subtree_bounding_box( self, prim: UsdGeom.Xform @@ -152,7 +155,7 @@ def save_usd_file_arrangement(self, new_stage_name: str, usd_file_names: list[st n_objects = len(usd_file_names) n_rows = int(np.floor(np.sqrt(n_objects))) n_cols = int(np.ceil(n_objects / n_rows)) - print(f"n_rows: {n_rows}, n_cols: {n_cols}") + self.log_info("Grid layout: %d rows x %d cols", n_rows, n_cols) x_spacing = 400.0 y_spacing = 400.0 x_offset = -x_spacing * (n_cols - 1) / 2 @@ -165,17 +168,17 @@ def save_usd_file_arrangement(self, new_stage_name: str, usd_file_names: list[st source_root = source_stage.GetPrimAtPath("/World") children = source_root.GetChildren() for child in children: - print(f"Copying {usd_file_name}:{child.GetPrimPath()}") + self.log_info("Copying %s:%s", usd_file_name, child.GetPrimPath()) new_stage.DefinePrim(child.GetPrimPath()).GetReferences().AddReference( assetPath=usd_file_name, primPath=child.GetPrimPath(), ) # Apply translation to t for grandchild in child.GetAllChildren(): - print(f" Bounding box of {grandchild.GetPrimPath()}") + self.log_debug(" Bounding box of %s", grandchild.GetPrimPath()) bbox_min, bbox_max = self.get_subtree_bounding_box(grandchild) bbox_center = (bbox_min + bbox_max) / 2 - print(f" Bounding box center: {bbox_center}") + self.log_debug(" Bounding box center: %s", bbox_center) xform = UsdGeom.Xformable(grandchild) if not xform.GetOrderedXformOps(): @@ -192,7 +195,7 @@ def save_usd_file_arrangement(self, new_stage_name: str, usd_file_names: list[st grid_y - bbox_center[1], -bbox_center[2], ) - print(f" Translating {grandchild.GetPrimPath()} to {translate}") + self.log_debug(" Translating %s to %s", grandchild.GetPrimPath(), translate) xform_op.Set(translate, Usd.TimeCode.Default()) for prim in source_stage.Traverse(): @@ -206,8 +209,8 @@ def save_usd_file_arrangement(self, new_stage_name: str, usd_file_names: list[st and len(mesh_material) > 0 else str(mesh_material.GetPath()) ) - print( - f" Mesh {prim.GetPrimPath()} has material {material_path}" + self.log_debug( + " Mesh %s has material %s", prim.GetPrimPath(), material_path ) new_prim = new_stage.GetPrimAtPath(prim.GetPrimPath()) material = UsdShade.Material.Get(new_stage, material_path) @@ -215,11 +218,11 @@ def save_usd_file_arrangement(self, new_stage_name: str, usd_file_names: list[st binding_api = UsdShade.MaterialBindingAPI.Apply(new_prim) binding_api.Bind(material) else: - print( - f" Cannot bind. No new prim found for {prim.GetPrimPath()}" + self.log_warning( + " Cannot bind. No new prim found for %s", prim.GetPrimPath() ) - print("Exporting stage...") + self.log_info("Exporting stage...") new_stage.Export(new_stage_name) def merge_usd_files(self, output_filename: str, input_filenames_list: list[str]): @@ -290,7 +293,7 @@ def merge_usd_files(self, output_filename: str, input_filenames_list: list[str]) # Copy all root prims from input for prim in input_stage.GetPseudoRoot().GetAllChildren(): new_path = "/" + prim.GetName() - print(f"Copying {prim.GetPrimPath()} to {new_path}") + self.log_info("Copying %s to %s", prim.GetPrimPath(), new_path) # Recursively copy prim hierarchy with all attributes and time samples def _copy_prim(src_prim, target_path): @@ -356,8 +359,8 @@ def _copy_prim(src_prim, target_path): and len(mesh_material) > 0 else str(mesh_material.GetPath()) ) - print( - f" Binding material {material_path} to {prim.GetPrimPath()}" + self.log_debug( + " Binding material %s to %s", material_path, prim.GetPrimPath() ) # Get corresponding mesh prim and material in target stage new_prim = stage.GetPrimAtPath(prim.GetPrimPath()) @@ -367,12 +370,12 @@ def _copy_prim(src_prim, target_path): binding_api = UsdShade.MaterialBindingAPI.Apply(new_prim) binding_api.Bind(material) else: - print( - f" Warning: Material not found at {material_path} in target stage" + self.log_warning( + " Material not found at %s in target stage", material_path ) else: - print( - f" Warning: Cannot bind material. No mesh prim found at {prim.GetPrimPath()}" + self.log_warning( + " Cannot bind material. No mesh prim found at %s", prim.GetPrimPath() ) # Set stage time range metadata for animation playback @@ -383,8 +386,8 @@ def _copy_prim(src_prim, target_path): stage.SetTimeCodesPerSecond(time_codes_per_second) if frames_per_second is not None: stage.SetFramesPerSecond(frames_per_second) - print(f"\nSet stage time range: {global_start_time} to {global_end_time}") - print(f"Time codes per second: {time_codes_per_second}, Frames per second: {frames_per_second}") + self.log_info("Set stage time range: %.1f to %.1f", global_start_time, global_end_time) + self.log_info("Time codes per second: %s, Frames per second: %s", time_codes_per_second, frames_per_second) # Save with USDA format # stage.GetRootLayer().Export(output_path, args=['--usdFormat', 'usda']) @@ -445,8 +448,9 @@ def merge_usd_files_flattened( frames_per_second = None # Add references to all input files - for input_path in input_filenames_list: - print(f"Referencing {input_path}") + num_files = len(input_filenames_list) + for idx, input_path in enumerate(input_filenames_list): + self.log_progress(idx + 1, num_files, prefix="Referencing files") input_stage = Usd.Stage.Open(input_path, Usd.Stage.LoadAll) # Track time range from this input file @@ -463,7 +467,7 @@ def merge_usd_files_flattened( # Reference each top-level prim from the input file for prim in input_stage.GetPseudoRoot().GetAllChildren(): new_path = "/" + prim.GetName() - print(f" Adding reference: {prim.GetPrimPath()} -> {new_path}") + self.log_debug(" Adding reference: %s -> %s", prim.GetPrimPath(), new_path) # Create prim and add reference to source file temp_stage.DefinePrim(new_path).GetReferences().AddReference( @@ -479,12 +483,12 @@ def merge_usd_files_flattened( temp_stage.SetTimeCodesPerSecond(time_codes_per_second) if frames_per_second is not None: temp_stage.SetFramesPerSecond(frames_per_second) - print(f"Time range: {global_start_time} to {global_end_time}") - print(f"Time codes per second: {time_codes_per_second}, Frames per second: {frames_per_second}") + self.log_info("Time range: %.1f to %.1f", global_start_time, global_end_time) + self.log_info("Time codes per second: %s, Frames per second: %s", time_codes_per_second, frames_per_second) # Flatten the composed stage into a single layer # This resolves all references and bakes everything into one file - print("Flattening composed stage...") + self.log_info("Flattening composed stage...") flattened_layer = temp_stage.Flatten() # Create output stage from flattened layer @@ -497,11 +501,11 @@ def merge_usd_files_flattened( output_stage.SetEndTimeCode(global_end_time) if time_codes_per_second is not None: output_stage.SetTimeCodesPerSecond(time_codes_per_second) - print(f"Set output TimeCodesPerSecond: {time_codes_per_second}") + self.log_info("Set output TimeCodesPerSecond: %s", time_codes_per_second) if frames_per_second is not None: output_stage.SetFramesPerSecond(frames_per_second) - print(f"Set output FramesPerSecond: {frames_per_second}") + self.log_info("Set output FramesPerSecond: %s", frames_per_second) # Export the flattened layer with corrected metadata - print(f"Exporting to {output_filename}") + self.log_info("Exporting to %s", output_filename) output_stage.Export(output_filename) diff --git a/tests/test_image_tools.py b/tests/test_image_tools.py index bb9683f..22e47f3 100644 --- a/tests/test_image_tools.py +++ b/tests/test_image_tools.py @@ -16,222 +16,221 @@ class TestImageTools: """Test suite for ImageTools conversions.""" - + @pytest.fixture def image_tools(self): """Create ImageTools instance.""" return ImageTools() - + def test_itk_to_sitk_scalar_image(self, image_tools): """Test conversion of scalar ITK image to SimpleITK.""" # Create a simple 3D scalar ITK image size = [10, 20, 30] spacing = [1.0, 2.0, 3.0] origin = [0.0, 0.0, 0.0] - + # Create ITK image with known values ImageType = itk.Image[itk.F, 3] itk_image = ImageType.New() - + region = itk.ImageRegion[3]() region.SetSize(size) itk_image.SetRegions(region) itk_image.SetSpacing(spacing) itk_image.SetOrigin(origin) itk_image.Allocate() - + # Fill with test pattern itk_image.FillBuffer(42.0) - + # Convert to SimpleITK sitk_image = image_tools.convert_itk_image_to_sitk(itk_image) - + # Verify metadata - assert sitk_image.GetSize() == tuple(reversed(size)) - assert sitk_image.GetSpacing() == tuple(reversed(spacing)) - assert sitk_image.GetOrigin() == tuple(reversed(origin)) - + assert sitk_image.GetSize() == tuple(size) + assert sitk_image.GetSpacing() == tuple(spacing) + assert sitk_image.GetOrigin() == tuple(origin) + # Verify data array_sitk = sitk.GetArrayFromImage(sitk_image) assert np.allclose(array_sitk, 42.0) - + print("✓ ITK to SimpleITK scalar conversion successful") - + def test_sitk_to_itk_scalar_image(self, image_tools): """Test conversion of scalar SimpleITK image to ITK.""" # Create a simple 3D scalar SimpleITK image size = [10, 20, 30] spacing = [1.0, 2.0, 3.0] origin = [0.0, 0.0, 0.0] - + # Create SimpleITK image sitk_image = sitk.Image(size, sitk.sitkFloat32) sitk_image.SetSpacing(spacing) sitk_image.SetOrigin(origin) - + # Fill with test pattern array = np.ones((size[2], size[1], size[0]), dtype=np.float32) * 99.0 sitk_image = sitk.GetImageFromArray(array) sitk_image.SetSpacing(spacing) sitk_image.SetOrigin(origin) - + # Convert to ITK itk_image = image_tools.convert_sitk_image_to_itk(sitk_image) - + # Verify metadata - assert itk.size(itk_image) == tuple(reversed(size)) + assert itk.size(itk_image) == tuple(size) assert itk.spacing(itk_image) == tuple(spacing) assert itk.origin(itk_image) == tuple(origin) - + # Verify data array_itk = itk.array_from_image(itk_image) assert np.allclose(array_itk, 99.0) - + print("✓ SimpleITK to ITK scalar conversion successful") - + def test_roundtrip_scalar_image(self, image_tools): """Test roundtrip conversion: ITK -> SimpleITK -> ITK.""" # Create ITK image size = [15, 25, 35] spacing = [0.5, 1.5, 2.5] origin = [10.0, 20.0, 30.0] - + ImageType = itk.Image[itk.F, 3] itk_image_original = ImageType.New() - + region = itk.ImageRegion[3]() region.SetSize(size) itk_image_original.SetRegions(region) itk_image_original.SetSpacing(spacing) itk_image_original.SetOrigin(origin) itk_image_original.Allocate() - + # Fill with test data array_original = np.random.rand(size[2], size[1], size[0]).astype(np.float32) itk.array_view_from_image(itk_image_original)[:] = array_original - + # Roundtrip conversion sitk_image = image_tools.convert_itk_image_to_sitk(itk_image_original) itk_image_final = image_tools.convert_sitk_image_to_itk(sitk_image) - + # Verify metadata preserved assert itk.size(itk_image_final) == tuple(size) assert np.allclose(itk.spacing(itk_image_final), spacing) assert np.allclose(itk.origin(itk_image_final), origin) - + # Verify data preserved array_final = itk.array_from_image(itk_image_final) assert np.allclose(array_original, array_final) - + print("✓ Roundtrip scalar conversion successful") - + def test_itk_to_sitk_vector_image(self, image_tools): """Test conversion of vector ITK image to SimpleITK.""" # Create a 3D vector ITK image (like a displacement field) size = [8, 12, 16] spacing = [1.0, 1.0, 1.0] origin = [0.0, 0.0, 0.0] - + # Create vector image with 3 components VectorImageType = itk.Image[itk.Vector[itk.F, 3], 3] itk_image = VectorImageType.New() - + region = itk.ImageRegion[3]() region.SetSize(size) itk_image.SetRegions(region) itk_image.SetSpacing(spacing) itk_image.SetOrigin(origin) itk_image.Allocate() - + # Fill with test vector data array = np.random.rand(size[2], size[1], size[0], 3).astype(np.float32) itk.array_view_from_image(itk_image)[:] = array - + # Convert to SimpleITK sitk_image = image_tools.convert_itk_image_to_sitk(itk_image) - + # Verify it's a vector image assert sitk_image.GetNumberOfComponentsPerPixel() == 3 - + # Verify metadata - assert sitk_image.GetSize() == tuple(reversed(size)) - assert sitk_image.GetSpacing() == tuple(reversed(spacing)) - + assert sitk_image.GetSize() == tuple(size) + assert sitk_image.GetSpacing() == tuple(spacing) + # Verify data array_sitk = sitk.GetArrayFromImage(sitk_image) assert np.allclose(array, array_sitk) - + print("✓ ITK to SimpleITK vector conversion successful") - + def test_sitk_to_itk_vector_image(self, image_tools): """Test conversion of vector SimpleITK image to ITK.""" # Create a 3D vector SimpleITK image size = [8, 12, 16] spacing = [1.0, 1.0, 1.0] origin = [0.0, 0.0, 0.0] - + # Create vector data array = np.random.rand(size[2], size[1], size[0], 3).astype(np.float32) - + # Create SimpleITK vector image sitk_image = sitk.GetImageFromArray(array, isVector=True) sitk_image.SetSpacing(spacing) sitk_image.SetOrigin(origin) - + # Convert to ITK itk_image = image_tools.convert_sitk_image_to_itk(sitk_image) - + # Verify it's a vector image assert itk_image.GetNumberOfComponentsPerPixel() == 3 - + # Verify metadata - assert itk.size(itk_image) == tuple(reversed(size)) + assert itk.size(itk_image) == tuple(size) assert itk.spacing(itk_image) == tuple(spacing) - + # Verify data array_itk = itk.array_from_image(itk_image) assert np.allclose(array, array_itk) - + print("✓ SimpleITK to ITK vector conversion successful") - + def test_roundtrip_vector_image(self, image_tools): """Test roundtrip conversion for vector images: ITK -> SimpleITK -> ITK.""" # Create ITK vector image size = [10, 15, 20] spacing = [0.8, 1.2, 1.6] origin = [5.0, 10.0, 15.0] - + VectorImageType = itk.Image[itk.Vector[itk.F, 3], 3] itk_image_original = VectorImageType.New() - + region = itk.ImageRegion[3]() region.SetSize(size) itk_image_original.SetRegions(region) itk_image_original.SetSpacing(spacing) itk_image_original.SetOrigin(origin) itk_image_original.Allocate() - + # Fill with test vector data array_original = np.random.rand(size[2], size[1], size[0], 3).astype(np.float32) itk.array_view_from_image(itk_image_original)[:] = array_original - + # Roundtrip conversion sitk_image = image_tools.convert_itk_image_to_sitk(itk_image_original) itk_image_final = image_tools.convert_sitk_image_to_itk(sitk_image) - + # Verify metadata preserved assert itk.size(itk_image_final) == tuple(size) assert np.allclose(itk.spacing(itk_image_final), spacing) assert np.allclose(itk.origin(itk_image_final), origin) assert itk_image_final.GetNumberOfComponentsPerPixel() == 3 - + # Verify data preserved array_final = itk.array_from_image(itk_image_final) assert np.allclose(array_original, array_final) - + print("✓ Roundtrip vector conversion successful") if __name__ == "__main__": pytest.main([__file__, "-v", "-s"]) - diff --git a/tests/test_register_images_ants.py b/tests/test_register_images_ants.py index a491f6a..252b855 100644 --- a/tests/test_register_images_ants.py +++ b/tests/test_register_images_ants.py @@ -263,21 +263,17 @@ def test_registration_with_initial_transform( moving_image = test_images[1] # Create initial translation transform - initial_tfm_FM = itk.TranslationTransform[itk.D, 3].New() - initial_tfm_FM.SetOffset([5.0, 5.0, 5.0]) - initial_tfm_MF = itk.TranslationTransform[itk.D, 3].New() initial_tfm_MF.SetOffset([-5.0, -5.0, -5.0]) print("\nRegistering with initial transform...") - print(" Initial offset: [5.0, 5.0, 5.0]") + print(" Initial offset: [-5.0, -5.0, -5.0]") registrar_ants.set_modality('ct') registrar_ants.set_fixed_image(fixed_image) result = registrar_ants.register( moving_image=moving_image, - initial_phi_FM=initial_tfm_FM, initial_phi_MF=initial_tfm_MF, ) @@ -519,27 +515,36 @@ def test_transform_conversion_cycle_affine(self, registrar_ants, test_images): print(f" Original center: {[center[i] for i in range(3)]}") print(f" Original translation: {[translation[i] for i in range(3)]}") - # Convert ITK -> ANTs - ants_tfm = registrar_ants.itk_transform_to_ants_transform( - affine_tfm, reference_image - ) - assert ants_tfm is not None, "ANTs transform is None" - print(f" ANTs transform type: {ants_tfm.transform_type}") - - # Convert back ANTs -> ITK via displacement field - # (ANTs stores as displacement field, so we convert back through that) - + # Convert ITK -> ANTs file with tempfile.TemporaryDirectory() as tmpdir: - # Save ANTs transform to file temp_tfm_file = os.path.join(tmpdir, "temp_transform.mat") - ants.write_transform(ants_tfm, temp_tfm_file) - - # Read back as displacement field - recovered_tfm = ( - registrar_ants._antsfile_to_itk_displacement_field_transform( - temp_tfm_file, reference_image - ) + transform_files = registrar_ants.itk_transform_to_antsfile( + affine_tfm, reference_image, temp_tfm_file ) + assert len(transform_files) == 1, "Should return one transform file" + # Note: The returned filename may have extension added/modified + assert os.path.exists( + transform_files[0] + ), f"Transform file not found: {transform_files[0]}" + print(f" ANTs transform written to: {transform_files[0]}") + + # Read the transform to verify (use the actual written file) + ants_tfm = ants.read_transform(transform_files[0]) + print(f" ANTs transform type: {ants_tfm.transform_type}") + + # Convert back ANTs -> ITK + # Affine transforms are stored as affine in ANTs, so read back as affine + if ants_tfm.transform_type == "AffineTransform": + recovered_tfm = registrar_ants._antsfile_to_itk_affine_transform( + transform_files[0] + ) + else: + # For displacement field transforms + recovered_tfm = ( + registrar_ants._antsfile_to_itk_displacement_field_transform( + transform_files[0], reference_image + ) + ) assert recovered_tfm is not None, "Recovered transform is None" print(f" Recovered transform type: {type(recovered_tfm).__name__}") @@ -587,18 +592,25 @@ def test_transform_conversion_cycle_displacement_field( print("\nTesting displacement field transform conversion cycle...") - # Create a simple displacement field - VectorImageType = itk.Image[itk.Vector[itk.F, 3], 3] - disp_field = VectorImageType.New() - disp_field.CopyInformation(reference_image) - disp_field.SetRegions(reference_image.GetLargestPossibleRegion()) - disp_field.Allocate() + # Create a simple displacement field with double precision + # Use ImageTools to create the correct type + from physiomotion4d.image_tools import ImageTools + + image_tools = ImageTools() - # Fill with a simple displacement pattern - disp_array = itk.array_from_image(disp_field) - # Create a smooth displacement field (small random displacements) - for i in range(3): - disp_array[..., i] = np.random.randn(*disp_array.shape[:-1]) * 0.5 + # Create displacement array (small random displacements) + ref_size = itk.size(reference_image) + disp_array = ( + np.random.randn( + int(ref_size[2]), int(ref_size[1]), int(ref_size[0]), 3 + ).astype(np.float64) + * 0.5 + ) + + # Convert to ITK image with correct vector type + disp_field = image_tools.convert_array_to_image_of_vectors( + disp_array, itk.D, reference_image + ) # Create displacement field transform disp_tfm = itk.DisplacementFieldTransform[itk.D, 3].New() @@ -609,22 +621,24 @@ def test_transform_conversion_cycle_displacement_field( f" Max displacement magnitude: {np.max(np.linalg.norm(disp_array, axis=-1)):.3f}" ) - # Convert ITK -> ANTs - ants_tfm = registrar_ants.itk_transform_to_ants_transform( - disp_tfm, reference_image - ) - assert ants_tfm is not None, "ANTs transform is None" - print(" ANTs transform created successfully") - - # Convert back ANTs -> ITK - + # Convert ITK -> ANTs file with tempfile.TemporaryDirectory() as tmpdir: temp_tfm_file = os.path.join(tmpdir, "temp_disp_transform.mat") - ants.write_transform(ants_tfm, temp_tfm_file) + transform_files = registrar_ants.itk_transform_to_antsfile( + disp_tfm, reference_image, temp_tfm_file + ) + assert len(transform_files) == 1, "Should return one transform file" + # Note: The returned filename may have extension added/modified + assert os.path.exists( + transform_files[0] + ), f"Transform file not found: {transform_files[0]}" + print(f" ANTs transform written successfully to: {transform_files[0]}") + + # Convert back ANTs -> ITK recovered_tfm = ( registrar_ants._antsfile_to_itk_displacement_field_transform( - temp_tfm_file, reference_image + transform_files[0], reference_image ) ) @@ -689,11 +703,18 @@ def test_transform_conversion_with_composite(self, registrar_ants, test_images): f" Composite transform with {composite_tfm.GetNumberOfTransforms()} transforms" ) - # Convert to ANTs - ants_tfm = registrar_ants.itk_transform_to_ants_transform( - composite_tfm, reference_image - ) - assert ants_tfm is not None, "ANTs transform is None" + # Convert to ANTs file + with tempfile.TemporaryDirectory() as tmpdir: + temp_tfm_file = os.path.join(tmpdir, "temp_composite_transform.mat") + transform_files = registrar_ants.itk_transform_to_antsfile( + composite_tfm, reference_image, temp_tfm_file + ) + assert len(transform_files) == 1, "Should return one transform file" + # Note: The returned filename may have extension added/modified + assert os.path.exists( + transform_files[0] + ), f"Transform file not found: {transform_files[0]}" + print(f" ANTs transform written to: {transform_files[0]}") # Test on sample points test_points = [ diff --git a/tests/test_register_images_icon.py b/tests/test_register_images_icon.py index 1708ca4..9e20c1a 100644 --- a/tests/test_register_images_icon.py +++ b/tests/test_register_images_icon.py @@ -30,8 +30,8 @@ def test_registrar_initialization(self, registrar_icon): registrar_icon, 'fixed_image_mask' ), "Missing fixed_image_mask attribute" assert hasattr( - registrar_icon, 'num_iterations' - ), "Missing num_iterations attribute" + registrar_icon, 'number_of_iterations' + ), "Missing number_of_iterations attribute" assert hasattr(registrar_icon, 'net'), "Missing net attribute (ICON network)" print("\nICON registrar initialized successfully") @@ -50,10 +50,12 @@ def test_set_modality(self, registrar_icon): def test_set_number_of_iterations(self, registrar_icon): """Test setting number of iterations.""" registrar_icon.set_number_of_iterations(10) - assert registrar_icon.num_iterations == 10, "Number of iterations not set" + assert registrar_icon.number_of_iterations == 10, "Number of iterations not set" registrar_icon.set_number_of_iterations(5) - assert registrar_icon.num_iterations == 5, "Number of iterations update failed" + assert ( + registrar_icon.number_of_iterations == 5 + ), "Number of iterations update failed" print("\nNumber of iterations setting works correctly") @@ -363,7 +365,6 @@ def test_registration_with_initial_transform( result = registrar_icon.register( moving_image=moving_image, - initial_phi_FM=initial_tfm_FM, initial_phi_MF=initial_tfm_MF, ) diff --git a/tests/test_register_time_series_images.py b/tests/test_register_time_series_images.py index b0aa15f..9b1abb8 100644 --- a/tests/test_register_time_series_images.py +++ b/tests/test_register_time_series_images.py @@ -24,7 +24,12 @@ def test_registrar_initialization_ants(self): registrar = RegisterTimeSeriesImages(registration_method='ants') assert registrar is not None, "Registrar not initialized" assert registrar.registration_method == 'ants', "Method not set correctly" - assert registrar.registrar is not None, "Internal registrar not created" + assert ( + registrar.registrar_ants is not None + ), "Internal ANTs registrar not created" + assert ( + registrar.registrar_icon is not None + ), "Internal ICON registrar not created" print("\n✓ Time series registrar initialized with ANTs") @@ -33,7 +38,12 @@ def test_registrar_initialization_icon(self): registrar = RegisterTimeSeriesImages(registration_method='icon') assert registrar is not None, "Registrar not initialized" assert registrar.registration_method == 'icon', "Method not set correctly" - assert registrar.registrar is not None, "Internal registrar not created" + assert ( + registrar.registrar_ants is not None + ), "Internal ANTs registrar not created" + assert ( + registrar.registrar_icon is not None + ), "Internal ICON registrar not created" print("\n✓ Time series registrar initialized with ICON") @@ -49,9 +59,6 @@ def test_set_modality(self): registrar = RegisterTimeSeriesImages(registration_method='ants') registrar.set_modality('ct') assert registrar.modality == 'ct', "Modality not set correctly" - assert ( - registrar.registrar.modality == 'ct' - ), "Modality not passed to internal registrar" print("\n✓ Modality setting works correctly") @@ -62,9 +69,6 @@ def test_set_fixed_image(self, test_images): registrar.set_fixed_image(fixed_image) assert registrar.fixed_image is not None, "Fixed image not set" - assert ( - registrar.registrar.fixed_image is not None - ), "Fixed image not passed through" print("\n✓ Fixed image set successfully") print(f" Image size: {itk.size(registrar.fixed_image)}") @@ -103,7 +107,7 @@ def test_register_time_series_basic(self, test_images, test_directories): result = registrar.register_time_series( moving_images=moving_images, starting_index=0, - register_start_to_reference=True, + register_start_to_fixed_image=True, portion_of_prior_transform_to_init_next_transform=0.0, ) @@ -123,7 +127,9 @@ def test_register_time_series_basic(self, test_images, test_directories): assert len(losses) == len(moving_images), "losses length mismatch" # Verify all transforms are valid - for i, (phi_MF, phi_FM) in enumerate(zip(phi_MF_list, phi_FM_list)): + for i, (phi_MF, phi_FM) in enumerate( + zip(phi_MF_list, phi_FM_list, strict=False) + ): assert phi_MF is not None, f"phi_MF[{i}] is None" assert phi_FM is not None, f"phi_FM[{i}] is None" @@ -150,7 +156,7 @@ def test_register_time_series_with_prior(self, test_images, test_directories): print("\nRegistering time series (with prior)...") print(f" Number of moving images: {len(moving_images)}") - print(f" Using prior transform weight: 0.5") + print(" Using prior transform weight: 0.5") registrar = RegisterTimeSeriesImages(registration_method='ants') registrar.set_modality('ct') @@ -160,7 +166,7 @@ def test_register_time_series_with_prior(self, test_images, test_directories): result = registrar.register_time_series( moving_images=moving_images, starting_index=1, # Start from middle - register_start_to_reference=True, + register_start_to_fixed_image=True, portion_of_prior_transform_to_init_next_transform=0.5, ) @@ -189,7 +195,7 @@ def test_register_time_series_identity_start(self, test_images): result = registrar.register_time_series( moving_images=moving_images, starting_index=0, - register_start_to_reference=False, # Use identity + register_start_to_fixed_image=False, # Use identity portion_of_prior_transform_to_init_next_transform=0.0, ) @@ -218,7 +224,7 @@ def test_register_time_series_different_starting_indices(self, test_images): result = registrar.register_time_series( moving_images=moving_images, starting_index=starting_index, - register_start_to_reference=True, + register_start_to_fixed_image=True, portion_of_prior_transform_to_init_next_transform=0.0, ) @@ -302,7 +308,7 @@ def test_transform_application_time_series(self, test_images, test_directories): result = registrar.register_time_series( moving_images=moving_images, starting_index=0, - register_start_to_reference=True, + register_start_to_fixed_image=True, portion_of_prior_transform_to_init_next_transform=0.0, ) @@ -342,7 +348,7 @@ def test_register_time_series_icon(self, test_images): result = registrar.register_time_series( moving_images=moving_images, starting_index=0, - register_start_to_reference=True, + register_start_to_fixed_image=True, portion_of_prior_transform_to_init_next_transform=0.0, ) @@ -354,8 +360,6 @@ def test_register_time_series_icon(self, test_images): def test_register_time_series_with_mask(self, test_images, test_directories): """Test time series registration with fixed image mask.""" - output_dir = test_directories["output"] - fixed_image = test_images[0] moving_images = test_images[1:3] @@ -389,7 +393,7 @@ def test_register_time_series_with_mask(self, test_images, test_directories): result = registrar.register_time_series( moving_images=moving_images, starting_index=0, - register_start_to_reference=True, + register_start_to_fixed_image=True, portion_of_prior_transform_to_init_next_transform=0.0, ) @@ -404,7 +408,7 @@ def test_bidirectional_registration(self, test_images): print("\nTesting bidirectional registration...") print(f" Total images: {len(moving_images)}") - print(f" Starting from middle (index 2)") + print(" Starting from middle (index 2)") registrar = RegisterTimeSeriesImages(registration_method='ants') registrar.set_modality('ct') @@ -414,7 +418,7 @@ def test_bidirectional_registration(self, test_images): result = registrar.register_time_series( moving_images=moving_images, starting_index=2, # Middle image - register_start_to_reference=True, + register_start_to_fixed_image=True, portion_of_prior_transform_to_init_next_transform=0.0, ) diff --git a/tests/test_transform_tools.py b/tests/test_transform_tools.py index 4b32cf8..4d9ec72 100644 --- a/tests/test_transform_tools.py +++ b/tests/test_transform_tools.py @@ -247,13 +247,75 @@ def test_convert_transform_to_displacement_field( print(f" Field size: {itk.size(deformation_field)}") print(f" Field shape: {field_arr.shape}") - # Save deformation field - itk.imwrite( + # Save deformation field using imwriteVD3 (for double precision vector images) + transform_tools.imwriteVD3( deformation_field, str(tfm_output_dir / "deformation_field.mha"), compression=True, ) + def test_imwrite_imread_vd3( + self, transform_tools, ants_registration_results, test_images, test_directories + ): + """Test reading and writing double precision vector images.""" + output_dir = test_directories["output"] + tfm_output_dir = output_dir / "transform_tools" + tfm_output_dir.mkdir(exist_ok=True) + + fixed_image = test_images[0] + phi_MF = ants_registration_results["phi_MF"] + + print("\nTesting imwriteVD3 and imreadVD3...") + + # Generate a deformation field + deformation_field = transform_tools.convert_transform_to_displacement_field( + phi_MF, fixed_image + ) + + # Verify it's double precision vector image + field_type = str(type(deformation_field)) + print(f" Original field type: {field_type}") + assert "VectorD" in field_type or "Vector[D" in field_type, \ + "Expected double precision vector image" + + # Get original data for comparison + original_arr = itk.array_from_image(deformation_field) + + # Write using imwriteVD3 + output_path = str(tfm_output_dir / "test_vector_field_vd3.mha") + transform_tools.imwriteVD3(deformation_field, output_path, compression=True) + + print(f" Wrote to: {output_path}") + + # Read back using imreadVD3 + field_read = transform_tools.imreadVD3(output_path) + + # Verify read field + assert field_read is not None, "Read field is None" + assert itk.size(field_read) == itk.size(deformation_field), "Size mismatch" + + # Verify it's double precision + read_type = str(type(field_read)) + print(f" Read field type: {read_type}") + assert "VectorD" in read_type or "Vector[D" in read_type, \ + "Expected double precision vector image after reading" + + # Compare data + read_arr = itk.array_from_image(field_read) + assert read_arr.shape == original_arr.shape, "Array shape mismatch" + + # Check numerical accuracy (should be very close, small float precision loss) + max_diff = np.max(np.abs(read_arr - original_arr)) + mean_diff = np.mean(np.abs(read_arr - original_arr)) + + print(f"✓ Vector field I/O test complete") + print(f" Max difference: {max_diff:.6e}") + print(f" Mean difference: {mean_diff:.6e}") + + # Differences should be very small (float precision conversion) + assert max_diff < 1e-5, f"Max difference too large: {max_diff}" + assert mean_diff < 1e-6, f"Mean difference too large: {mean_diff}" + def test_convert_vtk_matrix_to_itk_transform(self, transform_tools): """Test converting VTK matrix to ITK transform.""" # Create a VTK matrix From eac45811cdbbeb26edbe40778f40f091a2ae4b1a Mon Sep 17 00:00:00 2001 From: Stephen Aylward Date: Sun, 4 Jan 2026 07:16:47 -0500 Subject: [PATCH 02/13] ENH: Extensive changes for consistency. Improved PCA. --- .github/workflows/docs.yml | 13 + .github/workflows/test.yml | 12 +- README.md | 12 +- .../1-input_meshes_to_input_surfaces.ipynb | 2 +- ...2-input_surfaces_to_surfaces_aligned.ipynb | 676 ++++----- ...ces_aligned_correspond_to_pca_inputs.ipynb | 10 +- .../6-create_template_mask.ipynb | 66 + docs/LOGGING_API_REFERENCE.md | 72 +- docs/ants_initial_transform_guide.md | 47 +- docs/api/registration.rst | 48 +- docs/api/utilities.rst | 2 +- docs/conf.py | 2 +- docs/examples.rst | 11 +- docs/quickstart.rst | 4 +- docs/tutorials/basic_workflow.rst | 4 +- docs/user_guide/logging.rst | 18 +- docs/user_guide/registration.rst | 14 +- .../colormap_vtk_to_usd.ipynb | 221 +-- .../displacement_field_to_usd.ipynb | 520 +++---- .../0-download_and_convert_4d_to_3d.ipynb | 33 +- .../1-register_images.ipynb | 458 +++--- .../2-generate_segmentation.ipynb | 139 +- ...ransform_dynamic_and_static_contours.ipynb | 276 ++-- .../4-merge_dynamic_and_static_usd.ipynb | 30 +- .../test_vista3d_class.ipynb | 26 +- .../test_vista3d_inMem.ipynb | 166 +-- ...eart_model_to_model_registration_pca.ipynb | 1294 ++++++++++------- .../heart_model_to_patient_wip.ipynb | 620 +++++--- .../0-download_and_convert_4d_to_3d.ipynb | 27 +- .../1-heart_vtkseries_to_usd.ipynb | 84 +- .../0-register_dirlab_4dct.ipynb | 650 ++++----- .../1-make_dirlab_models.ipynb | 324 ++--- .../2-paint_dirlab_models.ipynb | 22 +- .../Experiment_ArrangeOnStage.ipynb | 52 + .../Experiment_CombineModels.ipynb | 107 ++ .../Experiment_SegReg.ipynb | 70 + .../Experiment_SubSurfaceScatter.ipynb | 52 + .../Lung-VesselsAirways/0-GenData.ipynb | 380 +---- .../Reconstruct4DCT/reconstruct_4d_ct.ipynb | 763 +++++----- .../reconstruct_4d_ct_class.ipynb | 701 +++++---- scripts/README.md | 4 +- src/physiomotion4d/__init__.py | 21 +- src/physiomotion4d/contour_tools.py | 209 ++- .../heart_gated_ct_to_usd_workflow.py | 40 +- .../heart_model_to_patient_workflow.py | 997 ++++++------- src/physiomotion4d/image_tools.py | 59 +- src/physiomotion4d/physiomotion4d_base.py | 8 +- src/physiomotion4d/register_images_ants.py | 174 ++- src/physiomotion4d/register_images_base.py | 109 +- src/physiomotion4d/register_images_icon.py | 80 +- .../register_model_to_image_pca.py | 1211 --------------- .../register_model_to_model_icp.py | 297 ---- .../register_model_to_model_masks.py | 448 ------ .../register_models_distance_maps.py | 381 +++++ src/physiomotion4d/register_models_icp.py | 250 ++++ src/physiomotion4d/register_models_pca.py | 818 +++++++++++ .../register_time_series_images.py | 312 ++-- src/physiomotion4d/segment_chest_base.py | 8 +- src/physiomotion4d/transform_tools.py | 59 - tests/conftest.py | 32 +- tests/test_register_images_ants.py | 84 +- tests/test_register_images_icon.py | 106 +- tests/test_register_time_series_images.py | 117 +- tests/test_transform_tools.py | 94 +- 64 files changed, 6538 insertions(+), 7408 deletions(-) create mode 100644 data/KCL-Heart-Model/6-create_template_mask.ipynb create mode 100644 experiments/Lung-GatedCT_To_USD/Experiment_ArrangeOnStage.ipynb create mode 100644 experiments/Lung-GatedCT_To_USD/Experiment_CombineModels.ipynb create mode 100644 experiments/Lung-GatedCT_To_USD/Experiment_SegReg.ipynb create mode 100644 experiments/Lung-GatedCT_To_USD/Experiment_SubSurfaceScatter.ipynb delete mode 100644 src/physiomotion4d/register_model_to_image_pca.py delete mode 100644 src/physiomotion4d/register_model_to_model_icp.py delete mode 100644 src/physiomotion4d/register_model_to_model_masks.py create mode 100644 src/physiomotion4d/register_models_distance_maps.py create mode 100644 src/physiomotion4d/register_models_icp.py create mode 100644 src/physiomotion4d/register_models_pca.py diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index ebc5b9c..4756469 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -20,6 +20,19 @@ jobs: uses: actions/setup-python@v5 with: python-version: '3.11' + + - name: Install system dependencies + run: | + sudo apt-get update + sudo apt-get install -y \ + libgl1 \ + libglib2.0-0 \ + libgomp1 \ + libsm6 \ + libxrender1 \ + libxext6 \ + libxrandr2 \ + libxi6 - name: Install dependencies run: | diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index bf24cfc..822ac5f 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -42,11 +42,11 @@ jobs: run: | sudo apt-get update sudo apt-get install -y \ - libgl1-mesa-glx \ - libxrender1 \ - libgomp1 \ + libgl1 \ libglib2.0-0 \ + libgomp1 \ libsm6 \ + libxrender1 \ libxext6 \ libxrandr2 \ libxi6 @@ -209,11 +209,11 @@ jobs: run: | sudo apt-get update sudo apt-get install -y \ - libgl1-mesa-glx \ - libxrender1 \ - libgomp1 \ + libgl1 \ libglib2.0-0 \ + libgomp1 \ libsm6 \ + libxrender1 \ libxext6 \ libxrandr2 \ libxi6 diff --git a/README.md b/README.md index cec243a..ba016eb 100644 --- a/README.md +++ b/README.md @@ -98,9 +98,9 @@ print(f"PhysioMotion4D version: {physiomotion4d.__version__}") - `RegisterImagesANTs`: Classical deformable registration using ANTs - `RegisterTimeSeriesImages`: Specialized time series registration for 4D CT - Model-to-Image/Model Registration: - - `RegisterModelToImagePCA`: PCA-based statistical shape model registration - - `RegisterModelToModelICP`: ICP-based surface registration - - `RegisterModelToModelMasks`: Mask-based deformable model registration + - `RegisterModelsPCA`: PCA-based statistical shape model registration + - `RegisterModelsICP`: ICP-based surface registration + - `RegisterModelsDistanceMaps`: Mask-based deformable model registration - `RegisterImagesBase`: Base class for custom registration methods - **Base Classes**: Foundation classes providing common functionality - `PhysioMotion4DBase`: Base class providing standardized logging and debug settings @@ -227,8 +227,8 @@ transforms = time_series_reg.register_time_series( ) # Get forward and inverse displacement fields -phi_FM = results["phi_FM"] # Fixed to moving -phi_MF = results["phi_MF"] # Moving to fixed +inverse_transform = results["inverse_transform"] # Fixed to moving +forward_transform = results["forward_transform"] # Moving to fixed ``` ### Logging and Debug Control @@ -243,7 +243,7 @@ from physiomotion4d import HeartModelToPatientWorkflow, PhysioMotion4DBase PhysioMotion4DBase.set_log_level(logging.DEBUG) # Or filter to show logs from specific classes only -PhysioMotion4DBase.set_log_classes(["HeartModelToPatientWorkflow", "RegisterModelToImagePCA"]) +PhysioMotion4DBase.set_log_classes(["HeartModelToPatientWorkflow", "RegisterModelsPCA"]) # Show all classes again PhysioMotion4DBase.set_log_all_classes() diff --git a/data/KCL-Heart-Model/1-input_meshes_to_input_surfaces.ipynb b/data/KCL-Heart-Model/1-input_meshes_to_input_surfaces.ipynb index f76d5b6..4450258 100644 --- a/data/KCL-Heart-Model/1-input_meshes_to_input_surfaces.ipynb +++ b/data/KCL-Heart-Model/1-input_meshes_to_input_surfaces.ipynb @@ -78,7 +78,7 @@ "metadata": {}, "outputs": [], "source": [ - "# Process average_mesh.vtk\n", + "# Process average_mesh.vtk provided from the KCL data collection.\n", "\n", "mesh = pv.read('average_mesh.vtk')\n", "\n", diff --git a/data/KCL-Heart-Model/2-input_surfaces_to_surfaces_aligned.ipynb b/data/KCL-Heart-Model/2-input_surfaces_to_surfaces_aligned.ipynb index 9addfa5..22e711c 100644 --- a/data/KCL-Heart-Model/2-input_surfaces_to_surfaces_aligned.ipynb +++ b/data/KCL-Heart-Model/2-input_surfaces_to_surfaces_aligned.ipynb @@ -1,341 +1,341 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# ICP Rigid Registration: Align Heart Models to Average\n", - "\n", - "This notebook performs ICP (Iterative Closest Point) rigid registration to align each individual heart model to the average model.\n", - "\n", - "**Workflow:**\n", - "1. Load the average mesh (`input_meshes/average.vtk`)\n", - "2. Load each individual mesh (`input_meshes/01.vtk` through `20.vtk`)\n", - "3. Use ICP rigid registration to align each mesh to the average\n", - "4. Save the aligned meshes to `icp_aligned_meshes/`\n", - "5. Visualize the results\n" - ] + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# ICP Rigid Registration: Align Heart Models to Average\n", + "\n", + "This notebook performs ICP (Iterative Closest Point) rigid registration to align each individual heart model to the average model.\n", + "\n", + "**Workflow:**\n", + "1. Load the average mesh (`input_meshes/average.vtk`)\n", + "2. Load each individual mesh (`input_meshes/01.vtk` through `20.vtk`)\n", + "3. Use ICP rigid registration to align each mesh to the average\n", + "4. Save the aligned meshes to `icp_aligned_meshes/`\n", + "5. Visualize the results\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import sys\n", + "import os\n", + "from pathlib import Path\n", + "import numpy as np\n", + "import pyvista as pv\n", + "\n", + "# Add the src directory to the path to import the registration class\n", + "sys.path.insert(0, str(Path.cwd().parent.parent / 'src'))\n", + "\n", + "from physiomotion4d.register_model_to_model_icp import RegisterModelsICP\n", + "\n", + "# Enable interactive plotting\n", + "pv.set_jupyter_backend('trame')\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1. Load the Average Mesh (Target)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Load the average mesh - this will be our fixed target\n", + "template_mesh_path = Path('./average.vtk')\n", + "template_mesh = pv.read(template_mesh_path)\n", + "\n", + "print(f\"Average mesh loaded:\")\n", + "print(f\" Points: {template_mesh.n_points}\")\n", + "print(f\" Cells: {template_mesh.n_cells}\")\n", + "print(f\" Center: {template_mesh.center}\")\n", + "print(f\" Bounds: {template_mesh.bounds}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2. Find All Individual Mesh Files\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Get all individual mesh files (excluding average.vtk)\n", + "input_meshes_dir = Path('surfaces')\n", + "mesh_files = sorted([f for f in input_meshes_dir.glob('??.vtp')])\n", + "\n", + "print(f\"Found {len(mesh_files)} individual mesh files:\")\n", + "for mesh_file in mesh_files:\n", + " print(f\" {mesh_file.name}\")\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3. Perform ICP Rigid Registration for Each Mesh\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Create output directory for aligned meshes\n", + "output_dir = Path('surfaces_aligned')\n", + "output_dir.mkdir(exist_ok=True)\n", + "\n", + "# Store results\n", + "aligned_meshes = {}\n", + "transforms_point_forward = {} # Moving to Fixed point transforms (forward_point_transform)\n", + "transforms_point_inverse = {} # Fixed to Moving point transforms (inverse_point_transform)\n", + "\n", + "# Process each mesh\n", + "for mesh_file in mesh_files:\n", + " print(f\"\\n{'='*60}\")\n", + " print(f\"Processing: {mesh_file.name}\")\n", + " print(f\"{'='*60}\")\n", + " \n", + " # Load the moving mesh\n", + " moving_mesh = pv.read(mesh_file)\n", + " print(f\" Loaded mesh: {moving_mesh.n_points} points\")\n", + " \n", + " # Extract surface if needed (in case it's a volume mesh)\n", + " if isinstance(moving_mesh, pv.UnstructuredGrid):\n", + " print(f\" Extracting surface from volume mesh...\")\n", + " moving_mesh = moving_mesh.extract_surface()\n", + " print(f\" Surface mesh: {moving_mesh.n_points} points\")\n", + " \n", + " # Initialize registrar\n", + " registrar = RegisterModelsICP(\n", + " moving_mesh=moving_mesh,\n", + " fixed_mesh=template_mesh\n", + " )\n", + " \n", + " # Perform rigid ICP registration\n", + " result = registrar.register(mode='rigid', max_iterations=2000)\n", + " \n", + " # Store results\n", + " mesh_id = mesh_file.stem\n", + " aligned_meshes[mesh_id] = result['moving_mesh']\n", + " transforms_point_forward[mesh_id] = result['forward_point_transform']\n", + " transforms_point_inverse[mesh_id] = result['inverse_point_transform']\n", + " \n", + " # Save aligned mesh\n", + " output_path = output_dir / f\"{mesh_id}.vtp\"\n", + " result['moving_mesh'].save(output_path)\n", + " print(f\"\\n Saved aligned mesh: {output_path}\")\n", + "\n", + "print(f\"\\n{'='*60}\")\n", + "print(f\"ICP registration complete for all {len(mesh_files)} meshes!\")\n", + "print(f\"{'='*60}\")\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 4. Visualize Results: Before and After Registration\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Select a few examples to visualize (e.g., 01, 05, 10, 15, 20)\n", + "example_ids = ['01', '05', '10', '15', '20']\n", + "\n", + "for mesh_id in example_ids:\n", + " if mesh_id not in aligned_meshes:\n", + " continue\n", + " \n", + " # Load original mesh\n", + " original_mesh = pv.read(f'input_meshes/{mesh_id}.vtk')\n", + " if isinstance(original_mesh, pv.UnstructuredGrid):\n", + " original_mesh = original_mesh.extract_surface()\n", + " \n", + " # Create side-by-side comparison\n", + " plotter = pv.Plotter(shape=(1, 2))\n", + " \n", + " # Left: Before registration\n", + " plotter.subplot(0, 0)\n", + " plotter.add_mesh(template_mesh, color='lightblue', opacity=1.0, label='Average')\n", + " plotter.add_mesh(original_mesh, color='red', opacity=1.0, label=f'Original {mesh_id}')\n", + " plotter.add_text(f'Before ICP Registration - {mesh_id}', position='upper_left', font_size=10)\n", + " plotter.add_legend()\n", + " plotter.show_axes()\n", + " \n", + " # Right: After registration\n", + " plotter.subplot(0, 1)\n", + " plotter.add_mesh(template_mesh, color='lightblue', opacity=1.0, label='Average')\n", + " plotter.add_mesh(aligned_meshes[mesh_id], color='green', opacity=1.0, label=f'Aligned {mesh_id}')\n", + " plotter.add_text(f'After ICP Registration - {mesh_id}', position='upper_left', font_size=10)\n", + " plotter.add_legend()\n", + " plotter.show_axes()\n", + " \n", + " plotter.link_views()\n", + " plotter.show()\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 6. Calculate Registration Statistics\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "\n", + "# Calculate statistics for each registration\n", + "stats_data = []\n", + "\n", + "for mesh_id, aligned_mesh in aligned_meshes.items():\n", + " # Calculate distance from aligned mesh to average mesh\n", + " # Using point-to-point distance as a metric\n", + " \n", + " # Get closest points on average mesh for each point in aligned mesh\n", + " closest_points = template_mesh.find_closest_cell(aligned_mesh.points, return_closest_point=True)[1]\n", + " \n", + " # Calculate distances\n", + " distances = np.linalg.norm(aligned_mesh.points - closest_points, axis=1)\n", + " \n", + " stats_data.append({\n", + " 'Mesh ID': mesh_id,\n", + " 'Mean Distance (mm)': np.mean(distances),\n", + " 'Median Distance (mm)': np.median(distances),\n", + " 'Std Distance (mm)': np.std(distances),\n", + " 'Max Distance (mm)': np.max(distances),\n", + " 'Min Distance (mm)': np.min(distances)\n", + " })\n", + "\n", + "# Create DataFrame and display\n", + "stats_df = pd.DataFrame(stats_data)\n", + "stats_df = stats_df.sort_values('Mesh ID')\n", + "\n", + "print(\"\\nRegistration Statistics (Distance from aligned mesh to average mesh):\")\n", + "print(\"=\"*80)\n", + "print(stats_df.to_string(index=False))\n", + "print(\"=\"*80)\n", + "\n", + "# Summary statistics\n", + "print(f\"\\nOverall Summary:\")\n", + "print(f\" Average mean distance: {stats_df['Mean Distance (mm)'].mean():.3f} mm\")\n", + "print(f\" Average median distance: {stats_df['Median Distance (mm)'].mean():.3f} mm\")\n", + "print(f\" Range of mean distances: {stats_df['Mean Distance (mm)'].min():.3f} - {stats_df['Mean Distance (mm)'].max():.3f} mm\")\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 7. Save Registration Statistics\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Save statistics to CSV\n", + "stats_csv_path = output_dir / 'registration_statistics.csv'\n", + "stats_df.to_csv(stats_csv_path, index=False)\n", + "print(f\"\\nStatistics saved to: {stats_csv_path}\")\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 8. Visualize Distance Distributions\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "\n", + "# Create bar plot of mean distances\n", + "fig, axes = plt.subplots(2, 1, figsize=(12, 8))\n", + "\n", + "# Plot 1: Mean distances\n", + "axes[0].bar(stats_df['Mesh ID'], stats_df['Mean Distance (mm)'], color='steelblue')\n", + "axes[0].set_xlabel('Mesh ID')\n", + "axes[0].set_ylabel('Mean Distance (mm)')\n", + "axes[0].set_title('Mean Distance from Aligned Mesh to Average Mesh (After ICP Registration)')\n", + "axes[0].grid(axis='y', alpha=0.3)\n", + "\n", + "# Plot 2: Box plot style visualization\n", + "axes[1].errorbar(\n", + " stats_df['Mesh ID'], \n", + " stats_df['Median Distance (mm)'],\n", + " yerr=stats_df['Std Distance (mm)'],\n", + " fmt='o',\n", + " capsize=5,\n", + " capthick=2,\n", + " color='coral',\n", + " ecolor='gray',\n", + " label='Median ± Std'\n", + ")\n", + "axes[1].set_xlabel('Mesh ID')\n", + "axes[1].set_ylabel('Distance (mm)')\n", + "axes[1].set_title('Median Distance ± Standard Deviation')\n", + "axes[1].legend()\n", + "axes[1].grid(axis='y', alpha=0.3)\n", + "\n", + "plt.tight_layout()\n", + "plt.savefig(output_dir / 'registration_statistics.png', dpi=150, bbox_inches='tight')\n", + "plt.show()\n", + "\n", + "print(f\"\\nPlot saved to: {output_dir / 'registration_statistics.png'}\")\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.11" + } }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import sys\n", - "import os\n", - "from pathlib import Path\n", - "import numpy as np\n", - "import pyvista as pv\n", - "\n", - "# Add the src directory to the path to import the registration class\n", - "sys.path.insert(0, str(Path.cwd().parent.parent / 'src'))\n", - "\n", - "from physiomotion4d.register_model_to_model_icp import RegisterModelToModelICP\n", - "\n", - "# Enable interactive plotting\n", - "pv.set_jupyter_backend('trame')\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 1. Load the Average Mesh (Target)\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Load the average mesh - this will be our fixed target\n", - "average_mesh_path = Path('./average.vtk')\n", - "average_mesh = pv.read(average_mesh_path)\n", - "\n", - "print(f\"Average mesh loaded:\")\n", - "print(f\" Points: {average_mesh.n_points}\")\n", - "print(f\" Cells: {average_mesh.n_cells}\")\n", - "print(f\" Center: {average_mesh.center}\")\n", - "print(f\" Bounds: {average_mesh.bounds}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 2. Find All Individual Mesh Files\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Get all individual mesh files (excluding average.vtk)\n", - "input_meshes_dir = Path('surfaces')\n", - "mesh_files = sorted([f for f in input_meshes_dir.glob('??.vtp')])\n", - "\n", - "print(f\"Found {len(mesh_files)} individual mesh files:\")\n", - "for mesh_file in mesh_files:\n", - " print(f\" {mesh_file.name}\")\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 3. Perform ICP Rigid Registration for Each Mesh\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Create output directory for aligned meshes\n", - "output_dir = Path('surfaces_aligned')\n", - "output_dir.mkdir(exist_ok=True)\n", - "\n", - "# Store results\n", - "aligned_meshes = {}\n", - "transforms_MF = {} # Moving to Fixed transforms\n", - "transforms_FM = {} # Fixed to Moving transforms\n", - "\n", - "# Process each mesh\n", - "for mesh_file in mesh_files:\n", - " print(f\"\\n{'='*60}\")\n", - " print(f\"Processing: {mesh_file.name}\")\n", - " print(f\"{'='*60}\")\n", - " \n", - " # Load the moving mesh\n", - " moving_mesh = pv.read(mesh_file)\n", - " print(f\" Loaded mesh: {moving_mesh.n_points} points\")\n", - " \n", - " # Extract surface if needed (in case it's a volume mesh)\n", - " if isinstance(moving_mesh, pv.UnstructuredGrid):\n", - " print(f\" Extracting surface from volume mesh...\")\n", - " moving_mesh = moving_mesh.extract_surface()\n", - " print(f\" Surface mesh: {moving_mesh.n_points} points\")\n", - " \n", - " # Initialize registrar\n", - " registrar = RegisterModelToModelICP(\n", - " moving_mesh=moving_mesh,\n", - " fixed_mesh=average_mesh\n", - " )\n", - " \n", - " # Perform rigid ICP registration\n", - " result = registrar.register(mode='rigid', max_iterations=2000)\n", - " \n", - " # Store results\n", - " mesh_id = mesh_file.stem\n", - " aligned_meshes[mesh_id] = result['moving_mesh']\n", - " transforms_MF[mesh_id] = result['phi_MF']\n", - " transforms_FM[mesh_id] = result['phi_FM']\n", - " \n", - " # Save aligned mesh\n", - " output_path = output_dir / f\"{mesh_id}.vtp\"\n", - " result['moving_mesh'].save(output_path)\n", - " print(f\"\\n Saved aligned mesh: {output_path}\")\n", - "\n", - "print(f\"\\n{'='*60}\")\n", - "print(f\"ICP registration complete for all {len(mesh_files)} meshes!\")\n", - "print(f\"{'='*60}\")\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 4. Visualize Results: Before and After Registration\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Select a few examples to visualize (e.g., 01, 05, 10, 15, 20)\n", - "example_ids = ['01', '05', '10', '15', '20']\n", - "\n", - "for mesh_id in example_ids:\n", - " if mesh_id not in aligned_meshes:\n", - " continue\n", - " \n", - " # Load original mesh\n", - " original_mesh = pv.read(f'input_meshes/{mesh_id}.vtk')\n", - " if isinstance(original_mesh, pv.UnstructuredGrid):\n", - " original_mesh = original_mesh.extract_surface()\n", - " \n", - " # Create side-by-side comparison\n", - " plotter = pv.Plotter(shape=(1, 2))\n", - " \n", - " # Left: Before registration\n", - " plotter.subplot(0, 0)\n", - " plotter.add_mesh(average_mesh, color='lightblue', opacity=1.0, label='Average')\n", - " plotter.add_mesh(original_mesh, color='red', opacity=1.0, label=f'Original {mesh_id}')\n", - " plotter.add_text(f'Before ICP Registration - {mesh_id}', position='upper_left', font_size=10)\n", - " plotter.add_legend()\n", - " plotter.show_axes()\n", - " \n", - " # Right: After registration\n", - " plotter.subplot(0, 1)\n", - " plotter.add_mesh(average_mesh, color='lightblue', opacity=1.0, label='Average')\n", - " plotter.add_mesh(aligned_meshes[mesh_id], color='green', opacity=1.0, label=f'Aligned {mesh_id}')\n", - " plotter.add_text(f'After ICP Registration - {mesh_id}', position='upper_left', font_size=10)\n", - " plotter.add_legend()\n", - " plotter.show_axes()\n", - " \n", - " plotter.link_views()\n", - " plotter.show()\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 6. Calculate Registration Statistics\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import pandas as pd\n", - "\n", - "# Calculate statistics for each registration\n", - "stats_data = []\n", - "\n", - "for mesh_id, aligned_mesh in aligned_meshes.items():\n", - " # Calculate distance from aligned mesh to average mesh\n", - " # Using point-to-point distance as a metric\n", - " \n", - " # Get closest points on average mesh for each point in aligned mesh\n", - " closest_points = average_mesh.find_closest_cell(aligned_mesh.points, return_closest_point=True)[1]\n", - " \n", - " # Calculate distances\n", - " distances = np.linalg.norm(aligned_mesh.points - closest_points, axis=1)\n", - " \n", - " stats_data.append({\n", - " 'Mesh ID': mesh_id,\n", - " 'Mean Distance (mm)': np.mean(distances),\n", - " 'Median Distance (mm)': np.median(distances),\n", - " 'Std Distance (mm)': np.std(distances),\n", - " 'Max Distance (mm)': np.max(distances),\n", - " 'Min Distance (mm)': np.min(distances)\n", - " })\n", - "\n", - "# Create DataFrame and display\n", - "stats_df = pd.DataFrame(stats_data)\n", - "stats_df = stats_df.sort_values('Mesh ID')\n", - "\n", - "print(\"\\nRegistration Statistics (Distance from aligned mesh to average mesh):\")\n", - "print(\"=\"*80)\n", - "print(stats_df.to_string(index=False))\n", - "print(\"=\"*80)\n", - "\n", - "# Summary statistics\n", - "print(f\"\\nOverall Summary:\")\n", - "print(f\" Average mean distance: {stats_df['Mean Distance (mm)'].mean():.3f} mm\")\n", - "print(f\" Average median distance: {stats_df['Median Distance (mm)'].mean():.3f} mm\")\n", - "print(f\" Range of mean distances: {stats_df['Mean Distance (mm)'].min():.3f} - {stats_df['Mean Distance (mm)'].max():.3f} mm\")\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 7. Save Registration Statistics\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Save statistics to CSV\n", - "stats_csv_path = output_dir / 'registration_statistics.csv'\n", - "stats_df.to_csv(stats_csv_path, index=False)\n", - "print(f\"\\nStatistics saved to: {stats_csv_path}\")\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 8. Visualize Distance Distributions\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import matplotlib.pyplot as plt\n", - "\n", - "# Create bar plot of mean distances\n", - "fig, axes = plt.subplots(2, 1, figsize=(12, 8))\n", - "\n", - "# Plot 1: Mean distances\n", - "axes[0].bar(stats_df['Mesh ID'], stats_df['Mean Distance (mm)'], color='steelblue')\n", - "axes[0].set_xlabel('Mesh ID')\n", - "axes[0].set_ylabel('Mean Distance (mm)')\n", - "axes[0].set_title('Mean Distance from Aligned Mesh to Average Mesh (After ICP Registration)')\n", - "axes[0].grid(axis='y', alpha=0.3)\n", - "\n", - "# Plot 2: Box plot style visualization\n", - "axes[1].errorbar(\n", - " stats_df['Mesh ID'], \n", - " stats_df['Median Distance (mm)'],\n", - " yerr=stats_df['Std Distance (mm)'],\n", - " fmt='o',\n", - " capsize=5,\n", - " capthick=2,\n", - " color='coral',\n", - " ecolor='gray',\n", - " label='Median ± Std'\n", - ")\n", - "axes[1].set_xlabel('Mesh ID')\n", - "axes[1].set_ylabel('Distance (mm)')\n", - "axes[1].set_title('Median Distance ± Standard Deviation')\n", - "axes[1].legend()\n", - "axes[1].grid(axis='y', alpha=0.3)\n", - "\n", - "plt.tight_layout()\n", - "plt.savefig(output_dir / 'registration_statistics.png', dpi=150, bbox_inches='tight')\n", - "plt.show()\n", - "\n", - "print(f\"\\nPlot saved to: {output_dir / 'registration_statistics.png'}\")\n" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "venv", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.11" - } - }, - "nbformat": 4, - "nbformat_minor": 2 + "nbformat": 4, + "nbformat_minor": 2 } diff --git a/data/KCL-Heart-Model/4-surfaces_aligned_correspond_to_pca_inputs.ipynb b/data/KCL-Heart-Model/4-surfaces_aligned_correspond_to_pca_inputs.ipynb index c7a2b8d..1a63a70 100644 --- a/data/KCL-Heart-Model/4-surfaces_aligned_correspond_to_pca_inputs.ipynb +++ b/data/KCL-Heart-Model/4-surfaces_aligned_correspond_to_pca_inputs.ipynb @@ -33,7 +33,7 @@ "import numpy as np\n", "import pyvista as pv\n", "\n", - "from physiomotion4d.register_model_to_model_masks import RegisterModelToModelMasks\n", + "from physiomotion4d.register_model_to_model_masks import RegisterModelsDistanceMaps\n", "\n", "# Enable interactive plotting\n", "pv.set_jupyter_backend('trame')\n" @@ -110,7 +110,7 @@ " return None\n", " surface_mesh = pv.read(surface_file)\n", "\n", - " registrar = RegisterModelToModelMasks(\n", + " registrar = RegisterModelsDistanceMaps(\n", " moving_mesh=correspond_mesh,\n", " fixed_mesh=surface_mesh,\n", " reference_image=ref_image,\n", @@ -145,8 +145,8 @@ "processed_meshes = {}\n", "failed_files = []\n", "\n", - "average_mesh = pv.read(\"average_surface.vtp\")\n", - "bounds = average_mesh.bounds\n", + "template_mesh = pv.read(\"average_surface.vtp\")\n", + "bounds = template_mesh.bounds\n", "xmin = bounds[0]\n", "xmax = bounds[1]\n", "ymin = bounds[2]\n", @@ -220,4 +220,4 @@ }, "nbformat": 4, "nbformat_minor": 2 -} +} \ No newline at end of file diff --git a/data/KCL-Heart-Model/6-create_template_mask.ipynb b/data/KCL-Heart-Model/6-create_template_mask.ipynb new file mode 100644 index 0000000..eea4963 --- /dev/null +++ b/data/KCL-Heart-Model/6-create_template_mask.ipynb @@ -0,0 +1,66 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 2, + "id": "2616aac8", + "metadata": {}, + "outputs": [], + "source": [ + "import itk\n", + "import pyvista as pv\n", + "\n", + "from physiomotion4d import ContourTools\n", + "\n", + "template_model_surface = pv.read('average_surface.vtp')\n", + "\n", + "contour_tools = ContourTools()\n", + "\n", + "reference_image = contour_tools.create_reference_image(\n", + " template_model_surface,\n", + ")\n", + "\n", + "template_mask = contour_tools.create_mask_from_mesh(\n", + " template_model_surface,\n", + " reference_image,\n", + ")\n", + "\n", + "itk.imwrite(template_mask, \"average_binary_mask.nii.gz\")\n", + "\n", + "\n", + "# Then use 3D Slicer to assign the following labels:\n", + "# 1: heart muscle\n", + "# 2: right ventricle\n", + "# 3: left ventricle\n", + "# 4: right atrium\n", + "# 5: left atrium\n", + "# 6: background rim around heart muscle\n", + "\n", + "# save the 3D Slicer labelmap file as \"average_labelmap.nii.gz\"\n", + "\n", + "\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.11" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/LOGGING_API_REFERENCE.md b/docs/LOGGING_API_REFERENCE.md index c25f294..6d0fdbb 100644 --- a/docs/LOGGING_API_REFERENCE.md +++ b/docs/LOGGING_API_REFERENCE.md @@ -15,7 +15,7 @@ PhysioMotion4DBase.set_log_level(logging.INFO) PhysioMotion4DBase.set_log_level('DEBUG') # Can use string too # Filter to show only specific classes -PhysioMotion4DBase.set_log_classes(["RegisterModelToImagePCA", "HeartModelToPatientWorkflow"]) +PhysioMotion4DBase.set_log_classes(["RegisterModelsPCA", "HeartModelToPatientWorkflow"]) # Show all classes (disable filtering) PhysioMotion4DBase.set_log_all_classes() @@ -44,24 +44,24 @@ obj.log_progress(current, total, prefix="Processing") ### PhysioMotion4DBase Class Methods -| Method | Parameters | Description | -|--------|------------|-------------| -| `set_log_level(log_level)` | `log_level: int \| str` | Set logging level for all classes | -| `set_log_classes(class_names)` | `class_names: list[str]` | Show logs only from specified classes | -| `set_log_all_classes()` | None | Show logs from all classes | -| `get_log_classes()` | None | Get list of currently filtered classes | +| Method | Parameters | Description | +| ------------------------------ | ------------------------ | -------------------------------------- | +| `set_log_level(log_level)` | `log_level: int \| str` | Set logging level for all classes | +| `set_log_classes(class_names)` | `class_names: list[str]` | Show logs only from specified classes | +| `set_log_all_classes()` | None | Show logs from all classes | +| `get_log_classes()` | None | Get list of currently filtered classes | ### Instance Methods -| Method | Parameters | Description | -|--------|------------|-------------| -| `log_debug(message)` | `message: str` | Log DEBUG level message | -| `log_info(message)` | `message: str` | Log INFO level message | -| `log_warning(message)` | `message: str` | Log WARNING level message | -| `log_error(message)` | `message: str` | Log ERROR level message | -| `log_critical(message)` | `message: str` | Log CRITICAL level message | -| `log_section(title, width, char)` | `title: str, width: int=70, char: str='='` | Log formatted section header | -| `log_progress(current, total, prefix)` | `current: int, total: int, prefix: str='Progress'` | Log progress information | +| Method | Parameters | Description | +| -------------------------------------- | -------------------------------------------------- | ---------------------------- | +| `log_debug(message)` | `message: str` | Log DEBUG level message | +| `log_info(message)` | `message: str` | Log INFO level message | +| `log_warning(message)` | `message: str` | Log WARNING level message | +| `log_error(message)` | `message: str` | Log ERROR level message | +| `log_critical(message)` | `message: str` | Log CRITICAL level message | +| `log_section(title, width, char)` | `title: str, width: int=70, char: str='='` | Log formatted section header | +| `log_progress(current, total, prefix)` | `current: int, total: int, prefix: str='Progress'` | Log progress information | ## Usage Patterns @@ -108,7 +108,7 @@ def process_items(self, items): ### Pattern 4: Global Log Control ```python # Create multiple objects -pca_reg = RegisterModelToImagePCA(..., log_level=logging.INFO) +pca_reg = RegisterModelsPCA(..., log_level=logging.INFO) workflow = HeartModelToPatientWorkflow(..., log_level=logging.INFO) # Change log level for both at once @@ -125,14 +125,14 @@ PhysioMotion4DBase.set_log_level(logging.WARNING) ### Pattern 5: Selective Class Filtering ```python # Show only PCA registration logs -PhysioMotion4DBase.set_log_classes(["RegisterModelToImagePCA"]) +PhysioMotion4DBase.set_log_classes(["RegisterModelsPCA"]) pca_reg.log_info("This is shown") workflow.log_info("This is hidden") # Show both classes PhysioMotion4DBase.set_log_classes([ - "RegisterModelToImagePCA", + "RegisterModelsPCA", "HeartModelToPatientWorkflow" ]) @@ -151,7 +151,7 @@ workflow.run_workflow() # Focus on specific class for debugging PhysioMotion4DBase.set_log_level(logging.DEBUG) -PhysioMotion4DBase.set_log_classes(["RegisterModelToImagePCA"]) +PhysioMotion4DBase.set_log_classes(["RegisterModelsPCA"]) # Now only PCA registration shows detailed debug output workflow.run_workflow() @@ -163,13 +163,13 @@ PhysioMotion4DBase.set_log_all_classes() ## Log Levels -| Level | Numeric Value | When to Use | -|-------|---------------|-------------| -| `logging.DEBUG` | 10 | Detailed diagnostic information for debugging | -| `logging.INFO` | 20 | General informational messages (default) | -| `logging.WARNING` | 30 | Warning messages about potential issues | -| `logging.ERROR` | 40 | Error messages for serious problems | -| `logging.CRITICAL` | 50 | Critical errors that may cause termination | +| Level | Numeric Value | When to Use | +| ------------------ | ------------- | --------------------------------------------- | +| `logging.DEBUG` | 10 | Detailed diagnostic information for debugging | +| `logging.INFO` | 20 | General informational messages (default) | +| `logging.WARNING` | 30 | Warning messages about potential issues | +| `logging.ERROR` | 40 | Error messages for serious problems | +| `logging.CRITICAL` | 50 | Critical errors that may cause termination | ## Output Format @@ -180,15 +180,15 @@ TIMESTAMP - PhysioMotion4D - LEVEL - [ClassName] Message Example: ``` -2025-12-13 11:35:27 - PhysioMotion4D - INFO - [RegisterModelToImagePCA] Converting mean shape points... +2025-12-13 11:35:27 - PhysioMotion4D - INFO - [RegisterModelsPCA] Converting mean shape points... 2025-12-13 11:35:27 - PhysioMotion4D - DEBUG - [HeartModelToPatientWorkflow] Auto-generating masks... -2025-12-13 11:35:27 - PhysioMotion4D - WARNING - [RegisterModelToImagePCA] No points found within threshold +2025-12-13 11:35:27 - PhysioMotion4D - WARNING - [RegisterModelsPCA] No points found within threshold ``` ## Available Classes Current PhysioMotion4D classes with logging support: -- `RegisterModelToImagePCA` - PCA-based model-to-image registration +- `RegisterModelsPCA` - PCA-based model-to-image registration - `HeartModelToPatientWorkflow` - Multi-stage heart model registration - (More classes will be added as they are converted) @@ -197,7 +197,7 @@ Current PhysioMotion4D classes with logging support: ### Use Case 1: Normal Operation ```python # Default: INFO level, all classes shown -registrar = RegisterModelToImagePCA(..., log_level=logging.INFO) +registrar = RegisterModelsPCA(..., log_level=logging.INFO) result = registrar.register(...) ``` @@ -205,15 +205,15 @@ result = registrar.register(...) ```python # WARNING level: only warnings and errors PhysioMotion4DBase.set_log_level(logging.WARNING) -registrar = RegisterModelToImagePCA(...) +registrar = RegisterModelsPCA(...) result = registrar.register(...) ``` ### Use Case 3: Debug Specific Component ```python -# Debug only RegisterModelToImagePCA +# Debug only RegisterModelsPCA PhysioMotion4DBase.set_log_level(logging.DEBUG) -PhysioMotion4DBase.set_log_classes(["RegisterModelToImagePCA"]) +PhysioMotion4DBase.set_log_classes(["RegisterModelsPCA"]) workflow = HeartModelToPatientWorkflow(...) # Only PCA component will show debug messages @@ -223,7 +223,7 @@ workflow.run_workflow() ### Use Case 4: Production Logging to File ```python # Log everything to file for analysis -registrar = RegisterModelToImagePCA( +registrar = RegisterModelsPCA( ..., log_level=logging.DEBUG, log_to_file="registration.log" @@ -239,7 +239,7 @@ PhysioMotion4DBase.set_log_level(logging.INFO) PhysioMotion4DBase.set_log_level(logging.DEBUG) # Focus on specific class -PhysioMotion4DBase.set_log_classes(["RegisterModelToImagePCA"]) +PhysioMotion4DBase.set_log_classes(["RegisterModelsPCA"]) # Re-run the problematic part registrar.optimize_rigid_alignment(...) diff --git a/docs/ants_initial_transform_guide.md b/docs/ants_initial_transform_guide.md index a97a8dc..f30f427 100644 --- a/docs/ants_initial_transform_guide.md +++ b/docs/ants_initial_transform_guide.md @@ -30,12 +30,12 @@ def itk_transform_to_ants_transform( 3. Creates an ANTsPy transform object using `ants.transform_from_displacement_field()` 4. Returns the ANTsPy transform object (no disk I/O) -### Updated Method: `registration_method()` +### Updated Method: `register()` -The registration method now accepts an `initial_phi_MF` parameter: +The registration method now accepts an `initial_forward_transform` parameter: **New Parameter:** -- `initial_phi_MF` (itk.Transform, optional): Initial transform from moving to fixed space. Can be any ITK transform type. +- `initial_forward_transform` (itk.Transform, optional): Initial transform from moving to fixed space. Can be any ITK transform type. **Supported Transform Types:** - `itk.AffineTransform` @@ -69,7 +69,7 @@ registrar.set_fixed_image(fixed_image) result = registrar.register( moving_image=moving_image, - initial_phi_MF=initial_tfm # Pass ITK transform directly! + initial_forward_transform=initial_tfm # Pass ITK transform directly! ) ``` @@ -86,7 +86,7 @@ registrar_deform = RegisterImagesANTs() registrar_deform.set_fixed_image(fixed_image) result_deform = registrar_deform.register( moving_image=moving_image, - initial_phi_MF=result_rigid["phi_MF"] # Use previous result + initial_forward_transform=result_rigid["forward_transform"] # Use previous result ) ``` @@ -105,7 +105,7 @@ registrar = RegisterImagesANTs() registrar.set_fixed_image(fixed_image) result = registrar.register( moving_image=moving_image, - initial_phi_MF=disp_tfm + initial_forward_transform=disp_tfm ) ``` @@ -118,8 +118,7 @@ registrar.set_fixed_image(fixed_labels) result = registrar.register( moving_image=moving_labels, - images_are_labelmaps=True, - initial_phi_MF=initial_transform # Works with label registration too! + initial_forward_transform=initial_transform ) ``` @@ -144,7 +143,7 @@ registrar = RegisterImagesANTs() registrar.set_fixed_image(fixed_image) result = registrar.register( moving_image=moving_image, - initial_phi_MF=composite # Composite is automatically converted + initial_forward_transform=composite # Composite is automatically converted ) ``` @@ -163,21 +162,21 @@ When you pass an `initial_transform` to `ants.registration()`: ```python result = registrar.register( moving_image=moving_image, - initial_phi_MF=initial_transform + initial_forward_transform=initial_transform ) # The returned transforms INCLUDE the initial transform! -phi_MF = result["phi_MF"] # = initial_phi_MF ∘ registration_refinement -phi_FM = result["phi_FM"] # = registration_refinement^(-1) ∘ initial_phi_MF^(-1) +forward_transform = result["forward_transform"] +inverse_transform = result["inverse_transform"] ``` The composition is done as follows: -- **phi_MF**: `initial_phi_MF` → `registration_result` (applied in sequence) -- **phi_FM**: `registration_result_inverse` → `initial_phi_MF_inverse` (applied in sequence) +- **forward_transform**: `initial_forward_transform` → `registration_result` (applied in sequence) +- **inverse_transform**: `registration_result_inverse` → `initial_forward_transform_inverse` (applied in sequence) ### Coordinate Systems -- The initial transform should map points from **moving space to fixed space** (phi_MF) +- The initial forward transform maps points from **moving space to fixed space** - The displacement field represents the displacement at each voxel in the **fixed image space** - ANTs internally handles the transform composition with its optimization @@ -189,8 +188,8 @@ The composition is done as follows: ### Transform Direction -The parameter `initial_phi_MF` represents: -- **phi_MF**: Transform from Moving → Fixed space +The parameter `initial_forward_transform` represents: +- **forward_transform**: Transform from Moving → Fixed space - This aligns with ANTs' expected initial transform direction ### Interpolation @@ -231,7 +230,7 @@ result = ants.registration( The implementation handles several edge cases: -1. **Null transforms**: If `initial_phi_MF=None`, uses identity transform (default behavior) +1. **Null transforms**: If `initial_forward_transform=None`, uses identity transform (default behavior) 2. **List transforms**: If ITK returns a list with one transform, extracts the single transform 3. **Transform composition**: Composite transforms are flattened to a single displacement field 4. **Type checking**: Validates that inputs are proper ITK transforms @@ -256,11 +255,11 @@ registrar = RegisterImagesANTs() registrar.set_fixed_image(fixed_image) result = registrar.register( moving_image=moving_image, - initial_phi_MF=initial_affine # Pass initial alignment + initial_forward_transform=initial_affine # Pass initial alignment ) # Step 4: The returned transform includes BOTH the initial affine AND the deformation -phi_MF = result["phi_MF"] +forward_transform = result["forward_transform"] # phi_MF is a CompositeTransform containing: # 1. initial_affine (applied first) # 2. deformation from ANTs registration (applied second) @@ -307,11 +306,11 @@ warped = ants.apply_transforms( # ✅ Composition is automatic! result = registrar.register( moving_image=moving_image, - initial_phi_MF=initial_tfm + initial_forward_transform=initial_tfm ) # The returned transform already includes everything -phi_MF = result["phi_MF"] # Complete transform ready to use +forward_transform = result["forward_transform"] # Complete transform ready to use ``` ## Comparison with File-Based Approach @@ -342,11 +341,11 @@ os.remove("temp_transform.mat") # ✅ No file I/O, automatic composition! result = registrar.register( moving_image=moving_image, - initial_phi_MF=itk_tfm # ITK transform object + initial_forward_transform=itk_tfm # ITK transform object ) # Transform is complete and ready to use -phi_MF = result["phi_MF"] +forward_transform = result["forward_transform"] ``` ## References diff --git a/docs/api/registration.rst b/docs/api/registration.rst index 81a6163..2e411b7 100644 --- a/docs/api/registration.rst +++ b/docs/api/registration.rst @@ -73,15 +73,15 @@ ICON (Inverse Consistent Optimization Network) results = registerer.register(moving_image) # Get results - phi_FM = results["phi_FM"] # Forward deformation field - phi_MF = results["phi_MF"] # Inverse deformation field + forward_transform = results["forward_transform"] # Forward deformation field + inverse_transform = results["inverse_transform"] # Inverse deformation field registered_image = results["registered_image"] similarity = results["similarity_score"] **Output Dictionary:** - * ``phi_FM``: Forward deformation field (fixed → moving) - * ``phi_MF``: Inverse deformation field (moving → fixed) + * ``forward_transform``: Used to warp an image from moving to fixed space + * ``inverse_transform``: Used to warp an image from fixed to moving space * ``registered_image``: Moving image warped to fixed space * ``similarity_score``: Registration quality metric * ``inverse_consistency_error``: Inverse consistency metric @@ -189,9 +189,9 @@ Time Series Registration ) # Access results - phi_MF_list = results["phi_MF_list"] # List of transforms - phi_FM_list = results["phi_FM_list"] # List of inverse transforms - losses = results["losses"] # List of registration losses + forward_transforms_list = results["forward_transforms"] # List of transforms + inverse_transforms_list = results["inverse_transforms"] # List of inverse transforms + registration_losses = results["losses"] # List of registration losses **Use Cases:** @@ -218,7 +218,7 @@ These methods register statistical shape models or segmentation meshes to medica PCA-based Registration ---------------------- -.. autoclass:: physiomotion4d.RegisterModelToImagePCA +.. autoclass:: physiomotion4d.RegisterModelsPCA :members: :undoc-members: :show-inheritance: @@ -236,10 +236,10 @@ PCA-based Registration .. code-block:: python - from physiomotion4d import RegisterModelToImagePCA + from physiomotion4d import RegisterModelsPCA import itk - registerer = RegisterModelToImagePCA() + registerer = RegisterModelsPCA() # Load shape model and image shape_model = load_statistical_shape_model("heart_model.pkl") @@ -292,7 +292,7 @@ Model-to-Model Registration ICP Registration ---------------- -.. autoclass:: physiomotion4d.RegisterModelToModelICP +.. autoclass:: physiomotion4d.RegisterModelsICP :members: :undoc-members: :show-inheritance: @@ -310,25 +310,29 @@ ICP Registration .. code-block:: python - from physiomotion4d import RegisterModelToModelICP + from physiomotion4d import RegisterModelsICP import pyvista as pv - registerer = RegisterModelToModelICP() - # Load meshes fixed_mesh = pv.read("reference_mesh.vtp") moving_mesh = pv.read("target_mesh.vtp") - registerer.set_fixed_mesh(fixed_mesh) - registerer.set_max_iterations(100) + # Initialize registrar + registrar = RegisterModelsICP( + moving_mesh=moving_mesh, + fixed_mesh=fixed_mesh + ) # Register - aligned_mesh, transform = registerer.register(moving_mesh) + result = registrar.register(mode='rigid', max_iterations=100) + aligned_mesh = result['moving_mesh'] + forward_point_transform = result['forward_point_transform'] + inverse_point_transform = result['inverse_point_transform'] Mask-based Registration ----------------------- -.. autoclass:: physiomotion4d.RegisterModelToModelMasks +.. autoclass:: physiomotion4d.RegisterModelsDistanceMaps :members: :undoc-members: :show-inheritance: @@ -396,10 +400,10 @@ PhysioMotion4D provides utilities for evaluating registration quality: tre = compute_target_registration_error(fixed_landmarks, warped_landmarks) # Inverse consistency (bidirectional registration) - ice = compute_inverse_consistency_error(phi_FM, phi_MF) + ice = compute_inverse_consistency_error(inverse_transform, forward_transform) # Jacobian determinant (topology preservation) - jac_det = compute_jacobian_determinant(phi_FM) + jac_det = compute_jacobian_determinant(inverse_transform) folding_points = jac_det < 0 # Locations with folding Best Practices @@ -419,12 +423,12 @@ Choosing a Registration Method 3. **For initialization/coarse alignment**: - * Start with :class:`RegisterModelToModelICP` + * Start with :class:`RegisterModelsICP` * Then refine with image-based registration 4. **For shape model fitting**: - * Use :class:`RegisterModelToImagePCA` + * Use :class:`RegisterModelsPCA` * Especially when prior knowledge available Parameter Selection diff --git a/docs/api/utilities.rst b/docs/api/utilities.rst index 927a5c9..b36ecc0 100644 --- a/docs/api/utilities.rst +++ b/docs/api/utilities.rst @@ -481,7 +481,7 @@ Class-Based Filtering # Show logs only from specific classes PhysioMotion4DBase.set_log_classes([ - "RegisterModelToImagePCA", + "RegisterModelsPCA", "HeartModelToPatientWorkflow" ]) diff --git a/docs/conf.py b/docs/conf.py index 243dfa6..ff210cb 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -55,7 +55,7 @@ html_theme_options = { 'logo_only': False, - 'display_version': True, + # 'display_version': True, # Deprecated in sphinx_rtd_theme >= 1.0 'prev_next_buttons_location': 'bottom', 'style_external_links': False, 'vcs_pageview_mode': '', diff --git a/docs/examples.rst b/docs/examples.rst index 395b48b..d2e0cec 100644 --- a/docs/examples.rst +++ b/docs/examples.rst @@ -149,13 +149,14 @@ Fast GPU-accelerated registration: results = registerer.register(moving) # Get results - phi_FM = results["phi_FM"] - phi_MF = results["phi_MF"] + inverse_transform = results["inverse_transform"] + forward_transform = results["forward_transform"] registered = results["registered_image"] # Save itk.imwrite(registered, "registered.mha") - itk.imwrite(phi_FM, "transform_forward.mha") + itk.transformwrite(forward_transform, "transform_forward.hdf") + itk.transformwrite(inverse_transform, "transform_inverse.hdf") Multi-Phase Cardiac Registration --------------------------------- @@ -182,7 +183,7 @@ Register all cardiac phases to reference: for frame_file in frame_files[1:]: # Skip reference moving = itk.imread(frame_file) results = registerer.register(moving) - transforms.append(results["phi_FM"]) + transforms.append(results["inverse_transform"]) print(f"Registered {frame_file}: similarity = {results['similarity_score']:.3f}") @@ -577,7 +578,7 @@ Mix and match different components: results = registerer.register(frame) warped_mesh = transform_tools.apply_transform_to_contour( reference_mesh, - results["phi_FM"] + results["inverse_transform"] ) meshes.append(warped_mesh) diff --git a/docs/quickstart.rst b/docs/quickstart.rst index d327b27..ec6d8df 100644 --- a/docs/quickstart.rst +++ b/docs/quickstart.rst @@ -158,8 +158,8 @@ For standalone registration: results = registerer.register(moving_image) # Get transformation fields - phi_FM = results["phi_FM"] # Forward transform - phi_MF = results["phi_MF"] # Inverse transform + inverse_transform = results["inverse_transform"] # Fixed to moving space + forward_transform = results["forward_transform"] # Moving to fixed space VTK to USD Conversion --------------------- diff --git a/docs/tutorials/basic_workflow.rst b/docs/tutorials/basic_workflow.rst index d5bceb6..3b95e05 100644 --- a/docs/tutorials/basic_workflow.rst +++ b/docs/tutorials/basic_workflow.rst @@ -402,8 +402,8 @@ The workflow generates several intermediate files: │ ├── frame_001.mha │ └── ... ├── transforms/ # Registration transforms - │ ├── phi_FM_001.mha - │ ├── phi_MF_001.mha + │ ├── inverse_transform_001.mha + │ ├── forward_transform_001.mha │ └── ... ├── masks/ # Segmentation masks │ ├── heart_mask.nrrd diff --git a/docs/user_guide/logging.rst b/docs/user_guide/logging.rst index a1ea0d0..48d2f19 100644 --- a/docs/user_guide/logging.rst +++ b/docs/user_guide/logging.rst @@ -51,7 +51,7 @@ Example output: .. code-block:: text - 2025-12-13 11:24:49 - PhysioMotion4D - INFO - [RegisterModelToImagePCA] Processing started + 2025-12-13 11:24:49 - PhysioMotion4D - INFO - [RegisterModelsPCA] Processing started 2025-12-13 11:24:49 - PhysioMotion4D - DEBUG - [HeartModelToPatientWorkflow] Detailed debug info Multiple Log Levels @@ -79,12 +79,12 @@ Filter to show logs from only specific classes: from physiomotion4d import PhysioMotion4DBase - # Show only RegisterModelToImagePCA logs - PhysioMotion4DBase.set_log_classes(["RegisterModelToImagePCA"]) + # Show only RegisterModelsPCA logs + PhysioMotion4DBase.set_log_classes(["RegisterModelsPCA"]) # Show logs from multiple classes PhysioMotion4DBase.set_log_classes([ - "RegisterModelToImagePCA", + "RegisterModelsPCA", "HeartModelToPatientWorkflow" ]) @@ -225,13 +225,13 @@ Class Filtering for Selective Debugging from physiomotion4d import ( PhysioMotion4DBase, - RegisterModelToImagePCA, + RegisterModelsPCA, HeartModelToPatientWorkflow ) import logging # Create multiple objects - registrar1 = RegisterModelToImagePCA(..., log_level=logging.INFO) + registrar1 = RegisterModelsPCA(..., log_level=logging.INFO) registrar2 = HeartModelToPatientWorkflow(..., log_level=logging.INFO) # Show logs from all classes (default) @@ -239,14 +239,14 @@ Class Filtering for Selective Debugging registrar2.log_info("Message from Workflow") # Both messages are shown - # Filter to show only RegisterModelToImagePCA - PhysioMotion4DBase.set_log_classes(["RegisterModelToImagePCA"]) + # Filter to show only RegisterModelsPCA + PhysioMotion4DBase.set_log_classes(["RegisterModelsPCA"]) registrar1.log_info("This is shown") registrar2.log_info("This is hidden") # Show multiple specific classes PhysioMotion4DBase.set_log_classes([ - "RegisterModelToImagePCA", + "RegisterModelsPCA", "HeartModelToPatientWorkflow" ]) diff --git a/docs/user_guide/registration.rst b/docs/user_guide/registration.rst index b4df917..e45ca70 100644 --- a/docs/user_guide/registration.rst +++ b/docs/user_guide/registration.rst @@ -37,19 +37,19 @@ Model-to-Image/Model Registration For registering anatomical models to patient data: -* **ICP (RegisterModelToModelICP)**: Iterative Closest Point for surface alignment +* **ICP (RegisterModelsICP)**: Iterative Closest Point for surface alignment - Best for: Initial rough alignment, mesh registration - Speed: Very fast (<10 seconds) - Type: Rigid/affine registration -* **Mask-based (RegisterModelToModelMasks)**: Deformable registration using binary masks +* **Mask-based (RegisterModelsDistanceMaps)**: Deformable registration using binary masks - Best for: Model-to-patient fitting - Features: Dice coefficient optimization - Type: Deformable registration -* **PCA-based (RegisterModelToImagePCA)**: Statistical shape model registration +* **PCA-based (RegisterModelsPCA)**: Statistical shape model registration - Best for: Shape prior constraints - Features: Low-dimensional optimization @@ -87,8 +87,8 @@ Image Registration with ICON results = registerer.register(moving) # Get displacement fields - phi_FM = results["phi_FM"] # Fixed to moving - phi_MF = results["phi_MF"] # Moving to fixed + inverse_transform = results["inverse_transform"] # Fixed to moving + forward_transform = results["forward_transform"] # Moving to fixed Time Series Registration ------------------------- @@ -108,8 +108,8 @@ Time Series Registration results = reg.register_time_series(image_filenames=image_files) # Access transforms - transforms_FM = results["phi_FM_list"] - transforms_MF = results["phi_MF_list"] + transforms_inverse = results["inverse_transforms"] + transforms_forward = results["forward_transforms"] Model to Patient Registration ------------------------------ diff --git a/experiments/Colormap-VTK_To_USD/colormap_vtk_to_usd.ipynb b/experiments/Colormap-VTK_To_USD/colormap_vtk_to_usd.ipynb index c4539ac..fc8b78b 100644 --- a/experiments/Colormap-VTK_To_USD/colormap_vtk_to_usd.ipynb +++ b/experiments/Colormap-VTK_To_USD/colormap_vtk_to_usd.ipynb @@ -35,18 +35,7 @@ "metadata": {}, "outputs": [], "source": [ - "import numpy as np\n", - "import pyvista as pv\n", - "from pathlib import Path\n", - "\n", - "from physiomotion4d import ConvertVTK4DToUSD\n", - "\n", - "# Create output directory\n", - "output_dir = Path(\"./output\")\n", - "output_dir.mkdir(exist_ok=True)\n", - "\n", - "print(\"PhysioMotion4D Colormap Examples\")\n", - "print(\"=\" * 50)" + "import numpy as np\nimport pyvista as pv\nfrom pathlib import Path\n\nfrom physiomotion4d import ConvertVTK4DToUSD\n\n# Create output directory\noutput_dir = Path(\"./output\")\noutput_dir.mkdir(exist_ok=True)\n\nprint(\"PhysioMotion4D Colormap Examples\")\nprint(\"=\" * 50)" ] }, { @@ -64,39 +53,7 @@ "metadata": {}, "outputs": [], "source": [ - "def create_example_mesh_with_data(time_step):\n", - " \"\"\"\n", - " Create a sphere mesh with synthetic data for demonstration.\n", - " \n", - " Parameters\n", - " ----------\n", - " time_step : int\n", - " Current time step for animation\n", - " \n", - " Returns\n", - " -------\n", - " pyvista.PolyData\n", - " Sphere mesh with point data arrays\n", - " \"\"\"\n", - " # Create a sphere\n", - " sphere = pv.Sphere(radius=1.0, theta_resolution=30, phi_resolution=30)\n", - " \n", - " # Add synthetic point data (e.g., simulating transmembrane potential)\n", - " points = sphere.points\n", - " z_coords = points[:, 2]\n", - " theta = 2 * np.pi * time_step / 10.0 # Full cycle every 10 frames\n", - " \n", - " # Simulate transmembrane potential: -80 mV (rest) to +20 mV (depolarized)\n", - " potential = -80.0 + 100.0 * (0.5 + 0.5 * np.sin(3 * z_coords + theta))\n", - " sphere.point_data['transmembrane_potential'] = potential\n", - " \n", - " # Add temperature data example\n", - " temperature = 20.0 + 15.0 * (0.5 + 0.5 * np.cos(2 * z_coords - theta))\n", - " sphere.point_data['temperature'] = temperature\n", - " \n", - " return sphere\n", - "\n", - "print(\"Helper function defined successfully\")" + "def create_example_mesh_with_data(time_step):\n \"\"\"\n Create a sphere mesh with synthetic data for demonstration.\n \n Parameters\n ----------\n time_step : int\n Current time step for animation\n \n Returns\n -------\n pyvista.PolyData\n Sphere mesh with point data arrays\n \"\"\"\n # Create a sphere\n sphere = pv.Sphere(radius=1.0, theta_resolution=30, phi_resolution=30)\n \n # Add synthetic point data (e.g., simulating transmembrane potential)\n points = sphere.points\n z_coords = points[:, 2]\n theta = 2 * np.pi * time_step / 10.0 # Full cycle every 10 frames\n \n # Simulate transmembrane potential: -80 mV (rest) to +20 mV (depolarized)\n potential = -80.0 + 100.0 * (0.5 + 0.5 * np.sin(3 * z_coords + theta))\n sphere.point_data['transmembrane_potential'] = potential\n \n # Add temperature data example\n temperature = 20.0 + 15.0 * (0.5 + 0.5 * np.cos(2 * z_coords - theta))\n sphere.point_data['temperature'] = temperature\n \n return sphere\n\nprint(\"Helper function defined successfully\")" ] }, { @@ -114,36 +71,7 @@ "metadata": {}, "outputs": [], "source": [ - "print(\"\\nExample 1: Plasma colormap (default) with auto range\")\n", - "print(\"-\" * 50)\n", - "\n", - "# Create time series of meshes\n", - "meshes = [create_example_mesh_with_data(t) for t in range(20)]\n", - "\n", - "# Initialize converter\n", - "converter = ConvertVTK4DToUSD(\n", - " data_basename=\"CardiacModel\",\n", - " input_polydata=meshes,\n", - " mask_ids=None\n", - ")\n", - "\n", - "# List available arrays for coloring\n", - "print(\"Available point data arrays:\")\n", - "available = converter.list_available_arrays()\n", - "for name, info in available.items():\n", - " print(f\" - {name}: range={info['range']}, dtype={info['dtype']}\")\n", - "\n", - "# Set colormap (automatic range)\n", - "converter.set_colormap(\n", - " color_by_array=\"transmembrane_potential\",\n", - " colormap=\"plasma\",\n", - " intensity_range=None # Auto-detect from data\n", - ")\n", - "\n", - "# Convert to USD\n", - "output_file = output_dir / \"example1_plasma_auto.usd\"\n", - "stage = converter.convert(str(output_file))\n", - "print(f\"\\n✓ Created: {output_file}\")" + "print(\"\\nExample 1: Plasma colormap (default) with auto range\")\nprint(\"-\" * 50)\n\n# Create time series of meshes\nmeshes = [create_example_mesh_with_data(t) for t in range(20)]\n\n# Initialize converter\nconverter = ConvertVTK4DToUSD(\n data_basename=\"CardiacModel\",\n input_polydata=meshes,\n mask_ids=None\n)\n\n# List available arrays for coloring\nprint(\"Available point data arrays:\")\navailable = converter.list_available_arrays()\nfor name, info in available.items():\n print(f\" - {name}: range={info['range']}, dtype={info['dtype']}\")\n\n# Set colormap (automatic range)\nconverter.set_colormap(\n color_by_array=\"transmembrane_potential\",\n colormap=\"plasma\",\n intensity_range=None # Auto-detect from data\n)\n\n# Convert to USD\noutput_file = output_dir / \"example1_plasma_auto.usd\"\nstage = converter.convert(str(output_file))\nprint(f\"\\n✓ Created: {output_file}\")" ] }, { @@ -161,26 +89,7 @@ "metadata": {}, "outputs": [], "source": [ - "print(\"\\nExample 2: Rainbow colormap with custom range [-80, 20] mV\")\n", - "print(\"-\" * 50)\n", - "\n", - "meshes = [create_example_mesh_with_data(t) for t in range(20)]\n", - "\n", - "converter = ConvertVTK4DToUSD(\n", - " data_basename=\"CardiacModel\",\n", - " input_polydata=meshes,\n", - " mask_ids=None\n", - ")\n", - "\n", - "converter.set_colormap(\n", - " color_by_array=\"transmembrane_potential\",\n", - " colormap=\"rainbow\",\n", - " intensity_range=(-80.0, 20.0) # Physiological range for action potential\n", - ")\n", - "\n", - "output_file = output_dir / \"example2_rainbow_custom.usd\"\n", - "stage = converter.convert(str(output_file))\n", - "print(f\"✓ Created: {output_file}\")" + "print(\"\\nExample 2: Rainbow colormap with custom range [-80, 20] mV\")\nprint(\"-\" * 50)\n\nmeshes = [create_example_mesh_with_data(t) for t in range(20)]\n\nconverter = ConvertVTK4DToUSD(\n data_basename=\"CardiacModel\",\n input_polydata=meshes,\n mask_ids=None\n)\n\nconverter.set_colormap(\n color_by_array=\"transmembrane_potential\",\n colormap=\"rainbow\",\n intensity_range=(-80.0, 20.0) # Physiological range for action potential\n)\n\noutput_file = output_dir / \"example2_rainbow_custom.usd\"\nstage = converter.convert(str(output_file))\nprint(f\"✓ Created: {output_file}\")" ] }, { @@ -198,26 +107,7 @@ "metadata": {}, "outputs": [], "source": [ - "print(\"\\nExample 3: Heat colormap for temperature visualization\")\n", - "print(\"-\" * 50)\n", - "\n", - "meshes = [create_example_mesh_with_data(t) for t in range(20)]\n", - "\n", - "converter = ConvertVTK4DToUSD(\n", - " data_basename=\"TemperatureModel\",\n", - " input_polydata=meshes,\n", - " mask_ids=None\n", - ")\n", - "\n", - "converter.set_colormap(\n", - " color_by_array=\"temperature\",\n", - " colormap=\"heat\",\n", - " intensity_range=(15.0, 40.0) # Temperature range in Celsius\n", - ")\n", - "\n", - "output_file = output_dir / \"example3_heat_temperature.usd\"\n", - "stage = converter.convert(str(output_file))\n", - "print(f\"✓ Created: {output_file}\")" + "print(\"\\nExample 3: Heat colormap for temperature visualization\")\nprint(\"-\" * 50)\n\nmeshes = [create_example_mesh_with_data(t) for t in range(20)]\n\nconverter = ConvertVTK4DToUSD(\n data_basename=\"TemperatureModel\",\n input_polydata=meshes,\n mask_ids=None\n)\n\nconverter.set_colormap(\n color_by_array=\"temperature\",\n colormap=\"heat\",\n intensity_range=(15.0, 40.0) # Temperature range in Celsius\n)\n\noutput_file = output_dir / \"example3_heat_temperature.usd\"\nstage = converter.convert(str(output_file))\nprint(f\"✓ Created: {output_file}\")" ] }, { @@ -235,26 +125,7 @@ "metadata": {}, "outputs": [], "source": [ - "print(\"\\nExample 4: Coolwarm colormap for diverging data\")\n", - "print(\"-\" * 50)\n", - "\n", - "meshes = [create_example_mesh_with_data(t) for t in range(20)]\n", - "\n", - "converter = ConvertVTK4DToUSD(\n", - " data_basename=\"CardiacModel\",\n", - " input_polydata=meshes,\n", - " mask_ids=None\n", - ")\n", - "\n", - "converter.set_colormap(\n", - " color_by_array=\"transmembrane_potential\",\n", - " colormap=\"coolwarm\",\n", - " intensity_range=(-80.0, 20.0)\n", - ")\n", - "\n", - "output_file = output_dir / \"example4_coolwarm_diverging.usd\"\n", - "stage = converter.convert(str(output_file))\n", - "print(f\"✓ Created: {output_file}\")" + "print(\"\\nExample 4: Coolwarm colormap for diverging data\")\nprint(\"-\" * 50)\n\nmeshes = [create_example_mesh_with_data(t) for t in range(20)]\n\nconverter = ConvertVTK4DToUSD(\n data_basename=\"CardiacModel\",\n input_polydata=meshes,\n mask_ids=None\n)\n\nconverter.set_colormap(\n color_by_array=\"transmembrane_potential\",\n colormap=\"coolwarm\",\n intensity_range=(-80.0, 20.0)\n)\n\noutput_file = output_dir / \"example4_coolwarm_diverging.usd\"\nstage = converter.convert(str(output_file))\nprint(f\"✓ Created: {output_file}\")" ] }, { @@ -272,26 +143,7 @@ "metadata": {}, "outputs": [], "source": [ - "print(\"\\nExample 5: Grayscale colormap\")\n", - "print(\"-\" * 50)\n", - "\n", - "meshes = [create_example_mesh_with_data(t) for t in range(20)]\n", - "\n", - "converter = ConvertVTK4DToUSD(\n", - " data_basename=\"CardiacModel\",\n", - " input_polydata=meshes,\n", - " mask_ids=None\n", - ")\n", - "\n", - "converter.set_colormap(\n", - " color_by_array=\"transmembrane_potential\",\n", - " colormap=\"grayscale\",\n", - " intensity_range=None\n", - ")\n", - "\n", - "output_file = output_dir / \"example5_grayscale.usd\"\n", - "stage = converter.convert(str(output_file))\n", - "print(f\"✓ Created: {output_file}\")" + "print(\"\\nExample 5: Grayscale colormap\")\nprint(\"-\" * 50)\n\nmeshes = [create_example_mesh_with_data(t) for t in range(20)]\n\nconverter = ConvertVTK4DToUSD(\n data_basename=\"CardiacModel\",\n input_polydata=meshes,\n mask_ids=None\n)\n\nconverter.set_colormap(\n color_by_array=\"transmembrane_potential\",\n colormap=\"grayscale\",\n intensity_range=None\n)\n\noutput_file = output_dir / \"example5_grayscale.usd\"\nstage = converter.convert(str(output_file))\nprint(f\"✓ Created: {output_file}\")" ] }, { @@ -309,32 +161,7 @@ "metadata": {}, "outputs": [], "source": [ - "print(\"\\nExample 6: Random colormap for categorical visualization\")\n", - "print(\"-\" * 50)\n", - "\n", - "meshes = [create_example_mesh_with_data(t) for t in range(20)]\n", - "\n", - "# Add categorical-like data (discrete regions)\n", - "for mesh in meshes:\n", - " z_values = mesh.points[:, 2]\n", - " regions = np.floor(3 * (z_values + 1) / 2) # 3 regions\n", - " mesh.point_data['region_id'] = regions\n", - "\n", - "converter = ConvertVTK4DToUSD(\n", - " data_basename=\"RegionModel\",\n", - " input_polydata=meshes,\n", - " mask_ids=None\n", - ")\n", - "\n", - "converter.set_colormap(\n", - " color_by_array=\"region_id\",\n", - " colormap=\"random\",\n", - " intensity_range=None\n", - ")\n", - "\n", - "output_file = output_dir / \"example6_random_categorical.usd\"\n", - "stage = converter.convert(str(output_file))\n", - "print(f\"✓ Created: {output_file}\")" + "print(\"\\nExample 6: Random colormap for categorical visualization\")\nprint(\"-\" * 50)\n\nmeshes = [create_example_mesh_with_data(t) for t in range(20)]\n\n# Add categorical-like data (discrete regions)\nfor mesh in meshes:\n z_values = mesh.points[:, 2]\n regions = np.floor(3 * (z_values + 1) / 2) # 3 regions\n mesh.point_data['region_id'] = regions\n\nconverter = ConvertVTK4DToUSD(\n data_basename=\"RegionModel\",\n input_polydata=meshes,\n mask_ids=None\n)\n\nconverter.set_colormap(\n color_by_array=\"region_id\",\n colormap=\"random\",\n intensity_range=None\n)\n\noutput_file = output_dir / \"example6_random_categorical.usd\"\nstage = converter.convert(str(output_file))\nprint(f\"✓ Created: {output_file}\")" ] }, { @@ -352,29 +179,7 @@ "metadata": {}, "outputs": [], "source": [ - "print(\"\\nExample 7: Method chaining with viridis colormap\")\n", - "print(\"-\" * 50)\n", - "\n", - "meshes = [create_example_mesh_with_data(t) for t in range(20)]\n", - "\n", - "output_file = output_dir / \"example7_viridis_chained.usd\"\n", - "\n", - "# Method chaining for concise code\n", - "stage = (\n", - " ConvertVTK4DToUSD(\n", - " data_basename=\"CardiacModel\",\n", - " input_polydata=meshes,\n", - " mask_ids=None\n", - " )\n", - " .set_colormap(\n", - " color_by_array=\"transmembrane_potential\",\n", - " colormap=\"viridis\",\n", - " intensity_range=(-80.0, 20.0)\n", - " )\n", - " .convert(str(output_file))\n", - ")\n", - "\n", - "print(f\"✓ Created: {output_file}\")" + "print(\"\\nExample 7: Method chaining with viridis colormap\")\nprint(\"-\" * 50)\n\nmeshes = [create_example_mesh_with_data(t) for t in range(20)]\n\noutput_file = output_dir / \"example7_viridis_chained.usd\"\n\n# Method chaining for concise code\nstage = (\n ConvertVTK4DToUSD(\n data_basename=\"CardiacModel\",\n input_polydata=meshes,\n mask_ids=None\n )\n .set_colormap(\n color_by_array=\"transmembrane_potential\",\n colormap=\"viridis\",\n intensity_range=(-80.0, 20.0)\n )\n .convert(str(output_file))\n)\n\nprint(f\"✓ Created: {output_file}\")" ] }, { @@ -413,11 +218,7 @@ "metadata": {}, "outputs": [], "source": [ - "print(\"\\n\" + \"=\" * 50)\n", - "print(\"All examples completed!\")\n", - "print(\"=\" * 50)\n", - "print(f\"\\nOutput files created in: {output_dir.absolute()}\")\n", - "print(\"\\nView these USD files in NVIDIA Omniverse to see the colormap visualizations.\")" + "print(\"\\n\" + \"=\" * 50)\nprint(\"All examples completed!\")\nprint(\"=\" * 50)\nprint(f\"\\nOutput files created in: {output_dir.absolute()}\")\nprint(\"\\nView these USD files in NVIDIA Omniverse to see the colormap visualizations.\")" ] } ], @@ -442,4 +243,4 @@ }, "nbformat": 4, "nbformat_minor": 4 -} +} \ No newline at end of file diff --git a/experiments/DisplacementField_To_USD/displacement_field_to_usd.ipynb b/experiments/DisplacementField_To_USD/displacement_field_to_usd.ipynb index a531c22..f9a44c7 100644 --- a/experiments/DisplacementField_To_USD/displacement_field_to_usd.ipynb +++ b/experiments/DisplacementField_To_USD/displacement_field_to_usd.ipynb @@ -1,325 +1,199 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Displacement Field to USD for Omniverse Visualization\n", - "\n", - "This notebook demonstrates how to convert time-varying 3D displacement fields (stored as NIfTI images) into USD format for visualization in NVIDIA Omniverse using the PhysicsNeMo extension.\n", - "\n", - "## Pipeline Overview\n", - "\n", - "1. Load 3D vector fields from NIfTI files using ITK\n", - "2. Convert ITK images to VTK data structures\n", - "3. Create PhysicsNeMo-compatible USD stages for time-varying visualization\n", - "4. Export animated USD stage for Omniverse Create/Kit\n", - "\n", - "## Architecture\n", - "\n", - "The `DisplacementFieldToUSD` class encapsulates all pipeline logic for converting medical imaging displacement fields to Omniverse-compatible USD format.\n", - "\n", - "## Required Libraries\n", - "\n", - "- ITK (InsightToolkit) for medical image I/O\n", - "- VTK (Visualization Toolkit) for data structure conversion \n", - "- numpy for array processing\n", - "- PhysicsNeMo and PhysicsNeMo-Sym for physics-based visualization\n", - "- Omniverse USD Python API for stage creation\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 1. Install Dependencies\n", - "\n", - "Install the latest versions of PhysicsNeMo and PhysicsNeMo-Sym from GitHub.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Install PhysicsNeMo and PhysicsNeMo-Sym from GitHub\n", - "# Uncomment these lines if running for the first time\n", - "\n", - "# !pip install git+https://github.com/NVIDIA/physicsnemo.git\n", - "# !pip install git+https://github.com/NVIDIA/physicsnemo-sym.git\n", - "\n", - "# Install other required packages\n", - "# !pip install itk vtk numpy pxr\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 2. Import Libraries\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import os\n", - "from pathlib import Path\n", - "from typing import List, Optional, Tuple\n", - "\n", - "import itk\n", - "import numpy as np\n", - "import vtk\n", - "from pxr import Gf, Usd, UsdGeom, Vt\n", - "\n", - "print(\"Libraries imported successfully!\")\n", - "print(f\"ITK version: {itk.Version.GetITKVersion()}\")\n", - "print(f\"VTK version: {vtk.vtkVersion.GetVTKVersion()}\")\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 3. Import DisplacementFieldToUSD Class\n", - "\n", - "Import the converter class from the local module.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Import the DisplacementFieldToUSD class from the local module\n", - "from displacement_field_converter import DisplacementFieldToUSD\n", - "\n", - "# Display class documentation\n", - "help(DisplacementFieldToUSD)\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 4. Helper: Generate Sample Displacement Fields\n", - "\n", - "For demonstration purposes, this function creates synthetic time-varying displacement fields.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def generate_sample_displacement_fields(\n", - " output_dir: str,\n", - " n_timesteps: int = 10,\n", - " size: Tuple[int, int, int] = (32, 32, 32)\n", - ") -> List[str]:\n", - " \"\"\"\n", - " Generate synthetic time-varying displacement fields for demonstration.\n", - " \n", - " Creates a rotating/pulsating vector field pattern.\n", - " \n", - " Args:\n", - " output_dir: Directory to save NIfTI files\n", - " n_timesteps: Number of time steps to generate\n", - " size: 3D size of the displacement field\n", - " \n", - " Returns:\n", - " List of file paths to generated NIfTI files\n", - " \"\"\"\n", - " os.makedirs(output_dir, exist_ok=True)\n", - " file_paths = []\n", - " \n", - " # Create coordinate grid\n", - " z, y, x = np.meshgrid(\n", - " np.linspace(-1, 1, size[2]),\n", - " np.linspace(-1, 1, size[1]),\n", - " np.linspace(-1, 1, size[0]),\n", - " indexing='ij'\n", - " )\n", - " \n", - " for t in range(n_timesteps):\n", - " # Time-varying rotation angle\n", - " theta = 2 * np.pi * t / n_timesteps\n", - " \n", - " # Create rotating vector field with radial component\n", - " r = np.sqrt(x**2 + y**2 + z**2)\n", - " \n", - " # Displacement components (rotating + pulsating)\n", - " displacement_x = -y * np.cos(theta) + z * np.sin(theta)\n", - " displacement_y = x * np.cos(theta) - r * 0.2 * np.sin(theta)\n", - " displacement_z = -x * np.sin(theta) + y * np.cos(theta)\n", - " \n", - " # Scale by distance from center (creates flow pattern)\n", - " scale_factor = 5.0 * (1 - r / np.max(r))\n", - " displacement_x *= scale_factor\n", - " displacement_y *= scale_factor\n", - " displacement_z *= scale_factor\n", - " \n", - " # Stack into vector field (z, y, x, 3)\n", - " displacement_field = np.stack(\n", - " [displacement_x, displacement_y, displacement_z],\n", - " axis=-1\n", - " ).astype(np.float32)\n", - " \n", - " # Convert to ITK image\n", - " itk_image = itk.image_from_array(displacement_field, is_vector=True)\n", - " itk_image.SetSpacing([1.0, 1.0, 1.0])\n", - " itk_image.SetOrigin([0.0, 0.0, 0.0])\n", - " \n", - " # Save as NIfTI\n", - " file_path = os.path.join(output_dir, f\"displacement_t{t:03d}.nii.gz\")\n", - " itk.imwrite(itk_image, file_path, compression=True)\n", - " file_paths.append(file_path)\n", - " \n", - " print(f\"Generated: {file_path}\")\n", - " \n", - " return file_paths\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 5. Example Usage\n", - "\n", - "Demonstrate the complete pipeline with synthetic data.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Configuration\n", - "output_dir = \"./sample_displacement_fields\"\n", - "usd_output_path = \"./displacement_field_animation.usd\"\n", - "n_timesteps = 10\n", - "\n", - "# Generate sample data\n", - "print(\"Generating sample displacement fields...\")\n", - "nifti_files = generate_sample_displacement_fields(\n", - " output_dir,\n", - " n_timesteps=n_timesteps,\n", - " size=(32, 32, 32)\n", - ")\n", - "\n", - "print(f\"\\\\nGenerated {len(nifti_files)} sample files\")\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Create converter instance\n", - "converter = DisplacementFieldToUSD(\n", - " subsample_factor=4,\n", - " vector_scale=2.0\n", - ")\n", - "\n", - "# Run complete pipeline\n", - "stage = converter.process_pipeline(\n", - " nifti_files=nifti_files,\n", - " output_path=usd_output_path,\n", - " fps=24.0\n", - ")\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Example with real displacement field data\n", - "# Uncomment and modify paths to use your own data:\n", - "\n", - "# nifti_files = [\n", - "# \"path/to/displacement_t00.nii.gz\",\n", - "# \"path/to/displacement_t01.nii.gz\",\n", - "# \"path/to/displacement_t02.nii.gz\",\n", - "# # ... more files\n", - "# ]\n", - "#\n", - "# converter = DisplacementFieldToUSD(subsample_factor=8, vector_scale=5.0)\n", - "# stage = converter.process_pipeline(\n", - "# nifti_files=nifti_files,\n", - "# output_path=\"cardiac_motion.usd\",\n", - "# fps=10.0\n", - "# )\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 7. Visualization in Omniverse\n", - "\n", - "### Steps to Visualize:\n", - "\n", - "1. **Open Omniverse Create or Kit**\n", - " - Launch NVIDIA Omniverse Create\n", - " - File → Open and select the generated USD file\n", - "\n", - "2. **Enable PhysicsNeMo Extension**\n", - " - Window → Extensions\n", - " - Search for \"PhysicsNeMo\"\n", - " - Enable the extension\n", - "\n", - "3. **Visualize the Vector Field**\n", - " - Select `/DisplacementField/VectorField` prim in the stage\n", - " - In PhysicsNeMo panel, choose visualization mode:\n", - " - **Streamlines**: Flow trajectories\n", - " - **Vector Glyphs**: Direction and magnitude at points\n", - " - **Volume Rendering**: Field magnitude as volume\n", - " - **Flow Particles**: Animated particles\n", - "\n", - "4. **Play Animation**\n", - " - Use timeline controls to play through time steps\n", - " - Adjust playback speed as needed\n", - "\n", - "### Class Methods Summary:\n", - "\n", - "- `load_nifti_files()`: Load displacement fields from NIfTI\n", - "- `convert_to_vtk()`: Convert ITK images to VTK format\n", - "- `extract_all_vector_fields()`: Extract subsampled data\n", - "- `create_usd_stage()`: Create time-varying USD stage\n", - "- `process_pipeline()`: Run complete pipeline\n", - "\n", - "### Key Features:\n", - "\n", - "✅ Class-based architecture for clean, reusable code\n", - "✅ Complete ITK → VTK → USD pipeline\n", - "✅ Time-varying animation support \n", - "✅ PhysicsNeMo-compatible velocities attribute\n", - "✅ Configurable subsampling and vector scaling\n", - "✅ Production-ready for medical imaging workflows\n", - "\n", - "This implementation encapsulates all logic in the `DisplacementFieldToUSD` class, making it easy to integrate into larger pipelines or customize for specific use cases.\n" - ] - } - ], - "metadata": { - "language_info": { - "name": "python" - } + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Displacement Field to USD for Omniverse Visualization\n", + "\n", + "This notebook demonstrates how to convert time-varying 3D displacement fields (stored as NIfTI images) into USD format for visualization in NVIDIA Omniverse using the PhysicsNeMo extension.\n", + "\n", + "## Pipeline Overview\n", + "\n", + "1. Load 3D vector fields from NIfTI files using ITK\n", + "2. Convert ITK images to VTK data structures\n", + "3. Create PhysicsNeMo-compatible USD stages for time-varying visualization\n", + "4. Export animated USD stage for Omniverse Create/Kit\n", + "\n", + "## Architecture\n", + "\n", + "The `DisplacementFieldToUSD` class encapsulates all pipeline logic for converting medical imaging displacement fields to Omniverse-compatible USD format.\n", + "\n", + "## Required Libraries\n", + "\n", + "- ITK (InsightToolkit) for medical image I/O\n", + "- VTK (Visualization Toolkit) for data structure conversion \n", + "- numpy for array processing\n", + "- PhysicsNeMo and PhysicsNeMo-Sym for physics-based visualization\n", + "- Omniverse USD Python API for stage creation\n" + ] }, - "nbformat": 4, - "nbformat_minor": 2 -} + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1. Install Dependencies\n", + "\n", + "Install the latest versions of PhysicsNeMo and PhysicsNeMo-Sym from GitHub.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Install PhysicsNeMo and PhysicsNeMo-Sym from GitHub\n# Uncomment these lines if running for the first time\n\n# !pip install git+https://github.com/NVIDIA/physicsnemo.git\n# !pip install git+https://github.com/NVIDIA/physicsnemo-sym.git\n\n# Install other required packages\n# !pip install itk vtk numpy pxr\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2. Import Libraries\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\nfrom pathlib import Path\nfrom typing import List, Optional, Tuple\n\nimport itk\nimport numpy as np\nimport vtk\nfrom pxr import Gf, Usd, UsdGeom, Vt\n\nprint(\"Libraries imported successfully!\")\nprint(f\"ITK version: {itk.Version.GetITKVersion()}\")\nprint(f\"VTK version: {vtk.vtkVersion.GetVTKVersion()}\")\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3. Import DisplacementFieldToUSD Class\n", + "\n", + "Import the converter class from the local module.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Import the DisplacementFieldToUSD class from the local module\nfrom displacement_field_converter import DisplacementFieldToUSD\n\n# Display class documentation\nhelp(DisplacementFieldToUSD)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 4. Helper: Generate Sample Displacement Fields\n", + "\n", + "For demonstration purposes, this function creates synthetic time-varying displacement fields.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def generate_sample_displacement_fields(\n output_dir: str,\n n_timesteps: int = 10,\n size: Tuple[int, int, int] = (32, 32, 32)\n) -> List[str]:\n \"\"\"\n Generate synthetic time-varying displacement fields for demonstration.\n \n Creates a rotating/pulsating vector field pattern.\n \n Args:\n output_dir: Directory to save NIfTI files\n n_timesteps: Number of time steps to generate\n size: 3D size of the displacement field\n \n Returns:\n List of file paths to generated NIfTI files\n \"\"\"\n os.makedirs(output_dir, exist_ok=True)\n file_paths = []\n \n # Create coordinate grid\n z, y, x = np.meshgrid(\n np.linspace(-1, 1, size[2]),\n np.linspace(-1, 1, size[1]),\n np.linspace(-1, 1, size[0]),\n indexing='ij'\n )\n \n for t in range(n_timesteps):\n # Time-varying rotation angle\n theta = 2 * np.pi * t / n_timesteps\n \n # Create rotating vector field with radial component\n r = np.sqrt(x**2 + y**2 + z**2)\n \n # Displacement components (rotating + pulsating)\n displacement_x = -y * np.cos(theta) + z * np.sin(theta)\n displacement_y = x * np.cos(theta) - r * 0.2 * np.sin(theta)\n displacement_z = -x * np.sin(theta) + y * np.cos(theta)\n \n # Scale by distance from center (creates flow pattern)\n scale_factor = 5.0 * (1 - r / np.max(r))\n displacement_x *= scale_factor\n displacement_y *= scale_factor\n displacement_z *= scale_factor\n \n # Stack into vector field (z, y, x, 3)\n displacement_field = np.stack(\n [displacement_x, displacement_y, displacement_z],\n axis=-1\n ).astype(np.float32)\n \n # Convert to ITK image\n itk_image = itk.image_from_array(displacement_field, is_vector=True)\n itk_image.SetSpacing([1.0, 1.0, 1.0])\n itk_image.SetOrigin([0.0, 0.0, 0.0])\n \n # Save as NIfTI\n file_path = os.path.join(output_dir, f\"displacement_t{t:03d}.nii.gz\")\n itk.imwrite(itk_image, file_path, compression=True)\n file_paths.append(file_path)\n \n print(f\"Generated: {file_path}\")\n \n return file_paths\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 5. Example Usage\n", + "\n", + "Demonstrate the complete pipeline with synthetic data.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Configuration\noutput_dir = \"./sample_displacement_fields\"\nusd_output_path = \"./displacement_field_animation.usd\"\nn_timesteps = 10\n\n# Generate sample data\nprint(\"Generating sample displacement fields...\")\nnifti_files = generate_sample_displacement_fields(\n output_dir,\n n_timesteps=n_timesteps,\n size=(32, 32, 32)\n)\n\nprint(f\"\\\\nGenerated {len(nifti_files)} sample files\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Create converter instance\nconverter = DisplacementFieldToUSD(\n subsample_factor=4,\n vector_scale=2.0\n)\n\n# Run complete pipeline\nstage = converter.process_pipeline(\n nifti_files=nifti_files,\n output_path=usd_output_path,\n fps=24.0\n)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Example with real displacement field data\n# Uncomment and modify paths to use your own data:\n\n# nifti_files = [\n# \"path/to/displacement_t00.nii.gz\",\n# \"path/to/displacement_t01.nii.gz\",\n# \"path/to/displacement_t02.nii.gz\",\n# # ... more files\n# ]\n#\n# converter = DisplacementFieldToUSD(subsample_factor=8, vector_scale=5.0)\n# stage = converter.process_pipeline(\n# nifti_files=nifti_files,\n# output_path=\"cardiac_motion.usd\",\n# fps=10.0\n# )\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 7. Visualization in Omniverse\n", + "\n", + "### Steps to Visualize:\n", + "\n", + "1. **Open Omniverse Create or Kit**\n", + " - Launch NVIDIA Omniverse Create\n", + " - File → Open and select the generated USD file\n", + "\n", + "2. **Enable PhysicsNeMo Extension**\n", + " - Window → Extensions\n", + " - Search for \"PhysicsNeMo\"\n", + " - Enable the extension\n", + "\n", + "3. **Visualize the Vector Field**\n", + " - Select `/DisplacementField/VectorField` prim in the stage\n", + " - In PhysicsNeMo panel, choose visualization mode:\n", + " - **Streamlines**: Flow trajectories\n", + " - **Vector Glyphs**: Direction and magnitude at points\n", + " - **Volume Rendering**: Field magnitude as volume\n", + " - **Flow Particles**: Animated particles\n", + "\n", + "4. **Play Animation**\n", + " - Use timeline controls to play through time steps\n", + " - Adjust playback speed as needed\n", + "\n", + "### Class Methods Summary:\n", + "\n", + "- `load_nifti_files()`: Load displacement fields from NIfTI\n", + "- `convert_to_vtk()`: Convert ITK images to VTK format\n", + "- `extract_all_vector_fields()`: Extract subsampled data\n", + "- `create_usd_stage()`: Create time-varying USD stage\n", + "- `process_pipeline()`: Run complete pipeline\n", + "\n", + "### Key Features:\n", + "\n", + "✅ Class-based architecture for clean, reusable code\n", + "✅ Complete ITK → VTK → USD pipeline\n", + "✅ Time-varying animation support \n", + "✅ PhysicsNeMo-compatible velocities attribute\n", + "✅ Configurable subsampling and vector scaling\n", + "✅ Production-ready for medical imaging workflows\n", + "\n", + "This implementation encapsulates all logic in the `DisplacementFieldToUSD` class, making it easy to integrate into larger pipelines or customize for specific use cases.\n" + ] + } + ], + "metadata": { + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} \ No newline at end of file diff --git a/experiments/Heart-GatedCT_To_USD/0-download_and_convert_4d_to_3d.ipynb b/experiments/Heart-GatedCT_To_USD/0-download_and_convert_4d_to_3d.ipynb index f457620..4becf5a 100644 --- a/experiments/Heart-GatedCT_To_USD/0-download_and_convert_4d_to_3d.ipynb +++ b/experiments/Heart-GatedCT_To_USD/0-download_and_convert_4d_to_3d.ipynb @@ -6,11 +6,7 @@ "metadata": {}, "outputs": [], "source": [ - "import os\n", - "import shutil\n", - "import urllib\n", - "\n", - "from physiomotion4d.convert_nrrd_4d_to_3d import ConvertNRRD4DTo3D" + "import os\nimport shutil\nimport urllib\n\nfrom physiomotion4d.convert_nrrd_4d_to_3d import ConvertNRRD4DTo3D" ] }, { @@ -19,14 +15,7 @@ "metadata": {}, "outputs": [], "source": [ - "data_dir = \"../../data/Slicer-Heart-CT\"\n", - "output_dir = \"./results/\"\n", - "\n", - "if not os.path.exists(data_dir):\n", - " os.makedirs(data_dir)\n", - "\n", - "if not os.path.exists(output_dir):\n", - " os.makedirs(output_dir)" + "data_dir = \"../../data/Slicer-Heart-CT\"\noutput_dir = \"./results/\"\n\nif not os.path.exists(data_dir):\n os.makedirs(data_dir)\n\nif not os.path.exists(output_dir):\n os.makedirs(output_dir)" ] }, { @@ -35,14 +24,7 @@ "metadata": {}, "outputs": [], "source": [ - "\n", - "\n", - "\n", - "input_image_url = \"https://github.com/Slicer-Heart-CT/Slicer-Heart-CT/releases/download/TestingData/TruncalValve_4DCT.seq.nrrd\"\n", - "input_image_filename = os.path.join(data_dir, \"TruncalValve_4DCT.seq.nrrd\")\n", - "\n", - "if not os.path.exists(input_image_filename):\n", - " urllib.request.urlretrieve(input_image_url, input_image_filename)" + "\n\n\ninput_image_url = \"https://github.com/Slicer-Heart-CT/Slicer-Heart-CT/releases/download/TestingData/TruncalValve_4DCT.seq.nrrd\"\ninput_image_filename = os.path.join(data_dir, \"TruncalValve_4DCT.seq.nrrd\")\n\nif not os.path.exists(input_image_filename):\n urllib.request.urlretrieve(input_image_url, input_image_filename)" ] }, { @@ -51,12 +33,7 @@ "metadata": {}, "outputs": [], "source": [ - "conv = ConvertNRRD4DTo3D()\n", - "conv.load_nrrd_4d(f\"{data_dir}/TruncalValve_4DCT.seq.nrrd\")\n", - "conv.save_3d_images(f\"{data_dir}/slice\")\n", - "\n", - "# Save the mid-stroke slice as the fixed/reference image\n", - "shutil.copyfile(f\"{data_dir}/slice_007.mha\", f\"{output_dir}/slice_fixed.mha\")" + "conv = ConvertNRRD4DTo3D()\nconv.load_nrrd_4d(f\"{data_dir}/TruncalValve_4DCT.seq.nrrd\")\nconv.save_3d_images(f\"{data_dir}/slice\")\n\n# Save the mid-stroke slice as the fixed/reference image\nshutil.copyfile(f\"{data_dir}/slice_007.mha\", f\"{output_dir}/slice_fixed.mha\")" ] } ], @@ -81,4 +58,4 @@ }, "nbformat": 4, "nbformat_minor": 2 -} +} \ No newline at end of file diff --git a/experiments/Heart-GatedCT_To_USD/1-register_images.ipynb b/experiments/Heart-GatedCT_To_USD/1-register_images.ipynb index 69aa147..3488590 100644 --- a/experiments/Heart-GatedCT_To_USD/1-register_images.ipynb +++ b/experiments/Heart-GatedCT_To_USD/1-register_images.ipynb @@ -1,231 +1,231 @@ { - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "id": "3ce61753-11ad-4ade-9afe-6ad1bc748e25", - "metadata": {}, - "outputs": [], - "source": [ - "import itk\n", - "import os\n", - "\n", - "from physiomotion4d.register_images_icon import RegisterImagesICON\n", - "from physiomotion4d.segment_chest_total_segmentator import SegmentChestTotalSegmentator\n", - "\n", - "from physiomotion4d.transform_tools import TransformTools" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "9d2e5d21", - "metadata": {}, - "outputs": [], - "source": [ - "data_dir = \"../../data/Slicer-Heart-CT\"\n", - "\n", - "output_dir = os.path.join(\".\", \"results\")\n", - "if not os.path.exists(output_dir):\n", - " os.makedirs(output_dir)\n", - "\n", - "fixed_image_filename = os.path.join(output_dir, \"slice_fixed.mha\")\n", - "fixed_image = itk.imread(fixed_image_filename)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "b35f90c6", - "metadata": {}, - "outputs": [], - "source": [ - "seg = SegmentChestTotalSegmentator()\n", - "seg.contrast_threshold = 500\n", - "result = seg.segment(fixed_image, contrast_enhanced_study=True)\n", - "labelmap_mask = result[\"labelmap\"]\n", - "lung_mask = result[\"lung\"]\n", - "heart_mask = result[\"heart\"]\n", - "major_vessels_mask = result[\"major_vessels\"]\n", - "bone_mask = result[\"bone\"]\n", - "soft_tissue_mask = result[\"soft_tissue\"]\n", - "other_mask = result[\"other\"]\n", - "contrast_mask = result[\"contrast\"]\n", - "\n", - "fixed_image_labelmap = labelmap_mask\n", - "itk.imwrite(\n", - " fixed_image_labelmap,\n", - " os.path.join(output_dir, \"slice_fixed_mask.mha\"),\n", - " compression=True,\n", - ")\n", - "\n", - "heart_arr = itk.GetArrayFromImage(heart_mask)\n", - "contrast_arr = itk.GetArrayFromImage(contrast_mask)\n", - "major_vessels_arr = itk.GetArrayFromImage(major_vessels_mask)\n", - "fixed_image_dynamic_anatomy_mask = itk.GetImageFromArray(\n", - " heart_arr + contrast_arr + major_vessels_arr\n", - ")\n", - "fixed_image_dynamic_anatomy_mask.CopyInformation(fixed_image)\n", - "itk.imwrite(\n", - " fixed_image_dynamic_anatomy_mask,\n", - " os.path.join(output_dir, \"slice_fixed.dynamic_anatomy_mask.mha\"),\n", - " compression=True,\n", - ")\n", - "\n", - "lung_arr = itk.GetArrayFromImage(lung_mask)\n", - "bone_arr = itk.GetArrayFromImage(bone_mask)\n", - "other_arr = itk.GetArrayFromImage(other_mask)\n", - "fixed_image_static_mask = itk.GetImageFromArray(lung_arr + bone_arr + other_arr)\n", - "fixed_image_static_mask.CopyInformation(fixed_image)\n", - "itk.imwrite(\n", - " fixed_image_static_mask,\n", - " os.path.join(output_dir, \"slice_fixed.static_anatomy_mask.mha\"),\n", - " compression=True,\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "10ffbbf6", - "metadata": {}, - "outputs": [], - "source": [ - "reg = RegisterImagesICON()\n", - "reg.set_mask_dilation(5)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "cc9418da", - "metadata": {}, - "outputs": [], - "source": [ - "for i in range(21):\n", - " print(f\"Processing slice {i:03d}\")\n", - " moving_image = itk.imread(os.path.join(data_dir, f\"slice_{i:03d}.mha\"))\n", - " result = seg.segment(moving_image, contrast_enhanced_study=True)\n", - " labelmap_mask = result[\"labelmap\"]\n", - " lung_mask = result[\"lung\"]\n", - " heart_mask = result[\"heart\"]\n", - " major_vessels_mask = result[\"major_vessels\"]\n", - " bone_mask = result[\"bone\"]\n", - " soft_tissue_mask = result[\"soft_tissue\"]\n", - " other_mask = result[\"other\"]\n", - " contrast_mask = result[\"contrast\"]\n", - " itk.imwrite(\n", - " labelmap_mask,\n", - " os.path.join(output_dir, f\"slice_{i:03d}_mask.mha\"),\n", - " compression=True,\n", - " )\n", - "\n", - " # Register the whole image\n", - " reg.set_fixed_image(fixed_image)\n", - " results = reg.register(moving_image)\n", - " phi_FM = results[\"phi_FM\"]\n", - " phi_MF = results[\"phi_MF\"]\n", - " moving_image_reg = TransformTools().transform_image(\n", - " moving_image, phi_MF, fixed_image, \"sinc\"\n", - " ) # Final resampling with sinc\n", - " itk.imwrite(moving_image_reg, os.path.join(output_dir, f\"slice_{i:03d}.reg_all.mha\"), compression=True)\n", - " itk.transformwrite([phi_FM], os.path.join(output_dir, f\"slice_{i:03d}.reg_all.phi_FM.hdf\"), compression=True)\n", - " itk.transformwrite([phi_MF], os.path.join(output_dir, f\"slice_{i:03d}.reg_all.phi_MF.hdf\"), compression=True)\n", - "\n", - " # Register the dynamic anatomy mask\n", - " heart_arr = itk.GetArrayFromImage(heart_mask)\n", - " contrast_arr = itk.GetArrayFromImage(contrast_mask)\n", - " major_vessels_arr = itk.GetArrayFromImage(major_vessels_mask)\n", - " dynamic_anatomy_arr = heart_arr + contrast_arr + major_vessels_arr\n", - " moving_image_dynamic_anatomy_mask = itk.GetImageFromArray(dynamic_anatomy_arr)\n", - " moving_image_dynamic_anatomy_mask.CopyInformation(moving_image)\n", - " reg.set_fixed_image(fixed_image)\n", - " reg.set_fixed_image_mask(fixed_image_dynamic_anatomy_mask)\n", - " results = reg.register(\n", - " moving_image, moving_image_dynamic_anatomy_mask\n", - " )\n", - " phi_FM = results[\"phi_FM\"]\n", - " phi_MF = results[\"phi_MF\"]\n", - " moving_image_reg_dynamic_anatomy = TransformTools().transform_image(\n", - " moving_image, phi_MF, fixed_image, \"sinc\"\n", - " ) # Final resampling with sinc\n", - " itk.imwrite(\n", - " moving_image_dynamic_anatomy_mask,\n", - " os.path.join(output_dir, f\"slice_{i:03d}.dynamic_anatomy_mask.mha\"),\n", - " compression=True,\n", - " )\n", - " itk.imwrite(\n", - " moving_image_reg_dynamic_anatomy,\n", - " os.path.join(output_dir, f\"slice_{i:03d}.reg_dynamic_anatomy.mha\"),\n", - " compression=True,\n", - " )\n", - " itk.transformwrite(\n", - " [phi_FM],\n", - " os.path.join(output_dir, f\"slice_{i:03d}.reg_dynamic_anatomy.phi_FM.hdf\"),\n", - " compression=True,\n", - " )\n", - " itk.transformwrite(\n", - " [phi_MF],\n", - " os.path.join(output_dir, f\"slice_{i:03d}.reg_dynamic_anatomy.phi_MF.hdf\"),\n", - " compression=True,\n", - " )\n", - "\n", - " # Register the static anatomy mask\n", - " lung_arr = itk.GetArrayFromImage(lung_mask)\n", - " bone_arr = itk.GetArrayFromImage(bone_mask)\n", - " other_arr = itk.GetArrayFromImage(other_mask)\n", - " moving_image_static_mask = itk.GetImageFromArray(lung_arr + bone_arr + other_arr)\n", - " moving_image_static_mask.CopyInformation(moving_image)\n", - " reg.set_fixed_image(fixed_image)\n", - " reg.set_fixed_image_mask(fixed_image_static_mask)\n", - " results = reg.register(moving_image, moving_image_static_mask)\n", - " phi_FM = results[\"phi_FM\"]\n", - " phi_MF = results[\"phi_MF\"]\n", - " moving_image_reg_static = TransformTools().transform_image(\n", - " moving_image, phi_MF, fixed_image, \"sinc\"\n", - " ) # Final resampling with sinc\n", - " itk.imwrite(\n", - " moving_image_static_mask,\n", - " os.path.join(output_dir, f\"slice_{i:03d}.static_anatomy_mask.mha\"),\n", - " compression=True,\n", - " )\n", - " itk.imwrite(\n", - " moving_image_reg_static,\n", - " os.path.join(output_dir, f\"slice_{i:03d}.reg_static_anatomy.mha\"),\n", - " compression=True,\n", - " )\n", - " itk.transformwrite(\n", - " [phi_FM],\n", - " os.path.join(output_dir, f\"slice_{i:03d}.reg_static_anatomy.phi_FM.hdf\"),\n", - " compression=True,\n", - " )\n", - " itk.transformwrite(\n", - " [phi_MF],\n", - " os.path.join(output_dir, f\"slice_{i:03d}.reg_static_anatomy.phi_MF.hdf\"),\n", - " compression=True,\n", - " )" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "venv", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.11" - } - }, - "nbformat": 4, - "nbformat_minor": 5 + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "3ce61753-11ad-4ade-9afe-6ad1bc748e25", + "metadata": {}, + "outputs": [], + "source": [ + "import itk\n", + "import os\n", + "\n", + "from physiomotion4d.register_images_icon import RegisterImagesICON\n", + "from physiomotion4d.segment_chest_total_segmentator import SegmentChestTotalSegmentator\n", + "\n", + "from physiomotion4d.transform_tools import TransformTools" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9d2e5d21", + "metadata": {}, + "outputs": [], + "source": [ + "data_dir = \"../../data/Slicer-Heart-CT\"\n", + "\n", + "output_dir = os.path.join(\".\", \"results\")\n", + "if not os.path.exists(output_dir):\n", + " os.makedirs(output_dir)\n", + "\n", + "fixed_image_filename = os.path.join(output_dir, \"slice_fixed.mha\")\n", + "fixed_image = itk.imread(fixed_image_filename)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b35f90c6", + "metadata": {}, + "outputs": [], + "source": [ + "seg = SegmentChestTotalSegmentator()\n", + "seg.contrast_threshold = 500\n", + "result = seg.segment(fixed_image, contrast_enhanced_study=True)\n", + "labelmap_mask = result[\"labelmap\"]\n", + "lung_mask = result[\"lung\"]\n", + "heart_mask = result[\"heart\"]\n", + "major_vessels_mask = result[\"major_vessels\"]\n", + "bone_mask = result[\"bone\"]\n", + "soft_tissue_mask = result[\"soft_tissue\"]\n", + "other_mask = result[\"other\"]\n", + "contrast_mask = result[\"contrast\"]\n", + "\n", + "fixed_image_labelmap = labelmap_mask\n", + "itk.imwrite(\n", + " fixed_image_labelmap,\n", + " os.path.join(output_dir, \"slice_fixed_mask.mha\"),\n", + " compression=True,\n", + ")\n", + "\n", + "heart_arr = itk.GetArrayFromImage(heart_mask)\n", + "contrast_arr = itk.GetArrayFromImage(contrast_mask)\n", + "major_vessels_arr = itk.GetArrayFromImage(major_vessels_mask)\n", + "fixed_image_dynamic_anatomy_mask = itk.GetImageFromArray(\n", + " heart_arr + contrast_arr + major_vessels_arr\n", + ")\n", + "fixed_image_dynamic_anatomy_mask.CopyInformation(fixed_image)\n", + "itk.imwrite(\n", + " fixed_image_dynamic_anatomy_mask,\n", + " os.path.join(output_dir, \"slice_fixed.dynamic_anatomy_mask.mha\"),\n", + " compression=True,\n", + ")\n", + "\n", + "lung_arr = itk.GetArrayFromImage(lung_mask)\n", + "bone_arr = itk.GetArrayFromImage(bone_mask)\n", + "other_arr = itk.GetArrayFromImage(other_mask)\n", + "fixed_image_static_mask = itk.GetImageFromArray(lung_arr + bone_arr + other_arr)\n", + "fixed_image_static_mask.CopyInformation(fixed_image)\n", + "itk.imwrite(\n", + " fixed_image_static_mask,\n", + " os.path.join(output_dir, \"slice_fixed.static_anatomy_mask.mha\"),\n", + " compression=True,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "10ffbbf6", + "metadata": {}, + "outputs": [], + "source": [ + "reg = RegisterImagesICON()\n", + "reg.set_mask_dilation(5)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cc9418da", + "metadata": {}, + "outputs": [], + "source": [ + "for i in range(21):\n", + " print(f\"Processing slice {i:03d}\")\n", + " moving_image = itk.imread(os.path.join(data_dir, f\"slice_{i:03d}.mha\"))\n", + " result = seg.segment(moving_image, contrast_enhanced_study=True)\n", + " labelmap_mask = result[\"labelmap\"]\n", + " lung_mask = result[\"lung\"]\n", + " heart_mask = result[\"heart\"]\n", + " major_vessels_mask = result[\"major_vessels\"]\n", + " bone_mask = result[\"bone\"]\n", + " soft_tissue_mask = result[\"soft_tissue\"]\n", + " other_mask = result[\"other\"]\n", + " contrast_mask = result[\"contrast\"]\n", + " itk.imwrite(\n", + " labelmap_mask,\n", + " os.path.join(output_dir, f\"slice_{i:03d}_mask.mha\"),\n", + " compression=True,\n", + " )\n", + "\n", + " # Register the whole image\n", + " reg.set_fixed_image(fixed_image)\n", + " results = reg.register(moving_image)\n", + " inverse_transform = results[\"inverse_transform\"]\n", + " forward_transform = results[\"forward_transform\"]\n", + " moving_image_reg = TransformTools().transform_image(\n", + " moving_image, forward_transform, fixed_image, \"sinc\"\n", + " ) # Final resampling with sinc\n", + " itk.imwrite(moving_image_reg, os.path.join(output_dir, f\"slice_{i:03d}.reg_all.mha\"), compression=True)\n", + " itk.transformwrite([forward_transform], os.path.join(output_dir, f\"slice_{i:03d}.reg_all.forward.hdf\"), compression=True)\n", + " itk.transformwrite([inverse_transform], os.path.join(output_dir, f\"slice_{i:03d}.reg_all.inverse.hdf\"), compression=True)\n", + "\n", + " # Register the dynamic anatomy mask\n", + " heart_arr = itk.GetArrayFromImage(heart_mask)\n", + " contrast_arr = itk.GetArrayFromImage(contrast_mask)\n", + " major_vessels_arr = itk.GetArrayFromImage(major_vessels_mask)\n", + " dynamic_anatomy_arr = heart_arr + contrast_arr + major_vessels_arr\n", + " moving_image_dynamic_anatomy_mask = itk.GetImageFromArray(dynamic_anatomy_arr)\n", + " moving_image_dynamic_anatomy_mask.CopyInformation(moving_image)\n", + " reg.set_fixed_image(fixed_image)\n", + " reg.set_fixed_image_mask(fixed_image_dynamic_anatomy_mask)\n", + " results = reg.register(\n", + " moving_image, moving_image_dynamic_anatomy_mask\n", + " )\n", + " inverse_transform = results[\"inverse_transform\"]\n", + " forward_transform = results[\"forward_transform\"]\n", + " moving_image_reg_dynamic_anatomy = TransformTools().transform_image(\n", + " moving_image, forward_transform, fixed_image, \"sinc\"\n", + " ) # Final resampling with sinc\n", + " itk.imwrite(\n", + " moving_image_dynamic_anatomy_mask,\n", + " os.path.join(output_dir, f\"slice_{i:03d}.dynamic_anatomy_mask.mha\"),\n", + " compression=True,\n", + " )\n", + " itk.imwrite(\n", + " moving_image_reg_dynamic_anatomy,\n", + " os.path.join(output_dir, f\"slice_{i:03d}.reg_dynamic_anatomy.mha\"),\n", + " compression=True,\n", + " )\n", + " itk.transformwrite(\n", + " [forward_transform],\n", + " os.path.join(output_dir, f\"slice_{i:03d}.reg_dynamic_anatomy.forward.hdf\"),\n", + " compression=True,\n", + " )\n", + " itk.transformwrite(\n", + " [inverse_transform],\n", + " os.path.join(output_dir, f\"slice_{i:03d}.reg_dynamic_anatomy.inverse.hdf\"),\n", + " compression=True,\n", + " )\n", + "\n", + " # Register the static anatomy mask\n", + " lung_arr = itk.GetArrayFromImage(lung_mask)\n", + " bone_arr = itk.GetArrayFromImage(bone_mask)\n", + " other_arr = itk.GetArrayFromImage(other_mask)\n", + " moving_image_static_mask = itk.GetImageFromArray(lung_arr + bone_arr + other_arr)\n", + " moving_image_static_mask.CopyInformation(moving_image)\n", + " reg.set_fixed_image(fixed_image)\n", + " reg.set_fixed_image_mask(fixed_image_static_mask)\n", + " results = reg.register(moving_image, moving_image_static_mask)\n", + " inverse_transform = results[\"inverse_transform\"]\n", + " forward_transform = results[\"forward_transform\"]\n", + " moving_image_reg_static = TransformTools().transform_image(\n", + " moving_image, forward_transform, fixed_image, \"sinc\"\n", + " ) # Final resampling with sinc\n", + " itk.imwrite(\n", + " moving_image_static_mask,\n", + " os.path.join(output_dir, f\"slice_{i:03d}.static_anatomy_mask.mha\"),\n", + " compression=True,\n", + " )\n", + " itk.imwrite(\n", + " moving_image_reg_static,\n", + " os.path.join(output_dir, f\"slice_{i:03d}.reg_static_anatomy.mha\"),\n", + " compression=True,\n", + " )\n", + " itk.transformwrite(\n", + " [forward_transform],\n", + " os.path.join(output_dir, f\"slice_{i:03d}.reg_static_anatomy.forward.hdf\"),\n", + " compression=True,\n", + " )\n", + " itk.transformwrite(\n", + " [inverse_transform],\n", + " os.path.join(output_dir, f\"slice_{i:03d}.reg_static_anatomy.inverse.hdf\"),\n", + " compression=True,\n", + " )" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.11" + } + }, + "nbformat": 4, + "nbformat_minor": 5 } diff --git a/experiments/Heart-GatedCT_To_USD/2-generate_segmentation.ipynb b/experiments/Heart-GatedCT_To_USD/2-generate_segmentation.ipynb index 285879b..04371a4 100644 --- a/experiments/Heart-GatedCT_To_USD/2-generate_segmentation.ipynb +++ b/experiments/Heart-GatedCT_To_USD/2-generate_segmentation.ipynb @@ -7,14 +7,7 @@ "metadata": {}, "outputs": [], "source": [ - "import os\n", - "\n", - "import itk\n", - "import numpy as np\n", - "import pyvista as pv\n", - "\n", - "from physiomotion4d.contour_tools import ContourTools\n", - "from physiomotion4d.segment_chest_total_segmentator import SegmentChestTotalSegmentator\n" + "import os\n\nimport itk\nimport numpy as np\nimport pyvista as pv\n\nfrom physiomotion4d.contour_tools import ContourTools\nfrom physiomotion4d.segment_chest_total_segmentator import SegmentChestTotalSegmentator\n" ] }, { @@ -24,10 +17,7 @@ "metadata": {}, "outputs": [], "source": [ - "# When re-running, you can bypass certain long-running steps\n", - "re_run_image_max = False\n", - "use_fixed_image = True\n", - "re_run_image_segmentation = True" + "# When re-running, you can bypass certain long-running steps\nre_run_image_max = False\nuse_fixed_image = True\nre_run_image_segmentation = True" ] }, { @@ -37,35 +27,7 @@ "metadata": {}, "outputs": [], "source": [ - "output_dir = \"./results\"\n", - "max_image = None\n", - "print(\"Computing max image...\")\n", - "if re_run_image_max and not use_fixed_image:\n", - " # Compute max of all images\n", - " image = None\n", - " try:\n", - " image = itk.imread(os.path.join(output_dir, \"slice_000.reg_dynamic_anatomy.mha\"))\n", - " except (FileNotFoundError, OSError):\n", - " print(\"No image found. Aborting. Please run 1-generate_images.ipynb first.\")\n", - " exit(1)\n", - " arr = itk.array_from_image(image)\n", - " print(arr.shape)\n", - " arr = np.where(arr == 0, -1000, arr)\n", - " for i in range(1, 21):\n", - " print(f\"Processing slice {i:03d}...\")\n", - " tmp_arr = itk.array_from_image(\n", - " itk.imread(os.path.join(output_dir, f\"slice_{i:03d}.reg_dynamic_anatomy.mha\"))\n", - " )\n", - " tmp_arr = np.where(tmp_arr == 0, -1000, tmp_arr)\n", - " arr = np.maximum(arr, tmp_arr)\n", - " print(\"Max image computed.\")\n", - " max_image = itk.image_from_array(arr)\n", - " max_image.CopyInformation(image)\n", - " itk.imwrite(\n", - " max_image,\n", - " os.path.join(output_dir, \"slice_max.reg_dynamic_anatomy.mha\"),\n", - " compression=True,\n", - " )" + "output_dir = \"./results\"\nmax_image = None\nprint(\"Computing max image...\")\nif re_run_image_max and not use_fixed_image:\n # Compute max of all images\n image = None\n try:\n image = itk.imread(os.path.join(output_dir, \"slice_000.reg_dynamic_anatomy.mha\"))\n except (FileNotFoundError, OSError):\n print(\"No image found. Aborting. Please run 1-generate_images.ipynb first.\")\n exit(1)\n arr = itk.array_from_image(image)\n print(arr.shape)\n arr = np.where(arr == 0, -1000, arr)\n for i in range(1, 21):\n print(f\"Processing slice {i:03d}...\")\n tmp_arr = itk.array_from_image(\n itk.imread(os.path.join(output_dir, f\"slice_{i:03d}.reg_dynamic_anatomy.mha\"))\n )\n tmp_arr = np.where(tmp_arr == 0, -1000, tmp_arr)\n arr = np.maximum(arr, tmp_arr)\n print(\"Max image computed.\")\n max_image = itk.image_from_array(arr)\n max_image.CopyInformation(image)\n itk.imwrite(\n max_image,\n os.path.join(output_dir, \"slice_max.reg_dynamic_anatomy.mha\"),\n compression=True,\n )" ] }, { @@ -75,34 +37,7 @@ "metadata": {}, "outputs": [], "source": [ - "if use_fixed_image:\n", - " max_image = itk.imread(os.path.join(output_dir, \"slice_fixed.mha\"))\n", - " outname = \"slice_fixed\"\n", - "else:\n", - " max_image = itk.imread(os.path.join(output_dir, \"slice_max.reg_dynamic_anatomy.mha\"))\n", - " outname = \"slice_max\"\n", - "\n", - "seg = SegmentChestTotalSegmentator()\n", - "seg.contrast_threshold = 500\n", - "if re_run_image_segmentation:\n", - " result = seg.segment(max_image, contrast_enhanced_study=True)\n", - " labelmap_image = result[\"labelmap\"]\n", - " lung_mask = result[\"lung\"]\n", - " heart_mask = result[\"heart\"]\n", - " major_vessels_mask = result[\"major_vessels\"]\n", - " bone_mask = result[\"bone\"]\n", - " soft_tissue_mask = result[\"soft_tissue\"]\n", - " other_mask = result[\"other\"]\n", - " contrast_mask = result[\"contrast\"]\n", - " itk.imwrite(\n", - " labelmap_image,\n", - " os.path.join(output_dir, f\"{outname}.all_mask.mha\"),\n", - " compression=True,\n", - " )\n", - "else:\n", - " labelmap_image = itk.imread(\n", - " os.path.join(output_dir, f\"{outname}.all_mask.mha\")\n", - " )" + "if use_fixed_image:\n max_image = itk.imread(os.path.join(output_dir, \"slice_fixed.mha\"))\n outname = \"slice_fixed\"\nelse:\n max_image = itk.imread(os.path.join(output_dir, \"slice_max.reg_dynamic_anatomy.mha\"))\n outname = \"slice_max\"\n\nseg = SegmentChestTotalSegmentator()\nseg.contrast_threshold = 500\nif re_run_image_segmentation:\n result = seg.segment(max_image, contrast_enhanced_study=True)\n labelmap_image = result[\"labelmap\"]\n lung_mask = result[\"lung\"]\n heart_mask = result[\"heart\"]\n major_vessels_mask = result[\"major_vessels\"]\n bone_mask = result[\"bone\"]\n soft_tissue_mask = result[\"soft_tissue\"]\n other_mask = result[\"other\"]\n contrast_mask = result[\"contrast\"]\n itk.imwrite(\n labelmap_image,\n os.path.join(output_dir, f\"{outname}.all_mask.mha\"),\n compression=True,\n )\nelse:\n labelmap_image = itk.imread(\n os.path.join(output_dir, f\"{outname}.all_mask.mha\")\n )" ] }, { @@ -112,9 +47,7 @@ "metadata": {}, "outputs": [], "source": [ - "con = ContourTools()\n", - "all_contours = con.extract_contours(labelmap_image)\n", - "all_contours.save(os.path.join(output_dir, f\"{outname}.all_mask.vtp\"))" + "con = ContourTools()\nall_contours = con.extract_contours(labelmap_image)\nall_contours.save(os.path.join(output_dir, f\"{outname}.all_mask.vtp\"))" ] }, { @@ -124,14 +57,7 @@ "metadata": {}, "outputs": [], "source": [ - "label_arr = itk.array_from_image(labelmap_image)\n", - "lung_arr = itk.array_from_image(lung_mask)\n", - "heart_arr = itk.array_from_image(heart_mask)\n", - "major_vessels_arr = itk.array_from_image(major_vessels_mask)\n", - "bone_arr = itk.array_from_image(bone_mask)\n", - "soft_tissue_arr = itk.array_from_image(soft_tissue_mask)\n", - "other_arr = itk.array_from_image(other_mask)\n", - "contrast_arr = itk.array_from_image(contrast_mask)" + "label_arr = itk.array_from_image(labelmap_image)\nlung_arr = itk.array_from_image(lung_mask)\nheart_arr = itk.array_from_image(heart_mask)\nmajor_vessels_arr = itk.array_from_image(major_vessels_mask)\nbone_arr = itk.array_from_image(bone_mask)\nsoft_tissue_arr = itk.array_from_image(soft_tissue_mask)\nother_arr = itk.array_from_image(other_mask)\ncontrast_arr = itk.array_from_image(contrast_mask)" ] }, { @@ -141,19 +67,7 @@ "metadata": {}, "outputs": [], "source": [ - "dynamic_anatomy_arr = np.maximum(heart_arr, contrast_arr)\n", - "dynamic_anatomy_arr = np.maximum(dynamic_anatomy_arr, major_vessels_arr)\n", - "dynamic_anatomy_arr = np.where(dynamic_anatomy_arr, label_arr, 0)\n", - "dynamic_anatomy_image = itk.image_from_array(dynamic_anatomy_arr.astype(np.int16))\n", - "dynamic_anatomy_image.CopyInformation(labelmap_image)\n", - "itk.imwrite(\n", - " dynamic_anatomy_image,\n", - " os.path.join(output_dir, f\"{outname}.dynamic_anatomy_mask.mha\"),\n", - " compression=True,\n", - ")\n", - "\n", - "contours = con.extract_contours(dynamic_anatomy_image)\n", - "contours.save(os.path.join(output_dir, f\"{outname}.dynamic_anatomy_mask.vtp\"))\n" + "dynamic_anatomy_arr = np.maximum(heart_arr, contrast_arr)\ndynamic_anatomy_arr = np.maximum(dynamic_anatomy_arr, major_vessels_arr)\ndynamic_anatomy_arr = np.where(dynamic_anatomy_arr, label_arr, 0)\ndynamic_anatomy_image = itk.image_from_array(dynamic_anatomy_arr.astype(np.int16))\ndynamic_anatomy_image.CopyInformation(labelmap_image)\nitk.imwrite(\n dynamic_anatomy_image,\n os.path.join(output_dir, f\"{outname}.dynamic_anatomy_mask.mha\"),\n compression=True,\n)\n\ncontours = con.extract_contours(dynamic_anatomy_image)\ncontours.save(os.path.join(output_dir, f\"{outname}.dynamic_anatomy_mask.vtp\"))\n" ] }, { @@ -163,18 +77,7 @@ "metadata": {}, "outputs": [], "source": [ - "static_anatomy_arr = lung_arr + bone_arr + soft_tissue_arr + other_arr\n", - "static_anatomy_arr = np.where(static_anatomy_arr, label_arr, 0)\n", - "static_anatomy_image = itk.image_from_array(static_anatomy_arr.astype(np.int16))\n", - "static_anatomy_image.CopyInformation(labelmap_image)\n", - "itk.imwrite(\n", - " static_anatomy_image,\n", - " os.path.join(output_dir, f\"{outname}.static_anatomy_mask.mha\"),\n", - " compression=True,\n", - ")\n", - "\n", - "contours = con.extract_contours(static_anatomy_image)\n", - "contours.save(os.path.join(output_dir, f\"{outname}.static_anatomy_mask.vtp\"))" + "static_anatomy_arr = lung_arr + bone_arr + soft_tissue_arr + other_arr\nstatic_anatomy_arr = np.where(static_anatomy_arr, label_arr, 0)\nstatic_anatomy_image = itk.image_from_array(static_anatomy_arr.astype(np.int16))\nstatic_anatomy_image.CopyInformation(labelmap_image)\nitk.imwrite(\n static_anatomy_image,\n os.path.join(output_dir, f\"{outname}.static_anatomy_mask.mha\"),\n compression=True,\n)\n\ncontours = con.extract_contours(static_anatomy_image)\ncontours.save(os.path.join(output_dir, f\"{outname}.static_anatomy_mask.vtp\"))" ] }, { @@ -184,29 +87,7 @@ "metadata": {}, "outputs": [], "source": [ - "input_image = None\n", - "if use_fixed_image:\n", - " input_image = itk.imread(os.path.join(output_dir, \"slice_fixed.mha\"), itk.SS)\n", - "else:\n", - " input_image = itk.imread(os.path.join(output_dir, \"slice_max.reg_dynamic_anatomy.mha\"), itk.SS)\n", - "arr = itk.array_from_image(input_image)\n", - "flipped_input_image = itk.image_from_array(arr)\n", - "flipped_input_image.CopyInformation(input_image)\n", - "\n", - "image = pv.wrap(itk.vtk_image_from_image(flipped_input_image))\n", - "\n", - "pl = pv.Plotter()\n", - "pl.add_mesh(image.slice(normal=\"z\"), cmap=\"bone\", show_scalar_bar=False, opacity=0.5)\n", - "pl.add_mesh(\n", - " contours.slice(normal=\"z\"),\n", - " cmap=\"pink\",\n", - " clim=[50, 800],\n", - " show_scalar_bar=False,\n", - " opacity=1.0,\n", - ")\n", - "pl.set_background(\"black\")\n", - "pl.camera_position = \"xy\"\n", - "pl.show()" + "input_image = None\nif use_fixed_image:\n input_image = itk.imread(os.path.join(output_dir, \"slice_fixed.mha\"), itk.SS)\nelse:\n input_image = itk.imread(os.path.join(output_dir, \"slice_max.reg_dynamic_anatomy.mha\"), itk.SS)\narr = itk.array_from_image(input_image)\nflipped_input_image = itk.image_from_array(arr)\nflipped_input_image.CopyInformation(input_image)\n\nimage = pv.wrap(itk.vtk_image_from_image(flipped_input_image))\n\npl = pv.Plotter()\npl.add_mesh(image.slice(normal=\"z\"), cmap=\"bone\", show_scalar_bar=False, opacity=0.5)\npl.add_mesh(\n contours.slice(normal=\"z\"),\n cmap=\"pink\",\n clim=[50, 800],\n show_scalar_bar=False,\n opacity=1.0,\n)\npl.set_background(\"black\")\npl.camera_position = \"xy\"\npl.show()" ] } ], @@ -231,4 +112,4 @@ }, "nbformat": 4, "nbformat_minor": 5 -} +} \ No newline at end of file diff --git a/experiments/Heart-GatedCT_To_USD/3-transform_dynamic_and_static_contours.ipynb b/experiments/Heart-GatedCT_To_USD/3-transform_dynamic_and_static_contours.ipynb index 83f1b4d..040237f 100644 --- a/experiments/Heart-GatedCT_To_USD/3-transform_dynamic_and_static_contours.ipynb +++ b/experiments/Heart-GatedCT_To_USD/3-transform_dynamic_and_static_contours.ipynb @@ -1,140 +1,140 @@ { - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "id": "3ce61753-11ad-4ade-9afe-6ad1bc748e25", - "metadata": {}, - "outputs": [], - "source": [ - "import os\n", - "\n", - "import itk\n", - "import pyvista as pv\n", - "\n", - "from physiomotion4d.contour_tools import ContourTools\n", - "from physiomotion4d.convert_vtk_4d_to_usd import ConvertVTK4DToUSD\n", - "from physiomotion4d.segment_chest_total_segmentator import SegmentChestTotalSegmentator\n", - "from physiomotion4d.usd_anatomy_tools import USDAnatomyTools\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "240f1d14", - "metadata": {}, - "outputs": [], - "source": [ - "data_dir = \"../../data\"\n", - "output_dir = \"results\"\n", - "\n", - "base_name = \"slice_fixed\"\n", - "#base_name = \"slice_max.reg_dynamic_anatomy\"\n", - "\n", - "project_name = \"Slicer_CardiacGatedCT\"\n", - "\n", - "do_transform_contours = True" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "ef6002e8", - "metadata": {}, - "outputs": [], - "source": [ - "def transform_contours(contours, transform_filenames, base_name, output_dir, project_name):\n", - " con = ContourTools()\n", - " for i, transform_filename in enumerate(transform_filenames):\n", - " phi_MF = itk.transformread(transform_filename)[0]\n", - " print(f\"Applying transform {transform_filename} to {base_name}\")\n", - "\n", - " new_contours = con.transform_contours(contours, phi_MF, with_deformation_magnitude=True)\n", - " new_contours.save(\n", - " os.path.join(output_dir, f\"slice_{i:03d}.reg_{base_name}_inv.{base_name}_mask.vtp\")\n", - " )\n", - "\n", - " files = [\n", - " f\"{output_dir}/slice_{i:03d}.reg_{base_name}_inv.{base_name}_mask.vtp\"\n", - " for i in range(21)\n", - " ]\n", - " seg = SegmentChestTotalSegmentator()\n", - " all_mask_ids = seg.all_mask_ids\n", - "\n", - " polydata = [pv.read(f) for f in files]\n", - "\n", - " print(\"Converting vtp models to USD\")\n", - " converter = ConvertVTK4DToUSD(\n", - " project_name,\n", - " polydata,\n", - " all_mask_ids,\n", - " )\n", - " stage = converter.convert(\n", - " os.path.join(output_dir, f\"{project_name}.{base_name}.usd\"),\n", - " )\n", - "\n", - " painter = USDAnatomyTools(stage)\n", - " painter.enhance_meshes(seg)\n", - " if os.path.exists(os.path.join(output_dir, f\"{project_name}.{base_name}_painted.usd\")):\n", - " os.remove(os.path.join(output_dir, f\"{project_name}.{base_name}_painted.usd\"))\n", - " stage.Export(os.path.join(output_dir, f\"{project_name}.{base_name}_painted.usd\"))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "90c291bb", - "metadata": {}, - "outputs": [], - "source": [ - "dynamic_transform_filenames = [\n", - " os.path.join(output_dir, f\"slice_{i:03d}.reg_dynamic_anatomy.phi_MF.hdf\")\n", - " for i in range(21)\n", - "]\n", - "static_transform_filenames = [\n", - " os.path.join(output_dir, f\"slice_{i:03d}.reg_static_anatomy.phi_MF.hdf\")\n", - " for i in range(21)\n", - "]\n", - "all_transform_filenames = [\n", - " os.path.join(output_dir, f\"slice_{i:03d}.reg_all.phi_MF.hdf\")\n", - " for i in range(21)\n", - "]\n", - "\n", - "dynamic_anatomy_contours = pv.read(\n", - " os.path.join(output_dir, f\"{base_name}.dynamic_anatomy_mask.vtp\")\n", - ")\n", - "static_anatomy_contours = pv.read(\n", - " os.path.join(output_dir, f\"{base_name}.static_anatomy_mask.vtp\")\n", - ")\n", - "all_contours = pv.read(\n", - " os.path.join(output_dir, f\"{base_name}.all_mask.vtp\")\n", - ")\n", - "\n", - "transform_contours(all_contours, all_transform_filenames, \"all\", output_dir, project_name)\n", - "transform_contours(dynamic_anatomy_contours, dynamic_transform_filenames, \"dynamic_anatomy\", output_dir, project_name)\n", - "transform_contours(static_anatomy_contours, static_transform_filenames, \"static_anatomy\", output_dir, project_name)\n" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "venv", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.11" - } - }, - "nbformat": 4, - "nbformat_minor": 5 + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "3ce61753-11ad-4ade-9afe-6ad1bc748e25", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "import itk\n", + "import pyvista as pv\n", + "\n", + "from physiomotion4d.contour_tools import ContourTools\n", + "from physiomotion4d.convert_vtk_4d_to_usd import ConvertVTK4DToUSD\n", + "from physiomotion4d.segment_chest_total_segmentator import SegmentChestTotalSegmentator\n", + "from physiomotion4d.usd_anatomy_tools import USDAnatomyTools\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "240f1d14", + "metadata": {}, + "outputs": [], + "source": [ + "data_dir = \"../../data\"\n", + "output_dir = \"results\"\n", + "\n", + "base_name = \"slice_fixed\"\n", + "#base_name = \"slice_max.reg_dynamic_anatomy\"\n", + "\n", + "project_name = \"Slicer_CardiacGatedCT\"\n", + "\n", + "do_transform_contours = True" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ef6002e8", + "metadata": {}, + "outputs": [], + "source": [ + "def transform_contours(contours, transform_filenames, base_name, output_dir, project_name):\n", + " con = ContourTools()\n", + " for i, transform_filename in enumerate(transform_filenames):\n", + " forward_transform = itk.transformread(transform_filename)[0]\n", + " print(f\"Applying transform {transform_filename} to {base_name}\")\n", + "\n", + " new_contours = con.transform_contours(contours, forward_transform, with_deformation_magnitude=True)\n", + " new_contours.save(\n", + " os.path.join(output_dir, f\"slice_{i:03d}.reg_{base_name}_inv.{base_name}_mask.vtp\")\n", + " )\n", + "\n", + " files = [\n", + " f\"{output_dir}/slice_{i:03d}.reg_{base_name}_inv.{base_name}_mask.vtp\"\n", + " for i in range(21)\n", + " ]\n", + " seg = SegmentChestTotalSegmentator()\n", + " all_mask_ids = seg.all_mask_ids\n", + "\n", + " polydata = [pv.read(f) for f in files]\n", + "\n", + " print(\"Converting vtp models to USD\")\n", + " converter = ConvertVTK4DToUSD(\n", + " project_name,\n", + " polydata,\n", + " all_mask_ids,\n", + " )\n", + " stage = converter.convert(\n", + " os.path.join(output_dir, f\"{project_name}.{base_name}.usd\"),\n", + " )\n", + "\n", + " painter = USDAnatomyTools(stage)\n", + " painter.enhance_meshes(seg)\n", + " if os.path.exists(os.path.join(output_dir, f\"{project_name}.{base_name}_painted.usd\")):\n", + " os.remove(os.path.join(output_dir, f\"{project_name}.{base_name}_painted.usd\"))\n", + " stage.Export(os.path.join(output_dir, f\"{project_name}.{base_name}_painted.usd\"))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "90c291bb", + "metadata": {}, + "outputs": [], + "source": [ + "dynamic_transform_filenames = [\n", + " os.path.join(output_dir, f\"slice_{i:03d}.reg_dynamic_anatomy.forward.hdf\")\n", + " for i in range(21)\n", + "]\n", + "static_transform_filenames = [\n", + " os.path.join(output_dir, f\"slice_{i:03d}.reg_static_anatomy.forward.hdf\")\n", + " for i in range(21)\n", + "]\n", + "all_transform_filenames = [\n", + " os.path.join(output_dir, f\"slice_{i:03d}.reg_all.forward.hdf\")\n", + " for i in range(21)\n", + "]\n", + "\n", + "dynamic_anatomy_contours = pv.read(\n", + " os.path.join(output_dir, f\"{base_name}.dynamic_anatomy_mask.vtp\")\n", + ")\n", + "static_anatomy_contours = pv.read(\n", + " os.path.join(output_dir, f\"{base_name}.static_anatomy_mask.vtp\")\n", + ")\n", + "all_contours = pv.read(\n", + " os.path.join(output_dir, f\"{base_name}.all_mask.vtp\")\n", + ")\n", + "\n", + "transform_contours(all_contours, all_transform_filenames, \"all\", output_dir, project_name)\n", + "transform_contours(dynamic_anatomy_contours, dynamic_transform_filenames, \"dynamic_anatomy\", output_dir, project_name)\n", + "transform_contours(static_anatomy_contours, static_transform_filenames, \"static_anatomy\", output_dir, project_name)\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.11" + } + }, + "nbformat": 4, + "nbformat_minor": 5 } diff --git a/experiments/Heart-GatedCT_To_USD/4-merge_dynamic_and_static_usd.ipynb b/experiments/Heart-GatedCT_To_USD/4-merge_dynamic_and_static_usd.ipynb index fec3f0b..35e4f34 100644 --- a/experiments/Heart-GatedCT_To_USD/4-merge_dynamic_and_static_usd.ipynb +++ b/experiments/Heart-GatedCT_To_USD/4-merge_dynamic_and_static_usd.ipynb @@ -7,9 +7,7 @@ "metadata": {}, "outputs": [], "source": [ - "import os\n", - "\n", - "from physiomotion4d.usd_tools import USDTools" + "import os\n\nfrom physiomotion4d.usd_tools import USDTools" ] }, { @@ -19,29 +17,7 @@ "metadata": {}, "outputs": [], "source": [ - "usd_tools = USDTools()\n", - "\n", - "if os.path.exists(\"Slicer_CardiacGatedCT.merged_painted.usd\"):\n", - " os.remove(\"Slicer_CardiacGatedCT.merged_painted.usd\")\n", - "\n", - "usd_tools.merge_usd_files(\n", - " \"Slicer_CardiacGatedCT.merged_painted.usd\",\n", - " [\n", - " \"results/Slicer_CardiacGatedCT.dynamic_anatomy_painted.usd\",\n", - " \"results/Slicer_CardiacGatedCT.static_anatomy_painted.usd\",\n", - " ],\n", - ")\n", - "\n", - "if os.path.exists(\"Slicer_CardiacGatedCT.flattened_merged_painted.usd\"):\n", - " os.remove(\"Slicer_CardiacGatedCT.flattened_merged_painted.usd\")\n", - "\n", - "usd_tools.merge_usd_files_flattened(\n", - " \"Slicer_CardiacGatedCT.flattened_merged_painted.usd\",\n", - " [\n", - " \"results/Slicer_CardiacGatedCT.dynamic_anatomy_painted.usd\",\n", - " \"results/Slicer_CardiacGatedCT.static_anatomy_painted.usd\",\n", - " ],\n", - ")" + "usd_tools = USDTools()\n\nif os.path.exists(\"Slicer_CardiacGatedCT.merged_painted.usd\"):\n os.remove(\"Slicer_CardiacGatedCT.merged_painted.usd\")\n\nusd_tools.merge_usd_files(\n \"Slicer_CardiacGatedCT.merged_painted.usd\",\n [\n \"results/Slicer_CardiacGatedCT.dynamic_anatomy_painted.usd\",\n \"results/Slicer_CardiacGatedCT.static_anatomy_painted.usd\",\n ],\n)\n\nif os.path.exists(\"Slicer_CardiacGatedCT.flattened_merged_painted.usd\"):\n os.remove(\"Slicer_CardiacGatedCT.flattened_merged_painted.usd\")\n\nusd_tools.merge_usd_files_flattened(\n \"Slicer_CardiacGatedCT.flattened_merged_painted.usd\",\n [\n \"results/Slicer_CardiacGatedCT.dynamic_anatomy_painted.usd\",\n \"results/Slicer_CardiacGatedCT.static_anatomy_painted.usd\",\n ],\n)" ] } ], @@ -66,4 +42,4 @@ }, "nbformat": 4, "nbformat_minor": 5 -} +} \ No newline at end of file diff --git a/experiments/Heart-GatedCT_To_USD/test_vista3d_class.ipynb b/experiments/Heart-GatedCT_To_USD/test_vista3d_class.ipynb index f7615b1..0da013e 100644 --- a/experiments/Heart-GatedCT_To_USD/test_vista3d_class.ipynb +++ b/experiments/Heart-GatedCT_To_USD/test_vista3d_class.ipynb @@ -7,13 +7,7 @@ "metadata": {}, "outputs": [], "source": [ - "import itk\n", - "import os\n", - "\n", - "from physiomotion4d.segment_chest_vista_3d import SegmentChestVista3D\n", - "\n", - "output_dir = \"./results\"\n", - "max_image = itk.imread(os.path.join(output_dir, \"slice_max.reg_dynamic_anatomy.mha\"))" + "import itk\nimport os\n\nfrom physiomotion4d.segment_chest_vista_3d import SegmentChestVista3D\n\noutput_dir = \"./results\"\nmax_image = itk.imread(os.path.join(output_dir, \"slice_max.reg_dynamic_anatomy.mha\"))" ] }, { @@ -62,21 +56,7 @@ } ], "source": [ - "seg = SegmentChestVista3D()\n", - "result = seg.segment(max_image, contrast_enhanced_study=True)\n", - "labelmap_image = result[\"labelmap\"]\n", - "lung_mask = result[\"lung\"]\n", - "heart_mask = result[\"heart\"]\n", - "major_vessels_mask = result[\"major_vessels\"]\n", - "bone_mask = result[\"bone\"]\n", - "soft_tissue_mask = result[\"soft_tissue\"]\n", - "other_mask = result[\"other\"]\n", - "contrast_mask = result[\"contrast\"]\n", - "itk.imwrite(\n", - " labelmap_image,\n", - " os.path.join(output_dir, \"slice_max.reg_dynamic_anatomy.all_mask-test.mha\"),\n", - " compression=True,\n", - ")" + "seg = SegmentChestVista3D()\nresult = seg.segment(max_image, contrast_enhanced_study=True)\nlabelmap_image = result[\"labelmap\"]\nlung_mask = result[\"lung\"]\nheart_mask = result[\"heart\"]\nmajor_vessels_mask = result[\"major_vessels\"]\nbone_mask = result[\"bone\"]\nsoft_tissue_mask = result[\"soft_tissue\"]\nother_mask = result[\"other\"]\ncontrast_mask = result[\"contrast\"]\nitk.imwrite(\n labelmap_image,\n os.path.join(output_dir, \"slice_max.reg_dynamic_anatomy.all_mask-test.mha\"),\n compression=True,\n)" ] } ], @@ -101,4 +81,4 @@ }, "nbformat": 4, "nbformat_minor": 5 -} +} \ No newline at end of file diff --git a/experiments/Heart-GatedCT_To_USD/test_vista3d_inMem.ipynb b/experiments/Heart-GatedCT_To_USD/test_vista3d_inMem.ipynb index fc79e0b..9e253ff 100644 --- a/experiments/Heart-GatedCT_To_USD/test_vista3d_inMem.ipynb +++ b/experiments/Heart-GatedCT_To_USD/test_vista3d_inMem.ipynb @@ -7,160 +7,7 @@ "metadata": {}, "outputs": [], "source": [ - "import numpy as np\n", - "import torch\n", - "\n", - "\n", - "def vista3d_inference_from_itk(\n", - " itk_image,\n", - " label_prompt=None,\n", - " points=None,\n", - " point_labels=None,\n", - " device=None,\n", - " bundle_path=None,\n", - " model_cache_dir=None,\n", - "):\n", - " # 1. Import dependencies\n", - " import itk\n", - " from monai.bundle import download\n", - " from monai.data.itk_torch_bridge import itk_image_to_metatensor\n", - " from monai.transforms import (\n", - " EnsureChannelFirst,\n", - " Spacing,\n", - " ScaleIntensityRange,\n", - " CropForeground,\n", - " EnsureType,\n", - " )\n", - " from monai.inferers import sliding_window_inference\n", - " from monai.networks.nets import vista3d132\n", - " from monai.utils import set_determinism\n", - "\n", - " set_determinism(seed=42)\n", - " if device is None:\n", - " device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", - "\n", - " # 2. Handle \"no prompts\" case: segment all classes\n", - " if label_prompt is None and points is None:\n", - " everything_labels = list(\n", - " set([i + 1 for i in range(132)]) - set([2, 16, 18, 20, 21, 23, 24, 25, 26])\n", - " )\n", - " label_prompt = everything_labels\n", - " print(f\"No prompt provided. Using everything_labels for {len(everything_labels)} classes.\")\n", - "\n", - " if points is not None and point_labels is None:\n", - " raise ValueError(\"point_labels must be provided when points are specified\")\n", - "\n", - " # 3. Download model bundle if needed\n", - " if bundle_path is None:\n", - " import tempfile\n", - "\n", - " if model_cache_dir is None:\n", - " model_cache_dir = tempfile.mkdtemp()\n", - " try:\n", - " download(name=\"vista3d\", bundle_dir=model_cache_dir, source=\"monaihosting\")\n", - " except Exception:\n", - " download(name=\"vista3d\", bundle_dir=model_cache_dir, source=\"github\")\n", - " bundle_path = f\"{model_cache_dir}/vista3d\"\n", - "\n", - " # 4. ITK->MetaTensor (in memory)\n", - " meta_tensor = itk_image_to_metatensor(itk_image, channel_dim=None, dtype=torch.float32)\n", - "\n", - " input_size = itk_image.GetLargestPossibleRegion().GetSize()\n", - "\n", - " # 5. Preprocessing pipeline\n", - " processed = meta_tensor\n", - " processed = EnsureChannelFirst(channel_dim=None)(processed)\n", - " processed = EnsureType(dtype=torch.float32)(processed)\n", - " processed = Spacing(pixdim=[1.5, 1.5, 1.5], mode=\"bilinear\")(processed)\n", - " processed = ScaleIntensityRange(a_min=-1024, a_max=1024, b_min=0.0, b_max=1.0, clip=True)(\n", - " processed\n", - " )\n", - " processed = CropForeground()(processed)\n", - "\n", - " # 6. Load VISTA3D\n", - " model = vista3d132(encoder_embed_dim=48, in_channels=1)\n", - " model_path = f\"{bundle_path}/models/model.pt\"\n", - " checkpoint = torch.load(model_path, map_location=device)\n", - " model.load_state_dict(checkpoint)\n", - " model.eval()\n", - " model.to(device)\n", - "\n", - " # 7. Prepare input tensor\n", - " input_tensor = processed\n", - " if not isinstance(input_tensor, torch.Tensor):\n", - " input_tensor = torch.tensor(np.asarray(input_tensor), dtype=torch.float32)\n", - " if input_tensor.dim() == 3:\n", - " input_tensor = input_tensor.unsqueeze(0)\n", - " if input_tensor.dim() == 4:\n", - " input_tensor = input_tensor.unsqueeze(0)\n", - " input_tensor = input_tensor.to(device)\n", - "\n", - " # 8. Prepare model inputs\n", - " model_inputs = {\"image\": input_tensor}\n", - " if label_prompt is not None:\n", - " label_prompt_tensor = torch.tensor(label_prompt, dtype=torch.long, device=device)\n", - " model_inputs[\"label_prompt\"] = label_prompt_tensor\n", - " print('label_prompt_tensor shape', label_prompt_tensor.shape)\n", - " if points is not None:\n", - " point_coords = torch.tensor(points, dtype=torch.float32, device=device).unsqueeze(0)\n", - " point_labels_tensor = torch.tensor(\n", - " point_labels, dtype=torch.float32, device=device\n", - " ).unsqueeze(0)\n", - " model_inputs[\"points\"] = point_coords\n", - " model_inputs[\"point_labels\"] = point_labels_tensor\n", - " print('point_coords shape', point_coords.shape)\n", - "\n", - " # 9. Sliding window inference for large images\n", - " def predictor_fn(x):\n", - " args = {k: v for k, v in model_inputs.items() if k != \"image\"}\n", - " print(x.shape)\n", - " return model(x, **args)\n", - "\n", - " with torch.no_grad():\n", - " if any(dim > 128 for dim in input_tensor.shape[2:]):\n", - " print(\"Sliding window inference\")\n", - " output = sliding_window_inference(\n", - " input_tensor,\n", - " roi_size=[128, 128, 128],\n", - " sw_batch_size=1,\n", - " predictor=predictor_fn,\n", - " overlap=0.5,\n", - " mode=\"gaussian\",\n", - " device=device,\n", - " )\n", - " else:\n", - " print(\"Single window inference\")\n", - " output = model(input_tensor, **{k: v for k, v in model_inputs.items() if k != \"image\"})\n", - "\n", - " print('output shape', output.shape)\n", - " # 10. Postprocess: multi-class to label map\n", - " output = output.cpu()\n", - " if hasattr(output, \"detach\"):\n", - " output = output.detach()\n", - " if isinstance(output, dict):\n", - " if \"pred\" in output:\n", - " output = output[\"pred\"]\n", - " else:\n", - " output = list(output.values())[0]\n", - "\n", - " if output.shape[1] > 1:\n", - " label_map = torch.argmax(output, dim=1).squeeze(0).numpy().astype(np.uint16)\n", - " else:\n", - " label_map = (output > 0.5).squeeze(0).cpu().numpy().astype(np.uint8)\n", - "\n", - " # Ensure output is zyx order for ITK\n", - " if label_map.shape != tuple(reversed(input_size)):\n", - " # Some transforms may flip axes; reorder as needed.\n", - " label_map_for_itk = np.transpose(label_map, axes=range(label_map.ndim)[::-1])\n", - " else:\n", - " label_map_for_itk = label_map\n", - "\n", - " # ITK expects z,y,x ordering for GetImageFromArray\n", - " output_itk = itk.GetImageFromArray(label_map_for_itk)\n", - " itk.imwrite(output_itk, 'output_itk.mha')\n", - "\n", - " # Return output in ITK format matching the input (size, spacing, origin, direction, type)\n", - " return output_itk" + "import numpy as np\nimport torch\n\n\ndef vista3d_inference_from_itk(\n itk_image,\n label_prompt=None,\n points=None,\n point_labels=None,\n device=None,\n bundle_path=None,\n model_cache_dir=None,\n):\n # 1. Import dependencies\n import itk\n from monai.bundle import download\n from monai.data.itk_torch_bridge import itk_image_to_metatensor\n from monai.transforms import (\n EnsureChannelFirst,\n Spacing,\n ScaleIntensityRange,\n CropForeground,\n EnsureType,\n )\n from monai.inferers import sliding_window_inference\n from monai.networks.nets import vista3d132\n from monai.utils import set_determinism\n\n set_determinism(seed=42)\n if device is None:\n device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n\n # 2. Handle \"no prompts\" case: segment all classes\n if label_prompt is None and points is None:\n everything_labels = list(\n set([i + 1 for i in range(132)]) - set([2, 16, 18, 20, 21, 23, 24, 25, 26])\n )\n label_prompt = everything_labels\n print(f\"No prompt provided. Using everything_labels for {len(everything_labels)} classes.\")\n\n if points is not None and point_labels is None:\n raise ValueError(\"point_labels must be provided when points are specified\")\n\n # 3. Download model bundle if needed\n if bundle_path is None:\n import tempfile\n\n if model_cache_dir is None:\n model_cache_dir = tempfile.mkdtemp()\n try:\n download(name=\"vista3d\", bundle_dir=model_cache_dir, source=\"monaihosting\")\n except Exception:\n download(name=\"vista3d\", bundle_dir=model_cache_dir, source=\"github\")\n bundle_path = f\"{model_cache_dir}/vista3d\"\n\n # 4. ITK->MetaTensor (in memory)\n meta_tensor = itk_image_to_metatensor(itk_image, channel_dim=None, dtype=torch.float32)\n\n input_size = itk_image.GetLargestPossibleRegion().GetSize()\n\n # 5. Preprocessing pipeline\n processed = meta_tensor\n processed = EnsureChannelFirst(channel_dim=None)(processed)\n processed = EnsureType(dtype=torch.float32)(processed)\n processed = Spacing(pixdim=[1.5, 1.5, 1.5], mode=\"bilinear\")(processed)\n processed = ScaleIntensityRange(a_min=-1024, a_max=1024, b_min=0.0, b_max=1.0, clip=True)(\n processed\n )\n processed = CropForeground()(processed)\n\n # 6. Load VISTA3D\n model = vista3d132(encoder_embed_dim=48, in_channels=1)\n model_path = f\"{bundle_path}/models/model.pt\"\n checkpoint = torch.load(model_path, map_location=device)\n model.load_state_dict(checkpoint)\n model.eval()\n model.to(device)\n\n # 7. Prepare input tensor\n input_tensor = processed\n if not isinstance(input_tensor, torch.Tensor):\n input_tensor = torch.tensor(np.asarray(input_tensor), dtype=torch.float32)\n if input_tensor.dim() == 3:\n input_tensor = input_tensor.unsqueeze(0)\n if input_tensor.dim() == 4:\n input_tensor = input_tensor.unsqueeze(0)\n input_tensor = input_tensor.to(device)\n\n # 8. Prepare model inputs\n model_inputs = {\"image\": input_tensor}\n if label_prompt is not None:\n label_prompt_tensor = torch.tensor(label_prompt, dtype=torch.long, device=device)\n model_inputs[\"label_prompt\"] = label_prompt_tensor\n print('label_prompt_tensor shape', label_prompt_tensor.shape)\n if points is not None:\n point_coords = torch.tensor(points, dtype=torch.float32, device=device).unsqueeze(0)\n point_labels_tensor = torch.tensor(\n point_labels, dtype=torch.float32, device=device\n ).unsqueeze(0)\n model_inputs[\"points\"] = point_coords\n model_inputs[\"point_labels\"] = point_labels_tensor\n print('point_coords shape', point_coords.shape)\n\n # 9. Sliding window inference for large images\n def predictor_fn(x):\n args = {k: v for k, v in model_inputs.items() if k != \"image\"}\n print(x.shape)\n return model(x, **args)\n\n with torch.no_grad():\n if any(dim > 128 for dim in input_tensor.shape[2:]):\n print(\"Sliding window inference\")\n output = sliding_window_inference(\n input_tensor,\n roi_size=[128, 128, 128],\n sw_batch_size=1,\n predictor=predictor_fn,\n overlap=0.5,\n mode=\"gaussian\",\n device=device,\n )\n else:\n print(\"Single window inference\")\n output = model(input_tensor, **{k: v for k, v in model_inputs.items() if k != \"image\"})\n\n print('output shape', output.shape)\n # 10. Postprocess: multi-class to label map\n output = output.cpu()\n if hasattr(output, \"detach\"):\n output = output.detach()\n if isinstance(output, dict):\n if \"pred\" in output:\n output = output[\"pred\"]\n else:\n output = list(output.values())[0]\n\n if output.shape[1] > 1:\n label_map = torch.argmax(output, dim=1).squeeze(0).numpy().astype(np.uint16)\n else:\n label_map = (output > 0.5).squeeze(0).cpu().numpy().astype(np.uint8)\n\n # Ensure output is zyx order for ITK\n if label_map.shape != tuple(reversed(input_size)):\n # Some transforms may flip axes; reorder as needed.\n label_map_for_itk = np.transpose(label_map, axes=range(label_map.ndim)[::-1])\n else:\n label_map_for_itk = label_map\n\n # ITK expects z,y,x ordering for GetImageFromArray\n output_itk = itk.GetImageFromArray(label_map_for_itk)\n itk.imwrite(output_itk, 'output_itk.mha')\n\n # Return output in ITK format matching the input (size, spacing, origin, direction, type)\n return output_itk" ] }, { @@ -228,14 +75,7 @@ } ], "source": [ - "import itk\n", - "\n", - "# Load an ITK image\n", - "image = itk.imread('results/slice.reg_max.mha')\n", - "\n", - "spleen_segmentation = vista3d_inference_from_itk(image, model_cache_dir='./network_weights')\n", - "\n", - "itk.imwrite(spleen_segmentation, 'totalSegmentation2.mha')" + "import itk\n\n# Load an ITK image\nimage = itk.imread('results/slice.reg_max.mha')\n\nspleen_segmentation = vista3d_inference_from_itk(image, model_cache_dir='./network_weights')\n\nitk.imwrite(spleen_segmentation, 'totalSegmentation2.mha')" ] } ], @@ -260,4 +100,4 @@ }, "nbformat": 4, "nbformat_minor": 5 -} +} \ No newline at end of file diff --git a/experiments/Heart-Model_To_Patient/heart_model_to_model_registration_pca.ipynb b/experiments/Heart-Model_To_Patient/heart_model_to_model_registration_pca.ipynb index 698e6c3..fdb9019 100644 --- a/experiments/Heart-Model_To_Patient/heart_model_to_model_registration_pca.ipynb +++ b/experiments/Heart-Model_To_Patient/heart_model_to_model_registration_pca.ipynb @@ -1,563 +1,737 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# PCA-based Heart Model to Image Registration Experiment\n", - "\n", - "This notebook demonstrates using the `RegisterModelToImagePCA` class to register\n", - "a statistical shape model to patient CT images using PCA-based shape variation.\n", - "\n", - "## Overview\n", - "- Uses the KCL Heart Model PCA statistical shape model\n", - "- Registers to the same Duke Heart CT data as the original notebook\n", - "- Two-stage optimization: rigid alignment + PCA shape fitting\n", - "- Converts segmentation mask to intensity image for registration" - ] + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# PCA-based Heart Model to Image Registration Experiment\n", + "\n", + "This notebook demonstrates using the `RegisterModelToImagePCA` class to register\n", + "a statistical shape model to patient CT images using PCA-based shape variation.\n", + "\n", + "## Overview\n", + "- Uses the KCL Heart Model PCA statistical shape model\n", + "- Registers to the same Duke Heart CT data as the original notebook\n", + "- Two-stage optimization: rigid alignment + PCA shape fitting\n", + "- Converts segmentation mask to intensity image for registration" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setup and Imports" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "# PCA-based Heart Model to Image Registration Experiment\n", + "\n", + "import os\n", + "from pathlib import Path\n", + "\n", + "import itk\n", + "import numpy as np\n", + "import pyvista as pv\n", + "from itk import TubeTK as ttk\n", + "\n", + "# Import from PhysioMotion4D package\n", + "from physiomotion4d import (\n", + " ContourTools,\n", + " RegisterModelsICP,\n", + " RegisterModelsPCA,\n", + " TransformTools,\n", + ")\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Define File Paths" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Patient data: c:\\src\\Projects\\PhysioMotion\\PhysioMotion4D\\data\\Slicer-Heart-CT\n", + "PCA Model data: c:\\src\\Projects\\PhysioMotion\\PhysioMotion4D\\data\\KCL-Heart-Model\\pca\n", + "Output directory: c:\\src\\Projects\\PhysioMotion\\PhysioMotion4D\\experiments\\Heart-Model_To_Patient\\results_pca\n" + ] + } + ], + "source": [ + "# Patient CT image (defines coordinate frame)\n", + "patient_data_dir = Path.cwd().parent.parent / 'data' / 'Slicer-Heart-CT'\n", + "patient_ct_path = patient_data_dir / 'patient_img.mha'\n", + "patient_ct_heart_mask_path = patient_data_dir / 'patient_heart_wall_mask.nii.gz'\n", + "\n", + "# PCA heart model data\n", + "heart_model_data_dir = Path.cwd().parent.parent / 'data' / 'KCL-Heart-Model'\n", + "heart_model_path = heart_model_data_dir / 'average_mesh.vtk'\n", + "\n", + "# PCA heart model data\n", + "template_model_data_dir = Path.cwd().parent.parent / 'data' / 'KCL-Heart-Model' / 'pca'\n", + "template_model_surface_path = template_model_data_dir / 'pca_All_mean.vtk'\n", + "pca_json_path = template_model_data_dir / 'pca.json'\n", + "pca_group_key = 'All'\n", + "\n", + "# Output directory\n", + "output_dir = Path.cwd() / 'results_pca'\n", + "os.makedirs(output_dir, exist_ok=True)\n", + "\n", + "print(f\"Patient data: {patient_data_dir}\")\n", + "print(f\"PCA Model data: {template_model_data_dir}\")\n", + "print(f\"Output directory: {output_dir}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Load and Preprocess Patient Image" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loading patient CT image...\n", + " Original size: itkSize3 ([307, 234, 152])\n", + " Original spacing: itkVectorD3 ([1, 1, 1])\n", + "Resampling to sotropic...\n", + " Resampled size: itkSize3 ([307, 234, 152])\n", + " Resampled spacing: itkVectorD3 ([1, 1, 1])\n", + "✓ Saved preprocessed image\n" + ] + } + ], + "source": [ + "# Load patient CT image\n", + "print(\"Loading patient CT image...\")\n", + "patient_image = itk.imread(str(patient_ct_path))\n", + "print(f\" Original size: {itk.size(patient_image)}\")\n", + "print(f\" Original spacing: {itk.spacing(patient_image)}\")\n", + "\n", + "# Resample to 1mm isotropic spacing\n", + "print(\"Resampling to sotropic...\")\n", + "resampler = ttk.ResampleImage.New(Input=patient_image)\n", + "resampler.SetMakeHighResIso(True)\n", + "resampler.Update()\n", + "patient_image = resampler.GetOutput()\n", + "\n", + "print(f\" Resampled size: {itk.size(patient_image)}\")\n", + "print(f\" Resampled spacing: {itk.spacing(patient_image)}\")\n", + "\n", + "# Save preprocessed image\n", + "itk.imwrite(patient_image, str(output_dir / 'patient_image.mha'), compression=True)\n", + "print(f\"✓ Saved preprocessed image\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Load and Process Heart Segmentation Mask" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loading heart segmentation mask...\n", + " Mask size: itkSize3 ([307, 234, 152])\n", + " Mask spacing: itkVectorD3 ([1, 1, 1])\n" + ] + } + ], + "source": [ + "# Load heart segmentation mask\n", + "print(\"Loading heart segmentation mask...\")\n", + "patient_heart_mask_image = itk.imread(str(patient_ct_heart_mask_path))\n", + "\n", + "print(f\" Mask size: {itk.size(patient_heart_mask_image)}\")\n", + "print(f\" Mask spacing: {itk.spacing(patient_heart_mask_image)}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Flipping image axes: True, True, False\n", + "✓ Images flipped to standard orientation\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "__array__ implementation doesn't accept a copy keyword, so passing copy=False failed. __array__ must implement 'dtype' and 'copy' keyword arguments. To learn more, see the migration guide https://numpy.org/devdocs/numpy_2_0_migration_guide.html#adapting-to-changes-in-the-copy-keyword\n", + "__array__ implementation doesn't accept a copy keyword, so passing copy=False failed. __array__ must implement 'dtype' and 'copy' keyword arguments. To learn more, see the migration guide https://numpy.org/devdocs/numpy_2_0_migration_guide.html#adapting-to-changes-in-the-copy-keyword\n", + "__array__ implementation doesn't accept a copy keyword, so passing copy=False failed. __array__ must implement 'dtype' and 'copy' keyword arguments. To learn more, see the migration guide https://numpy.org/devdocs/numpy_2_0_migration_guide.html#adapting-to-changes-in-the-copy-keyword\n" + ] + } + ], + "source": [ + "# Handle image orientation (flip if needed)\n", + "flip0 = np.array(patient_heart_mask_image.GetDirection())[0,0] < 0\n", + "flip1 = np.array(patient_heart_mask_image.GetDirection())[1,1] < 0\n", + "flip2 = np.array(patient_heart_mask_image.GetDirection())[2,2] < 0\n", + "\n", + "if flip0 or flip1 or flip2:\n", + " print(f\"Flipping image axes: {flip0}, {flip1}, {flip2}\")\n", + "\n", + " # Flip CT image\n", + " flip_filter = itk.FlipImageFilter.New(Input=patient_image)\n", + " flip_filter.SetFlipAxes([int(flip0), int(flip1), int(flip2)])\n", + " flip_filter.SetFlipAboutOrigin(True)\n", + " flip_filter.Update()\n", + " patient_image = flip_filter.GetOutput()\n", + " id_mat = itk.Matrix[itk.D, 3, 3]()\n", + " id_mat.SetIdentity()\n", + " patient_image.SetDirection(id_mat)\n", + "\n", + " # Flip mask image\n", + " flip_filter = itk.FlipImageFilter.New(Input=patient_heart_mask_image)\n", + " flip_filter.SetFlipAxes([int(flip0), int(flip1), int(flip2)])\n", + " flip_filter.SetFlipAboutOrigin(True)\n", + " flip_filter.Update()\n", + " patient_heart_mask_image = flip_filter.GetOutput()\n", + " patient_heart_mask_image.SetDirection(id_mat)\n", + "\n", + " print(\"✓ Images flipped to standard orientation\")\n", + "\n", + "# Save oriented images\n", + "itk.imwrite(patient_image, str(output_dir / 'patient_image_oriented.mha'), compression=True)\n", + "itk.imwrite(patient_heart_mask_image, str(output_dir / 'patient_heart_mask_oriented.mha'), compression=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Convert Segmentation Mask to a Surface" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "contour_tools = ContourTools()\n", + "patient_surface = contour_tools.extract_contours(patient_heart_mask_image)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Perform Initial ICP Rigid Registration\n", + "\n", + "Use ICP (Iterative Closest Point) with affine mode to align the model surface to the patient surface extracted from the segmentation mask. This provides a good initial alignment for the PCA-based registration.\n", + "\n", + "The ICP registration pipeline:\n", + "1. Centroid alignment (automatic)\n", + "2. Rigid ICP alignment\n", + "\n", + "The PCA registration will then refine this initial alignment with shape model constraints." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loading PCA heart model...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2025-12-31 15:38:09 INFO RegisterModelsICP ======================================================================\n", + "2025-12-31 15:38:09 INFO RegisterModelsICP RIGID ICP Alignment\n", + "2025-12-31 15:38:09 INFO RegisterModelsICP ======================================================================\n", + "2025-12-31 15:38:09 INFO RegisterModelsICP Step 1: Translating by [648.62508965 318.68302345 962.48420525] to align centroids...\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + " Template surface: 167240 points\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2025-12-31 15:38:11 INFO RegisterModelsICP Step 2: Performing rigid ICP (max iterations: 2000)...\n", + "2025-12-31 15:38:22 INFO RegisterModelsICP RIGID ICP registration complete!\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "✓ ICP affine registration complete\n", + " Transform = AffineTransform (000001879D6168B0)\n", + " RTTI typeinfo: class itk::AffineTransform\n", + " Reference Count: 1\n", + " Modified Time: 1613\n", + " Debug: Off\n", + " Object Name: \n", + " Observers: \n", + " none\n", + " Matrix: \n", + " 0.987013 -0.0990425 -0.126474 \n", + " 0.114333 0.98617 0.119986 \n", + " 0.112841 -0.132888 0.984687 \n", + " Offset: [648.026, 318.947, 958.768]\n", + " Center: [0, 0, 0]\n", + " Translation: [648.026, 318.947, 958.768]\n", + " Inverse: \n", + " 0.987013 0.114333 0.112841 \n", + " -0.0990425 0.98617 -0.132888 \n", + " -0.126474 0.119986 0.984687 \n", + " Singular: 0\n", + "\n", + " Saved ICP-aligned model surface\n", + " Saved ICP transform\n" + ] + } + ], + "source": [ + "#Load the pca model\n", + "print(\"Loading PCA heart model...\")\n", + "template_model = pv.read(str(heart_model_path))\n", + "\n", + "template_model_surface = pv.read(template_model_surface_path)\n", + "print(f\" Template surface: {template_model_surface.n_points} points\")\n", + "\n", + "icp_registrar = RegisterModelsICP(\n", + " moving_model=template_model_surface,\n", + " fixed_model=patient_surface\n", + ")\n", + "\n", + "icp_result = icp_registrar.register(mode='rigid', max_iterations=2000)\n", + "\n", + "# Get the aligned mesh and transform\n", + "icp_registered_model_surface = icp_result['registered_model']\n", + "icp_forward_point_transform = icp_result['forward_point_transform']\n", + "\n", + "print(\"\\n✓ ICP affine registration complete\")\n", + "print(\" Transform =\", icp_result['forward_point_transform'])\n", + "\n", + "# Save aligned model \n", + "icp_registered_model_surface.save(str(output_dir / 'icp_registered_model_surface.vtp'))\n", + "print(\" Saved ICP-aligned model surface\")\n", + "\n", + "itk.transformwrite([icp_result['forward_point_transform']], str(output_dir / 'icp_transform.hdf'), compression=True)\n", + "print(\" Saved ICP transform\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "✓ Applied ICP transform to full model mesh\n" + ] + } + ], + "source": [ + "\n", + "# Apply ICP transform to the full average mesh (not just surface)\n", + "# This gets the volumetric mesh into patient space for PCA registration\n", + "transform_tools = TransformTools()\n", + "icp_registered_model = transform_tools.transform_pvcontour(\n", + " template_model,\n", + " icp_forward_point_transform\n", + ")\n", + "icp_registered_model.save(str(output_dir / 'icp_registered_model.vtk'))\n", + "print(\"\\n✓ Applied ICP transform to full model mesh\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Initialize PCA Registration" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2025-12-31 15:38:26 INFO Loading PCA data from SlicerSALT format...\n", + "2025-12-31 15:38:26 INFO JSON file: c:\\src\\Projects\\PhysioMotion\\PhysioMotion4D\\data\\KCL-Heart-Model\\pca\\pca.json\n", + "2025-12-31 15:38:26 INFO Group key: All\n", + "2025-12-31 15:38:26 INFO Reading JSON file...\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "======================================================================\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2025-12-31 15:38:34 INFO Loaded 20 standard deviations\n", + "2025-12-31 15:38:34 INFO Loaded pca_eigenvectors with shape (20, 501720)\n", + "2025-12-31 15:38:34 INFO ✓ Data validation successful!\n", + "2025-12-31 15:38:34 INFO SlicerSALT PCA data loaded successfully!\n", + "2025-12-31 15:38:34 INFO ContourTools Computing signed distance map...\n", + "2025-12-31 15:38:37 INFO RegisterModelsPCA Converting mean shape points to ITK format...\n", + "2025-12-31 15:38:37 INFO RegisterModelsPCA Converted 167240 points to ITK format\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "✓ PCA registrar initialized\n", + " Using ICP-aligned mesh as starting point\n", + " Number of points: 167240\n", + " Number of PCA modes: 10\n" + ] + } + ], + "source": [ + "## Initialize PCA Registration\n", + "print(\"=\"*70)\n", + "\n", + "# Use the ICP-aligned mesh as the starting point for PCA registration\n", + "pca_registrar = RegisterModelsPCA.from_slicersalt(\n", + " pca_template_model=template_model_surface,\n", + " pca_json_filename=pca_json_path,\n", + " pca_group_key=pca_group_key,\n", + " pca_number_of_modes=10,\n", + " post_pca_transform=icp_forward_point_transform,\n", + " fixed_model=patient_surface,\n", + " reference_image=patient_image\n", + ")\n", + "\n", + "itk.imwrite(pca_registrar.fixed_distance_map, str(output_dir / \"distance_map.mha\"))\n", + "\n", + "print(\"✓ PCA registrar initialized\")\n", + "print(\" Using ICP-aligned mesh as starting point\")\n", + "print(f\" Number of points: {len(pca_registrar.pca_template_model.points)}\")\n", + "print(f\" Number of PCA modes: {pca_registrar.pca_number_of_modes}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Run PCA-Based Shape Optimization\n", + "\n", + "Now that we have a good initial alignment from ICP affine registration, we run the PCA-based registration to optimize the shape parameters." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2025-12-31 15:38:38 INFO RegisterModelsPCA ======================================================================\n", + "2025-12-31 15:38:38 INFO RegisterModelsPCA PCA-BASED MODEL-TO-IMAGE REGISTRATION\n", + "2025-12-31 15:38:38 INFO RegisterModelsPCA ======================================================================\n", + "2025-12-31 15:38:38 INFO RegisterModelsPCA Number of points: 167240\n", + "2025-12-31 15:38:38 INFO RegisterModelsPCA Modes to use: 10\n", + "2025-12-31 15:38:38 INFO RegisterModelsPCA Number of PCA modes: 10\n", + "2025-12-31 15:38:38 INFO RegisterModelsPCA PCA coefficient bounds: ±3.0 std deviations\n", + "2025-12-31 15:38:38 INFO RegisterModelsPCA Optimization method: L-BFGS-B\n", + "2025-12-31 15:38:38 INFO RegisterModelsPCA Max iterations: 50\n", + "2025-12-31 15:38:38 INFO RegisterModelsPCA Running optimization...\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "======================================================================\n", + "PCA-BASED SHAPE OPTIMIZATION\n", + "======================================================================\n", + "\n", + "Running complete PCA registration pipeline...\n", + " (Starting from ICP-aligned mesh)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2025-12-31 15:38:38 INFO RegisterModelsPCA Metric 1: [592.30134313 332.25256351 940.1120455 ] -> 1.793727\n", + "2025-12-31 15:38:38 INFO RegisterModelsPCA Params [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", + "2025-12-31 15:39:11 INFO RegisterModelsPCA Metric 101: [592.55321111 332.67948908 941.28025053] -> 0.802813\n", + "2025-12-31 15:39:11 INFO RegisterModelsPCA Params [-0.12394911 0.08781701 -0.4095937 1.10702867 -0.54654654 0.76025192\n", + " 0.40569734 0.26662927 -0.6470759 -0.14716453]\n", + "2025-12-31 15:39:35 INFO RegisterModelsPCA Metric 201: [592.58212814 332.67439075 941.2759575 ] -> 0.802497\n", + "2025-12-31 15:39:35 INFO RegisterModelsPCA Params [-0.11703578 0.09849052 -0.39701523 1.08269627 -0.56759478 0.72699977\n", + " 0.4454969 0.288553 -0.66492292 -0.17534791]\n", + "2025-12-31 15:39:55 INFO RegisterModelsPCA Metric 301: [592.5824063 332.67423722 941.27587751] -> 0.802497\n", + "2025-12-31 15:39:55 INFO RegisterModelsPCA Params [-0.11693249 0.09863351 -0.39694914 1.08219954 -0.56760953 0.72688042\n", + " 0.44609067 0.2888831 -0.66536179 -0.17487209]\n", + "2025-12-31 15:39:56 INFO RegisterModelsPCA Optimization completed!\n", + "2025-12-31 15:39:56 INFO RegisterModelsPCA Optimized PCA coefficients: [-0.11693249 0.09863351 -0.39694915 1.08219954 -0.56760953 0.72688042\n", + " 0.44609067 0.2888831 -0.66536179 -0.17487209]\n", + "2025-12-31 15:39:56 INFO RegisterModelsPCA Final mean intensity: 0.80\n", + "2025-12-31 15:39:56 INFO RegisterModelsPCA Creating final registered model...\n", + "2025-12-31 15:39:56 INFO RegisterModelsPCA Transforming points: 1/167240 (0.0%)\n", + "2025-12-31 15:39:57 INFO RegisterModelsPCA Transforming points: 16725/167240 (10.0%)\n", + "2025-12-31 15:39:58 INFO RegisterModelsPCA Transforming points: 33449/167240 (20.0%)\n", + "2025-12-31 15:39:59 INFO RegisterModelsPCA Transforming points: 50173/167240 (30.0%)\n", + "2025-12-31 15:40:00 INFO RegisterModelsPCA Transforming points: 66897/167240 (40.0%)\n", + "2025-12-31 15:40:01 INFO RegisterModelsPCA Transforming points: 83621/167240 (50.0%)\n", + "2025-12-31 15:40:02 INFO RegisterModelsPCA Transforming points: 100345/167240 (60.0%)\n", + "2025-12-31 15:40:03 INFO RegisterModelsPCA Transforming points: 117069/167240 (70.0%)\n", + "2025-12-31 15:40:04 INFO RegisterModelsPCA Transforming points: 133793/167240 (80.0%)\n", + "2025-12-31 15:40:05 INFO RegisterModelsPCA Transforming points: 150517/167240 (90.0%)\n", + "2025-12-31 15:40:06 INFO RegisterModelsPCA Transforming points: 167240/167240 (100.0%)\n", + "2025-12-31 15:40:06 INFO RegisterModelsPCA Registered model created with 167240 points\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "✓ PCA registration complete\n" + ] + } + ], + "source": [ + "print(\"\\n\" + \"=\"*70)\n", + "print(\"PCA-BASED SHAPE OPTIMIZATION\")\n", + "print(\"=\"*70)\n", + "print(\"\\nRunning complete PCA registration pipeline...\")\n", + "print(\" (Starting from ICP-aligned mesh)\")\n", + "\n", + "result = pca_registrar.register(\n", + " pca_number_of_modes=10, # Use first 10 PCA modes\n", + ")\n", + "\n", + "pca_registered_model_surface = result['registered_model']\n", + "\n", + "print(\"\\n✓ PCA registration complete\")\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Display Registration Results\n", + "\n", + "Review the optimization results from the PCA registration pipeline.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "======================================================================\n", + "REGISTRATION RESULTS\n", + "======================================================================\n", + "\n", + "Final Registration Metrics:\n", + " Final mean intensity: 0.8025\n", + "\n", + "Optimized PCA Coefficients (in units of std deviations):\n", + " Mode 1: -0.1169\n", + " Mode 2: 0.0986\n", + " Mode 3: -0.3969\n", + " Mode 4: 1.0822\n", + " Mode 5: -0.5676\n", + " Mode 6: 0.7269\n", + " Mode 7: 0.4461\n", + " Mode 8: 0.2889\n", + " Mode 9: -0.6654\n", + " Mode 10: -0.1749\n", + "\n", + "✓ Registration pipeline complete!\n" + ] + } + ], + "source": [ + "print(\"\\n\" + \"=\"*70)\n", + "print(\"REGISTRATION RESULTS\")\n", + "print(\"=\"*70)\n", + "\n", + "# Display results\n", + "print(\"\\nFinal Registration Metrics:\")\n", + "print(f\" Final mean intensity: {result['mean_distance']:.4f}\")\n", + "\n", + "print(\"\\nOptimized PCA Coefficients (in units of std deviations):\")\n", + "for i, coef in enumerate(result['pca_coefficients']):\n", + " print(f\" Mode {i+1:2d}: {coef:7.4f}\")\n", + "\n", + "print(\"\\n✓ Registration pipeline complete!\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Save Registration Results\n" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Saving results...\n", + " Saved final PCA-registered mesh\n", + " Saved PCA coefficients\n" + ] + } + ], + "source": [ + "print(\"\\nSaving results...\")\n", + "\n", + "# Save final PCA-registered mesh\n", + "pca_registered_model_surface.save(str(output_dir / 'pca_registered_model_surface.vtk'))\n", + "print(f\" Saved final PCA-registered mesh\")\n", + "\n", + "# Save PCA coefficients\n", + "np.savetxt(\n", + " str(output_dir / 'pca_coefficients.txt'),\n", + " result['pca_coefficients'],\n", + " header=f\"PCA coefficients for {len(result['pca_coefficients'])} modes\"\n", + ")\n", + "print(f\" Saved PCA coefficients\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Visualize Results\n" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "50b118fdb8f34e689d199bee01c0b31d", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Widget(value='